Avoid duplicate input kwargs in `_decode` (#28)
Browse files- Avoid duplicate input kwargs in `_decode` (18005e74b8257c981bb97dd4f350b06cd28f7aa6)
- avoid duplicate generate args (5d0120037703b4b70ec932f62ddb81e07b8b85c4)
- update modeling_minicpmo.py (cac55956a6efb7456cf5bbcad4e3e4f14d2e7ea9)
Co-authored-by: Zhihui He <[email protected]>
- modeling_minicpmo.py +7 -1
modeling_minicpmo.py
CHANGED
|
@@ -636,6 +636,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 636 |
return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs)
|
| 637 |
|
| 638 |
def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs):
|
|
|
|
|
|
|
| 639 |
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
| 640 |
outputs = self.llm.generate(
|
| 641 |
inputs_embeds=inputs_embeds,
|
|
@@ -777,6 +779,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 777 |
tokenizer=None,
|
| 778 |
vision_hidden_states=None,
|
| 779 |
stream=False,
|
|
|
|
| 780 |
**kwargs,
|
| 781 |
):
|
| 782 |
assert input_ids is not None
|
|
@@ -814,7 +817,10 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 814 |
outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs)
|
| 815 |
|
| 816 |
result = self._decode_text(outputs.sequences, tokenizer)
|
| 817 |
-
|
|
|
|
|
|
|
|
|
|
| 818 |
return result, outputs
|
| 819 |
|
| 820 |
def chat(
|
|
|
|
| 636 |
return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs)
|
| 637 |
|
| 638 |
def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs):
|
| 639 |
+
kwargs.pop("output_hidden_states", None)
|
| 640 |
+
kwargs.pop("return_dict_in_generate", None)
|
| 641 |
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
| 642 |
outputs = self.llm.generate(
|
| 643 |
inputs_embeds=inputs_embeds,
|
|
|
|
| 779 |
tokenizer=None,
|
| 780 |
vision_hidden_states=None,
|
| 781 |
stream=False,
|
| 782 |
+
decode_text=True,
|
| 783 |
**kwargs,
|
| 784 |
):
|
| 785 |
assert input_ids is not None
|
|
|
|
| 817 |
outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs)
|
| 818 |
|
| 819 |
result = self._decode_text(outputs.sequences, tokenizer)
|
| 820 |
+
|
| 821 |
+
if decode_text is False:
|
| 822 |
+
return outputs
|
| 823 |
+
|
| 824 |
return result, outputs
|
| 825 |
|
| 826 |
def chat(
|