LexiMind / tests /test_data /test_tokenization.py
OliverPerrin's picture
Refactor: Consolidate dependencies, improve testing, and add CI/CD
d18b34d
raw
history blame
3.78 kB
import unittest
from unittest.mock import MagicMock, patch
import torch
from src.data.tokenization import Tokenizer, TokenizerConfig
class TestTokenizer(unittest.TestCase):
@patch("src.data.tokenization.AutoTokenizer")
def test_tokenizer_initialization(self, mock_auto_tokenizer):
mock_hf_tokenizer = MagicMock()
mock_hf_tokenizer.pad_token_id = 0
mock_hf_tokenizer.bos_token_id = 1
mock_hf_tokenizer.eos_token_id = 2
mock_hf_tokenizer.vocab_size = 1000
mock_auto_tokenizer.from_pretrained.return_value = mock_hf_tokenizer
config = TokenizerConfig(pretrained_model_name="test-model")
tokenizer = Tokenizer(config)
self.assertEqual(tokenizer.pad_token_id, 0)
self.assertEqual(tokenizer.bos_token_id, 1)
self.assertEqual(tokenizer.eos_token_id, 2)
self.assertEqual(tokenizer.vocab_size, 1000)
mock_auto_tokenizer.from_pretrained.assert_called_with("test-model")
@patch("src.data.tokenization.AutoTokenizer")
def test_encode(self, mock_auto_tokenizer):
mock_hf_tokenizer = MagicMock()
mock_hf_tokenizer.pad_token_id = 0
mock_hf_tokenizer.bos_token_id = 1
mock_hf_tokenizer.eos_token_id = 2
mock_hf_tokenizer.encode.return_value = [10, 11, 12]
mock_auto_tokenizer.from_pretrained.return_value = mock_hf_tokenizer
tokenizer = Tokenizer()
ids = tokenizer.encode("hello world")
self.assertEqual(ids, [10, 11, 12])
mock_hf_tokenizer.encode.assert_called()
@patch("src.data.tokenization.AutoTokenizer")
def test_batch_encode(self, mock_auto_tokenizer):
mock_hf_tokenizer = MagicMock()
mock_hf_tokenizer.pad_token_id = 0
mock_hf_tokenizer.bos_token_id = 1
mock_hf_tokenizer.eos_token_id = 2
# Mock return value for __call__
mock_hf_tokenizer.return_value = {
"input_ids": torch.tensor([[10, 11], [12, 13]]),
"attention_mask": torch.tensor([[1, 1], [1, 1]]),
}
mock_auto_tokenizer.from_pretrained.return_value = mock_hf_tokenizer
tokenizer = Tokenizer()
output = tokenizer.batch_encode(["hello", "world"])
self.assertIn("input_ids", output)
self.assertIn("attention_mask", output)
self.assertIsInstance(output["input_ids"], torch.Tensor)
self.assertIsInstance(output["attention_mask"], torch.Tensor)
@patch("src.data.tokenization.AutoTokenizer")
def test_decode(self, mock_auto_tokenizer):
mock_hf_tokenizer = MagicMock()
mock_hf_tokenizer.pad_token_id = 0
mock_hf_tokenizer.bos_token_id = 1
mock_hf_tokenizer.eos_token_id = 2
mock_hf_tokenizer.decode.return_value = "hello world"
mock_auto_tokenizer.from_pretrained.return_value = mock_hf_tokenizer
tokenizer = Tokenizer()
text = tokenizer.decode([10, 11, 12])
self.assertEqual(text, "hello world")
mock_hf_tokenizer.decode.assert_called()
@patch("src.data.tokenization.AutoTokenizer")
def test_prepare_decoder_inputs(self, mock_auto_tokenizer):
mock_hf_tokenizer = MagicMock()
mock_hf_tokenizer.pad_token_id = 0
mock_hf_tokenizer.bos_token_id = 1
mock_hf_tokenizer.eos_token_id = 2
mock_auto_tokenizer.from_pretrained.return_value = mock_hf_tokenizer
tokenizer = Tokenizer()
labels = torch.tensor([[10, 11, 2], [12, 2, 0]]) # 0 is pad
decoder_inputs = tokenizer.prepare_decoder_inputs(labels)
# Should shift right and prepend BOS (1)
expected = torch.tensor([[1, 10, 11], [1, 12, 2]])
self.assertTrue(torch.equal(decoder_inputs, expected))
if __name__ == "__main__":
unittest.main()