Skip to content

pytorch_dedispersion.candidate_finder

CandidateFinder

Source code in pytorch_dedispersion/candidate_finder.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class CandidateFinder:
    def __init__(self, boxcar_data: torch.Tensor, window_size: int = 50) -> None:
        """
        Initialize CandidateFinder.

        Args:
            boxcar_data (torch.Tensor): Boxcar filtered data.
            window_size (int, optional): Window size (in samples) for trend removal. Defaults to 50.
        """
        self.boxcar_data = boxcar_data
        self.window_size = window_size

    def find_candidates(
            self,
            snr_threshold: float,
            boxcar_widths: Sequence[int],
            remove_trend: bool = False,
        ) -> List[Dict[str, Any]]:
        """
        Find candidates based on SNR threshold.

        Args:
            snr_threshold (float): SNR threshold for candidate detection.
            boxcar_widths (list[int]): List of boxcar widths.
            remove_trend (bool, optional): Whether to remove trend from data. Defaults to False.

        Returns:
            list[dict]: List of detected candidates.
        """
        candidates = []
        for i, data in enumerate(self.boxcar_data):
            if remove_trend:
                baseline = self.calculate_baseline(data)
                detrended_data = data - baseline[:, :data.shape[1]]
                snr = self.calculate_snr(detrended_data)
            else:
                snr = self.calculate_snr(data)
            above_threshold = snr > snr_threshold
            if above_threshold.any():
                candidate_indices = torch.nonzero(above_threshold, as_tuple=False)
                for idx in candidate_indices:
                    dm_index = idx[0].item()
                    time_sample = idx[1].item()
                    candidates.append({
                        'Boxcar Width': boxcar_widths[i],
                        'DM Index': dm_index,
                        'Time Sample': time_sample,
                        'SNR': snr[dm_index, time_sample].item()
                    })
        return candidates

    def calculate_baseline(self, data: torch.Tensor) -> torch.Tensor:
        """
        Calculate the baseline using a moving average.

        Args:
            data (torch.Tensor): The input data.

        Returns:
            torch.Tensor: The baseline data.
        """
        pad = self.window_size // 2
        padded_data = F.pad(data, (pad, pad), mode='reflect')
        baseline = F.avg_pool1d(padded_data.unsqueeze(0), kernel_size=self.window_size, stride=1, padding=0).squeeze(0)
        return baseline

    def calculate_snr(self, data: torch.Tensor) -> torch.Tensor:
        """
        Calculate the signal-to-noise ratio (SNR).

        Args:
            data (torch.Tensor): The input data.

        Returns:
            torch.Tensor: The SNR of the data.
        """
        mean = data.mean(dim=1, keepdim=True)
        std = data.std(dim=1, keepdim=True)
        snr = (data - mean) / std
        return snr

__init__

__init__(boxcar_data, window_size=50)

Initialize CandidateFinder.

Parameters:

Name Type Description Default
boxcar_data Tensor

Boxcar filtered data.

required
window_size int

Window size (in samples) for trend removal. Defaults to 50.

50
Source code in pytorch_dedispersion/candidate_finder.py
 6
 7
 8
 9
10
11
12
13
14
15
def __init__(self, boxcar_data: torch.Tensor, window_size: int = 50) -> None:
    """
    Initialize CandidateFinder.

    Args:
        boxcar_data (torch.Tensor): Boxcar filtered data.
        window_size (int, optional): Window size (in samples) for trend removal. Defaults to 50.
    """
    self.boxcar_data = boxcar_data
    self.window_size = window_size

calculate_baseline

calculate_baseline(data)

Calculate the baseline using a moving average.

Parameters:

Name Type Description Default
data Tensor

The input data.

required

Returns:

Type Description
Tensor

torch.Tensor: The baseline data.

Source code in pytorch_dedispersion/candidate_finder.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def calculate_baseline(self, data: torch.Tensor) -> torch.Tensor:
    """
    Calculate the baseline using a moving average.

    Args:
        data (torch.Tensor): The input data.

    Returns:
        torch.Tensor: The baseline data.
    """
    pad = self.window_size // 2
    padded_data = F.pad(data, (pad, pad), mode='reflect')
    baseline = F.avg_pool1d(padded_data.unsqueeze(0), kernel_size=self.window_size, stride=1, padding=0).squeeze(0)
    return baseline

calculate_snr

calculate_snr(data)

Calculate the signal-to-noise ratio (SNR).

Parameters:

Name Type Description Default
data Tensor

The input data.

required

Returns:

Type Description
Tensor

torch.Tensor: The SNR of the data.

Source code in pytorch_dedispersion/candidate_finder.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def calculate_snr(self, data: torch.Tensor) -> torch.Tensor:
    """
    Calculate the signal-to-noise ratio (SNR).

    Args:
        data (torch.Tensor): The input data.

    Returns:
        torch.Tensor: The SNR of the data.
    """
    mean = data.mean(dim=1, keepdim=True)
    std = data.std(dim=1, keepdim=True)
    snr = (data - mean) / std
    return snr

find_candidates

find_candidates(
    snr_threshold, boxcar_widths, remove_trend=False
)

Find candidates based on SNR threshold.

Parameters:

Name Type Description Default
snr_threshold float

SNR threshold for candidate detection.

required
boxcar_widths list[int]

List of boxcar widths.

required
remove_trend bool

Whether to remove trend from data. Defaults to False.

False

Returns:

Type Description
List[Dict[str, Any]]

list[dict]: List of detected candidates.

Source code in pytorch_dedispersion/candidate_finder.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def find_candidates(
        self,
        snr_threshold: float,
        boxcar_widths: Sequence[int],
        remove_trend: bool = False,
    ) -> List[Dict[str, Any]]:
    """
    Find candidates based on SNR threshold.

    Args:
        snr_threshold (float): SNR threshold for candidate detection.
        boxcar_widths (list[int]): List of boxcar widths.
        remove_trend (bool, optional): Whether to remove trend from data. Defaults to False.

    Returns:
        list[dict]: List of detected candidates.
    """
    candidates = []
    for i, data in enumerate(self.boxcar_data):
        if remove_trend:
            baseline = self.calculate_baseline(data)
            detrended_data = data - baseline[:, :data.shape[1]]
            snr = self.calculate_snr(detrended_data)
        else:
            snr = self.calculate_snr(data)
        above_threshold = snr > snr_threshold
        if above_threshold.any():
            candidate_indices = torch.nonzero(above_threshold, as_tuple=False)
            for idx in candidate_indices:
                dm_index = idx[0].item()
                time_sample = idx[1].item()
                candidates.append({
                    'Boxcar Width': boxcar_widths[i],
                    'DM Index': dm_index,
                    'Time Sample': time_sample,
                    'SNR': snr[dm_index, time_sample].item()
                })
    return candidates