• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import torch
2
3from . import benchmark
4
5
6class RNNEltwise(benchmark.Benchmark):
7    def __init__(self, mode, device, dtype, b, hs):
8        super().__init__(mode, device, dtype)
9        self.b = b
10        self.hs = hs
11        self.input = self.rand(
12            [b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
13        )
14        self.hx = self.rand(
15            [b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
16        )
17        self.cx = self.rand(
18            [b, hs], device=device, dtype=dtype, requires_grad=self.requires_grad
19        )
20        self.b_ih = self.rand(
21            [b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
22        )
23        self.b_hh = self.rand(
24            [b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
25        )
26        self.inputs = [
27            self.input,
28            self.hx,
29            self.cx,
30            self.b_ih,
31            self.b_hh,
32        ]
33
34    def forward(self, input, hx, cx, b_ih, b_hh):
35        gates = input + hx + b_ih + b_hh
36
37        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
38
39        ingate = torch.sigmoid(ingate)
40        forgetgate = torch.sigmoid(forgetgate)
41        cellgate = torch.tanh(cellgate)
42        outgate = torch.sigmoid(outgate)
43
44        cy = (forgetgate * cx) + (ingate * cellgate)
45        hy = outgate * torch.tanh(cy)
46
47        return hy, cy
48
49    def config(self):
50        return [self.b, self.hs]
51
52    @staticmethod
53    def module():
54        return "rnn_eltwise"
55
56    def memory_workload(self):
57        def memsize(t):
58            return t.numel() * t.element_size()
59
60        input_size = sum(memsize(t) for t in self.inputs)
61        output_size = 2 * memsize(self.cx)
62        io_size = input_size + output_size
63        return {"sol": io_size, "algorithmic": io_size}
64
65    @staticmethod
66    def default_configs():
67        return [[64, 512]]
68
69
70benchmark.register_benchmark_class(RNNEltwise)
71
72
73class DynamicLSTM(benchmark.DynamicShape, RNNEltwise):
74    def __init__(self, mode, device, dtype, b, hs):
75        benchmark.DynamicShape.__init__(self)
76        RNNEltwise.__init__(self, mode, device, dtype, b, hs)
77
78    def instantiate_input(self):
79        b, hs = self.rand_shape([self.b, self.hs])
80
81        self.input = self.rand(
82            [b, 4 * hs],
83            device=self.device,
84            dtype=self.dtype,
85            requires_grad=self.requires_grad,
86        )
87        self.hx = self.rand(
88            [b, 4 * hs],
89            device=self.device,
90            dtype=self.dtype,
91            requires_grad=self.requires_grad,
92        )
93        self.cx = self.rand(
94            [b, hs],
95            device=self.device,
96            dtype=self.dtype,
97            requires_grad=self.requires_grad,
98        )
99        self.b_ih = self.rand(
100            [b, 4 * hs],
101            device=self.device,
102            dtype=self.dtype,
103            requires_grad=self.requires_grad,
104        )
105        self.b_hh = self.rand(
106            [b, 4 * hs],
107            device=self.device,
108            dtype=self.dtype,
109            requires_grad=self.requires_grad,
110        )
111        self.inputs = [
112            self.input,
113            self.hx,
114            self.cx,
115            self.b_ih,
116            self.b_hh,
117        ]
118
119    @staticmethod
120    def module():
121        return "dynamic_lstm"
122
123
124benchmark.register_benchmark_class(DynamicLSTM)
125