# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe import unittest import torch class UpdateQuantizedKVCacheTest(unittest.TestCase): def _reset(self): self.quantized_k_cache = torch.zeros( (self.batch_size, self.seq_len, self.num_heads, self.head_dim), dtype=torch.int8, ) self.quantized_v_cache = torch.zeros( (self.batch_size, self.seq_len, self.num_heads, self.head_dim), dtype=torch.int8, ) self.k_scales_cache = torch.zeros( (self.batch_size, self.seq_len, self.num_heads, 1), dtype=torch.float64 ) self.v_scales_cache = torch.zeros( (self.batch_size, self.seq_len, self.num_heads, 1), dtype=torch.float64 ) self.k_zero_points_cache = torch.zeros( (self.batch_size, self.seq_len, self.num_heads, 1), dtype=torch.int64 ) self.v_zero_points_cache = torch.zeros( (self.batch_size, self.seq_len, self.num_heads, 1), dtype=torch.int64 ) def setUp(self): torch.manual_seed(42) self.batch_size = 1 self.seq_len = 10 self.num_heads = 8 self.head_dim = 4 self._reset() def _update_k(self, start_pos, value, scales, zero_points): seq_len = value.size(1) self.quantized_k_cache[:, start_pos : start_pos + seq_len, :, :] = value self.k_scales_cache[:, start_pos : start_pos + seq_len, :, :] = scales self.k_zero_points_cache[:, start_pos : start_pos + seq_len, :, :] = zero_points def _update_v(self, start_pos, value, scales, zero_points): seq_len = value.size(1) self.quantized_v_cache[:, start_pos : start_pos + seq_len, :, :] = value self.v_scales_cache[:, start_pos : start_pos + seq_len, :, :] = scales self.v_zero_points_cache[:, start_pos : start_pos + seq_len, :, :] = zero_points def _update_and_validate( self, k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos ): k_cache = self.quantized_k_cache.clone() v_cache = self.quantized_v_cache.clone() k_scales_cache = self.k_scales_cache.clone() v_scales_cache = self.v_scales_cache.clone() k_zero_points_cache = self.k_zero_points_cache.clone() v_zero_points_cache = self.v_zero_points_cache.clone() self._update_k(start_pos, k, k_scales, k_zero_points) self._update_v(start_pos, v, v_scales, v_zero_points) torch.ops.llama.update_quantized_cache(k, k_cache, start_pos) torch.ops.llama.update_quantized_cache(k_scales, k_scales_cache, start_pos) torch.ops.llama.update_quantized_cache( k_zero_points, k_zero_points_cache, start_pos ) torch.ops.llama.update_quantized_cache(v, v_cache, start_pos) torch.ops.llama.update_quantized_cache(v_scales, v_scales_cache, start_pos) torch.ops.llama.update_quantized_cache( v_zero_points, v_zero_points_cache, start_pos ) self.assertTrue(torch.allclose(k_cache, self.quantized_k_cache)) self.assertTrue(torch.allclose(v_cache, self.quantized_v_cache)) self.assertTrue(torch.allclose(k_scales_cache, self.k_scales_cache)) self.assertTrue(torch.allclose(v_scales_cache, self.v_scales_cache)) self.assertTrue(torch.allclose(k_zero_points_cache, self.k_zero_points_cache)) self.assertTrue(torch.allclose(v_zero_points_cache, self.v_zero_points_cache)) def test_update_kv_cache_simple(self): k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) start_pos = 0 self._update_and_validate( k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos ) def test_update_kv_cache_large_update(self): self._reset() k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) v = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) k_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) v_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) k_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) v_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) start_pos = 0 self._update_and_validate( k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos ) def test_update_kv_cache_update_nonzero_offset(self): self._reset() k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) start_pos = 2 self._update_and_validate( k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos ) def test_update_kv_cache_more_updates(self): self._reset() k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) start_pos = 2 self._update_and_validate( k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos ) k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) start_pos = 4 self._update_and_validate( k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos ) def test_batched_update_kv_cache_more_updates(self): self.batch_size = 7 self._reset() k = torch.randint(0, 50, (self.batch_size, 1, 8, 4), dtype=torch.int8) v = torch.randint(0, 50, (self.batch_size, 1, 8, 4), dtype=torch.int8) k_scales = torch.rand((self.batch_size, 1, 8, 1), dtype=torch.float64) v_scales = torch.rand((self.batch_size, 1, 8, 1), dtype=torch.float64) k_zero_points = torch.randint( 0, 20, (self.batch_size, 1, 8, 1), dtype=torch.int64 ) v_zero_points = torch.randint( 0, 20, (self.batch_size, 1, 8, 1), dtype=torch.int64 ) start_pos = 2 self._update_and_validate( k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos ) k = torch.randint(0, 50, (self.batch_size, 1, 8, 4), dtype=torch.int8) v = torch.randint(0, 50, (self.batch_size, 1, 8, 4), dtype=torch.int8) k_scales = torch.rand((self.batch_size, 1, 8, 1), dtype=torch.float64) v_scales = torch.rand((self.batch_size, 1, 8, 1), dtype=torch.float64) k_zero_points = torch.randint( 0, 20, (self.batch_size, 1, 8, 1), dtype=torch.int64 ) v_zero_points = torch.randint( 0, 20, (self.batch_size, 1, 8, 1), dtype=torch.int64 ) start_pos = 4 self._update_and_validate( k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos )