• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""cache_ops"""
16from ..._checkparam import Validator as validator
17from ...common import dtype as mstype
18from ..primitive import prim_attr_register, PrimitiveWithCheck
19from .. import signature as sig
20
21
22class UpdateCache(PrimitiveWithCheck):
23    """
24    Update the value fo input_x, similar to ScatterNdUpdate.
25    The difference is that UpdateCache will not update when indices < 0 or indices >= max_num.
26
27    Inputs:
28        - **input_x** (Parameter) - Parameter which is going to be updated.
29        - **indices** (Tensor) - Update indices of input_x.
30        - **updates** (Tensor) - The update values.
31
32    Outputs:
33        - **out** (Tensor) - Returns a [1] Tensor, which is not useful.
34    """
35    __mindspore_signature__ = (
36        sig.make_sig('input_x', sig.sig_rw.RW_WRITE,
37                     dtype=sig.sig_dtype.T),
38        sig.make_sig('indices', dtype=sig.sig_dtype.T1),
39        sig.make_sig('updates', dtype=sig.sig_dtype.T),
40        sig.make_sig('max_num', dtype=sig.sig_dtype.T1)
41    )
42
43    @prim_attr_register
44    def __init__(self):
45        """init UpdateCache"""
46
47        self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'],
48                                outputs=['out'])
49
50    def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):
51        return [1]
52
53    def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
54        validator.check_tensor_dtype_valid(
55            "indices", indices_dtype, mstype.int_type, self.name)
56        return input_x_dtype
57
58
59class SubAndFilter(PrimitiveWithCheck):
60    """
61    Dynamic kernel, sub an offset and
62    return the elements which in range [0, max_num).
63
64    Inputs:
65        - **input_x** (Tensor) - Input tensor.
66        - **max_num** (Int) - The max value of element that after sub `offset`.
67        - **offset** (int) - Specifies the offset value of this `input_x`.
68
69    Outputs:
70        tuple(Tensor), tuple of 2 tensors, filter_res and filter_idx.
71        - **filter_res** (Tensor) - The result that `input_x` minus `offset`,
72          and return which in the range [0, max_num).
73        - **filter_idx** (Tensor) - A tensor containing indices of elements in the input
74          coressponding to the output tensor.
75
76    Supported Platforms:
77        `CPU`
78
79    Examples:
80        >>> x = Tensor(np.array([1, 3, 5, 8, 9, 16]), mindspore.int32)
81        >>> max_num = 10
82        >>> offset = 5
83        >>> output = ops.SubAndFilter()(x, max_num, offset)
84        >>> print(output)
85        (Tensor(shape=[3], dtype=Int32, value= [0, 3, 4]),
86         Tensor(shape=[3], dtype=Int32, value= [2, 3, 4]))
87    """
88    @prim_attr_register
89    def __init__(self):
90        """init SubAndFilter"""
91
92        self.init_prim_io_names(inputs=['input_x', 'max_num', 'offset'],
93                                outputs=['sub_res', 'sub_idx'])
94
95    def check_shape(self, input_x_shape, max_num_shape, offset_shape):
96        return ((-1,), (-1,))
97
98    def check_dtype(self, input_x_dtype, max_num_dtype, offset_dtype):
99        validator.check_tensor_dtype_valid(
100            "input_x", input_x_dtype, mstype.int_type, self.name)
101        return input_x_dtype
102
103
104class MapUniform(PrimitiveWithCheck):
105    """
106    Map a tensor by using fomula : value = key % `group_num` * `per_group_size` + key // `group_num`.
107
108    Inputs:
109        - **input** (Tensor) - Input Tensor.
110        - **per_group_size** (int) - The size of each group.
111        - **group_num** (int) - The number of group.
112
113    Outputs:
114        Tensor, has the same dtype and shape as the `input`.
115
116    Supported Platforms:
117        `CPU`
118
119    Examples:
120        >>> input_x = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7]))
121        >>> per_group_size = 4
122        >>> group_num = 2
123        >>> map_uniform = ops.MapUniform()
124        >>> output = map_uniform(input_x, per_group_size, group_num)
125        >>> print(output)
126        [0, 4, 1, 5, 2, 6, 3, 7]
127    """
128
129    @prim_attr_register
130    def __init__(self):
131        """init MapUniform"""
132        self.init_prim_io_names(inputs=['input', 'per_group_size', 'group_num'],
133                                outputs=['output'])
134
135    def check_dtype(self, input_dtype, per_group_size_dtype, group_num_dtype):
136        validator.check_tensor_dtype_valid(
137            "input", input_dtype, mstype.int_type, self.name)
138        validator.check_value_type(
139            'per_group_size', per_group_size_dtype, [mstype.Int], self.name)
140        validator.check_value_type(
141            'group_num', group_num_dtype, [mstype.Int], self.name)
142
143
144class CacheSwapTable(PrimitiveWithCheck):
145    """
146    Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry.
147
148    Inputs:
149        - **cache_table** (Parameter) - The cache table which is on device.
150        - **swap_cache_idx** (Tensor) - The index of table which need to swap. -1 is skipped.
151        - **miss_value** (int) - The values which arg going to swap into cache table.
152
153    Outputs:
154        - **old_value** (Tensor) - The values which are swapped out.
155    """
156    __mindspore_signature__ = (
157        sig.make_sig('cache_table', sig.sig_rw.RW_WRITE,
158                     dtype=sig.sig_dtype.T),
159        sig.make_sig('swap_cache_idx', dtype=sig.sig_dtype.T1),
160        sig.make_sig('miss_value', dtype=sig.sig_dtype.T)
161    )
162
163    @prim_attr_register
164    def __init__(self):
165        """init CacheSwapTable"""
166
167        self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'],
168                                outputs=['old_value'])
169
170    def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
171        if len(cache_table_shape) != 2:
172            raise ValueError(
173                "cache table shape must be 2, but got %d" % len(cache_table_shape))
174
175        return miss_value_shape
176
177    def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
178        validator.check_tensor_dtype_valid(
179            "swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
180        return miss_value_dtype
181
182
183class MapCacheIdx(PrimitiveWithCheck):
184    """
185    MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together.
186    When input an indices tensor, it will output the cache indices which search in hashmap.
187    """
188    __mindspore_signature__ = (
189        sig.make_sig('hashmap', sig.sig_rw.RW_WRITE,
190                     dtype=sig.sig_dtype.T),
191        sig.make_sig('indices', dtype=sig.sig_dtype.T),
192        sig.make_sig('step', dtype=sig.sig_dtype.T),
193        sig.make_sig('emb_max_num', dtype=sig.sig_dtype.T),
194        sig.make_sig('cache_max_num', dtype=sig.sig_dtype.T)
195    )
196
197    @prim_attr_register
198    def __init__(self):
199        """init MapCacheIdx"""
200
201        self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'],
202                                outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx'])
203
204    def __check__(self, hashmap, indices, step, emb_max_num, offset):
205        hashmap_shape = hashmap['shape']
206        if len(hashmap_shape) != 2:
207            raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, "
208                             "but got %d." % len(hashmap_shape))
209        out_shape = (indices['shape'], -1, -1, -1)
210
211        hashmap_dtype = hashmap['dtype']
212        indices_dtype = indices['dtype']
213        args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
214        validator.check_tensors_dtypes_same_and_valid(
215            args, mstype.int_type, self.name)
216        out_dtype = (hashmap_dtype, hashmap_dtype,
217                     hashmap_dtype, hashmap_dtype)
218
219        out = {'shape': out_shape,
220               'dtype': out_dtype,
221               'value': None}
222        if 'max_shape' in indices:
223            out['max_shape'] = (indices['max_shape'], indices['max_shape'],
224                                indices['max_shape'], indices['max_shape'])
225        else:
226            out['max_shape'] = (indices['shape'], indices['shape'],
227                                indices['shape'], indices['shape'])
228        if 'min_shape' in indices:
229            out['min_shape'] = (indices['min_shape'], 0, 0, 0)
230        else:
231            out['min_shape'] = (0, 0, 0, 0)
232        return out
233
234
235class DynamicAssign(PrimitiveWithCheck):
236    """
237    Assigns `Parameter` with a value, the `value` can have a dynamic shape.
238
239    Inputs:
240        - **variable** (Parameter) - The `Parameter`.
241        - **value** (Tensor) - The value to be assigned.
242
243    Outputs:
244        Tensor, has the same type as original `variable`.
245
246    Supported Platforms:
247        `CPU`
248    """
249    __mindspore_signature__ = (
250        sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
251        sig.make_sig('value', dtype=sig.sig_dtype.T)
252    )
253
254    @prim_attr_register
255    def __init__(self):
256        self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output'])
257
258    def check_dtype(self, variable, value):
259        if variable != mstype.type_refkey:
260            validator.check_tensor_dtype_valid(
261                "variable", variable, mstype.number_type, self.name)
262        validator.check_scalar_or_tensor_types_same(
263            {"value": value}, mstype.number_type, self.name)
264
265
266class PadAndShift(PrimitiveWithCheck):
267    """
268    Pad a tensor with -1, and shift with a length.
269
270    Inputs:
271        - **input_x** (Tensor) - The input Tensor, which will be copied
272            to `output`.
273        - **cum_sum_arr** (Tensor) - The last value of cum_sum_arr is
274            the pad length of output tensor, cum_sum_arr[shift_idx] is
275            the start to shift, and cum_sum_arr[shift_idx+1] is the end.
276        - **shift_idx** (Int) - The idx of cum_sum_arr.
277        if use python, PadAndShift is:
278            output = [-1] * cum_sum_arr[-1]
279            start = cum_sum_arr[shift_idx]
280            end = cum_sum_arr[shift_idx + 1]
281            output[start:end] = input_x[:(end-start)]
282    Outputs:
283        Tensor, has the same type as original `variable`.
284
285    Supported Platforms:
286        `CPU`
287
288    Examples:
289        >>> input_x = Tensor(np.array([9, 13, -1, -1, -1, -1, -1, -1]), mstype.int32)
290        >>> cum_sum_arr = Tensor(np.array([0, 3, 5]), mstype.int32)
291        >>> shift_idx = 1
292        >>> pad_and_shift = ops.PadAndShift()
293        >>> output = pad_and_shift(input_x, cum_sum_arr, shift_idx)
294        >>> print(output)
295        [-1, -1, -1, 9, 13]
296    """
297    @prim_attr_register
298    def __init__(self):
299        self.init_prim_io_names(
300            inputs=['input_x', 'cum_sum_arr', 'shift_idx'], outputs=['output'])
301
302    def check_shape(self, input_x_shape, cum_sum_arr_shape, shift_idx_shape):
303        return input_x_shape
304
305    def check_dtype(self, input_x_dtype, cum_sum_arr_dtype, shift_idx_dtype):
306        return input_x_dtype
307