lgcharpe commited on
Commit
88a9ba5
·
verified ·
1 Parent(s): d5e44c9

Minor memory handling fixes

Browse files
Files changed (1) hide show
  1. modeling_ltgbert.py +5 -2
modeling_ltgbert.py CHANGED
@@ -230,7 +230,7 @@ class Attention(nn.Module):
230
 
231
  query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
232
  key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
233
- value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
234
 
235
  attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
236
 
@@ -346,7 +346,10 @@ class LtgbertModel(LtgbertPreTrainedModel):
346
  if self.config.is_decoder:
347
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | torch.triu(torch.ones(seq_length, seq_length, dtype=torch.bool, device=device), 1).unsqueeze(0).unsqueeze(0)
348
  else:
349
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
 
 
 
350
 
351
  static_embeddings, relative_embedding = self.embedding(input_ids.t())
352
  contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)
 
230
 
231
  query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
232
  key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
233
+ value = value.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
234
 
235
  attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
236
 
 
346
  if self.config.is_decoder:
347
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | torch.triu(torch.ones(seq_length, seq_length, dtype=torch.bool, device=device), 1).unsqueeze(0).unsqueeze(0)
348
  else:
349
+ if len(attention_mask.size()) == 2:
350
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
351
+ elif len(attention_mask.size()) == 3:
352
+ attention_mask = attention_mask.unsqueeze(1)
353
 
354
  static_embeddings, relative_embedding = self.embedding(input_ids.t())
355
  contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)