• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""cache_ops"""
16
17from __future__ import absolute_import
18from mindspore import _checkparam as validator
19from mindspore.common import dtype as mstype
20from mindspore.ops.primitive import prim_attr_register, PrimitiveWithCheck
21from mindspore.ops import signature as sig
22
23
24class UpdateCache(PrimitiveWithCheck):
25    """
26    Update the value fo input_x, similar to ScatterNdUpdate.
27    The difference is that UpdateCache will not update when indices < 0 or indices >= max_num.
28
29    Inputs:
30        - **input_x** (Parameter) - Parameter which is going to be updated.
31        - **indices** (Tensor) - Update indices of input_x.
32        - **updates** (Tensor) - The update values.
33
34    Outputs:
35        - **out** (Tensor) - Returns a [1] Tensor, which is not useful.
36    """
37    __mindspore_signature__ = (
38        sig.make_sig('input_x', sig.sig_rw.RW_WRITE,
39                     dtype=sig.sig_dtype.T),
40        sig.make_sig('indices', dtype=sig.sig_dtype.T1),
41        sig.make_sig('updates', dtype=sig.sig_dtype.T),
42        sig.make_sig('max_num', dtype=sig.sig_dtype.T1)
43    )
44
45    @prim_attr_register
46    def __init__(self):
47        """init UpdateCache"""
48
49        self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'],
50                                outputs=['out'])
51
52    def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):
53        return [1]
54
55    def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
56        validator.check_tensor_dtype_valid(
57            "indices", indices_dtype, mstype.int_type, self.name)
58        return input_x_dtype
59
60
61class SubAndFilter(PrimitiveWithCheck):
62    """
63    Dynamic kernel, sub an offset and
64    return the elements which in range [0, max_num).
65
66    Inputs:
67        - **input_x** (Tensor) - Input tensor.
68        - **max_num** (int) - The max value of element that after sub `offset`.
69        - **offset** (int) - Specifies the offset value of this `input_x`.
70
71    Outputs:
72        tuple(Tensor), tuple of 2 tensors, filter_res and filter_idx.
73        - **filter_res** (Tensor) - The result that `input_x` minus `offset`,
74          and return which in the range [0, max_num).
75        - **filter_idx** (Tensor) - A tensor containing indices of elements in the input
76          coressponding to the output tensor.
77
78    Supported Platforms:
79        `CPU`
80
81    Examples:
82        >>> x = Tensor(np.array([1, 3, 5, 8, 9, 16]), mindspore.int32)
83        >>> max_num = 10
84        >>> offset = 5
85        >>> output = ops.SubAndFilter()(x, max_num, offset)
86        >>> print(output)
87        (Tensor(shape=[3], dtype=Int32, value= [0, 3, 4]),
88         Tensor(shape=[3], dtype=Int32, value= [2, 3, 4]))
89    """
90    @prim_attr_register
91    def __init__(self):
92        """init SubAndFilter"""
93
94        self.init_prim_io_names(inputs=['input_x', 'max_num', 'offset'],
95                                outputs=['sub_res', 'sub_idx'])
96
97    def check_shape(self, input_x_shape, max_num_shape, offset_shape):
98        return ((-1,), (-1,))
99
100    def check_dtype(self, input_x_dtype, max_num_dtype, offset_dtype):
101        validator.check_tensor_dtype_valid(
102            "input_x", input_x_dtype, mstype.int_type, self.name)
103        return input_x_dtype
104
105
106class MapUniform(PrimitiveWithCheck):
107    """
108    Map a tensor by using formula : value = key % `group_num` * `per_group_size` + key // `group_num`.
109
110    Inputs:
111        - **input** (Tensor) - Input Tensor.
112        - **per_group_size** (int) - The size of each group.
113        - **group_num** (int) - The number of group.
114
115    Outputs:
116        Tensor, has the same dtype and shape as the `input`.
117
118    Supported Platforms:
119        `CPU`
120
121    Examples:
122        >>> input_x = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7]))
123        >>> per_group_size = 4
124        >>> group_num = 2
125        >>> map_uniform = ops.MapUniform()
126        >>> output = map_uniform(input_x, per_group_size, group_num)
127        >>> print(output)
128        [0, 4, 1, 5, 2, 6, 3, 7]
129    """
130
131    @prim_attr_register
132    def __init__(self):
133        """init MapUniform"""
134        self.init_prim_io_names(inputs=['input', 'per_group_size', 'group_num'],
135                                outputs=['output'])
136
137    def check_dtype(self, input_dtype, per_group_size_dtype, group_num_dtype):
138        validator.check_tensor_dtype_valid(
139            "input", input_dtype, mstype.int_type, self.name)
140        validator.check_value_type(
141            'per_group_size', per_group_size_dtype, [mstype.Int], self.name)
142        validator.check_value_type(
143            'group_num', group_num_dtype, [mstype.Int], self.name)
144
145
146class CacheSwapTable(PrimitiveWithCheck):
147    """
148    Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry.
149
150    Inputs:
151        - **cache_table** (Parameter) - The cache table which is on device.
152        - **swap_cache_idx** (Tensor) - The index of table which need to swap. -1 is skipped.
153        - **miss_value** (int) - The values which arg going to swap into cache table.
154
155    Outputs:
156        - **old_value** (Tensor) - The values which are swapped out.
157    """
158    __mindspore_signature__ = (
159        sig.make_sig('cache_table', sig.sig_rw.RW_WRITE,
160                     dtype=sig.sig_dtype.T),
161        sig.make_sig('swap_cache_idx', dtype=sig.sig_dtype.T1),
162        sig.make_sig('miss_value', dtype=sig.sig_dtype.T)
163    )
164
165    @prim_attr_register
166    def __init__(self):
167        """init CacheSwapTable"""
168
169        self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'],
170                                outputs=['old_value'])
171
172    def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
173        if len(cache_table_shape) != 2:
174            raise ValueError(
175                "cache table shape must be 2, but got %d" % len(cache_table_shape))
176
177        return miss_value_shape
178
179    def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
180        validator.check_tensor_dtype_valid(
181            "swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
182        return miss_value_dtype
183
184
185class MapCacheIdx(PrimitiveWithCheck):
186    """
187    MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together.
188    When input an indices tensor, it will output the cache indices which search in hashmap.
189    """
190    __mindspore_signature__ = (
191        sig.make_sig('hashmap', sig.sig_rw.RW_WRITE,
192                     dtype=sig.sig_dtype.T),
193        sig.make_sig('indices', dtype=sig.sig_dtype.T),
194        sig.make_sig('step', dtype=sig.sig_dtype.T),
195        sig.make_sig('emb_max_num', dtype=sig.sig_dtype.T),
196        sig.make_sig('cache_max_num', dtype=sig.sig_dtype.T)
197    )
198
199    @prim_attr_register
200    def __init__(self):
201        """init MapCacheIdx"""
202
203        self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'],
204                                outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx'])
205
206    def __check__(self, hashmap, indices, step, emb_max_num, offset):
207        hashmap_shape = hashmap['shape']
208        if len(hashmap_shape) != 2:
209            raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, "
210                             "but got %d." % len(hashmap_shape))
211        out_shape = (indices['shape'], -1, -1, -1)
212
213        hashmap_dtype = hashmap['dtype']
214        indices_dtype = indices['dtype']
215        args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
216        validator.check_tensors_dtypes_same_and_valid(
217            args, mstype.int_type, self.name)
218        out_dtype = (hashmap_dtype, hashmap_dtype,
219                     hashmap_dtype, hashmap_dtype)
220
221        out = {'shape': out_shape,
222               'dtype': out_dtype,
223               'value': None}
224
225        return out
226
227
228class DynamicAssign(PrimitiveWithCheck):
229    """
230    Assigns `Parameter` with a value, the `value` can have a dynamic shape.
231
232    Inputs:
233        - **variable** (Parameter) - The `Parameter`.
234        - **value** (Tensor) - The value to be assigned.
235
236    Outputs:
237        Tensor, has the same type as original `variable`.
238
239    Supported Platforms:
240        `CPU`
241    """
242    __mindspore_signature__ = (
243        sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
244        sig.make_sig('value', dtype=sig.sig_dtype.T)
245    )
246
247    @prim_attr_register
248    def __init__(self):
249        self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output'])
250
251    def check_dtype(self, variable, value):
252        if variable != mstype.type_refkey:
253            validator.check_tensor_dtype_valid(
254                "variable", variable, mstype.number_type, self.name)
255        validator.check_scalar_or_tensor_types_same(
256            {"value": value}, mstype.number_type, self.name)
257
258
259class PadAndShift(PrimitiveWithCheck):
260    """
261    Initialize a tensor with -1, and copy a slice from `input_x` to the padded Tensor.
262
263    Note:
264        If use python, PadAndShift is:
265            output = [-1] * cum_sum_arr[-1]
266            start = cum_sum_arr[shift_idx]
267            end = cum_sum_arr[shift_idx + 1]
268            output[start:end] = input_x[:(end-start)]
269
270    Inputs:
271        - **input_x** (Tensor) - The input Tensor, which will be copied
272          to `output`.
273        - **cum_sum_arr** (Tensor) - The last value of cum_sum_arr is
274          the pad length of output tensor, `cum_sum_arr[shift_idx]` is
275          the start to shift, and `cum_sum_arr[shift_idx+1]` is the end.
276        - **shift_idx** (int) - The idx of `cum_sum_arr` .
277
278    Outputs:
279        - **output** (Tensor) - Tensor, has the same type as `input`.
280
281    Raises:
282        TypeError: `input_x` or `cum_sum_arr` is not Tensor.
283        TypeError: `shift_idx` is not int.
284        ValueError: Value of `shift_idx` is larger than or equal to the length of `cum_sum_arr` .
285
286    Supported Platforms:
287        `CPU`
288
289    Examples:
290        >>> input_x = Tensor(np.array([9, 13, -1, -1, -1, -1, -1, -1]), mstype.int32)
291        >>> cum_sum_arr = Tensor(np.array([0, 3, 5]), mstype.int32)
292        >>> shift_idx = 1
293        >>> pad_and_shift = ops.PadAndShift()
294        >>> output = pad_and_shift(input_x, cum_sum_arr, shift_idx)
295        >>> print(output)
296        [-1, -1, -1, 9, 13]
297    """
298    @prim_attr_register
299    def __init__(self):
300        self.init_prim_io_names(
301            inputs=['input_x', 'cum_sum_arr', 'shift_idx'], outputs=['output'])
302
303    def check_shape(self, input_x_shape, cum_sum_arr_shape, shift_idx_shape):
304        return input_x_shape
305
306    def check_dtype(self, input_x_dtype, cum_sum_arr_dtype, shift_idx_dtype):
307        return input_x_dtype
308