• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from collections import namedtuple
2from typing import List, Tuple
3
4import torch
5from torch import Tensor
6
7from .cells import flat_lstm_cell, lstm_cell, premul_lstm_cell, premul_lstm_cell_no_bias
8
9
10# list[list[T]] -> list[T]
11def flatten_list(lst):
12    result = []
13    for inner in lst:
14        result.extend(inner)
15    return result
16
17
18"""
19Define a creator as a function:
20(options) -> (inputs, params, forward, backward_setup, backward)
21inputs: the inputs to the returned 'forward'. One can call
22    forward(*inputs) directly.
23params: List[Tensor] all requires_grad=True parameters.
24forward: function / graph executor / module
25    One can call rnn(rnn_inputs) using the outputs of the creator.
26backward_setup: backward_inputs = backward_setup(*outputs)
27    Then, we pass backward_inputs to backward. If None, then it is assumed to
28    be the identity function.
29backward: Given `output = backward_setup(*forward(*inputs))`, performs
30    backpropagation. If None, then nothing happens.
31
32fastrnns.bench times the forward and backward invocations.
33"""
34
35
36ModelDef = namedtuple(
37    "ModelDef", ["inputs", "params", "forward", "backward_setup", "backward"]
38)
39
40
41def lstm_backward_setup(lstm_outputs, seed=None):
42    hx, _ = lstm_outputs
43    return simple_backward_setup(hx, seed)
44
45
46def simple_backward_setup(output, seed=None):
47    assert isinstance(output, torch.Tensor)
48    if seed:
49        torch.manual_seed(seed)
50    grad_output = torch.randn_like(output)
51    return output, grad_output
52
53
54def simple_backward(output, grad_output, **kwargs):
55    return output.backward(grad_output, **kwargs)
56
57
58def pytorch_lstm_creator(**kwargs):
59    input, hidden, _, module = lstm_inputs(return_module=True, **kwargs)
60    return ModelDef(
61        inputs=[input, hidden],
62        params=flatten_list(module.all_weights),
63        forward=module,
64        backward_setup=lstm_backward_setup,
65        backward=simple_backward,
66    )
67
68
69def lstm_creator(script=True, **kwargs):
70    input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
71    inputs = [input, hidden] + params[0]
72    return ModelDef(
73        inputs=inputs,
74        params=flatten_list(params),
75        forward=lstm_factory(lstm_cell, script),
76        backward_setup=lstm_backward_setup,
77        backward=simple_backward,
78    )
79
80
81def lnlstm_creator(script=True, decompose_layernorm=False, **kwargs):
82    assert script is True
83    from .custom_lstms import script_lnlstm
84
85    input_size = kwargs["inputSize"]
86    hidden_size = kwargs["hiddenSize"]
87    seq_len = kwargs["seqLength"]
88    batch_size = kwargs["miniBatch"]
89    ge = script_lnlstm(
90        input_size, hidden_size, 1, decompose_layernorm=decompose_layernorm
91    ).cuda()
92
93    input = torch.randn(seq_len, batch_size, input_size, device="cuda")
94    states = [
95        (
96            torch.randn(batch_size, hidden_size, device="cuda"),
97            torch.randn(batch_size, hidden_size, device="cuda"),
98        )
99    ]
100
101    return ModelDef(
102        inputs=[input, states],
103        params=ge.parameters(),
104        forward=ge,
105        backward_setup=lstm_backward_setup,
106        backward=simple_backward,
107    )
108
109
110def dropoutlstm_creator(script=True, **kwargs):
111    assert script is True
112    from .custom_lstms import LSTMState, script_lstm
113
114    input_size = kwargs["inputSize"]
115    hidden_size = kwargs["hiddenSize"]
116    seq_len = kwargs["seqLength"]
117    batch_size = kwargs["miniBatch"]
118    num_layers = kwargs["numLayers"]
119    ge = script_lstm(input_size, hidden_size, num_layers, dropout=True).cuda()
120
121    input = torch.randn(seq_len, batch_size, input_size, device="cuda")
122    states = [
123        LSTMState(
124            torch.randn(batch_size, hidden_size, device="cuda"),
125            torch.randn(batch_size, hidden_size, device="cuda"),
126        )
127        for _ in range(num_layers)
128    ]
129    return ModelDef(
130        inputs=[input, states],
131        params=ge.parameters(),
132        forward=ge,
133        backward_setup=lstm_backward_setup,
134        backward=simple_backward,
135    )
136
137
138def lstm_premul_creator(script=True, **kwargs):
139    input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
140    inputs = [input, hidden] + params[0]
141    return ModelDef(
142        inputs=inputs,
143        params=flatten_list(params),
144        forward=lstm_factory_premul(premul_lstm_cell, script),
145        backward_setup=lstm_backward_setup,
146        backward=simple_backward,
147    )
148
149
150def lstm_premul_bias_creator(script=True, **kwargs):
151    input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
152    inputs = [input, hidden] + params[0]
153    return ModelDef(
154        inputs=inputs,
155        params=flatten_list(params),
156        forward=lstm_factory_premul_bias(premul_lstm_cell_no_bias, script),
157        backward_setup=lstm_backward_setup,
158        backward=simple_backward,
159    )
160
161
162def lstm_simple_creator(script=True, **kwargs):
163    input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
164    inputs = [input] + [h[0] for h in hidden] + params[0]
165    return ModelDef(
166        inputs=inputs,
167        params=flatten_list(params),
168        forward=lstm_factory_simple(flat_lstm_cell, script),
169        backward_setup=lstm_backward_setup,
170        backward=simple_backward,
171    )
172
173
174def lstm_multilayer_creator(script=True, **kwargs):
175    input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
176    inputs = [input, hidden, flatten_list(params)]
177    return ModelDef(
178        inputs=inputs,
179        params=flatten_list(params),
180        forward=lstm_factory_multilayer(lstm_cell, script),
181        backward_setup=lstm_backward_setup,
182        backward=simple_backward,
183    )
184
185
186def imagenet_cnn_creator(arch, jit=True):
187    def creator(device="cuda", **kwargs):
188        model = arch().to(device)
189        x = torch.randn(32, 3, 224, 224, device=device)
190        if jit:
191            model = torch.jit.trace(model, x)
192        return ModelDef(
193            inputs=(x,),
194            params=list(model.parameters()),
195            forward=model,
196            backward_setup=simple_backward_setup,
197            backward=simple_backward,
198        )
199
200    return creator
201
202
203def varlen_lstm_inputs(
204    minlen=30,
205    maxlen=100,
206    numLayers=1,
207    inputSize=512,
208    hiddenSize=512,
209    miniBatch=64,
210    return_module=False,
211    device="cuda",
212    seed=None,
213    **kwargs,
214):
215    if seed is not None:
216        torch.manual_seed(seed)
217    lengths = torch.randint(
218        low=minlen, high=maxlen, size=[miniBatch], dtype=torch.long, device=device
219    )
220    x = [torch.randn(length, inputSize, device=device) for length in lengths]
221    hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
222    cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
223    lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers).to(device)
224
225    if return_module:
226        return x, lengths, (hx, cx), lstm.all_weights, lstm
227    else:
228        # NB: lstm.all_weights format:
229        # wih, whh, bih, bhh = lstm.all_weights[layer]
230        return x, lengths, (hx, cx), lstm.all_weights, None
231
232
233def varlen_lstm_backward_setup(forward_output, seed=None):
234    if seed:
235        torch.manual_seed(seed)
236    rnn_utils = torch.nn.utils.rnn
237    sequences = forward_output[0]
238    padded = rnn_utils.pad_sequence(sequences)
239    grad = torch.randn_like(padded)
240    return padded, grad
241
242
243def varlen_pytorch_lstm_creator(**kwargs):
244    rnn_utils = torch.nn.utils.rnn
245    sequences, _, hidden, _, module = varlen_lstm_inputs(return_module=True, **kwargs)
246
247    def forward(sequences, hidden):
248        packed = rnn_utils.pack_sequence(sequences, enforce_sorted=False)
249        out, new_hidden = module(packed, hidden)
250        padded, lengths = rnn_utils.pad_packed_sequence(out)
251        # XXX: It's more efficient to store the output in its padded form,
252        # but that might not be conducive to loss computation.
253        # Un-padding the output also makes the backward pass 2x slower...
254        # return [padded[:lengths[i], i, :] for i in range(lengths.size(0))]
255        return padded, new_hidden
256
257    return ModelDef(
258        inputs=[sequences, hidden],
259        params=flatten_list(module.all_weights),
260        forward=forward,
261        backward_setup=lstm_backward_setup,
262        backward=simple_backward,
263    )
264
265
266def varlen_lstm_factory(cell, script):
267    def dynamic_rnn(
268        sequences: List[Tensor],
269        hiddens: Tuple[Tensor, Tensor],
270        wih: Tensor,
271        whh: Tensor,
272        bih: Tensor,
273        bhh: Tensor,
274    ) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]]:
275        hx, cx = hiddens
276        hxs = hx.unbind(1)
277        cxs = cx.unbind(1)
278        # List of: (output, hx, cx)
279        outputs = []
280        hx_outs = []
281        cx_outs = []
282
283        for batch in range(len(sequences)):
284            output = []
285            hy, cy = hxs[batch], cxs[batch]
286            inputs = sequences[batch].unbind(0)
287
288            for seq_idx in range(len(inputs)):
289                hy, cy = cell(
290                    inputs[seq_idx].unsqueeze(0), (hy, cy), wih, whh, bih, bhh
291                )
292                output += [hy]
293            outputs += [torch.stack(output)]
294            hx_outs += [hy.unsqueeze(0)]
295            cx_outs += [cy.unsqueeze(0)]
296
297        return outputs, (hx_outs, cx_outs)
298
299    if script:
300        cell = torch.jit.script(cell)
301        dynamic_rnn = torch.jit.script(dynamic_rnn)
302
303    return dynamic_rnn
304
305
306def varlen_lstm_creator(script=False, **kwargs):
307    sequences, _, hidden, params, _ = varlen_lstm_inputs(return_module=False, **kwargs)
308    inputs = [sequences, hidden] + params[0]
309    return ModelDef(
310        inputs=inputs,
311        params=flatten_list(params),
312        forward=varlen_lstm_factory(lstm_cell, script),
313        backward_setup=varlen_lstm_backward_setup,
314        backward=simple_backward,
315    )
316
317
318# cudnn_layernorm_lstm: since cudnn does not have Layernorm LSTM, we cannot benchmark
319# the lowerbound directly. Instead, we only benchmark the forward pass by mimicing the
320# computation of a cudnn lstm + seq_len * 3 layernorm computation. This should serve
321# as a perf lowerbound for the Layernorm LSTM forward pass(given that Layernorm itself
322# is invariant), the lowerbound of backward pass is hard to get since we lose the
323# intermediate results, we can still optimize the layernorm implementation to make
324# a faster forward lowerbound though.
325def layernorm_pytorch_lstm_creator(**kwargs):
326    input, hidden, _, module = lstm_inputs(return_module=True, **kwargs)
327    batch_size = kwargs["miniBatch"]
328    hidden_size = kwargs["hiddenSize"]
329    ln_i = torch.nn.LayerNorm(4 * hidden_size).cuda()
330    ln_h = torch.nn.LayerNorm(4 * hidden_size).cuda()
331    ln_c = torch.nn.LayerNorm(hidden_size).cuda()
332    ln_input1 = torch.randn(batch_size, 4 * hidden_size, device="cuda")
333
334    def forward(input, hidden):
335        out, new_hidden = module(input, hidden)
336        # plus (seq_len * three laynorm cell computation) to mimic the lower bound of
337        # Layernorm cudnn LSTM in the forward pass
338        seq_len = len(input.unbind(0))
339        hy, cy = new_hidden
340        for i in range(seq_len):
341            ln_i_output = ln_i(ln_input1)
342            ln_h_output = ln_h(ln_input1)
343            cy = ln_c(cy)
344
345        return out, (hy, cy)
346
347    return ModelDef(
348        inputs=[input, hidden],
349        params=flatten_list(module.all_weights),
350        forward=forward,
351        backward_setup=lstm_backward_setup,
352        backward=None,
353    )
354
355
356# input: lstm.all_weights format (wih, whh, bih, bhh = lstm.all_weights[layer])
357# output: packed_weights with format
358# packed_weights[0] is wih with size (layer, 4*hiddenSize, inputSize)
359# packed_weights[1] is whh with size (layer, 4*hiddenSize, hiddenSize)
360# packed_weights[2] is bih with size (layer, 4*hiddenSize)
361# packed_weights[3] is bhh with size (layer, 4*hiddenSize)
362def stack_weights(weights):
363    def unzip_columns(mat):
364        assert isinstance(mat, list)
365        assert isinstance(mat[0], list)
366        layers = len(mat)
367        columns = len(mat[0])
368        return [[mat[layer][col] for layer in range(layers)] for col in range(columns)]
369
370    # XXX: script fns have problems indexing multidim lists, so we try to
371    # avoid them by stacking tensors
372    all_weights = weights
373    packed_weights = [torch.stack(param) for param in unzip_columns(all_weights)]
374    return packed_weights
375
376
377# returns: x, (hx, cx), all_weights, lstm module with all_weights as params
378def lstm_inputs(
379    seqLength=100,
380    numLayers=1,
381    inputSize=512,
382    hiddenSize=512,
383    miniBatch=64,
384    dropout=0.0,
385    return_module=False,
386    device="cuda",
387    seed=None,
388):
389    if seed is not None:
390        torch.manual_seed(seed)
391    x = torch.randn(seqLength, miniBatch, inputSize, device=device)
392    hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
393    cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
394    lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers, dropout=dropout)
395    if "cuda" in device:
396        lstm = lstm.cuda()
397
398    if return_module:
399        return x, (hx, cx), lstm.all_weights, lstm
400    else:
401        # NB: lstm.all_weights format:
402        # wih, whh, bih, bhh = lstm.all_weights[layer]
403        return x, (hx, cx), lstm.all_weights, None
404
405
406def lstm_factory(cell, script):
407    def dynamic_rnn(
408        input: Tensor,
409        hidden: Tuple[Tensor, Tensor],
410        wih: Tensor,
411        whh: Tensor,
412        bih: Tensor,
413        bhh: Tensor,
414    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
415        hx, cx = hidden
416        outputs = []
417        inputs = input.unbind(0)
418        hy, cy = hx[0], cx[0]
419        for seq_idx in range(len(inputs)):
420            hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh)
421            outputs += [hy]
422        return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0))
423
424    if script:
425        cell = torch.jit.script(cell)
426        dynamic_rnn = torch.jit.script(dynamic_rnn)
427
428    return dynamic_rnn
429
430
431# premul: we're going to premultiply the inputs & weights
432def lstm_factory_premul(premul_cell, script):
433    def dynamic_rnn(
434        input: Tensor,
435        hidden: Tuple[Tensor, Tensor],
436        wih: Tensor,
437        whh: Tensor,
438        bih: Tensor,
439        bhh: Tensor,
440    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
441        hx, cx = hidden
442        outputs = []
443        inputs = torch.matmul(input, wih.t()).unbind(0)
444        hy, cy = hx[0], cx[0]
445        for seq_idx in range(len(inputs)):
446            hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bih, bhh)
447            outputs += [hy]
448        return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0))
449
450    if script:
451        premul_cell = torch.jit.script(premul_cell)
452        dynamic_rnn = torch.jit.script(dynamic_rnn)
453
454    return dynamic_rnn
455
456
457# premul: we're going to premultiply the inputs & weights, and add bias
458def lstm_factory_premul_bias(premul_cell, script):
459    def dynamic_rnn(
460        input: Tensor,
461        hidden: Tuple[Tensor, Tensor],
462        wih: Tensor,
463        whh: Tensor,
464        bih: Tensor,
465        bhh: Tensor,
466    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
467        hx, cx = hidden
468        outputs = []
469        inpSize = input.size()
470        # add bias for all timesteps instead of going step-by-step, results in a single reduction kernel in the backward
471        # FIXME matmul(x,y) + bias currently goes through jit AD, and backward formula in AD is not optimized for this
472        # case. Workaround with mm and views.
473        inpSize = input.size()
474        inputs = torch.mm(input.view(-1, inpSize[2]), wih.t()) + bih
475        inputs = inputs.view(inpSize[0], inpSize[1], -1).unbind(0)
476        hy, cy = hx[0], cx[0]
477        for seq_idx in range(len(inputs)):
478            hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bhh)
479            outputs += [hy]
480        return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0))
481
482    if script:
483        premul_cell = torch.jit.script(premul_cell)
484        dynamic_rnn = torch.jit.script(dynamic_rnn)
485
486    return dynamic_rnn
487
488
489# simple: flat inputs (no tuples), no list to accumulate outputs
490#         useful mostly for benchmarking older JIT versions
491def lstm_factory_simple(cell, script):
492    def dynamic_rnn(input, hx, cx, wih, whh, bih, bhh):
493        hy = hx  # for scoping
494        cy = cx  # for scoping
495        inputs = input.unbind(0)
496        for seq_idx in range(len(inputs)):
497            hy, cy = cell(inputs[seq_idx], hy, cy, wih, whh, bih, bhh)
498        return hy, cy
499
500    if script:
501        cell = torch.jit.script(cell)
502        dynamic_rnn = torch.jit.script(dynamic_rnn)
503
504    return dynamic_rnn
505
506
507def lstm_factory_multilayer(cell, script):
508    def dynamic_rnn(
509        input: Tensor, hidden: Tuple[Tensor, Tensor], params: List[Tensor]
510    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
511        params_stride = 4  # NB: this assumes that biases are there
512        hx, cx = hidden
513        hy, cy = hidden  # for scoping...
514        inputs, outputs = input.unbind(0), []
515        for layer in range(hx.size(0)):
516            hy = hx[layer]
517            cy = cx[layer]
518            base_idx = layer * params_stride
519            wih = params[base_idx]
520            whh = params[base_idx + 1]
521            bih = params[base_idx + 2]
522            bhh = params[base_idx + 3]
523            for seq_idx in range(len(inputs)):
524                hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh)
525                outputs += [hy]
526            inputs, outputs = outputs, []
527        return torch.stack(inputs), (hy.unsqueeze(0), cy.unsqueeze(0))
528
529    if script:
530        cell = torch.jit.script(cell)
531        dynamic_rnn = torch.jit.script(dynamic_rnn)
532
533    return dynamic_rnn
534