1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""cache_ops""" 16from ..._checkparam import Validator as validator 17from ...common import dtype as mstype 18from ..primitive import prim_attr_register, PrimitiveWithCheck 19from .. import signature as sig 20 21 22class UpdateCache(PrimitiveWithCheck): 23 """ 24 Update the value fo input_x, similar to ScatterNdUpdate. 25 The difference is that UpdateCache will not update when indices < 0 or indices >= max_num. 26 27 Inputs: 28 - **input_x** (Parameter) - Parameter which is going to be updated. 29 - **indices** (Tensor) - Update indices of input_x. 30 - **updates** (Tensor) - The update values. 31 32 Outputs: 33 - **out** (Tensor) - Returns a [1] Tensor, which is not useful. 34 """ 35 __mindspore_signature__ = ( 36 sig.make_sig('input_x', sig.sig_rw.RW_WRITE, 37 dtype=sig.sig_dtype.T), 38 sig.make_sig('indices', dtype=sig.sig_dtype.T1), 39 sig.make_sig('updates', dtype=sig.sig_dtype.T), 40 sig.make_sig('max_num', dtype=sig.sig_dtype.T1) 41 ) 42 43 @prim_attr_register 44 def __init__(self): 45 """init UpdateCache""" 46 47 self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'], 48 outputs=['out']) 49 50 def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape): 51 return [1] 52 53 def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): 54 validator.check_tensor_dtype_valid( 55 "indices", indices_dtype, mstype.int_type, self.name) 56 return input_x_dtype 57 58 59class SubAndFilter(PrimitiveWithCheck): 60 """ 61 Dynamic kernel, sub an offset and 62 return the elements which in range [0, max_num). 63 64 Inputs: 65 - **input_x** (Tensor) - Input tensor. 66 - **max_num** (Int) - The max value of element that after sub `offset`. 67 - **offset** (int) - Specifies the offset value of this `input_x`. 68 69 Outputs: 70 tuple(Tensor), tuple of 2 tensors, filter_res and filter_idx. 71 - **filter_res** (Tensor) - The result that `input_x` minus `offset`, 72 and return which in the range [0, max_num). 73 - **filter_idx** (Tensor) - A tensor containing indices of elements in the input 74 coressponding to the output tensor. 75 76 Supported Platforms: 77 `CPU` 78 79 Examples: 80 >>> x = Tensor(np.array([1, 3, 5, 8, 9, 16]), mindspore.int32) 81 >>> max_num = 10 82 >>> offset = 5 83 >>> output = ops.SubAndFilter()(x, max_num, offset) 84 >>> print(output) 85 (Tensor(shape=[3], dtype=Int32, value= [0, 3, 4]), 86 Tensor(shape=[3], dtype=Int32, value= [2, 3, 4])) 87 """ 88 @prim_attr_register 89 def __init__(self): 90 """init SubAndFilter""" 91 92 self.init_prim_io_names(inputs=['input_x', 'max_num', 'offset'], 93 outputs=['sub_res', 'sub_idx']) 94 95 def check_shape(self, input_x_shape, max_num_shape, offset_shape): 96 return ((-1,), (-1,)) 97 98 def check_dtype(self, input_x_dtype, max_num_dtype, offset_dtype): 99 validator.check_tensor_dtype_valid( 100 "input_x", input_x_dtype, mstype.int_type, self.name) 101 return input_x_dtype 102 103 104class MapUniform(PrimitiveWithCheck): 105 """ 106 Map a tensor by using fomula : value = key % `group_num` * `per_group_size` + key // `group_num`. 107 108 Inputs: 109 - **input** (Tensor) - Input Tensor. 110 - **per_group_size** (int) - The size of each group. 111 - **group_num** (int) - The number of group. 112 113 Outputs: 114 Tensor, has the same dtype and shape as the `input`. 115 116 Supported Platforms: 117 `CPU` 118 119 Examples: 120 >>> input_x = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7])) 121 >>> per_group_size = 4 122 >>> group_num = 2 123 >>> map_uniform = ops.MapUniform() 124 >>> output = map_uniform(input_x, per_group_size, group_num) 125 >>> print(output) 126 [0, 4, 1, 5, 2, 6, 3, 7] 127 """ 128 129 @prim_attr_register 130 def __init__(self): 131 """init MapUniform""" 132 self.init_prim_io_names(inputs=['input', 'per_group_size', 'group_num'], 133 outputs=['output']) 134 135 def check_dtype(self, input_dtype, per_group_size_dtype, group_num_dtype): 136 validator.check_tensor_dtype_valid( 137 "input", input_dtype, mstype.int_type, self.name) 138 validator.check_value_type( 139 'per_group_size', per_group_size_dtype, [mstype.Int], self.name) 140 validator.check_value_type( 141 'group_num', group_num_dtype, [mstype.Int], self.name) 142 143 144class CacheSwapTable(PrimitiveWithCheck): 145 """ 146 Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry. 147 148 Inputs: 149 - **cache_table** (Parameter) - The cache table which is on device. 150 - **swap_cache_idx** (Tensor) - The index of table which need to swap. -1 is skipped. 151 - **miss_value** (int) - The values which arg going to swap into cache table. 152 153 Outputs: 154 - **old_value** (Tensor) - The values which are swapped out. 155 """ 156 __mindspore_signature__ = ( 157 sig.make_sig('cache_table', sig.sig_rw.RW_WRITE, 158 dtype=sig.sig_dtype.T), 159 sig.make_sig('swap_cache_idx', dtype=sig.sig_dtype.T1), 160 sig.make_sig('miss_value', dtype=sig.sig_dtype.T) 161 ) 162 163 @prim_attr_register 164 def __init__(self): 165 """init CacheSwapTable""" 166 167 self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'], 168 outputs=['old_value']) 169 170 def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape): 171 if len(cache_table_shape) != 2: 172 raise ValueError( 173 "cache table shape must be 2, but got %d" % len(cache_table_shape)) 174 175 return miss_value_shape 176 177 def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): 178 validator.check_tensor_dtype_valid( 179 "swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name) 180 return miss_value_dtype 181 182 183class MapCacheIdx(PrimitiveWithCheck): 184 """ 185 MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together. 186 When input an indices tensor, it will output the cache indices which search in hashmap. 187 """ 188 __mindspore_signature__ = ( 189 sig.make_sig('hashmap', sig.sig_rw.RW_WRITE, 190 dtype=sig.sig_dtype.T), 191 sig.make_sig('indices', dtype=sig.sig_dtype.T), 192 sig.make_sig('step', dtype=sig.sig_dtype.T), 193 sig.make_sig('emb_max_num', dtype=sig.sig_dtype.T), 194 sig.make_sig('cache_max_num', dtype=sig.sig_dtype.T) 195 ) 196 197 @prim_attr_register 198 def __init__(self): 199 """init MapCacheIdx""" 200 201 self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'], 202 outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx']) 203 204 def __check__(self, hashmap, indices, step, emb_max_num, offset): 205 hashmap_shape = hashmap['shape'] 206 if len(hashmap_shape) != 2: 207 raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, " 208 "but got %d." % len(hashmap_shape)) 209 out_shape = (indices['shape'], -1, -1, -1) 210 211 hashmap_dtype = hashmap['dtype'] 212 indices_dtype = indices['dtype'] 213 args = {"hashmap": hashmap_dtype, "indices": indices_dtype} 214 validator.check_tensors_dtypes_same_and_valid( 215 args, mstype.int_type, self.name) 216 out_dtype = (hashmap_dtype, hashmap_dtype, 217 hashmap_dtype, hashmap_dtype) 218 219 out = {'shape': out_shape, 220 'dtype': out_dtype, 221 'value': None} 222 if 'max_shape' in indices: 223 out['max_shape'] = (indices['max_shape'], indices['max_shape'], 224 indices['max_shape'], indices['max_shape']) 225 else: 226 out['max_shape'] = (indices['shape'], indices['shape'], 227 indices['shape'], indices['shape']) 228 if 'min_shape' in indices: 229 out['min_shape'] = (indices['min_shape'], 0, 0, 0) 230 else: 231 out['min_shape'] = (0, 0, 0, 0) 232 return out 233 234 235class DynamicAssign(PrimitiveWithCheck): 236 """ 237 Assigns `Parameter` with a value, the `value` can have a dynamic shape. 238 239 Inputs: 240 - **variable** (Parameter) - The `Parameter`. 241 - **value** (Tensor) - The value to be assigned. 242 243 Outputs: 244 Tensor, has the same type as original `variable`. 245 246 Supported Platforms: 247 `CPU` 248 """ 249 __mindspore_signature__ = ( 250 sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), 251 sig.make_sig('value', dtype=sig.sig_dtype.T) 252 ) 253 254 @prim_attr_register 255 def __init__(self): 256 self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output']) 257 258 def check_dtype(self, variable, value): 259 if variable != mstype.type_refkey: 260 validator.check_tensor_dtype_valid( 261 "variable", variable, mstype.number_type, self.name) 262 validator.check_scalar_or_tensor_types_same( 263 {"value": value}, mstype.number_type, self.name) 264 265 266class PadAndShift(PrimitiveWithCheck): 267 """ 268 Pad a tensor with -1, and shift with a length. 269 270 Inputs: 271 - **input_x** (Tensor) - The input Tensor, which will be copied 272 to `output`. 273 - **cum_sum_arr** (Tensor) - The last value of cum_sum_arr is 274 the pad length of output tensor, cum_sum_arr[shift_idx] is 275 the start to shift, and cum_sum_arr[shift_idx+1] is the end. 276 - **shift_idx** (Int) - The idx of cum_sum_arr. 277 if use python, PadAndShift is: 278 output = [-1] * cum_sum_arr[-1] 279 start = cum_sum_arr[shift_idx] 280 end = cum_sum_arr[shift_idx + 1] 281 output[start:end] = input_x[:(end-start)] 282 Outputs: 283 Tensor, has the same type as original `variable`. 284 285 Supported Platforms: 286 `CPU` 287 288 Examples: 289 >>> input_x = Tensor(np.array([9, 13, -1, -1, -1, -1, -1, -1]), mstype.int32) 290 >>> cum_sum_arr = Tensor(np.array([0, 3, 5]), mstype.int32) 291 >>> shift_idx = 1 292 >>> pad_and_shift = ops.PadAndShift() 293 >>> output = pad_and_shift(input_x, cum_sum_arr, shift_idx) 294 >>> print(output) 295 [-1, -1, -1, 9, 13] 296 """ 297 @prim_attr_register 298 def __init__(self): 299 self.init_prim_io_names( 300 inputs=['input_x', 'cum_sum_arr', 'shift_idx'], outputs=['output']) 301 302 def check_shape(self, input_x_shape, cum_sum_arr_shape, shift_idx_shape): 303 return input_x_shape 304 305 def check_dtype(self, input_x_dtype, cum_sum_arr_dtype, shift_idx_dtype): 306 return input_x_dtype 307