• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import numpy
2from pt import configs
3
4import operator_benchmark as op_bench
5
6import torch
7
8
9"""Embedding and EmbeddingBag Operator Benchmark"""
10
11
12class EmbeddingBagBenchmark(op_bench.TorchBenchmarkBase):
13    def init(
14        self,
15        embeddingbags,
16        dim,
17        mode,
18        input_size,
19        offset,
20        sparse,
21        include_last_offset,
22        device,
23    ):
24        self.embedding = torch.nn.EmbeddingBag(
25            num_embeddings=embeddingbags,
26            embedding_dim=dim,
27            mode=mode,
28            include_last_offset=include_last_offset,
29            sparse=sparse,
30        ).to(device=device)
31        numpy.random.seed((1 << 32) - 1)
32        offsets = torch.LongTensor([offset], device=device)
33        input = torch.tensor(
34            numpy.random.randint(0, embeddingbags, input_size), device=device
35        ).long()
36        self.inputs = {
37            "input": input,
38            "offset": torch.cat(
39                (offsets, torch.tensor([input.size(0)], dtype=torch.long)), 0
40            ),
41        }
42        self.set_module_name("embeddingbag")
43
44    def forward(self, input, offset):
45        return self.embedding(input, offset)
46
47
48op_bench.generate_pt_test(configs.embeddingbag_short_configs, EmbeddingBagBenchmark)
49op_bench.generate_pt_gradient_test(
50    configs.embeddingbag_short_configs, EmbeddingBagBenchmark
51)
52
53
54class EmbeddingBenchmark(op_bench.TorchBenchmarkBase):
55    def init(self, num_embeddings, embedding_dim, input_size, device):
56        self.embedding = torch.nn.Embedding(
57            num_embeddings=num_embeddings, embedding_dim=embedding_dim
58        ).to(device=device)
59        numpy.random.seed((1 << 32) - 1)
60        input = torch.tensor(
61            numpy.random.randint(0, num_embeddings, input_size), device=device
62        ).long()
63        self.inputs = {"input": input}
64        self.set_module_name("embedding")
65
66    def forward(self, input):
67        return self.embedding(input)
68
69
70op_bench.generate_pt_test(configs.embedding_short_configs, EmbeddingBenchmark)
71op_bench.generate_pt_gradient_test(configs.embedding_short_configs, EmbeddingBenchmark)
72
73if __name__ == "__main__":
74    op_bench.benchmark_runner.main()
75