| """ |
| Contains functionality for creating PyTorch DataLoaders for |
| image classification data. |
| """ |
| import torch |
| from torchvision import transforms, datasets |
| from torch.utils.data import DataLoader |
| from torchvision.datasets import ImageFolder |
|
|
| def train_test_dataloader(train_dir: str, |
| test_dir: str, |
| transform: transforms.Compose, |
| batch_size: int): |
| """Creates training and testing DataLoaders. |
| |
| Takes in a training directory and testing directory path and turns |
| them into PyTorch Datasets using ImageFolder and then into PyTorch DataLoaders. |
| |
| Args: |
| train_dir: Path to training directory. |
| test_dir: Path to testing directory. |
| transform: torchvision transforms to perform on training and testing data. |
| batch_size: Number of samples per batch in each of the DataLoaders. |
| |
| Returns: |
| A tuple of (train_dataloader, test_dataloader, class_names). |
| Where class_names is a list of the target classes. |
| Example usage: |
| train_dataloader, test_dataloader, class_names = \ |
| = create_dataloaders(train_dir=path/to/train_dir, |
| test_dir=path/to/test_dir, |
| transform=some_transform, |
| batch_size=32) |
| """ |
| |
| dataset_train = ImageFolder(root=train_dir, transform=transform) |
| dataset_test = ImageFolder(root=test_dir, transform=transform) |
|
|
| |
| class_names = dataset_train.classes |
|
|
| |
| train_dataloader = DataLoader(dataset_train, |
| batch_size=batch_size, |
| shuffle=True) |
| test_dataloader = DataLoader(dataset_test, |
| batch_size=batch_size, |
| shuffle=True) |
| |
| return train_dataloader, test_dataloader, class_names |