| | |
| |
|
| | import torch |
| | from torchvision import transforms |
| | from torch.utils.data import Dataset |
| |
|
| | import os |
| | import numpy as np |
| | from PIL import Image |
| |
|
| |
|
| | |
| | class CrosswalkDataset(Dataset): |
| | def __init__(self, src_dir, transform=None): |
| | self.src_dir = src_dir |
| | self.transform = transform |
| |
|
| | dir_files = sorted(os.listdir(src_dir)) |
| | self.image_paths = [file_path for file_path in dir_files if file_path.endswith((".png", ".jpg", ".jpeg"))] |
| | self.label_paths = [file_path for file_path in dir_files if file_path.endswith(".txt")] |
| |
|
| | def __len__(self): |
| | return len(self.image_paths) |
| | |
| | def __getitem__(self, index): |
| | image_path = os.path.join(self.src_dir, self.image_paths[index]) |
| | label_path = os.path.join(self.src_dir, self.label_paths[index]) |
| |
|
| | label = [0, 0] |
| | try: |
| | if np.array([int(open(label_path).read().strip())]) == 1: |
| | label = [1, 0] |
| | else: |
| | label = [0, 1] |
| | except: |
| | pass |
| | image = Image.open(image_path) |
| | |
| | if self.transform is None: |
| | self.transform = transforms.ToTensor() |
| | |
| | return (self.transform(image), torch.FloatTensor(label)) |
| |
|
| |
|
| | |
| | |
| | |
| | vgg_transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.3, 0.3, 0.3]) |
| | ]) |
| |
|
| | res_transform = transforms.Compose([ |
| | transforms.Resize((256, 256)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.3, 0.3, 0.3]) |
| | ]) |
| |
|
| | mob3_transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| | ]) |