klemenk commited on
Commit
781adcc
·
verified ·
1 Parent(s): 6f610f2

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +18 -6
modeling_auristream.py CHANGED
@@ -72,7 +72,7 @@ class AuriStream(PreTrainedModel):
72
  elif isinstance(module, nn.Embedding):
73
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74
 
75
- def forward(self, seq, tgt=None, output_hidden_states=False, return_dict=False, up_until_layer=None):
76
  """
77
  Input: coch: torch.Tensor of shape (b, t)
78
  tgt_coch: torch.Tensor of shape (b, t) or None
@@ -112,6 +112,10 @@ class AuriStream(PreTrainedModel):
112
  logits = self.coch_head(x)
113
 
114
  if tgt is not None:
 
 
 
 
115
  loss = F.cross_entropy(
116
  logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1),
117
  )
@@ -123,14 +127,22 @@ class AuriStream(PreTrainedModel):
123
  loss += F.cross_entropy(
124
  future_logits.reshape(-1, self.config.vocab_size), tgt[:, (i+1):].reshape(-1),
125
  )
 
 
126
  # divide loss by number of future heads
127
  loss = loss / (len(self.future_heads) + 1)
128
-
129
  if return_dict:
130
- model_output = CausalLMOutput(
131
- loss=loss,
132
- logits=logits,
133
- )
 
 
 
 
 
 
134
  return model_output
135
 
136
  return logits, loss
 
72
  elif isinstance(module, nn.Embedding):
73
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74
 
75
+ def forward(self, seq, tgt=None, output_logits=False, output_hidden_states=False, return_dict=False, up_until_layer=None):
76
  """
77
  Input: coch: torch.Tensor of shape (b, t)
78
  tgt_coch: torch.Tensor of shape (b, t) or None
 
112
  logits = self.coch_head(x)
113
 
114
  if tgt is not None:
115
+
116
+ if output_logits:
117
+ all_logits = [logits]
118
+
119
  loss = F.cross_entropy(
120
  logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1),
121
  )
 
127
  loss += F.cross_entropy(
128
  future_logits.reshape(-1, self.config.vocab_size), tgt[:, (i+1):].reshape(-1),
129
  )
130
+ if output_logits:
131
+ all_logits.append(future_logits)
132
  # divide loss by number of future heads
133
  loss = loss / (len(self.future_heads) + 1)
134
+
135
  if return_dict:
136
+ if output_logits:
137
+ model_output = CausalLMOutput(
138
+ loss=loss,
139
+ logits=all_logits,
140
+ )
141
+ else:
142
+ model_output = CausalLMOutput(
143
+ loss=loss,
144
+ logits=logits,
145
+ )
146
  return model_output
147
 
148
  return logits, loss