| import numpy as np |
| import pickle |
| from torch.utils.data import Dataset, DataLoader |
| import os |
| import torch |
| from copy import deepcopy |
| from blimpy import Waterfall |
| from tqdm import tqdm |
| from copy import deepcopy |
| from sigpyproc.readers import FilReader |
| from torch import nn |
|
|
|
|
| def load_pickled_data(file_path): |
| with open(file_path, 'rb') as f: |
| data = pickle.load(f) |
| return data |
|
|
| |
| class CustomDataset(Dataset): |
| def __init__(self, data_dir, bit8=False, transform=None): |
| self.data_dir = data_dir |
| self.transform = transform |
| self.images = [] |
| self.labels = [] |
| self.classes = os.listdir(data_dir) |
| self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} |
| self.bit8 = bit8 |
| |
| for cls in self.classes: |
| class_dir = os.path.join(data_dir, cls) |
| for image_name in os.listdir(class_dir): |
| image_path = os.path.join(class_dir, image_name) |
| self.images.append(image_path) |
| self.labels.append(self.class_to_idx[cls]) |
|
|
| def __len__(self): |
| return len(self.images) |
|
|
| def __getitem__(self, idx): |
| image_path = self.images[idx] |
| label = self.labels[idx] |
| |
| image = load_pickled_data(image_path) |
| if self.transform is not None: |
| if self.bit8 == True: |
| new_image = self.transform(torch.from_numpy(image['8_data']).type(torch.float32)) |
| else: |
| new_image = self.transform(torch.from_numpy(image['data'])) |
| |
| return new_image, label |
|
|
| |
| class CustomDataset_Masked(Dataset): |
| def __init__(self, data_dir, transform=None): |
| self.data_dir = data_dir |
| self.transform = transform |
| self.images = [] |
| self.labels = [] |
| self.classes = os.listdir(data_dir) |
| self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} |
|
|
| |
| for cls in self.classes: |
| class_dir = os.path.join(data_dir, cls) |
| for image_name in os.listdir(class_dir): |
| image_path = os.path.join(class_dir, image_name) |
| self.images.append(image_path) |
| self.labels.append(self.class_to_idx[cls]) |
| |
| def __len__(self): |
| return len(self.images) |
|
|
| def __getitem__(self, idx): |
| image_path = self.images[idx] |
| |
| label = self.labels[idx] |
| |
| image = load_pickled_data(image_path) |
| if self.transform is not None: |
| if image['burst'].max() ==0: |
| new_burst = torch.from_numpy(image['burst']) |
| else: |
| new_burst = torch.from_numpy(image['burst']/image['burst'].max()) |
| ind = new_burst > 0.1 |
| ind_not = new_burst <= 0.1 |
| new_burst[ind] = 1 |
| new_burst[ind_not] = 0 |
| new_image = self.transform(torch.from_numpy(image['data'].data)) |
| new_burst_arr = torch.zeros_like(new_image) |
| new_burst_arr[ 0, :,:] = new_burst |
| new_burst_arr[ 1, :,:] = new_burst |
| new_burst_arr[ 2, :,:] = new_burst |
| return new_image, label, new_burst_arr |
|
|
| |
| class TestingDataset(Dataset): |
| def __init__(self, data_dir, bit8=False, transform=None): |
| self.data_dir = data_dir |
| self.transform = transform |
| self.images = [] |
| self.labels = [] |
| self.classes = os.listdir(data_dir) |
| self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} |
| self.bit8 = bit8 |
| |
| for cls in self.classes: |
| class_dir = os.path.join(data_dir, cls) |
| for image_name in os.listdir(class_dir): |
| image_path = os.path.join(class_dir, image_name) |
| self.images.append(image_path) |
| self.labels.append(self.class_to_idx[cls]) |
|
|
| def __len__(self): |
| return len(self.images) |
|
|
| def __getitem__(self, idx): |
| image_path = self.images[idx] |
| label = self.labels[idx] |
| |
| image = load_pickled_data(image_path) |
| params = image['params'] |
| if self.transform is not None: |
| params = image['params'] |
| if self.bit8 == True: |
| new_image = self.transform(torch.from_numpy(image['8_data']).type(torch.float32)) |
| else: |
| new_image = self.transform(torch.from_numpy(image['data'])) |
| params['labels'] = label |
| return new_image, (label, params['dm'], params['freq_ref'], params['snr'], params['boxcard']) |
|
|
| |
| class SearchDataset(Dataset): |
| def __init__(self, data_dir, transform=None, pickle_data=False): |
| self.window_size = 2048 |
| |
| if pickle_data: |
| with open(data_dir, 'rb') as f: |
| self.d = pickle.load(f) |
| self.header = self.d['header'] |
| self.images = self.crop(self.d['data'][:,0,:], self.window_size) |
| else: |
| self.obs = Waterfall(data_dir, max_load = 50) |
| self.header = self.obs.header |
| self.images = self.crop(self.obs.data[:,0,:], self.window_size) |
| self.transform = transform |
| self.SEC_PER_DAY = 86400 |
| |
| def crop(self, data, window_size = 2048): |
| n_samp = data.shape[0]//window_size |
| new_data = np.zeros((n_samp, window_size, 192 )) |
| for i in range(n_samp): |
| new_data[i, :,:] = data[ i*window_size : (i+1)*window_size, :] |
| return new_data |
| |
| def __len__(self): |
| return self.images.shape[0] |
| def __getitem__(self, idx): |
| data = self.images[idx, :, :].T |
| tindex = idx * self.window_size |
| time = self.header['tsamp'] * tindex / self.SEC_PER_DAY + self.header['tstart'] |
| if self.transform is not None: |
| new_image = self.transform(data) |
| return new_image, idx |
|
|
| |
| class SearchDataset_Sigproc(Dataset): |
| def __init__(self, data_dir, transform=None): |
| self.window_size = 2048 |
| fil = FilReader(data_dir) |
| self.header = fil.header |
| |
| read_data = fil.read_block(0, fil.header.nsamples)[:,1024:-1024] |
| read_data = np.swapaxes(read_data, 0,-1) |
| self.images = self.crop(read_data, self.window_size) |
| self.transform = transform |
| self.SEC_PER_DAY = 86400 |
| |
| def crop(self, data, window_size = 2048): |
| n_samp = data.shape[0]//window_size |
| new_data = np.zeros((n_samp, window_size, 192 )) |
| for i in range(n_samp): |
| new_data[i, :,:] = data[ i*window_size : (i+1)*window_size, :] |
| return new_data |
| |
| def __len__(self): |
| return self.images.shape[0] |
| |
| def __getitem__(self, idx): |
| data = self.images[idx, :, :].T |
| tindex = idx * self.window_size |
| time = self.header.tsamp * tindex / self.SEC_PER_DAY + self.header.tstart |
| if self.transform is not None: |
| new_image = self.transform(torch.from_numpy(data)) |
| return new_image, idx |
|
|
| |
| |
| |
| |
|
|
| def renorm(data): |
| mean = torch.mean(data) |
| std = torch.std(data) |
| |
| standardized_data = (data - mean) / std |
| return standardized_data |
|
|
| def transform(data): |
| copy_data = data.detach().clone() |
| rms = torch.std(data) |
| mean = torch.mean(data) |
| masks_rms = [-1, 5] |
| new_data = torch.zeros((len(masks_rms)+1, data.shape[0], data.shape[1])) |
| new_data[0,:,:] = renorm(torch.log10(copy_data+1e-10)) |
| for i in range(1, len(masks_rms)+1): |
| scale = masks_rms[i-1] |
| copy_data = data.detach().clone() |
| if scale < 0: |
| ind = copy_data < abs(scale) * rms + mean |
| copy_data[ind] = 0 |
| else: |
| ind = copy_data > (scale) * rms + mean |
| copy_data[ind] = 0 |
| new_data[i,:,:] = renorm(torch.log10(copy_data+1e-10)) |
| new_data = new_data.type(torch.float32) |
| slices = torch.chunk(new_data, 8, dim=-1) |
| new_data = torch.stack(slices, dim=1) |
| new_data = new_data.view(-1, new_data.size(2), new_data.size(3)) |
| return new_data |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def renorm_batched(data): |
| mins = torch.amin(data, (-2, -1)) |
| mins = mins.unsqueeze(1).unsqueeze(2) |
| mins = mins.expand(data.shape[0], 192, 2048) |
| shifted = data - mins |
| maxs = torch.amax(shifted, (-2, -1)) |
| maxs = maxs.unsqueeze(1).unsqueeze(2) |
| maxs = maxs.expand(data.shape[0], 192, 2048) |
| shifted = shifted/maxs |
| return shifted |
| |
|
|
| def transform_mask(data): |
| copy_data = deepcopy(data) |
| shift = copy_data - copy_data.min() |
| normalized_data = shift / shift.max() |
| new_data = np.zeros((3, data.shape[0], data.shape[1])) |
| for i in range(3): |
| new_data[i,:,:] = normalized_data |
| new_data = new_data.astype(np.float32) |
| return new_data |
|
|
|
|
| |
| def Convert_ONNX(model, saveloc, input_data_mock): |
| print("Saving to ONNX") |
| |
| model.eval() |
|
|
| |
| dummy_input = torch.autograd.Variable(input_data_mock) |
|
|
| |
| torch.onnx.export(model, |
| dummy_input, |
| saveloc, |
| input_names = ['modelInput'], |
| output_names = ['modelOutput'], |
| dynamic_axes={'modelInput' : {0 : 'batch_size'}, |
| 'modelOutput' : {0 : 'batch_size'}} ) |
| print(" ") |
| print('Model has been converted to ONNX') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|