yermandy commited on
Commit
5742c57
·
verified ·
1 Parent(s): fb28b8b

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_gend.py +5 -10
modeling_gend.py CHANGED
@@ -6,12 +6,15 @@ from transformers import PretrainedConfig, PreTrainedModel
6
 
7
 
8
  class LinearProbe(nn.Module):
9
- def __init__(self, input_dim, num_classes, normalize_inputs=False, detach_classifier_inputs=False):
10
  super().__init__()
11
  self.linear = nn.Linear(input_dim, num_classes)
12
  self.normalize_inputs = normalize_inputs
13
 
14
  def forward(self, x: torch.Tensor, **kwargs):
 
 
 
15
  return self.linear(x)
16
 
17
 
@@ -50,10 +53,6 @@ class CLIPEncoder(nn.Module):
50
 
51
  class DINOEncoder(nn.Module):
52
  def __init__(self, model_name="facebook/dinov2-with-registers-base"):
53
- """
54
- See models in src/config.py
55
- """
56
-
57
  super().__init__()
58
 
59
  from transformers import AutoImageProcessor, AutoModel, Dinov2Model, Dinov2WithRegistersModel
@@ -74,11 +73,7 @@ class DINOEncoder(nn.Module):
74
 
75
 
76
  class PerceptionEncoder(nn.Module):
77
- def __init__(
78
- self,
79
- model_name="vit_pe_core_large_patch14_336",
80
- img_size: None | int = None,
81
- ):
82
  super().__init__()
83
 
84
  if img_size is not None:
 
6
 
7
 
8
  class LinearProbe(nn.Module):
9
+ def __init__(self, input_dim, num_classes, normalize_inputs=False):
10
  super().__init__()
11
  self.linear = nn.Linear(input_dim, num_classes)
12
  self.normalize_inputs = normalize_inputs
13
 
14
  def forward(self, x: torch.Tensor, **kwargs):
15
+ if self.normalize_inputs:
16
+ x = F.normalize(x, p=2, dim=1)
17
+
18
  return self.linear(x)
19
 
20
 
 
53
 
54
  class DINOEncoder(nn.Module):
55
  def __init__(self, model_name="facebook/dinov2-with-registers-base"):
 
 
 
 
56
  super().__init__()
57
 
58
  from transformers import AutoImageProcessor, AutoModel, Dinov2Model, Dinov2WithRegistersModel
 
73
 
74
 
75
  class PerceptionEncoder(nn.Module):
76
+ def __init__(self, model_name="vit_pe_core_large_patch14_336", img_size: None | int = None):
 
 
 
 
77
  super().__init__()
78
 
79
  if img_size is not None: