Minor memory handling fixes
Browse files- modeling_ltgbert.py +5 -2
modeling_ltgbert.py
CHANGED
|
@@ -230,7 +230,7 @@ class Attention(nn.Module):
|
|
| 230 |
|
| 231 |
query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
| 232 |
key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
| 233 |
-
value = value.
|
| 234 |
|
| 235 |
attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
|
| 236 |
|
|
@@ -346,7 +346,10 @@ class LtgbertModel(LtgbertPreTrainedModel):
|
|
| 346 |
if self.config.is_decoder:
|
| 347 |
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | torch.triu(torch.ones(seq_length, seq_length, dtype=torch.bool, device=device), 1).unsqueeze(0).unsqueeze(0)
|
| 348 |
else:
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
static_embeddings, relative_embedding = self.embedding(input_ids.t())
|
| 352 |
contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)
|
|
|
|
| 230 |
|
| 231 |
query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
| 232 |
key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
| 233 |
+
value = value.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
| 234 |
|
| 235 |
attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
|
| 236 |
|
|
|
|
| 346 |
if self.config.is_decoder:
|
| 347 |
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | torch.triu(torch.ones(seq_length, seq_length, dtype=torch.bool, device=device), 1).unsqueeze(0).unsqueeze(0)
|
| 348 |
else:
|
| 349 |
+
if len(attention_mask.size()) == 2:
|
| 350 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 351 |
+
elif len(attention_mask.size()) == 3:
|
| 352 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 353 |
|
| 354 |
static_embeddings, relative_embedding = self.embedding(input_ids.t())
|
| 355 |
contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)
|