beppefolder commited on
Commit
036e8b3
·
verified ·
1 Parent(s): 97e0d2f

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +13 -0
model.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import timm
3
+
4
+ class Model(nn.Module):
5
+ def __init__(self, model_name, pretrained=True):
6
+ super(Model, self).__init__()
7
+
8
+ # Load the pretrained ConvNeXt model (you can choose the specific variant you want)
9
+ self.model = timm.create_model(model_name, pretrained=pretrained)
10
+ self.model.head.fc = nn.Linear(self.model.head.fc.in_features, 1) # change the last linear for classification
11
+
12
+ def forward(self, x):
13
+ return self.model(x)