1import numpy 2from pt import configs 3 4import operator_benchmark as op_bench 5 6import torch 7import torch.ao.nn.qat as nnqat 8from torch.ao.quantization import default_embedding_qat_qconfig 9 10 11""" 12Microbenchmarks for QAT Embedding + EmbeddingBag operators. 13""" 14 15 16class QATEmbeddingBagBenchmark(op_bench.TorchBenchmarkBase): 17 def init( 18 self, 19 embeddingbags, 20 dim, 21 mode, 22 input_size, 23 offset, 24 sparse, 25 include_last_offset, 26 device, 27 ): 28 qconfig = default_embedding_qat_qconfig 29 self.embedding = nnqat.EmbeddingBag( 30 num_embeddings=embeddingbags, 31 embedding_dim=dim, 32 mode=mode, 33 include_last_offset=include_last_offset, 34 sparse=sparse, 35 device=device, 36 qconfig=qconfig, 37 ) 38 numpy.random.seed((1 << 32) - 1) 39 offsets = torch.LongTensor([offset], device=device) 40 input = torch.tensor( 41 numpy.random.randint(0, embeddingbags, input_size), device=device 42 ).long() 43 self.inputs = { 44 "input": input, 45 "offset": torch.cat( 46 (offsets, torch.tensor([input.size(0)], dtype=torch.long)), 0 47 ), 48 } 49 self.set_module_name("qatEmbeddingBag") 50 51 def forward(self, input, offset): 52 return self.embedding(input, offset) 53 54 55# Currently, EmbeddingBag QAT does not support sparse embeddings. 56embeddingbag_short_dense_configs = [ 57 config 58 for config in configs.embeddingbag_short_configs 59 if {"sparse": True} not in config 60] 61 62op_bench.generate_pt_test(embeddingbag_short_dense_configs, QATEmbeddingBagBenchmark) 63op_bench.generate_pt_gradient_test( 64 embeddingbag_short_dense_configs, QATEmbeddingBagBenchmark 65) 66 67 68class QATEmbeddingBenchmark(op_bench.TorchBenchmarkBase): 69 def init(self, num_embeddings, embedding_dim, input_size, device): 70 qconfig = default_embedding_qat_qconfig 71 self.embedding = nnqat.Embedding( 72 num_embeddings=num_embeddings, 73 embedding_dim=embedding_dim, 74 qconfig=qconfig, 75 device=device, 76 ) 77 self.embedding.qconfig = default_embedding_qat_qconfig 78 numpy.random.seed((1 << 32) - 1) 79 self.input = torch.tensor( 80 numpy.random.randint(0, num_embeddings, input_size), device=device 81 ).long() 82 self.inputs = {"input": self.input} 83 self.set_module_name("qatEmbedding") 84 85 def forward(self, input): 86 return self.embedding(input) 87 88 89op_bench.generate_pt_test(configs.embedding_short_configs, QATEmbeddingBenchmark) 90op_bench.generate_pt_gradient_test( 91 configs.embedding_short_configs, QATEmbeddingBenchmark 92) 93 94if __name__ == "__main__": 95 op_bench.benchmark_runner.main() 96