import torch import torchvision from torch import nn from helper_functions import set_seeds def create_effnetb2_model(output_classes:int=3, seed=42): """ Creates a pretrained EfficientNet B2 model feature extractor, with the base layers frozen and the output classifier adjusted to the target setup returns: (model, transforms) model: The Feature extractor model instance of EfficientNetB2 """ # 1. Setup poretrained EffNetB2 weights effnetb2_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT #2. Get the transforms transforms = effnetb2_weights.transforms() #3. Setup pretrines model instance model = torchvision.models.efficientnet_b2(weights=effnetb2_weights) #4. Freeze the base layers in the model - this will stop all base layers from training for param in model.parameters(): param.requires_grad=False #5. Change the classification head #Set seed set_seeds(42) model.classifier = nn.Sequential( nn.Dropout(p=0.3, inplace=True), nn.Linear(in_features=1408, out_features=output_classes, bias=True) ) return model, transforms