• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15'''RNN operators module, include RNN, GRU'''
16import math
17import numpy as np
18import mindspore.ops as P
19import mindspore.common.dtype as mstype
20from mindspore.ops.primitive import constexpr
21from mindspore.common.initializer import initializer, Uniform
22from mindspore.common.tensor import Tensor
23from mindspore.common.parameter import ParameterTuple, Parameter
24from mindspore.nn.cell import Cell
25from mindspore import nn
26from mindspore import log as logger
27from mindspore._checkparam import Validator as validator
28
29__all__ = ['GRU', 'RNN', 'GRUCell', 'RNNCell']
30
31
32@constexpr
33def _init_state(shape, dtype, is_lstm):
34    hx = Tensor(np.zeros(shape), dtype)
35    cx = Tensor(np.zeros(shape), dtype)
36    if is_lstm:
37        return (hx, cx)
38    return hx
39
40
41@constexpr
42def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
43    validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
44
45
46@constexpr
47def _check_batch_size_equal(batch_size_x, batch_size_hx, cls_name):
48    if batch_size_x != batch_size_hx:
49        raise ValueError(f"For '{cls_name}' batch size of x and hx should be equal, but got {batch_size_x} of x "
50                         f"and {batch_size_hx} of hx.")
51
52
53@constexpr
54def _check_is_tensor(param_name, input_data, cls_name):
55    """Internal function, used to check whether the input data is Tensor."""
56    if input_data is not None and not isinstance(P.typeof(input_data), mstype.tensor_type):
57        raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', "
58                        f"but got '{P.typeof(input_data)}'")
59
60
61def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
62    '''RNN cell function with tanh activation'''
63    if b_ih is None:
64        igates = P.MatMul(False, True)(inputs, w_ih)
65        hgates = P.MatMul(False, True)(hidden, w_hh)
66    else:
67        igates = P.MatMul(False, True)(inputs, w_ih) + b_ih
68        hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh
69    return P.Tanh()(igates + hgates)
70
71
72def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
73    '''RNN cell function with relu activation'''
74    if b_ih is None:
75        igates = P.MatMul(False, True)(inputs, w_ih)
76        hgates = P.MatMul(False, True)(hidden, w_hh)
77    else:
78        igates = P.MatMul(False, True)(inputs, w_ih) + b_ih
79        hgates = P.MatMul(False, True)(hidden, w_hh) + b_hh
80    return P.ReLU()(igates + hgates)
81
82
83def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
84    '''LSTM cell function'''
85    hx, cx = hidden
86    if b_ih is None:
87        gates = P.MatMul(False, True)(inputs, w_ih) + P.MatMul(False, True)(hx, w_hh)
88    else:
89        gates = P.MatMul(False, True)(inputs, w_ih) + P.MatMul(False, True)(hx, w_hh) + b_ih + b_hh
90    ingate, forgetgate, cellgate, outgate = P.Split(1, 4)(gates)
91
92    ingate = P.Sigmoid()(ingate)
93    forgetgate = P.Sigmoid()(forgetgate)
94    cellgate = P.Tanh()(cellgate)
95    outgate = P.Sigmoid()(outgate)
96
97    cy = (forgetgate * cx) + (ingate * cellgate)
98    hy = outgate * P.Tanh()(cy)
99
100    return hy, cy
101
102
103def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
104    '''GRU cell function'''
105    if b_ih is None:
106        gi = P.MatMul(False, True)(inputs, w_ih)
107        gh = P.MatMul(False, True)(hidden, w_hh)
108    else:
109        gi = P.MatMul(False, True)(inputs, w_ih) + b_ih
110        gh = P.MatMul(False, True)(hidden, w_hh) + b_hh
111    i_r, i_i, i_n = P.Split(1, 3)(gi)
112    h_r, h_i, h_n = P.Split(1, 3)(gh)
113
114    resetgate = P.Sigmoid()(i_r + h_r)
115    inputgate = P.Sigmoid()(i_i + h_i)
116    newgate = P.Tanh()(i_n + resetgate * h_n)
117    hy = newgate + inputgate * (hidden - newgate)
118
119    return hy
120
121
122class _DynamicRNN(Cell):
123    '''Dynamic RNN module to compute RNN cell by timesteps'''
124    def __init__(self, mode):
125        super().__init__()
126        if mode == "RNN_RELU":
127            cell = _rnn_relu_cell
128        elif mode == "RNN_TANH":
129            cell = _rnn_tanh_cell
130        elif mode == "LSTM":
131            cell = _lstm_cell
132        elif mode == "GRU":
133            cell = _gru_cell
134        else:
135            raise ValueError(f"For '{self.cls_name}', the 'mode' should be in ['RNN_RELU', 'RNN_TANH', 'LSTM', 'GRU'], "
136                             f"but got {mode}.")
137        self.cell = cell
138        self.is_lstm = mode == "LSTM"
139
140    def recurrent(self, x, h_0, w_ih, w_hh, b_ih, b_hh):
141        '''recurrent steps without sequence length'''
142        time_step = x.shape[0]
143        outputs = []
144        t = 0
145        h = h_0
146        while t < time_step:
147            x_t = x[t:t+1:1]
148            x_t = P.Squeeze(0)(x_t)
149            h = self.cell(x_t, h, w_ih, w_hh, b_ih, b_hh)
150            if self.is_lstm:
151                outputs.append(h[0])
152            else:
153                outputs.append(h)
154            t += 1
155        outputs = P.Stack()(outputs)
156        return outputs, h
157
158    def variable_recurrent(self, x, h, seq_length, w_ih, w_hh, b_ih, b_hh):
159        '''recurrent steps with sequence length'''
160        time_step = x.shape[0]
161        h_t = h
162        if self.is_lstm:
163            hidden_size = h[0].shape[-1]
164            zero_output = P.ZerosLike()(h_t[0])
165        else:
166            hidden_size = h.shape[-1]
167            zero_output = P.ZerosLike()(h_t)
168        seq_length = P.Cast()(seq_length, mstype.float32)
169        seq_length = P.BroadcastTo((hidden_size, -1))(seq_length)
170        seq_length = P.Cast()(seq_length, mstype.int32)
171        seq_length = P.Transpose()(seq_length, (1, 0))
172
173        outputs = []
174        state_t = h_t
175        t = 0
176        while t < time_step:
177            x_t = x[t:t+1:1]
178            x_t = P.Squeeze(0)(x_t)
179            h_t = self.cell(x_t, state_t, w_ih, w_hh, b_ih, b_hh)
180            seq_cond = seq_length > t
181            if self.is_lstm:
182                state_t_0 = P.Select()(seq_cond, h_t[0], state_t[0])
183                state_t_1 = P.Select()(seq_cond, h_t[1], state_t[1])
184                output = P.Select()(seq_cond, h_t[0], zero_output)
185                state_t = (state_t_0, state_t_1)
186            else:
187                state_t = P.Select()(seq_cond, h_t, state_t)
188                output = P.Select()(seq_cond, h_t, zero_output)
189            outputs.append(output)
190            t += 1
191        outputs = P.Stack()(outputs)
192        return outputs, state_t
193
194    def construct(self, x, h, seq_length, w_ih, w_hh, b_ih, b_hh):
195        if seq_length is None:
196            return self.recurrent(x, h, w_ih, w_hh, b_ih, b_hh)
197        return self.variable_recurrent(x, h, seq_length, w_ih, w_hh, b_ih, b_hh)
198
199
200class _RNNBase(Cell):
201    '''Basic class for RNN operators'''
202    def __init__(self, mode, input_size, hidden_size, num_layers=1, has_bias=True,
203                 batch_first=False, dropout=0.0, bidirectional=False):
204        super().__init__()
205        validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
206        validator.check_positive_int(input_size, "input_size", self.cls_name)
207        validator.check_positive_int(num_layers, "num_layers", self.cls_name)
208        validator.check_is_float(dropout, "dropout", self.cls_name)
209        validator.check_value_type("has_bias", has_bias, [bool], self.cls_name)
210        validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
211        validator.check_value_type("bidirectional", bidirectional, [bool], self.cls_name)
212
213        if not 0 <= dropout < 1:
214            raise ValueError(f"For '{self.cls_name}', the 'dropout' should be a number in range [0, 1) "
215                             f"representing the probability of an element being zeroed, but got {dropout}.")
216
217        if dropout > 0 and num_layers == 1:
218            logger.warning("dropout option adds dropout after all but last "
219                           "recurrent layer, so non-zero dropout expects "
220                           "num_layers greater than 1, but got dropout={} and "
221                           "num_layers={}".format(dropout, num_layers))
222        if mode == "LSTM":
223            gate_size = 4 * hidden_size
224        elif mode == "GRU":
225            gate_size = 3 * hidden_size
226        elif mode == "RNN_TANH":
227            gate_size = hidden_size
228        elif mode == "RNN_RELU":
229            gate_size = hidden_size
230        else:
231            raise ValueError(f"For '{self.cls_name}', the 'mode' should be in ['RNN_RELU', 'RNN_TANH', 'LSTM', 'GRU'], "
232                             f"but got {mode}.")
233
234        self.reverse = P.ReverseV2([0])
235        self.reverse_sequence = P.ReverseSequence(0, 1)
236        self.hidden_size = hidden_size
237        self.batch_first = batch_first
238        self.num_layers = num_layers
239        self.dropout = dropout
240        self.dropout_op = nn.Dropout(float(1 - dropout))
241        self.bidirectional = bidirectional
242        self.has_bias = has_bias
243        self.rnn = _DynamicRNN(mode)
244        num_directions = 2 if bidirectional else 1
245        self.is_lstm = mode == "LSTM"
246
247        self.w_ih_list = []
248        self.w_hh_list = []
249        self.b_ih_list = []
250        self.b_hh_list = []
251        stdv = 1 / math.sqrt(self.hidden_size)
252        for layer in range(num_layers):
253            for direction in range(num_directions):
254                layer_input_size = input_size if layer == 0 else hidden_size * num_directions
255                suffix = '_reverse' if direction == 1 else ''
256
257                self.w_ih_list.append(Parameter(
258                    Tensor(np.random.uniform(-stdv, stdv, (gate_size, layer_input_size)).astype(np.float32)),
259                    name='weight_ih_l{}{}'.format(layer, suffix)))
260                self.w_hh_list.append(Parameter(
261                    Tensor(np.random.uniform(-stdv, stdv, (gate_size, hidden_size)).astype(np.float32)),
262                    name='weight_hh_l{}{}'.format(layer, suffix)))
263                if has_bias:
264                    self.b_ih_list.append(Parameter(
265                        Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)),
266                        name='bias_ih_l{}{}'.format(layer, suffix)))
267                    self.b_hh_list.append(Parameter(
268                        Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)),
269                        name='bias_hh_l{}{}'.format(layer, suffix)))
270        self.w_ih_list = ParameterTuple(self.w_ih_list)
271        self.w_hh_list = ParameterTuple(self.w_hh_list)
272        self.b_ih_list = ParameterTuple(self.b_ih_list)
273        self.b_hh_list = ParameterTuple(self.b_hh_list)
274
275    def _stacked_bi_dynamic_rnn(self, x, h, seq_length):
276        """stacked bidirectional dynamic_rnn"""
277        pre_layer = x
278        h_n = ()
279        c_n = ()
280        output = 0
281        for i in range(self.num_layers):
282            offset = i * 2
283            if self.has_bias:
284                w_f_ih, w_f_hh, b_f_ih, b_f_hh = \
285                    self.w_ih_list[offset], self.w_hh_list[offset], \
286                    self.b_ih_list[offset], self.b_hh_list[offset]
287                w_b_ih, w_b_hh, b_b_ih, b_b_hh = \
288                    self.w_ih_list[offset + 1], self.w_hh_list[offset + 1], \
289                    self.b_ih_list[offset + 1], self.b_hh_list[offset + 1]
290            else:
291                w_f_ih, w_f_hh = self.w_ih_list[offset], self.w_hh_list[offset]
292                w_b_ih, w_b_hh = self.w_ih_list[offset + 1], self.w_hh_list[offset + 1]
293                b_f_ih, b_f_hh, b_b_ih, b_b_hh = None, None, None, None
294            if self.is_lstm:
295                h_f_i = (h[0][offset], h[1][offset])
296                h_b_i = (h[0][offset + 1], h[1][offset + 1])
297            else:
298                h_f_i = h[offset]
299                h_b_i = h[offset + 1]
300            if seq_length is None:
301                x_b = self.reverse(pre_layer)
302            else:
303                x_b = self.reverse_sequence(pre_layer, seq_length)
304            output_f, h_t_f = self.rnn(pre_layer, h_f_i, seq_length, w_f_ih, w_f_hh, b_f_ih, b_f_hh)
305            output_b, h_t_b = self.rnn(x_b, h_b_i, seq_length, w_b_ih, w_b_hh, b_b_ih, b_b_hh)
306            if seq_length is None:
307                output_b = self.reverse(output_b)
308            else:
309                output_b = self.reverse_sequence(output_b, seq_length)
310            output = P.Concat(2)((output_f, output_b))
311            pre_layer = self.dropout_op(output) if (self.dropout != 0 and i < self.num_layers - 1) else output
312            if self.is_lstm:
313                h_n += (h_t_f[0], h_t_b[0],)
314                c_n += (h_t_f[1], h_t_b[1],)
315            else:
316                h_n += (h_t_f, h_t_b,)
317        if self.is_lstm:
318            h_n = P.Concat(0)(h_n)
319            c_n = P.Concat(0)(c_n)
320            h_n = h_n.view(h[0].shape)
321            c_n = c_n.view(h[1].shape)
322            return output, (h_n.view(h[0].shape), c_n.view(h[1].shape))
323        h_n = P.Concat(0)(h_n)
324        return output, h_n.view(h.shape)
325
326    def _stacked_dynamic_rnn(self, x, h, seq_length):
327        """stacked mutil_layer dynamic_rnn"""
328        pre_layer = x
329        h_n = ()
330        c_n = ()
331        output = 0
332        for i in range(self.num_layers):
333            if self.has_bias:
334                w_ih, w_hh, b_ih, b_hh = self.w_ih_list[i], self.w_hh_list[i], self.b_ih_list[i], self.b_hh_list[i]
335            else:
336                w_ih, w_hh = self.w_ih_list[i], self.w_hh_list[i]
337                b_ih, b_hh = None, None
338            if self.is_lstm:
339                h_i = (h[0][i], h[1][i])
340            else:
341                h_i = h[i]
342            output, h_t = self.rnn(pre_layer, h_i, seq_length, w_ih, w_hh, b_ih, b_hh)
343            pre_layer = self.dropout_op(output) if (self.dropout != 0 and i < self.num_layers - 1) else output
344            if self.is_lstm:
345                h_n += (h_t[0],)
346                c_n += (h_t[1],)
347            else:
348                h_n += (h_t,)
349        if self.is_lstm:
350            h_n = P.Concat(0)(h_n)
351            c_n = P.Concat(0)(c_n)
352            h_n = h_n.view(h[0].shape)
353            c_n = c_n.view(h[1].shape)
354            return output, (h_n.view(h[0].shape), c_n.view(h[1].shape))
355        h_n = P.Concat(0)(h_n)
356        return output, h_n.view(h.shape)
357
358    def construct(self, x, hx=None, seq_length=None):
359        '''Defines the RNN like operators performed'''
360        x_dtype = P.dtype(x)
361        hx_dtype = P.dtype(hx)
362        _check_input_dtype(x_dtype, "x", [mstype.float32], self.cls_name)
363        _check_input_dtype(hx_dtype, "hx", [mstype.float32], self.cls_name)
364        if seq_length is not None:
365            seq_length_dtype = P.dtype(seq_length)
366            _check_input_dtype(seq_length_dtype, "seq_length", [mstype.int32, mstype.int64], self.cls_name)
367
368        max_batch_size = x.shape[0] if self.batch_first else x.shape[1]
369        num_directions = 2 if self.bidirectional else 1
370        if hx is None:
371            hx = _init_state((self.num_layers * num_directions, max_batch_size, self.hidden_size),
372                             x.dtype, self.is_lstm)
373        if self.batch_first:
374            x = P.Transpose()(x, (1, 0, 2))
375        if self.bidirectional:
376            x, h = self._stacked_bi_dynamic_rnn(x, hx, seq_length)
377        else:
378            x, h = self._stacked_dynamic_rnn(x, hx, seq_length)
379        if self.batch_first:
380            x = P.Transpose()(x, (1, 0, 2))
381        return x, h
382
383
384class RNN(_RNNBase):
385    r"""
386    Stacked Elman RNN layers.
387
388    Apply RNN layer with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to the input.
389
390    For each element in the input sequence, each layer computes the following function:
391
392    .. math::
393        h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})
394
395    Here :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
396    the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
397    previous layer at time `t-1` or the initial hidden state at time `0`.
398    If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`.
399
400    Args:
401        input_size (int): Number of features of input.
402        hidden_size (int):  Number of features of hidden layer.
403        num_layers (int): Number of layers of stacked RNN. Default: 1.
404        nonlinearity (str): The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
405        has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
406        batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False.
407        dropout (float): If not 0.0, append `Dropout` layer on the outputs of each
408            RNN layer except the last layer. Default 0.0. The range of dropout is [0.0, 1.0).
409        bidirectional (bool): Specifies whether it is a bidirectional RNN,
410            num_directions=2 if bidirectional=True otherwise 1. Default: False.
411
412    Inputs:
413        - **x** (Tensor) - Tensor of data type mindspore.float32 and
414          shape (seq_len, batch_size, `input_size`) or (batch_size, seq_len, `input_size`).
415        - **hx** (Tensor) - Tensor of data type mindspore.float32 and
416          shape (num_directions * `num_layers`, batch_size, `hidden_size`). Data type of `hx` must be the same as `x`.
417        - **seq_length** (Tensor) - The length of each sequence in a input batch.
418          Tensor of shape :math:`(\text{batch_size})`. Default: None.
419          This input indicates the real sequence length before padding to avoid padded elements
420          have been used to compute hidden state and affect the final output. It is recommend to
421          use this input when **x** has padding elements.
422
423    Outputs:
424        Tuple, a tuple contains (`output`, `h_n`).
425
426        - **output** (Tensor) - Tensor of shape (seq_len, batch_size, num_directions * `hidden_size`) or
427          (batch_size, seq_len, num_directions * `hidden_size`).
428        - **hx_n** (Tensor) - Tensor of shape (num_directions * `num_layers`, batch_size, `hidden_size`).
429
430    Raises:
431        TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int.
432        TypeError: If `has_bias`, `batch_first` or `bidirectional` is not a bool.
433        TypeError: If `dropout` is neither a float nor an int.
434        ValueError: If `dropout` is not in range [0.0, 1.0).
435        ValueError: If `nonlinearity` is not in ['tanh', 'relu'].
436
437    Supported Platforms:
438        ``Ascend`` ``GPU``
439
440    Examples:
441        >>> net = nn.RNN(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
442        >>> x = Tensor(np.ones([3, 5, 10]).astype(np.float32))
443        >>> h0 = Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
444        >>> output, hn = net(x, h0)
445        >>> print(output.shape)
446        (3, 5, 16)
447    """
448    def __init__(self, *args, **kwargs):
449        if 'nonlinearity' in kwargs:
450            if kwargs['nonlinearity'] == 'tanh':
451                mode = 'RNN_TANH'
452            elif kwargs['nonlinearity'] == 'relu':
453                mode = 'RNN_RELU'
454            else:
455                raise ValueError(f"For '{self.cls_name}', the 'nonlinearity' should be in ['tanh', 'relu'], "
456                                 f"but got {kwargs['nonlinearity']}.")
457            del kwargs['nonlinearity']
458        else:
459            mode = 'RNN_TANH'
460
461        super(RNN, self).__init__(mode, *args, **kwargs)
462
463
464class GRU(_RNNBase):
465    r"""
466    Stacked GRU (Gated Recurrent Unit) layers.
467
468    Apply GRU layer to the input.
469
470    There are two gates in a GRU model; one is update gate and the other is reset gate.
471    Denote two consecutive time nodes as :math:`t-1` and :math:`t`.
472    Given an input :math:`x_t` at time :math:`t`, an hidden state :math:`h_{t-1}`, the update and reset gate at
473    time :math:`t` is computed using an gating mechanism. Update gate :math:`z_t` is designed to protect the cell
474    from perturbation by irrelevant inputs and past hidden state. Reset gate :math:`r_t` determines how much
475    information should be reset from old hidden state. New memory state :math:`{n}_t` is
476    calculated with the current input, on which the reset gate will be applied. Finally, current hidden state
477    :math:`h_{t}` is computed with the calculated update grate and new memory state. The complete
478    formulation is as follows.
479
480    .. math::
481        \begin{array}{ll}
482            r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
483            z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
484            n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
485            h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}
486        \end{array}
487
488    Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b`
489    are learnable weights between the output and the input in the formula. For instance,
490    :math:`W_{ir}, b_{ir}` are the weight and bias used to transform from input :math:`x` to :math:`r`.
491    Details can be found in paper
492    `Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation
493    <https://aclanthology.org/D14-1179.pdf>`_.
494
495    Args:
496        input_size (int): Number of features of input.
497        hidden_size (int):  Number of features of hidden layer.
498        num_layers (int): Number of layers of stacked GRU. Default: 1.
499        has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
500        batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False.
501        dropout (float): If not 0.0, append `Dropout` layer on the outputs of each
502            GRU layer except the last layer. Default 0.0. The range of dropout is [0.0, 1.0).
503        bidirectional (bool): Specifies whether it is a bidirectional GRU,
504            num_directions=2 if bidirectional=True otherwise 1. Default: False.
505
506    Inputs:
507        - **x** (Tensor) - Tensor of data type mindspore.float32 and
508          shape (seq_len, batch_size, `input_size`) or (batch_size, seq_len, `input_size`).
509        - **hx** (Tensor) - Tensor of data type mindspore.float32 and
510          shape (num_directions * `num_layers`, batch_size, `hidden_size`). Data type of `hx` must be the same as `x`.
511        - **seq_length** (Tensor) - The length of each sequence in a input batch.
512          Tensor of shape :math:`(\text{batch_size})`. Default: None.
513          This input indicates the real sequence length before padding to avoid padded elements
514          have been used to compute hidden state and affect the final output. It is recommend to
515          use this input when **x** has padding elements.
516
517    Outputs:
518        Tuple, a tuple contains (`output`, `h_n`).
519
520        - **output** (Tensor) - Tensor of shape (seq_len, batch_size, num_directions * `hidden_size`) or
521          (batch_size, seq_len, num_directions * `hidden_size`).
522        - **hx_n** (Tensor) - Tensor of shape (num_directions * `num_layers`, batch_size, `hidden_size`).
523
524    Raises:
525        TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int.
526        TypeError: If `has_bias`, `batch_first` or `bidirectional` is not a bool.
527        TypeError: If `dropout` is neither a float nor an int.
528        ValueError: If `dropout` is not in range [0.0, 1.0).
529
530    Supported Platforms:
531        ``Ascend`` ``GPU``
532
533    Examples:
534        >>> net = nn.GRU(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
535        >>> x = Tensor(np.ones([3, 5, 10]).astype(np.float32))
536        >>> h0 = Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
537        >>> output, hn = net(x, h0)
538        >>> print(output.shape)
539        (3, 5, 16)
540    """
541    def __init__(self, *args, **kwargs):
542        mode = 'GRU'
543        super(GRU, self).__init__(mode, *args, **kwargs)
544
545
546class _RNNCellBase(Cell):
547    '''Basic class for RNN Cells'''
548    def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int):
549        super().__init__()
550        validator.check_value_type("has_bias", has_bias, [bool], self.cls_name)
551        validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
552        validator.check_positive_int(input_size, "input_size", self.cls_name)
553        self.input_size = input_size
554        self.hidden_size = hidden_size
555        self.weight_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, input_size).astype(np.float32)))
556        self.weight_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, hidden_size).astype(np.float32)))
557        if has_bias:
558            self.bias_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32)))
559            self.bias_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32)))
560        else:
561            self.bias_ih = None
562            self.bias_hh = None
563        self.reset_parameters()
564
565    def reset_parameters(self):
566        stdv = 1 / math.sqrt(self.hidden_size)
567        for weight in self.get_parameters():
568            weight.set_data(initializer(Uniform(stdv), weight.shape))
569
570
571class RNNCell(_RNNCellBase):
572    r"""
573    An Elman RNN cell with tanh or ReLU non-linearity.
574
575    .. math::
576        h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})
577
578    Here :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
579    the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
580    previous layer at time `t-1` or the initial hidden state at time `0`.
581    If `nonlinearity` is `relu`, then `relu` is used instead of `tanh`.
582
583    Args:
584        input_size (int): Number of features of input.
585        hidden_size (int):  Number of features of hidden layer.
586        has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
587        nonlinearity (str): The non-linearity to use. Can be either `tanh` or `relu`. Default: `tanh`.
588
589    Inputs:
590        - **x** (Tensor) - Tensor of shape (batch_size, `input_size`).
591        - **hx** (Tensor) - Tensor of data type mindspore.float32 and shape (batch_size, `hidden_size`).
592          Data type of `hx` must be the same as `x`.
593
594    Outputs:
595        - **h'** (Tensor) - Tensor of shape (batch_size, `hidden_size`).
596
597    Raises:
598        TypeError: If `input_size` or `hidden_size` is not an int or not greater than 0.
599        TypeError: If `has_bias` is not a bool.
600        ValueError: If `nonlinearity` is not in ['tanh', 'relu'].
601
602    Supported Platforms:
603        ``Ascend`` ``GPU``
604
605    Examples:
606        >>> net = nn.RNNCell(10, 16)
607        >>> x = Tensor(np.ones([5, 3, 10]).astype(np.float32))
608        >>> hx = Tensor(np.ones([3, 16]).astype(np.float32))
609        >>> output = []
610        >>> for i in range(5):
611        >>>     hx = net(x[i], hx)
612        >>>     output.append(hx)
613        >>> print(output[0].shape)
614        (3, 16)
615    """
616    _non_linearity = ['tanh', 'relu']
617
618    def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"):
619        super().__init__(input_size, hidden_size, has_bias, num_chunks=1)
620        validator.check_value_type("nonlinearity", nonlinearity, [str], self.cls_name)
621        validator.check_string(nonlinearity, self._non_linearity, "nonlinearity", self.cls_name)
622        self.nonlinearity = nonlinearity
623
624    def construct(self, x, hx):
625        _check_is_tensor('x', x, self.cls_name)
626        _check_is_tensor('hx', hx, self.cls_name)
627        x_dtype = P.dtype(x)
628        hx_dtype = P.dtype(hx)
629        _check_input_dtype(x_dtype, "x", [mstype.float32], self.cls_name)
630        _check_input_dtype(hx_dtype, "hx", [mstype.float32], self.cls_name)
631        _check_batch_size_equal(x.shape[0], hx.shape[0], self.cls_name)
632
633        if self.nonlinearity == "tanh":
634            ret = _rnn_tanh_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
635        else:
636            ret = _rnn_relu_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
637        return ret
638
639
640class GRUCell(_RNNCellBase):
641    r"""
642    A GRU(Gated Recurrent Unit) cell.
643
644    .. math::
645
646        \begin{array}{ll}
647        r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
648        z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
649        n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
650        h' = (1 - z) * n + z * h
651        \end{array}
652
653    Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b`
654    are learnable weights between the output and the input in the formula. For instance,
655    :math:`W_{ir}, b_{ir}` are the weight and bias used to transform from input :math:`x` to :math:`r`.
656    Details can be found in paper
657    `Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation
658    <https://aclanthology.org/D14-1179.pdf>`_.
659
660    Args:
661        input_size (int): Number of features of input.
662        hidden_size (int):  Number of features of hidden layer.
663        has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
664
665    Inputs:
666        - **x** (Tensor) - Tensor of shape (batch_size, `input_size`).
667        - **hx** (Tensor) - Tensor of data type mindspore.float32 and shape (batch_size, `hidden_size`).
668          Data type of `hx` must be the same as `x`.
669
670    Outputs:
671        - **h'** (Tensor) - Tensor of shape (batch_size, `hidden_size`).
672
673    Raises:
674        TypeError: If `input_size`, `hidden_size` is not an int.
675        TypeError: If `has_bias` is not a bool.
676
677    Supported Platforms:
678        ``Ascend`` ``GPU``
679
680    Examples:
681        >>> net = nn.GRUCell(10, 16)
682        >>> x = Tensor(np.ones([5, 3, 10]).astype(np.float32))
683        >>> hx = Tensor(np.ones([3, 16]).astype(np.float32))
684        >>> output = []
685        >>> for i in range(5):
686        >>>     hx = net(x[i], hx)
687        >>>     output.append(hx)
688        >>> print(output[0].shape)
689        (3, 16)
690    """
691    def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True):
692        super().__init__(input_size, hidden_size, has_bias, num_chunks=3)
693
694    def construct(self, x, hx):
695        _check_is_tensor('x', x, self.cls_name)
696        _check_is_tensor('hx', hx, self.cls_name)
697        x_dtype = P.dtype(x)
698        hx_dtype = P.dtype(hx)
699        _check_input_dtype(x_dtype, "x", [mstype.float32], self.cls_name)
700        _check_input_dtype(hx_dtype, "hx", [mstype.float32], self.cls_name)
701        _check_batch_size_equal(x.shape[0], hx.shape[0], self.cls_name)
702
703        return _gru_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
704