| """ |
| Contains PyTorch model code to instantiate a TinyVGG model. |
| """ |
| import torch |
| from torch import nn |
|
|
|
|
| class TrashClassificationCNNModel(nn.Module): |
| def __init__(self, input_shape: int, hidden_units: int, output_shape: int): |
| super().__init__() |
| self.block_1 = nn.Sequential( |
| nn.Conv2d(input_shape, hidden_units, |
| kernel_size=3, |
| stride=1, |
| padding=1), |
| nn.ReLU(), |
| nn.Conv2d(hidden_units, hidden_units, |
| kernel_size=3, |
| stride=1, |
| padding=1), |
| nn.ReLU(), |
| nn.MaxPool2d(kernel_size=2) |
| ) |
| self.block_2 = nn.Sequential( |
| nn.Conv2d(hidden_units, hidden_units, |
| kernel_size=3, |
| stride=1, |
| padding=1), |
| nn.ReLU(), |
| nn.Conv2d(hidden_units, hidden_units, |
| kernel_size=3, |
| stride=1, |
| padding=1), |
| nn.ReLU(), |
| nn.MaxPool2d(kernel_size=2) |
| ) |
| self.classifier = nn.Sequential( |
| nn.Flatten(), |
| nn.Linear(in_features=hidden_units*28*28, |
| out_features=output_shape) |
| ) |
| |
| def forward(self, x): |
| x = self.block_1(x) |
| x = self.block_2(x) |
| x = self.classifier(x) |
| return x |