Skip to content

pytorch_dedispersion.dedisperse_candidates_large_hdf_two_threads

cpu_loader

cpu_loader(
    file_path,
    total_freq,
    slice_size,
    bad_channels,
    out_q,
    pbar_scan,
    offset=0,
    step=1,
)

Worker thread: reads slices [offset::step].

Source code in pytorch_dedispersion/dedisperse_candidates_large_hdf_two_threads.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def cpu_loader(
        file_path: str,
        total_freq: int,
        slice_size: int,
        bad_channels: Sequence[int],
        out_q: "Queue[Tuple[torch.Tensor, torch.Tensor]]",  # runtime-only type
        pbar_scan: Any,
        offset: int = 0,
        step: int = 1,
    ) -> None:
    """
    Worker thread: reads slices [offset::step].
    """
    for s in range(offset * slice_size,
                   total_freq,
                   step * slice_size):
        e = min(s + slice_size, total_freq)
        proc = DataProcessor(file_path,
                             freq_slice=(s, e),
                             bad_channels=bad_channels)
        proc.load_data()
        pbar_scan.update(1)

        if proc.data.shape[0] == 0:
            continue

        tensor_data = torch.as_tensor(proc.data,
                                      dtype=torch.float32).pin_memory()
        tensor_freq = torch.as_tensor(proc.get_frequencies(),
                                      dtype=torch.float32).pin_memory()
        out_q.put((tensor_data, tensor_freq))


    if offset == 0:
        out_q.put(None)

save_dedispersed_data

save_dedispersed_data(
    original_file_path,
    summed_data,
    dm_range,
    tsamp,
    verbose=True,
)

Save the dedispersed and summed data to an HDF5 Args: original_file_path (str): Path to the original input HDF5 file summed_data (torch.Tensor): The dedispersed data summed over frequency dm_range (torch.Tensor): Dispersion measure values corresponding to the data tsamp (float): Time sampling resolution in seconds verbose (bool): Whether to print status messages

Source code in pytorch_dedispersion/dedisperse_candidates_large_hdf_two_threads.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def save_dedispersed_data(
        original_file_path: str,
        summed_data: torch.Tensor,
        dm_range: torch.Tensor,
        tsamp: float,
        verbose: bool = True,
    ) -> None:
    """
    Save the dedispersed and summed data to an HDF5
    Args:
        original_file_path (str): Path to the original input HDF5 file
        summed_data (torch.Tensor): The dedispersed data summed over frequency
        dm_range (torch.Tensor): Dispersion measure values corresponding to the data
        tsamp (float): Time sampling resolution in seconds
        verbose (bool): Whether to print status messages
    """
    # Extract base name from original file
    base_name = os.path.splitext(os.path.basename(original_file_path))[0]
    output_filename = f"dedispersed_{base_name}.h5"

    # Ensure no existing file conflicts
    if os.path.exists(output_filename):
        os.remove(output_filename)

    # Convert to CPU before saving
    summed_data_cpu = summed_data.cpu().numpy()
    dm_range_cpu = dm_range.cpu().numpy()

    with h5py.File(output_filename, 'w') as hf:
        hf.create_dataset(
            "summed_data", 
            data=summed_data_cpu, 
            compression="gzip", 
            compression_opts=9
        )
        hf.create_dataset(
            "dm_values", 
            data=dm_range_cpu, 
            compression="gzip", 
            compression_opts=9
        )
        # Store metadata
        hf.attrs["tsamp"] = tsamp
        hf.attrs["nDM"]   = summed_data.shape[0]
        hf.attrs["nTime"] = summed_data.shape[1]

    if verbose:
        print(f"Dedispersed data saved to {output_filename}")