Issue with not all tensors being on the same device using demo code
#1
by
mrdbourke
- opened
Error:
RuntimeError: Expected all tensors to be on the same device, but got mat2 is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_bmm)
Code (copied directly from the model page):
import torch
from transformers import Mistral3ForConditionalGeneration, MistralCommonBackend
model_id = "mistralai/Ministral-3-8B-Reasoning-2512"
tokenizer = MistralCommonBackend.from_pretrained(model_id)
model = Mistral3ForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto"
)
image_url = "https://static.wikia.nocookie.net/essentialsdocs/images/7/70/Battle.png/revision/latest?cb=20220523172438"
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What action do you think I should take in this situation? List all the possible actions and explain why you think they are good or bad.",
},
{"type": "image_url", "image_url": {"url": image_url}},
],
},
]
tokenized = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True)
tokenized["input_ids"] = tokenized["input_ids"].to(device="cuda")
tokenized["pixel_values"] = tokenized["pixel_values"].to(dtype=torch.bfloat16, device="cuda")
image_sizes = [tokenized["pixel_values"].shape[-2:]]
output = model.generate(
**tokenized,
image_sizes=image_sizes,
max_new_tokens=8092,
)[0]
decoded_output = tokenizer.decode(output[len(tokenized["input_ids"][0]):])
print(decoded_output)
Error trace:
tekken.json: 100%
16.8M/16.8M [00:00<00:00, 189kB/s]
config.json:
1.55k/? [00:00<00:00, 120kB/s]
model.safetensors.index.json:
52.7k/? [00:00<00:00, 4.84MB/s]
Download complete: 100%
17.8G/17.8G [00:41<00:00, 397MB/s]
Fetching 4 files: 100%
4/4 [00:41<00:00, 17.40s/it]
Loading weights: 100%
531/531 [00:04<00:00, 464.06it/s, Materializing param=model.vision_tower.transformer.layers.23.ffn_norm.weight]
The tied weights mapping and config for this model specifies to tie model.language_model.embed_tokens.weight to lm_head.weight, but both are present in the checkpoints, so we will NOT tie them. You should update the config with `tie_word_embeddings=False` to silence this warning
generation_config.json: 100%
131/131 [00:00<00:00, 18.8kB/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipython-input-2430961163.py in <cell line: 0>()
30 image_sizes = [tokenized["pixel_values"].shape[-2:]]
31
---> 32 output = model.generate(
33 **tokenized,
34 image_sizes=image_sizes,
20 frames
/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
118 def decorate_context(*args, **kwargs):
119 with ctx_factory():
--> 120 return func(*args, **kwargs)
121
122 return decorate_context
/usr/local/lib/python3.12/dist-packages/transformers/generation/utils.py in generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, use_model_defaults, custom_generate, **kwargs)
2682
2683 # 9. Call generation mode
-> 2684 result = decoding_method(
2685 self,
2686 input_ids,
/usr/local/lib/python3.12/dist-packages/transformers/generation/utils.py in _sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
2875
2876 prefill_consumed = False
-> 2877 outputs = self._prefill(input_ids, generation_config, model_kwargs)
2878
2879 while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
/usr/local/lib/python3.12/dist-packages/transformers/generation/utils.py in _prefill(self, input_ids, generation_config, model_kwargs)
3851 model_kwargs = self._get_initial_cache_position(input_ids.shape[1], input_ids.device, model_kwargs)
3852 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
-> 3853 return self(**model_inputs, return_dict=True)
3854 else: # Chunked prefill
3855 # Even if we are not compiling the forward, flex is always compiled when used. With chunked prefill, we may
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1774 else:
-> 1775 return self._call_impl(*args, **kwargs)
1776
1777 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1784 or _global_backward_pre_hooks or _global_backward_hooks
1785 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786 return forward_call(*args, **kwargs)
1787
1788 result = None
/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py in wrapper(self, *args, **kwargs)
762 if return_dict_passed is not None:
763 return_dict = return_dict_passed
--> 764 output = func(self, *args, **kwargs)
765 if not return_dict and not isinstance(output, tuple):
766 output = output.to_tuple()
/usr/local/lib/python3.12/dist-packages/transformers/models/mistral3/modeling_mistral3.py in forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs)
445 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
446
--> 447 outputs = self.model(
448 input_ids=input_ids,
449 pixel_values=pixel_values,
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1774 else:
-> 1775 return self._call_impl(*args, **kwargs)
1776
1777 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1784 or _global_backward_pre_hooks or _global_backward_hooks
1785 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786 return forward_call(*args, **kwargs)
1787
1788 result = None
/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py in wrapper(self, *args, **kwargs)
762 if return_dict_passed is not None:
763 return_dict = return_dict_passed
--> 764 output = func(self, *args, **kwargs)
765 if not return_dict and not isinstance(output, tuple):
766 output = output.to_tuple()
/usr/local/lib/python3.12/dist-packages/transformers/models/mistral3/modeling_mistral3.py in forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, vision_feature_layer, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, image_sizes, **kwargs)
327 inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
328
--> 329 outputs = self.language_model(
330 attention_mask=attention_mask,
331 position_ids=position_ids,
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1774 else:
-> 1775 return self._call_impl(*args, **kwargs)
1776
1777 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1784 or _global_backward_pre_hooks or _global_backward_hooks
1785 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786 return forward_call(*args, **kwargs)
1787
1788 result = None
/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py in wrapper(self, *args, **kwargs)
917 outputs = func(self, *args, **kwargs)
918 else:
--> 919 outputs = func(self, *args, **kwargs)
920 except TypeError as original_exception:
921 # If we get a TypeError, it's possible that the model is not receiving the recordable kwargs correctly.
/usr/local/lib/python3.12/dist-packages/transformers/models/ministral3/modeling_ministral3.py in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, cache_position, **kwargs)
403
404 hidden_states = inputs_embeds
--> 405 position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
406
407 for decoder_layer in self.layers[: self.config.num_hidden_layers]:
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1774 else:
-> 1775 return self._call_impl(*args, **kwargs)
1776
1777 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1784 or _global_backward_pre_hooks or _global_backward_hooks
1785 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786 return forward_call(*args, **kwargs)
1787
1788 result = None
/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
118 def decorate_context(*args, **kwargs):
119 with ctx_factory():
--> 120 return func(*args, **kwargs)
121
122 return decorate_context
/usr/local/lib/python3.12/dist-packages/transformers/modeling_rope_utils.py in wrapper(self, x, position_ids, layer_type)
123 elif rope_type == "longrope":
124 longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
--> 125 return rope_forward(self, x, position_ids, **kwargs)
126
127 return wrapper
/usr/local/lib/python3.12/dist-packages/transformers/models/ministral3/modeling_ministral3.py in forward(self, x, position_ids)
335 device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
336 with torch.autocast(device_type=device_type, enabled=False): # Force float32
--> 337 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
338 emb = torch.cat((freqs, freqs), dim=-1)
339 cos = emb.cos() * self.attention_scaling
Perhaps the line:
with torch.autocast(device_type=device_type, enabled=False): # Force float32
Should be enabled=True?
Or is this the specific case where we want to be using CPU/float32?
Actually, turns out the position_ids_expanded aren't on the same device as inv_freq_expanded in transformers/models/ministral3/modeling_ministral3.py:
Before:
@torch
.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
After:
@torch
.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float().to(x.device)# CHANGED
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)