Skip to content

Setting up model and training loop

Model architecture

For this classification task we choose the EfficientNetV2-S model, which is a convolutional neural network architecture that has been shown to achieve state-of-the-art performance on image classification tasks while being computationally efficient. We will use the implementation of EfficientNetV2-S provided by torchvision. We will also use torchvision to load the ImageNet pretrained model weights.

# models/efficientnet.py

class EfficientNet_V2_S(nn.Module):
    def __init__(self,
                 num_classes:int):
        super().__init__()
        self.num_classes = num_classes
        self.model = models.efficientnet_v2_s(
            weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
            )
        # modify model for number of classes
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(in_features=1280,
                       out_features=num_classes,
                       bias=True)
        )

    def forward(self, x):
        output = self.model(x)
        return output

Setting up training loop

Here, we first set up a pytorch Dataset class and define a dataloder that will be used in the training loop.

class PlantVillageDataset(Dataset):
    def __init__(...):
        train_path = Path("PlantVillage") / "train.txt"
        # read path to the images and corresponding labels
        self.images, self.labels = read_data(train_path)
        # define data augmentation transform
        self.transform = transforms.Compose([
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                    std=[0.229, 0.224, 0.225]),
                transforms.RandomCrop(size=224),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip()
            ])
    ...

    def __getitem__(self, idx):
        image_path, label = self.images[idx], self.labels[idx]
        image = read_8bit_image(image_path)
        image = self.transform(image)
        return image, label

def loader(dataset:Dataset,
           batch_size:int,
           num_workers:int=1,
           ):
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=dataset.shuffle,
                        num_workers=num_workers)
    return loader

Once a loader is set up, we can define the training loop. We will use the Adam optimizer and CrossEntropyLoss for this classification task. We will also use a learning rate scheduler to adjust the learning rate during training.

class PlantTrainer(BaseTrainer):
    # BaseTrainer defines the basic functionality for training and validation loops, logging, checkpointing, etc.
    def train_one_epoch(self, epoch):
        total_loss = 0

        self.model.train()

        for k, batch in enumerate(self.train_loader):
            images, labels = batch
            images = images.to(device=self.device, memory_format=torch.channels_last)
            labels = labels.to(self.device)

            with self.amp_autocast:
                pred = self.model(images)
                loss = self.loss_fn(pred, labels)

            self.optimizer.zero_grad()
            if self.loss_scaler is None:
                loss.backward()
                self.optimizer.step()
            else:
                self.loss_scaler.scale(loss).backward()
                self.loss_scaler.step(self.optimizer)
                self.loss_scaler.update()

            if k%10 == 0:
                self.logger.info(f"Epoch {epoch}/{self.config['num_epochs']}  Iter {k:05d} : {loss:0.4f}")
            total_loss += loss
            if k == 20:
                break
        return total_loss