1import torch 2 3from . import benchmark 4 5 6class RNNEltwise(benchmark.Benchmark): 7 def __init__(self, mode, device, dtype, b, hs): 8 super().__init__(mode, device, dtype) 9 self.b = b 10 self.hs = hs 11 self.input = self.rand( 12 [b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad 13 ) 14 self.hx = self.rand( 15 [b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad 16 ) 17 self.cx = self.rand( 18 [b, hs], device=device, dtype=dtype, requires_grad=self.requires_grad 19 ) 20 self.b_ih = self.rand( 21 [b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad 22 ) 23 self.b_hh = self.rand( 24 [b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad 25 ) 26 self.inputs = [ 27 self.input, 28 self.hx, 29 self.cx, 30 self.b_ih, 31 self.b_hh, 32 ] 33 34 def forward(self, input, hx, cx, b_ih, b_hh): 35 gates = input + hx + b_ih + b_hh 36 37 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 38 39 ingate = torch.sigmoid(ingate) 40 forgetgate = torch.sigmoid(forgetgate) 41 cellgate = torch.tanh(cellgate) 42 outgate = torch.sigmoid(outgate) 43 44 cy = (forgetgate * cx) + (ingate * cellgate) 45 hy = outgate * torch.tanh(cy) 46 47 return hy, cy 48 49 def config(self): 50 return [self.b, self.hs] 51 52 @staticmethod 53 def module(): 54 return "rnn_eltwise" 55 56 def memory_workload(self): 57 def memsize(t): 58 return t.numel() * t.element_size() 59 60 input_size = sum(memsize(t) for t in self.inputs) 61 output_size = 2 * memsize(self.cx) 62 io_size = input_size + output_size 63 return {"sol": io_size, "algorithmic": io_size} 64 65 @staticmethod 66 def default_configs(): 67 return [[64, 512]] 68 69 70benchmark.register_benchmark_class(RNNEltwise) 71 72 73class DynamicLSTM(benchmark.DynamicShape, RNNEltwise): 74 def __init__(self, mode, device, dtype, b, hs): 75 benchmark.DynamicShape.__init__(self) 76 RNNEltwise.__init__(self, mode, device, dtype, b, hs) 77 78 def instantiate_input(self): 79 b, hs = self.rand_shape([self.b, self.hs]) 80 81 self.input = self.rand( 82 [b, 4 * hs], 83 device=self.device, 84 dtype=self.dtype, 85 requires_grad=self.requires_grad, 86 ) 87 self.hx = self.rand( 88 [b, 4 * hs], 89 device=self.device, 90 dtype=self.dtype, 91 requires_grad=self.requires_grad, 92 ) 93 self.cx = self.rand( 94 [b, hs], 95 device=self.device, 96 dtype=self.dtype, 97 requires_grad=self.requires_grad, 98 ) 99 self.b_ih = self.rand( 100 [b, 4 * hs], 101 device=self.device, 102 dtype=self.dtype, 103 requires_grad=self.requires_grad, 104 ) 105 self.b_hh = self.rand( 106 [b, 4 * hs], 107 device=self.device, 108 dtype=self.dtype, 109 requires_grad=self.requires_grad, 110 ) 111 self.inputs = [ 112 self.input, 113 self.hx, 114 self.cx, 115 self.b_ih, 116 self.b_hh, 117 ] 118 119 @staticmethod 120 def module(): 121 return "dynamic_lstm" 122 123 124benchmark.register_benchmark_class(DynamicLSTM) 125