• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Time Distributed."""
16
17from mindspore.ops.primitive import constexpr, Primitive
18from mindspore.ops import Reshape, Transpose, Stack, Unstack
19from mindspore.common import Tensor
20from mindspore._checkparam import Validator
21from ..cell import Cell
22
23__all__ = ['TimeDistributed']
24
25
26@constexpr
27def _check_reshape_pos(reshape_pos, inputs_shape, outputs_shape, prim_name=None):
28    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
29    if reshape_pos >= len(outputs_shape) or inputs_shape[reshape_pos] != outputs_shape[reshape_pos]:
30        raise ValueError(f"{msg_prefix} 'reshape_with_axis' is invalid in the input and output. "
31                         f"The 'reshape_pos' should be less than the length of 'outputs_shape', and the "
32                         f"'inputs_shape[reshape_pos]' should be equal to 'outputs_shape[reshape_pos]', but got "
33                         f"'reshape_pos': {reshape_pos}, 'inputs_shape': {inputs_shape}, 'outputs_shape': "
34                         f"{outputs_shape}. You may try pass parameters without 'reshape_with_axis'.")
35
36
37@constexpr
38def _check_expand_dims_axis(time_axis, ndim, prim_name=None):
39    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
40    if time_axis > ndim:
41        raise ValueError(f"{msg_prefix} value of 'time_axis' should be in range of [{-ndim - 1}, {ndim}], "
42                         f"but got {time_axis}.")
43
44
45@constexpr
46def _generate_perm(axis_a, axis_b, length):
47    perm = tuple(range(length))
48    axis_a, axis_b = (axis_a, axis_b) if axis_a < axis_b else (axis_b, axis_a)
49    return perm[:axis_a] + (perm[axis_b],) + perm[axis_a: axis_b] + perm[axis_b + 1:]
50
51
52@constexpr
53def _check_data(flag, prim_name=None):
54    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
55    if not flag:
56        raise TypeError(f"{msg_prefix} inputs and outputs should be a Tensor.")
57
58
59@constexpr
60def _check_inputs_dim(shape, prim_name=None):
61    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
62    if len(shape) < 3:
63        raise ValueError(f"{msg_prefix} inputs shape should be at least 3D, but got {len(shape)}.")
64
65
66class TimeDistributed(Cell):
67    r"""
68    The time distributed layer.
69
70    Time distributed is a wrapper which allows to apply a layer to every temporal slice of an input.
71    And the `x` should be at least 3D.
72    There are two cases in the implementation.
73    When reshape_with_axis provided, the reshape method will be chosen, which is more efficient;
74    otherwise, the method of dividing the inputs along time axis will be used, which is more general.
75    For example, reshape_with_axis could not be provided when deal with Batch Normalization.
76
77    Args:
78        layer(Union[Cell, Primitive]): The Cell or Primitive which will be wrapped.
79        time_axis(int): The axis of time_step.
80        reshape_with_axis(int): The axis which will be reshaped with time_axis. Default: None.
81
82    Inputs:
83        - **x** (Tensor) - Tensor of shape :math:`(N, T, *)`,
84          where :math:`*` means any number of additional dimensions.
85
86    Outputs:
87        Tensor of shape :math:`(N, T, *)`
88
89    Supported Platforms:
90        ``Ascend`` ``GPU`` ``CPU``
91
92    Raises:
93        TypeError: If layer is not a Cell or Primitive.
94
95    Examples:
96        >>> x = Tensor(np.random.random([32, 10, 3]), mindspore.float32)
97        >>> dense = nn.Dense(3, 6)
98        >>> net = nn.TimeDistributed(dense, time_axis=1, reshape_with_axis=0)
99        >>> output = net(x)
100        >>> print(output.shape)
101        (32, 10, 6)
102    """
103
104    def __init__(self, layer, time_axis, reshape_with_axis=None):
105        """Initialize TimeDistributed."""
106        if not isinstance(layer, (Cell, Primitive)):
107            raise TypeError(f"For '{self.cls_name}', the 'layer' should be Cell or Primitive instance, "
108                            f"but got type: {type(layer).__name__}.")
109        super(TimeDistributed, self).__init__()
110        Validator.check_is_int(time_axis, "time_axis", self.cls_name)
111        if reshape_with_axis is not None:
112            Validator.check_is_int(reshape_with_axis, "reshape_with_axis", self.cls_name)
113        self.layer = layer
114        self.time_axis = time_axis
115        self.reshape_with_axis = reshape_with_axis
116        self.transpose = Transpose()
117        self.reshape = Reshape()
118
119    def construct(self, inputs):
120        _check_data(isinstance(inputs, Tensor), self.cls_name)
121        _check_inputs_dim(inputs.shape, self.cls_name)
122        time_axis = self.time_axis % len(inputs.shape)
123        if self.reshape_with_axis is not None:
124            reshape_with_axis = self.reshape_with_axis % len(inputs.shape)
125            inputs_shape = inputs.shape
126            time_axis_new = len(inputs_shape) - 2 if reshape_with_axis == len(inputs_shape) - 1 \
127                else (reshape_with_axis + 1 if time_axis > reshape_with_axis else
128                      reshape_with_axis - 1)
129            reshape_pos = time_axis_new if time_axis_new < reshape_with_axis else reshape_with_axis
130            perm = _generate_perm(time_axis_new, time_axis, len(inputs_shape))
131            inputs = self.transpose(inputs, perm)
132            inputs_shape_new = inputs.shape
133            inputs = self.reshape(inputs, inputs_shape_new[: reshape_pos] + (-1,) + inputs_shape_new[reshape_pos + 2:])
134            outputs = self.layer(inputs)
135            _check_data(isinstance(outputs, Tensor), self.cls_name)
136            _check_reshape_pos(reshape_pos, inputs.shape, outputs.shape, self.cls_name)
137            outputs_shape_new = outputs.shape[:reshape_pos] + inputs_shape_new[reshape_pos: reshape_pos + 2]
138            if reshape_pos + 1 < len(outputs.shape):
139                outputs_shape_new += outputs.shape[reshape_pos + 1:]
140            return self.reshape(outputs, outputs_shape_new)
141
142        unstack = Unstack(time_axis)
143        inputs = unstack(inputs)
144        y = ()
145        for item in inputs:
146            outputs = self.layer(item)
147            _check_data(isinstance(outputs, Tensor), self.cls_name)
148            _check_expand_dims_axis(time_axis, outputs.ndim, self.cls_name)
149            y += (outputs,)
150        y = Stack(time_axis)(y)
151        return y
152