FoodVisionMini / model.py
Kelmoir's picture
added the files to the repo
dbe7aa3
import torch
import torchvision
from torch import nn
from helper_functions import set_seeds
def create_effnetb2_model(output_classes:int=3,
seed=42):
"""
Creates a pretrained EfficientNet B2 model feature extractor, with the base layers frozen and the output classifier adjusted to the target setup
returns:
(model, transforms)
model: The Feature extractor model instance of EfficientNetB2
"""
# 1. Setup poretrained EffNetB2 weights
effnetb2_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
#2. Get the transforms
transforms = effnetb2_weights.transforms()
#3. Setup pretrines model instance
model = torchvision.models.efficientnet_b2(weights=effnetb2_weights)
#4. Freeze the base layers in the model - this will stop all base layers from training
for param in model.parameters():
param.requires_grad=False
#5. Change the classification head
#Set seed
set_seeds(42)
model.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(in_features=1408,
out_features=output_classes,
bias=True)
)
return model, transforms