1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Time Distributed.""" 16 17from mindspore.ops.primitive import constexpr, Primitive 18from mindspore.ops import Reshape, Transpose, Stack, Unstack 19from mindspore.common import Tensor 20from mindspore._checkparam import Validator 21from ..cell import Cell 22 23__all__ = ['TimeDistributed'] 24 25 26@constexpr 27def _check_reshape_pos(reshape_pos, inputs_shape, outputs_shape, prim_name=None): 28 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 29 if reshape_pos >= len(outputs_shape) or inputs_shape[reshape_pos] != outputs_shape[reshape_pos]: 30 raise ValueError(f"{msg_prefix} 'reshape_with_axis' is invalid in the input and output. " 31 f"The 'reshape_pos' should be less than the length of 'outputs_shape', and the " 32 f"'inputs_shape[reshape_pos]' should be equal to 'outputs_shape[reshape_pos]', but got " 33 f"'reshape_pos': {reshape_pos}, 'inputs_shape': {inputs_shape}, 'outputs_shape': " 34 f"{outputs_shape}. You may try pass parameters without 'reshape_with_axis'.") 35 36 37@constexpr 38def _check_expand_dims_axis(time_axis, ndim, prim_name=None): 39 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 40 if time_axis > ndim: 41 raise ValueError(f"{msg_prefix} value of 'time_axis' should be in range of [{-ndim - 1}, {ndim}], " 42 f"but got {time_axis}.") 43 44 45@constexpr 46def _generate_perm(axis_a, axis_b, length): 47 perm = tuple(range(length)) 48 axis_a, axis_b = (axis_a, axis_b) if axis_a < axis_b else (axis_b, axis_a) 49 return perm[:axis_a] + (perm[axis_b],) + perm[axis_a: axis_b] + perm[axis_b + 1:] 50 51 52@constexpr 53def _check_data(flag, prim_name=None): 54 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 55 if not flag: 56 raise TypeError(f"{msg_prefix} inputs and outputs should be a Tensor.") 57 58 59@constexpr 60def _check_inputs_dim(shape, prim_name=None): 61 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 62 if len(shape) < 3: 63 raise ValueError(f"{msg_prefix} inputs shape should be at least 3D, but got {len(shape)}.") 64 65 66class TimeDistributed(Cell): 67 r""" 68 The time distributed layer. 69 70 Time distributed is a wrapper which allows to apply a layer to every temporal slice of an input. 71 And the `x` should be at least 3D. 72 There are two cases in the implementation. 73 When reshape_with_axis provided, the reshape method will be chosen, which is more efficient; 74 otherwise, the method of dividing the inputs along time axis will be used, which is more general. 75 For example, reshape_with_axis could not be provided when deal with Batch Normalization. 76 77 Args: 78 layer(Union[Cell, Primitive]): The Cell or Primitive which will be wrapped. 79 time_axis(int): The axis of time_step. 80 reshape_with_axis(int): The axis which will be reshaped with time_axis. Default: None. 81 82 Inputs: 83 - **x** (Tensor) - Tensor of shape :math:`(N, T, *)`, 84 where :math:`*` means any number of additional dimensions. 85 86 Outputs: 87 Tensor of shape :math:`(N, T, *)` 88 89 Supported Platforms: 90 ``Ascend`` ``GPU`` ``CPU`` 91 92 Raises: 93 TypeError: If layer is not a Cell or Primitive. 94 95 Examples: 96 >>> x = Tensor(np.random.random([32, 10, 3]), mindspore.float32) 97 >>> dense = nn.Dense(3, 6) 98 >>> net = nn.TimeDistributed(dense, time_axis=1, reshape_with_axis=0) 99 >>> output = net(x) 100 >>> print(output.shape) 101 (32, 10, 6) 102 """ 103 104 def __init__(self, layer, time_axis, reshape_with_axis=None): 105 """Initialize TimeDistributed.""" 106 if not isinstance(layer, (Cell, Primitive)): 107 raise TypeError(f"For '{self.cls_name}', the 'layer' should be Cell or Primitive instance, " 108 f"but got type: {type(layer).__name__}.") 109 super(TimeDistributed, self).__init__() 110 Validator.check_is_int(time_axis, "time_axis", self.cls_name) 111 if reshape_with_axis is not None: 112 Validator.check_is_int(reshape_with_axis, "reshape_with_axis", self.cls_name) 113 self.layer = layer 114 self.time_axis = time_axis 115 self.reshape_with_axis = reshape_with_axis 116 self.transpose = Transpose() 117 self.reshape = Reshape() 118 119 def construct(self, inputs): 120 _check_data(isinstance(inputs, Tensor), self.cls_name) 121 _check_inputs_dim(inputs.shape, self.cls_name) 122 time_axis = self.time_axis % len(inputs.shape) 123 if self.reshape_with_axis is not None: 124 reshape_with_axis = self.reshape_with_axis % len(inputs.shape) 125 inputs_shape = inputs.shape 126 time_axis_new = len(inputs_shape) - 2 if reshape_with_axis == len(inputs_shape) - 1 \ 127 else (reshape_with_axis + 1 if time_axis > reshape_with_axis else 128 reshape_with_axis - 1) 129 reshape_pos = time_axis_new if time_axis_new < reshape_with_axis else reshape_with_axis 130 perm = _generate_perm(time_axis_new, time_axis, len(inputs_shape)) 131 inputs = self.transpose(inputs, perm) 132 inputs_shape_new = inputs.shape 133 inputs = self.reshape(inputs, inputs_shape_new[: reshape_pos] + (-1,) + inputs_shape_new[reshape_pos + 2:]) 134 outputs = self.layer(inputs) 135 _check_data(isinstance(outputs, Tensor), self.cls_name) 136 _check_reshape_pos(reshape_pos, inputs.shape, outputs.shape, self.cls_name) 137 outputs_shape_new = outputs.shape[:reshape_pos] + inputs_shape_new[reshape_pos: reshape_pos + 2] 138 if reshape_pos + 1 < len(outputs.shape): 139 outputs_shape_new += outputs.shape[reshape_pos + 1:] 140 return self.reshape(outputs, outputs_shape_new) 141 142 unstack = Unstack(time_axis) 143 inputs = unstack(inputs) 144 y = () 145 for item in inputs: 146 outputs = self.layer(item) 147 _check_data(isinstance(outputs, Tensor), self.cls_name) 148 _check_expand_dims_axis(time_axis, outputs.ndim, self.cls_name) 149 y += (outputs,) 150 y = Stack(time_axis)(y) 151 return y 152