• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from typing import Optional
2
3import numpy as np
4
5import operator_benchmark as op_bench
6
7import torch
8from torch.testing._internal.common_quantization import lengths_to_offsets
9
10
11torch.ops.load_library("//caffe2/torch/fb/sparsenn:sparsenn_operators")
12
13
14embedding_bag_rowwise_offsets_short_configs = op_bench.cross_product_configs(
15    num_embeddings=(80,),
16    embedding_dim=(128, 256),
17    num_offsets=range(2, 10),
18    enable_per_sample_weights=(True, False),
19    include_last_offset=(True, False),
20    is_pruned_weights=(
21        True,
22        False,
23    ),
24    use_32bit_indices=(True, False),
25    use_32bit_offsets=(True, False),
26    tags=["short"],
27)
28
29
30embedding_bag_rowwise_offsets_long_configs = op_bench.cross_product_configs(
31    num_embeddings=(100, 120, 1000, 10_000, 20_000),
32    embedding_dim=(16, 64, 128, 256),
33    num_offsets=range(10, 20),
34    enable_per_sample_weights=(True, False),
35    include_last_offset=(True, False),
36    is_pruned_weights=(
37        True,
38        False,
39    ),
40    use_32bit_indices=(True, False),
41    use_32bit_offsets=(True, False),
42    tags=["long"],
43)
44
45
46full_configs = (
47    embedding_bag_rowwise_offsets_short_configs
48    + embedding_bag_rowwise_offsets_long_configs
49)
50
51four_bit_rowwise_ops = op_bench.op_list(
52    attrs=(
53        (
54            "qembeddingbag_4bit_rowwise_offsets",
55            torch.ops.quantized.embedding_bag_4bit_rowwise_offsets,
56        ),
57    ),
58    attr_names=("op_name", "op_func"),
59)
60
61byte_rowwise_ops = op_bench.op_list(
62    attrs=(
63        (
64            "qembeddingbag_byte_rowwise_offsets",
65            torch.ops.quantized.embedding_bag_byte_rowwise_offsets,
66        ),
67    ),
68    attr_names=("op_name", "op_func"),
69)
70
71
72def get_pruned_weights_and_mapping(q_weights):
73    indicator = torch.from_numpy(
74        np.random.uniform(low=-1.0, high=1.0, size=[q_weights.shape[0]]).astype(
75            np.float32
76        )
77    )
78
79    (
80        q_pruned_weights,
81        compressed_indices_mapping,
82    ) = torch.ops.fb.embedding_bag_rowwise_prune(
83        q_weights, indicator, 0.01, torch.int32
84    )
85
86    return q_pruned_weights, compressed_indices_mapping
87
88
89class EmbedddingBag4BitRowwiseOffsetsTest(op_bench.TorchBenchmarkBase):
90    def init(
91        self,
92        num_embeddings: int,
93        embedding_dim: int,
94        num_offsets: int,
95        enable_per_sample_weights: bool,
96        include_last_offset: bool,
97        is_pruned_weights: bool,
98        use_32bit_indices: bool,
99        use_32bit_offsets: bool,
100        op_func,
101    ):
102        self.num_embeddings = num_embeddings
103        self.embedding_dim = embedding_dim
104        self.num_offsets = num_offsets
105        self.enable_per_sample_weights = enable_per_sample_weights
106        self.include_last_offset = include_last_offset
107        self.max_segment_length = 20
108        self.num_lengths = np.random.randint(1, num_offsets + 1)
109        self.lengths = np.random.randint(
110            0, self.max_segment_length + 1, size=self.num_lengths
111        ).astype(np.int32)
112        self.num_indices = np.sum(self.lengths)
113        self.is_pruned_weights = is_pruned_weights
114        self.use_32bit_indices = use_32bit_indices
115        self.use_32bit_offsets = use_32bit_offsets
116
117        self.offsets = lengths_to_offsets(self.lengths)
118        self.indices = torch.from_numpy(
119            np.random.randint(
120                low=0, high=num_embeddings, size=self.num_indices, dtype=np.int64
121            )
122        )
123
124        self.indices = self.indices.int() if self.use_32bit_indices else self.indices
125        self.offsets = self.offsets.int() if self.use_32bit_offsets else self.offsets
126
127        if self.include_last_offset:
128            self.offsets = torch.cat(
129                (self.offsets, torch.tensor([self.indices.size(0)], dtype=torch.long)),
130                0,
131            )
132
133        self.weights = torch.from_numpy(
134            (
135                np.random.random_sample((self.num_embeddings, self.embedding_dim)) + 1
136            ).astype(np.float32)
137        )
138        self.indices = torch.from_numpy(
139            np.random.randint(
140                low=0, high=self.num_embeddings, size=self.num_indices, dtype=np.int64
141            )
142        )
143        self.prepack_func = torch.ops.quantized.embedding_bag_4bit_prepack
144
145        self.prepacked_weights = self.prepack_func(self.weights)
146        self.per_sample_weights = (
147            torch.from_numpy(
148                np.random.uniform(low=0.01, high=0.5, size=[len(self.indices)]).astype(
149                    np.float32
150                )
151            )
152            if self.enable_per_sample_weights
153            else None
154        )
155
156        self.compressed_indices = None
157
158        if self.is_pruned_weights:
159            (
160                self.prepacked_weights,
161                self.compressed_indices,
162            ) = get_pruned_weights_and_mapping(self.prepacked_weights)
163
164        self.inputs = {
165            "prepacked_weights": self.prepacked_weights,
166            "indices": self.indices,
167            "offsets": self.offsets,
168            "mode": 0,
169            "per_sample_weights": self.per_sample_weights,
170            "include_last_offset": self.include_last_offset,
171            "is_pruned_weights": self.is_pruned_weights,
172            "compressed_indices": self.compressed_indices,
173        }
174
175        self.op_func = op_func
176
177    def forward(
178        self,
179        prepacked_weights,
180        indices,
181        offsets,
182        mode: int,
183        per_sample_weights: Optional[torch.Tensor],
184        include_last_offset: bool,
185        is_pruned_weights: bool,
186        compressed_indices: Optional[torch.Tensor],
187    ):
188        return self.op_func(
189            prepacked_weights,
190            indices,
191            offsets,
192            mode=mode,
193            per_sample_weights=per_sample_weights,
194            include_last_offset=include_last_offset,
195            pruned_weights=is_pruned_weights,
196            compressed_indices_mapping=compressed_indices,
197        )
198
199
200class EmbedddingBagByteRowwiseOffsetsTest(op_bench.TorchBenchmarkBase):
201    def init(
202        self,
203        num_embeddings: int,
204        embedding_dim: int,
205        num_offsets: int,
206        enable_per_sample_weights: bool,
207        include_last_offset: bool,
208        is_pruned_weights: bool,
209        use_32bit_indices: bool,
210        use_32bit_offsets: bool,
211        op_func,
212    ):
213        self.num_embeddings = num_embeddings
214        self.embedding_dim = embedding_dim
215        self.num_offsets = num_offsets
216        self.enable_per_sample_weights = enable_per_sample_weights
217        self.include_last_offset = include_last_offset
218        self.max_segment_length = 20
219        self.num_lengths = np.random.randint(1, num_offsets + 1)
220        self.lengths = np.random.randint(
221            0, self.max_segment_length + 1, size=self.num_lengths
222        ).astype(np.int32)
223        self.is_pruned_weights = is_pruned_weights
224        self.use_32bit_indices = use_32bit_indices
225        self.use_32bit_offsets = use_32bit_offsets
226
227        self.num_indices = np.sum(self.lengths)
228        self.offsets = lengths_to_offsets(self.lengths)
229        self.indices = torch.from_numpy(
230            np.random.randint(
231                low=0, high=num_embeddings, size=self.num_indices, dtype=np.int64
232            )
233        )
234
235        self.indices = self.indices.int() if self.use_32bit_indices else self.indices
236        self.offsets = self.offsets.int() if self.use_32bit_offsets else self.offsets
237
238        if include_last_offset:
239            self.offsets = torch.cat(
240                (self.offsets, torch.tensor([self.indices.size(0)], dtype=torch.long)),
241                0,
242            )
243
244        self.weights = torch.from_numpy(
245            (
246                np.random.random_sample((self.num_embeddings, self.embedding_dim)) + 1
247            ).astype(np.float32)
248        )
249        self.indices = torch.from_numpy(
250            np.random.randint(
251                low=0, high=self.num_embeddings, size=self.num_indices, dtype=np.int64
252            )
253        )
254
255        self.prepack_func = torch.ops.quantized.embedding_bag_byte_prepack
256
257        self.prepacked_weights = self.prepack_func(self.weights)
258        self.per_sample_weights = (
259            torch.from_numpy(
260                np.random.uniform(low=0.01, high=0.5, size=[len(self.indices)]).astype(
261                    np.float32
262                )
263            )
264            if self.enable_per_sample_weights
265            else None
266        )
267
268        self.compressed_indices = None
269
270        if self.is_pruned_weights:
271            (
272                self.prepacked_weights,
273                self.compressed_indices,
274            ) = get_pruned_weights_and_mapping(self.prepacked_weights)
275
276        self.inputs = {
277            "prepacked_weights": self.prepacked_weights,
278            "indices": self.indices,
279            "offsets": self.offsets,
280            "mode": 0,
281            "per_sample_weights": self.per_sample_weights,
282            "include_last_offset": self.include_last_offset,
283            "is_pruned_weights": self.is_pruned_weights,
284            "compressed_indices": self.compressed_indices,
285        }
286
287        self.op_func = op_func
288
289    def forward(
290        self,
291        prepacked_weights,
292        indices,
293        offsets,
294        mode: int,
295        per_sample_weights: Optional[torch.Tensor],
296        include_last_offset: bool,
297        is_pruned_weights: bool,
298        compressed_indices: Optional[torch.Tensor],
299    ):
300        return self.op_func(
301            prepacked_weights,
302            indices,
303            offsets,
304            mode=0,
305            per_sample_weights=per_sample_weights,
306            include_last_offset=self.include_last_offset,
307            pruned_weights=self.is_pruned_weights,
308            compressed_indices_mapping=self.compressed_indices,
309        )
310
311
312op_bench.generate_pt_tests_from_op_list(
313    four_bit_rowwise_ops, full_configs, EmbedddingBag4BitRowwiseOffsetsTest
314)
315op_bench.generate_pt_tests_from_op_list(
316    byte_rowwise_ops, full_configs, EmbedddingBagByteRowwiseOffsetsTest
317)
318
319
320if __name__ == "__main__":
321    op_bench.benchmark_runner.main()
322