1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""cache_ops""" 16 17from __future__ import absolute_import 18from mindspore import _checkparam as validator 19from mindspore.common import dtype as mstype 20from mindspore.ops.primitive import prim_attr_register, PrimitiveWithCheck 21from mindspore.ops import signature as sig 22 23 24class UpdateCache(PrimitiveWithCheck): 25 """ 26 Update the value fo input_x, similar to ScatterNdUpdate. 27 The difference is that UpdateCache will not update when indices < 0 or indices >= max_num. 28 29 Inputs: 30 - **input_x** (Parameter) - Parameter which is going to be updated. 31 - **indices** (Tensor) - Update indices of input_x. 32 - **updates** (Tensor) - The update values. 33 34 Outputs: 35 - **out** (Tensor) - Returns a [1] Tensor, which is not useful. 36 """ 37 __mindspore_signature__ = ( 38 sig.make_sig('input_x', sig.sig_rw.RW_WRITE, 39 dtype=sig.sig_dtype.T), 40 sig.make_sig('indices', dtype=sig.sig_dtype.T1), 41 sig.make_sig('updates', dtype=sig.sig_dtype.T), 42 sig.make_sig('max_num', dtype=sig.sig_dtype.T1) 43 ) 44 45 @prim_attr_register 46 def __init__(self): 47 """init UpdateCache""" 48 49 self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'], 50 outputs=['out']) 51 52 def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape): 53 return [1] 54 55 def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): 56 validator.check_tensor_dtype_valid( 57 "indices", indices_dtype, mstype.int_type, self.name) 58 return input_x_dtype 59 60 61class SubAndFilter(PrimitiveWithCheck): 62 """ 63 Dynamic kernel, sub an offset and 64 return the elements which in range [0, max_num). 65 66 Inputs: 67 - **input_x** (Tensor) - Input tensor. 68 - **max_num** (int) - The max value of element that after sub `offset`. 69 - **offset** (int) - Specifies the offset value of this `input_x`. 70 71 Outputs: 72 tuple(Tensor), tuple of 2 tensors, filter_res and filter_idx. 73 - **filter_res** (Tensor) - The result that `input_x` minus `offset`, 74 and return which in the range [0, max_num). 75 - **filter_idx** (Tensor) - A tensor containing indices of elements in the input 76 coressponding to the output tensor. 77 78 Supported Platforms: 79 `CPU` 80 81 Examples: 82 >>> x = Tensor(np.array([1, 3, 5, 8, 9, 16]), mindspore.int32) 83 >>> max_num = 10 84 >>> offset = 5 85 >>> output = ops.SubAndFilter()(x, max_num, offset) 86 >>> print(output) 87 (Tensor(shape=[3], dtype=Int32, value= [0, 3, 4]), 88 Tensor(shape=[3], dtype=Int32, value= [2, 3, 4])) 89 """ 90 @prim_attr_register 91 def __init__(self): 92 """init SubAndFilter""" 93 94 self.init_prim_io_names(inputs=['input_x', 'max_num', 'offset'], 95 outputs=['sub_res', 'sub_idx']) 96 97 def check_shape(self, input_x_shape, max_num_shape, offset_shape): 98 return ((-1,), (-1,)) 99 100 def check_dtype(self, input_x_dtype, max_num_dtype, offset_dtype): 101 validator.check_tensor_dtype_valid( 102 "input_x", input_x_dtype, mstype.int_type, self.name) 103 return input_x_dtype 104 105 106class MapUniform(PrimitiveWithCheck): 107 """ 108 Map a tensor by using formula : value = key % `group_num` * `per_group_size` + key // `group_num`. 109 110 Inputs: 111 - **input** (Tensor) - Input Tensor. 112 - **per_group_size** (int) - The size of each group. 113 - **group_num** (int) - The number of group. 114 115 Outputs: 116 Tensor, has the same dtype and shape as the `input`. 117 118 Supported Platforms: 119 `CPU` 120 121 Examples: 122 >>> input_x = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7])) 123 >>> per_group_size = 4 124 >>> group_num = 2 125 >>> map_uniform = ops.MapUniform() 126 >>> output = map_uniform(input_x, per_group_size, group_num) 127 >>> print(output) 128 [0, 4, 1, 5, 2, 6, 3, 7] 129 """ 130 131 @prim_attr_register 132 def __init__(self): 133 """init MapUniform""" 134 self.init_prim_io_names(inputs=['input', 'per_group_size', 'group_num'], 135 outputs=['output']) 136 137 def check_dtype(self, input_dtype, per_group_size_dtype, group_num_dtype): 138 validator.check_tensor_dtype_valid( 139 "input", input_dtype, mstype.int_type, self.name) 140 validator.check_value_type( 141 'per_group_size', per_group_size_dtype, [mstype.Int], self.name) 142 validator.check_value_type( 143 'group_num', group_num_dtype, [mstype.Int], self.name) 144 145 146class CacheSwapTable(PrimitiveWithCheck): 147 """ 148 Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry. 149 150 Inputs: 151 - **cache_table** (Parameter) - The cache table which is on device. 152 - **swap_cache_idx** (Tensor) - The index of table which need to swap. -1 is skipped. 153 - **miss_value** (int) - The values which arg going to swap into cache table. 154 155 Outputs: 156 - **old_value** (Tensor) - The values which are swapped out. 157 """ 158 __mindspore_signature__ = ( 159 sig.make_sig('cache_table', sig.sig_rw.RW_WRITE, 160 dtype=sig.sig_dtype.T), 161 sig.make_sig('swap_cache_idx', dtype=sig.sig_dtype.T1), 162 sig.make_sig('miss_value', dtype=sig.sig_dtype.T) 163 ) 164 165 @prim_attr_register 166 def __init__(self): 167 """init CacheSwapTable""" 168 169 self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'], 170 outputs=['old_value']) 171 172 def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape): 173 if len(cache_table_shape) != 2: 174 raise ValueError( 175 "cache table shape must be 2, but got %d" % len(cache_table_shape)) 176 177 return miss_value_shape 178 179 def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): 180 validator.check_tensor_dtype_valid( 181 "swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name) 182 return miss_value_dtype 183 184 185class MapCacheIdx(PrimitiveWithCheck): 186 """ 187 MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together. 188 When input an indices tensor, it will output the cache indices which search in hashmap. 189 """ 190 __mindspore_signature__ = ( 191 sig.make_sig('hashmap', sig.sig_rw.RW_WRITE, 192 dtype=sig.sig_dtype.T), 193 sig.make_sig('indices', dtype=sig.sig_dtype.T), 194 sig.make_sig('step', dtype=sig.sig_dtype.T), 195 sig.make_sig('emb_max_num', dtype=sig.sig_dtype.T), 196 sig.make_sig('cache_max_num', dtype=sig.sig_dtype.T) 197 ) 198 199 @prim_attr_register 200 def __init__(self): 201 """init MapCacheIdx""" 202 203 self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'], 204 outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx']) 205 206 def __check__(self, hashmap, indices, step, emb_max_num, offset): 207 hashmap_shape = hashmap['shape'] 208 if len(hashmap_shape) != 2: 209 raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, " 210 "but got %d." % len(hashmap_shape)) 211 out_shape = (indices['shape'], -1, -1, -1) 212 213 hashmap_dtype = hashmap['dtype'] 214 indices_dtype = indices['dtype'] 215 args = {"hashmap": hashmap_dtype, "indices": indices_dtype} 216 validator.check_tensors_dtypes_same_and_valid( 217 args, mstype.int_type, self.name) 218 out_dtype = (hashmap_dtype, hashmap_dtype, 219 hashmap_dtype, hashmap_dtype) 220 221 out = {'shape': out_shape, 222 'dtype': out_dtype, 223 'value': None} 224 225 return out 226 227 228class DynamicAssign(PrimitiveWithCheck): 229 """ 230 Assigns `Parameter` with a value, the `value` can have a dynamic shape. 231 232 Inputs: 233 - **variable** (Parameter) - The `Parameter`. 234 - **value** (Tensor) - The value to be assigned. 235 236 Outputs: 237 Tensor, has the same type as original `variable`. 238 239 Supported Platforms: 240 `CPU` 241 """ 242 __mindspore_signature__ = ( 243 sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), 244 sig.make_sig('value', dtype=sig.sig_dtype.T) 245 ) 246 247 @prim_attr_register 248 def __init__(self): 249 self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output']) 250 251 def check_dtype(self, variable, value): 252 if variable != mstype.type_refkey: 253 validator.check_tensor_dtype_valid( 254 "variable", variable, mstype.number_type, self.name) 255 validator.check_scalar_or_tensor_types_same( 256 {"value": value}, mstype.number_type, self.name) 257 258 259class PadAndShift(PrimitiveWithCheck): 260 """ 261 Initialize a tensor with -1, and copy a slice from `input_x` to the padded Tensor. 262 263 Note: 264 If use python, PadAndShift is: 265 output = [-1] * cum_sum_arr[-1] 266 start = cum_sum_arr[shift_idx] 267 end = cum_sum_arr[shift_idx + 1] 268 output[start:end] = input_x[:(end-start)] 269 270 Inputs: 271 - **input_x** (Tensor) - The input Tensor, which will be copied 272 to `output`. 273 - **cum_sum_arr** (Tensor) - The last value of cum_sum_arr is 274 the pad length of output tensor, `cum_sum_arr[shift_idx]` is 275 the start to shift, and `cum_sum_arr[shift_idx+1]` is the end. 276 - **shift_idx** (int) - The idx of `cum_sum_arr` . 277 278 Outputs: 279 - **output** (Tensor) - Tensor, has the same type as `input`. 280 281 Raises: 282 TypeError: `input_x` or `cum_sum_arr` is not Tensor. 283 TypeError: `shift_idx` is not int. 284 ValueError: Value of `shift_idx` is larger than or equal to the length of `cum_sum_arr` . 285 286 Supported Platforms: 287 `CPU` 288 289 Examples: 290 >>> input_x = Tensor(np.array([9, 13, -1, -1, -1, -1, -1, -1]), mstype.int32) 291 >>> cum_sum_arr = Tensor(np.array([0, 3, 5]), mstype.int32) 292 >>> shift_idx = 1 293 >>> pad_and_shift = ops.PadAndShift() 294 >>> output = pad_and_shift(input_x, cum_sum_arr, shift_idx) 295 >>> print(output) 296 [-1, -1, -1, 9, 13] 297 """ 298 @prim_attr_register 299 def __init__(self): 300 self.init_prim_io_names( 301 inputs=['input_x', 'cum_sum_arr', 'shift_idx'], outputs=['output']) 302 303 def check_shape(self, input_x_shape, cum_sum_arr_shape, shift_idx_shape): 304 return input_x_shape 305 306 def check_dtype(self, input_x_dtype, cum_sum_arr_dtype, shift_idx_dtype): 307 return input_x_dtype 308