|
14 | 14 |
|
15 | 15 | # pylint: disable=all |
16 | 16 |
|
17 | | - |
18 | | -# This model will output tokens with value of 2 |
19 | | -# and will update caches with value of 1.0 |
20 | | -# class Dummy(torch.nn.Module): |
21 | | - |
22 | | -# def __init__(self): |
23 | | -# super().__init__() |
24 | | -# self.params = None |
25 | | - |
26 | | -# def forward( |
27 | | -# self, |
28 | | -# tokens: torch.Tensor, |
29 | | -# input_pos: torch.Tensor, |
30 | | -# caches: List[Any], |
31 | | -# mask, |
32 | | -# ): |
33 | | -# batch_size, seqlen = tokens.shape |
34 | | -# for cache in caches: |
35 | | -# cache.update(torch.ones((batch_size, seqlen))) |
36 | | -# return torch.ones((batch_size, seqlen), dtype=torch.int32) * 2 |
37 | | - |
38 | | - |
39 | | -# class EngineTest(unittest.TestCase): |
40 | | - |
41 | | -# def _make_small_engine(self, quantize=False): |
42 | | -# env_data = JetEngineEnvironmentData() |
43 | | -# env_data.max_input_sequence_length = 128 |
44 | | -# env_data.max_input_sequence_length = 128 |
45 | | -# env_data.cache_sequence_length = 128 |
46 | | -# env_data.model_type = 'llama-2-tiny' |
47 | | -# if quantize: |
48 | | -# env_data.enable_kv_quantization = True |
49 | | -# env_data.enable_weight_quantization = True |
50 | | - |
51 | | -# env = JetEngineEnvironment(env_data) |
52 | | -# model = Dummy() |
53 | | -# model.params = env._model_arg # llama's model arg |
54 | | - |
55 | | -# engine = PyTorchEngine(model, env) |
56 | | -# return engine |
| 17 | +import unittest |
| 18 | +import jax |
| 19 | +import jax.numpy as jnp |
| 20 | + |
| 21 | +from jetstream_pt.third_party.llama import model_exportable |
| 22 | +from jetstream_pt.engine import PyTorchEngine |
| 23 | +from tests import helpers |
| 24 | + |
| 25 | + |
| 26 | +class EngineTest(unittest.TestCase): |
| 27 | + |
| 28 | + def setup(self): |
| 29 | + env, model_arg = helpers.make_env_tiny(bf16_enable=True) |
| 30 | + model_ours = model_exportable.Transformer(model_arg, env) |
| 31 | + engine = PyTorchEngine(pt_model=model_ours, env=env) |
| 32 | + engine.rng = jax.random.PRNGKey(0) |
| 33 | + return engine |
| 34 | + |
| 35 | + def test_sampling_2D(self): |
| 36 | + # test greedy |
| 37 | + engine = self.setup() |
| 38 | + self.assertEqual(engine.env.sampling_algorithm, "greedy") |
| 39 | + logits = jnp.array([[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]]) |
| 40 | + token = engine._sampling(logits, batch_size=1) |
| 41 | + self.assertEqual(token, jnp.array([[0]])) |
| 42 | + self.assertTrue(jnp.isdtype(token, jnp.int32)) |
| 43 | + |
| 44 | + # test weighted |
| 45 | + engine.env.sampling_algorithm = "weighted" |
| 46 | + engine.env.temperature = 5.0 |
| 47 | + token = engine._sampling(logits, batch_size=1) |
| 48 | + self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) |
| 49 | + self.assertTrue(jnp.isdtype(token, jnp.int32)) |
| 50 | + |
| 51 | + # test topk |
| 52 | + engine.env.sampling_algorithm = "topk" |
| 53 | + engine.env.temperature = 5.0 |
| 54 | + engine.env.topk = 4 |
| 55 | + token = engine._sampling(logits, batch_size=1) |
| 56 | + self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) |
| 57 | + self.assertTrue(jnp.isdtype(token, jnp.int32)) |
| 58 | + |
| 59 | + # test nucleus |
| 60 | + engine.env.sampling_algorithm = "nucleus" |
| 61 | + engine.env.temperature = 0.0 |
| 62 | + engine.env.nucleus_topp = 0.8 |
| 63 | + token = engine._sampling(logits, batch_size=1) |
| 64 | + self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) |
| 65 | + self.assertTrue(jnp.isdtype(token, jnp.int32)) |
| 66 | + |
| 67 | + def test_sampling_3D(self): |
| 68 | + # test greedy |
| 69 | + engine = self.setup() |
| 70 | + self.assertEqual(engine.env.sampling_algorithm, "greedy") |
| 71 | + logits = jnp.array( |
| 72 | + [ |
| 73 | + [[0.4, 0.3, 0.2, 0.1], [0.5, 0.6, 0.7, 0.8]], |
| 74 | + [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]], |
| 75 | + ] |
| 76 | + ) |
| 77 | + token = engine._sampling(logits, batch_size=2) |
| 78 | + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [0]]))) |
| 79 | + self.assertTrue(jnp.isdtype(token, jnp.int32)) |
| 80 | + |
| 81 | + # test weighted |
| 82 | + engine.env.sampling_algorithm = "weighted" |
| 83 | + engine.env.temperature = 10.0 |
| 84 | + token = engine._sampling(logits, batch_size=2) |
| 85 | + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]]))) |
| 86 | + self.assertTrue(jnp.isdtype(token, jnp.int32)) |
| 87 | + |
| 88 | + # test topk |
| 89 | + engine.env.sampling_algorithm = "topk" |
| 90 | + engine.env.temperature = 1.0 |
| 91 | + engine.env.topk = 3 |
| 92 | + token = engine._sampling(logits, batch_size=2) |
| 93 | + self.assertTrue(jnp.array_equal(token, jnp.array([[1], [0]]))) |
| 94 | + self.assertTrue(jnp.isdtype(token, jnp.int32)) |
| 95 | + |
| 96 | + # test nucleus |
| 97 | + engine.env.sampling_algorithm = "nucleus" |
| 98 | + engine.env.temperature = 1.0 |
| 99 | + engine.env.nucleus_topp = 0.8 |
| 100 | + token = engine._sampling(logits, batch_size=2) |
| 101 | + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]]))) |
| 102 | + self.assertTrue(jnp.isdtype(token, jnp.int32)) |
57 | 103 |
|
58 | 104 |
|
59 | 105 | # def test_insert(self): |
|
229 | 275 | # # prefill |
230 | 276 |
|
231 | 277 |
|
232 | | -# if __name__ == '__main__': |
233 | | -# unittest.main() |
| 278 | +if __name__ == "__main__": |
| 279 | + unittest.main() |
0 commit comments