1from typing import Optional 2 3import numpy as np 4 5import operator_benchmark as op_bench 6 7import torch 8from torch.testing._internal.common_quantization import lengths_to_offsets 9 10 11torch.ops.load_library("//caffe2/torch/fb/sparsenn:sparsenn_operators") 12 13 14embedding_bag_rowwise_offsets_short_configs = op_bench.cross_product_configs( 15 num_embeddings=(80,), 16 embedding_dim=(128, 256), 17 num_offsets=range(2, 10), 18 enable_per_sample_weights=(True, False), 19 include_last_offset=(True, False), 20 is_pruned_weights=( 21 True, 22 False, 23 ), 24 use_32bit_indices=(True, False), 25 use_32bit_offsets=(True, False), 26 tags=["short"], 27) 28 29 30embedding_bag_rowwise_offsets_long_configs = op_bench.cross_product_configs( 31 num_embeddings=(100, 120, 1000, 10_000, 20_000), 32 embedding_dim=(16, 64, 128, 256), 33 num_offsets=range(10, 20), 34 enable_per_sample_weights=(True, False), 35 include_last_offset=(True, False), 36 is_pruned_weights=( 37 True, 38 False, 39 ), 40 use_32bit_indices=(True, False), 41 use_32bit_offsets=(True, False), 42 tags=["long"], 43) 44 45 46full_configs = ( 47 embedding_bag_rowwise_offsets_short_configs 48 + embedding_bag_rowwise_offsets_long_configs 49) 50 51four_bit_rowwise_ops = op_bench.op_list( 52 attrs=( 53 ( 54 "qembeddingbag_4bit_rowwise_offsets", 55 torch.ops.quantized.embedding_bag_4bit_rowwise_offsets, 56 ), 57 ), 58 attr_names=("op_name", "op_func"), 59) 60 61byte_rowwise_ops = op_bench.op_list( 62 attrs=( 63 ( 64 "qembeddingbag_byte_rowwise_offsets", 65 torch.ops.quantized.embedding_bag_byte_rowwise_offsets, 66 ), 67 ), 68 attr_names=("op_name", "op_func"), 69) 70 71 72def get_pruned_weights_and_mapping(q_weights): 73 indicator = torch.from_numpy( 74 np.random.uniform(low=-1.0, high=1.0, size=[q_weights.shape[0]]).astype( 75 np.float32 76 ) 77 ) 78 79 ( 80 q_pruned_weights, 81 compressed_indices_mapping, 82 ) = torch.ops.fb.embedding_bag_rowwise_prune( 83 q_weights, indicator, 0.01, torch.int32 84 ) 85 86 return q_pruned_weights, compressed_indices_mapping 87 88 89class EmbedddingBag4BitRowwiseOffsetsTest(op_bench.TorchBenchmarkBase): 90 def init( 91 self, 92 num_embeddings: int, 93 embedding_dim: int, 94 num_offsets: int, 95 enable_per_sample_weights: bool, 96 include_last_offset: bool, 97 is_pruned_weights: bool, 98 use_32bit_indices: bool, 99 use_32bit_offsets: bool, 100 op_func, 101 ): 102 self.num_embeddings = num_embeddings 103 self.embedding_dim = embedding_dim 104 self.num_offsets = num_offsets 105 self.enable_per_sample_weights = enable_per_sample_weights 106 self.include_last_offset = include_last_offset 107 self.max_segment_length = 20 108 self.num_lengths = np.random.randint(1, num_offsets + 1) 109 self.lengths = np.random.randint( 110 0, self.max_segment_length + 1, size=self.num_lengths 111 ).astype(np.int32) 112 self.num_indices = np.sum(self.lengths) 113 self.is_pruned_weights = is_pruned_weights 114 self.use_32bit_indices = use_32bit_indices 115 self.use_32bit_offsets = use_32bit_offsets 116 117 self.offsets = lengths_to_offsets(self.lengths) 118 self.indices = torch.from_numpy( 119 np.random.randint( 120 low=0, high=num_embeddings, size=self.num_indices, dtype=np.int64 121 ) 122 ) 123 124 self.indices = self.indices.int() if self.use_32bit_indices else self.indices 125 self.offsets = self.offsets.int() if self.use_32bit_offsets else self.offsets 126 127 if self.include_last_offset: 128 self.offsets = torch.cat( 129 (self.offsets, torch.tensor([self.indices.size(0)], dtype=torch.long)), 130 0, 131 ) 132 133 self.weights = torch.from_numpy( 134 ( 135 np.random.random_sample((self.num_embeddings, self.embedding_dim)) + 1 136 ).astype(np.float32) 137 ) 138 self.indices = torch.from_numpy( 139 np.random.randint( 140 low=0, high=self.num_embeddings, size=self.num_indices, dtype=np.int64 141 ) 142 ) 143 self.prepack_func = torch.ops.quantized.embedding_bag_4bit_prepack 144 145 self.prepacked_weights = self.prepack_func(self.weights) 146 self.per_sample_weights = ( 147 torch.from_numpy( 148 np.random.uniform(low=0.01, high=0.5, size=[len(self.indices)]).astype( 149 np.float32 150 ) 151 ) 152 if self.enable_per_sample_weights 153 else None 154 ) 155 156 self.compressed_indices = None 157 158 if self.is_pruned_weights: 159 ( 160 self.prepacked_weights, 161 self.compressed_indices, 162 ) = get_pruned_weights_and_mapping(self.prepacked_weights) 163 164 self.inputs = { 165 "prepacked_weights": self.prepacked_weights, 166 "indices": self.indices, 167 "offsets": self.offsets, 168 "mode": 0, 169 "per_sample_weights": self.per_sample_weights, 170 "include_last_offset": self.include_last_offset, 171 "is_pruned_weights": self.is_pruned_weights, 172 "compressed_indices": self.compressed_indices, 173 } 174 175 self.op_func = op_func 176 177 def forward( 178 self, 179 prepacked_weights, 180 indices, 181 offsets, 182 mode: int, 183 per_sample_weights: Optional[torch.Tensor], 184 include_last_offset: bool, 185 is_pruned_weights: bool, 186 compressed_indices: Optional[torch.Tensor], 187 ): 188 return self.op_func( 189 prepacked_weights, 190 indices, 191 offsets, 192 mode=mode, 193 per_sample_weights=per_sample_weights, 194 include_last_offset=include_last_offset, 195 pruned_weights=is_pruned_weights, 196 compressed_indices_mapping=compressed_indices, 197 ) 198 199 200class EmbedddingBagByteRowwiseOffsetsTest(op_bench.TorchBenchmarkBase): 201 def init( 202 self, 203 num_embeddings: int, 204 embedding_dim: int, 205 num_offsets: int, 206 enable_per_sample_weights: bool, 207 include_last_offset: bool, 208 is_pruned_weights: bool, 209 use_32bit_indices: bool, 210 use_32bit_offsets: bool, 211 op_func, 212 ): 213 self.num_embeddings = num_embeddings 214 self.embedding_dim = embedding_dim 215 self.num_offsets = num_offsets 216 self.enable_per_sample_weights = enable_per_sample_weights 217 self.include_last_offset = include_last_offset 218 self.max_segment_length = 20 219 self.num_lengths = np.random.randint(1, num_offsets + 1) 220 self.lengths = np.random.randint( 221 0, self.max_segment_length + 1, size=self.num_lengths 222 ).astype(np.int32) 223 self.is_pruned_weights = is_pruned_weights 224 self.use_32bit_indices = use_32bit_indices 225 self.use_32bit_offsets = use_32bit_offsets 226 227 self.num_indices = np.sum(self.lengths) 228 self.offsets = lengths_to_offsets(self.lengths) 229 self.indices = torch.from_numpy( 230 np.random.randint( 231 low=0, high=num_embeddings, size=self.num_indices, dtype=np.int64 232 ) 233 ) 234 235 self.indices = self.indices.int() if self.use_32bit_indices else self.indices 236 self.offsets = self.offsets.int() if self.use_32bit_offsets else self.offsets 237 238 if include_last_offset: 239 self.offsets = torch.cat( 240 (self.offsets, torch.tensor([self.indices.size(0)], dtype=torch.long)), 241 0, 242 ) 243 244 self.weights = torch.from_numpy( 245 ( 246 np.random.random_sample((self.num_embeddings, self.embedding_dim)) + 1 247 ).astype(np.float32) 248 ) 249 self.indices = torch.from_numpy( 250 np.random.randint( 251 low=0, high=self.num_embeddings, size=self.num_indices, dtype=np.int64 252 ) 253 ) 254 255 self.prepack_func = torch.ops.quantized.embedding_bag_byte_prepack 256 257 self.prepacked_weights = self.prepack_func(self.weights) 258 self.per_sample_weights = ( 259 torch.from_numpy( 260 np.random.uniform(low=0.01, high=0.5, size=[len(self.indices)]).astype( 261 np.float32 262 ) 263 ) 264 if self.enable_per_sample_weights 265 else None 266 ) 267 268 self.compressed_indices = None 269 270 if self.is_pruned_weights: 271 ( 272 self.prepacked_weights, 273 self.compressed_indices, 274 ) = get_pruned_weights_and_mapping(self.prepacked_weights) 275 276 self.inputs = { 277 "prepacked_weights": self.prepacked_weights, 278 "indices": self.indices, 279 "offsets": self.offsets, 280 "mode": 0, 281 "per_sample_weights": self.per_sample_weights, 282 "include_last_offset": self.include_last_offset, 283 "is_pruned_weights": self.is_pruned_weights, 284 "compressed_indices": self.compressed_indices, 285 } 286 287 self.op_func = op_func 288 289 def forward( 290 self, 291 prepacked_weights, 292 indices, 293 offsets, 294 mode: int, 295 per_sample_weights: Optional[torch.Tensor], 296 include_last_offset: bool, 297 is_pruned_weights: bool, 298 compressed_indices: Optional[torch.Tensor], 299 ): 300 return self.op_func( 301 prepacked_weights, 302 indices, 303 offsets, 304 mode=0, 305 per_sample_weights=per_sample_weights, 306 include_last_offset=self.include_last_offset, 307 pruned_weights=self.is_pruned_weights, 308 compressed_indices_mapping=self.compressed_indices, 309 ) 310 311 312op_bench.generate_pt_tests_from_op_list( 313 four_bit_rowwise_ops, full_configs, EmbedddingBag4BitRowwiseOffsetsTest 314) 315op_bench.generate_pt_tests_from_op_list( 316 byte_rowwise_ops, full_configs, EmbedddingBagByteRowwiseOffsetsTest 317) 318 319 320if __name__ == "__main__": 321 op_bench.benchmark_runner.main() 322