Spaces:
Build error
Build error
| import torch | |
| import torchvision | |
| from torch import nn | |
| from helper_functions import set_seeds | |
| def create_effnetb2_model(output_classes:int=3, | |
| seed=42): | |
| """ | |
| Creates a pretrained EfficientNet B2 model feature extractor, with the base layers frozen and the output classifier adjusted to the target setup | |
| returns: | |
| (model, transforms) | |
| model: The Feature extractor model instance of EfficientNetB2 | |
| """ | |
| # 1. Setup poretrained EffNetB2 weights | |
| effnetb2_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT | |
| #2. Get the transforms | |
| transforms = effnetb2_weights.transforms() | |
| #3. Setup pretrines model instance | |
| model = torchvision.models.efficientnet_b2(weights=effnetb2_weights) | |
| #4. Freeze the base layers in the model - this will stop all base layers from training | |
| for param in model.parameters(): | |
| param.requires_grad=False | |
| #5. Change the classification head | |
| #Set seed | |
| set_seeds(42) | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(p=0.3, inplace=True), | |
| nn.Linear(in_features=1408, | |
| out_features=output_classes, | |
| bias=True) | |
| ) | |
| return model, transforms | |