1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16@File : test_row_tensor.py 17@Author: 18@Date : 2020-06-08 19@Desc : test mindspore row_tensor's operation 20""" 21import numpy as np 22import pytest 23 24import mindspore as ms 25import mindspore.nn as nn 26from mindspore.ops import composite as C 27from mindspore.ops import functional as F 28from mindspore.ops import operations as P 29from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like 30from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register 31from mindspore.ops._grad_experimental.grad_base import bprop_getters 32from mindspore.ops._utils.utils import generate_shape_index 33from mindspore import Tensor, context 34from mindspore.common.parameter import Parameter, ParameterTuple 35from mindspore.common.sparse_tensor import RowTensorInner 36from mindspore.common import dtype as mstype 37from mindspore import _checkparam as validator 38from mindspore.nn import Optimizer 39from mindspore.nn import TrainOneStepCell, WithLossCell 40from mindspore.nn.optim import Momentum 41from mindspore.train import Model 42from ....dataset_mock import MindData 43 44 45context.set_context(mode=context.GRAPH_MODE) 46 47reduce_sum = P.ReduceSum() 48unsorted_segment_sum = P.UnsortedSegmentSum() 49transpose = P.Transpose() 50shape_op = P.Shape() 51reshape = P.Reshape() 52size_op = P.Size() 53invert_permutation = P.InvertPermutation() 54logical_and = P.LogicalAnd() 55 56 57def get_axis(x): 58 shape = shape_op(x) 59 length = F.tuple_len(shape) 60 perm = F.make_range(0, length) 61 return perm 62 63 64class MSELoss(nn.Cell): 65 def __init__(self): 66 super(MSELoss, self).__init__() 67 self.reduce_sum = P.ReduceSum() 68 self.square = P.Square() 69 self.reduce_mean = P.ReduceMean() 70 71 def construct(self, data, label): 72 diff = data - label 73 return self.reduce_mean(self.square(diff), get_axis(diff)) 74 75 76class MindDataSet(MindData): 77 def __init__(self, dataset_types, dataset_shapes): 78 super(MindDataSet, self).__init__(size=2, batch_size=32, 79 np_types=dataset_types, 80 output_shapes=dataset_shapes, 81 input_indexs=(0, 1)) 82 83 def __next__(self): 84 if self._size < self._iter_num: 85 raise StopIteration 86 self._iter_num += 1 87 lst = [] 88 for shape_, type_ in zip(self._output_shapes, self._np_types): 89 lst.append(Tensor(np.ones(shape_).astype(type_))) 90 return tuple(lst) 91 92 93@constexpr 94def _generate_inverse_index(x_shape, axis): 95 x_rank = len(x_shape) 96 index = tuple(range(x_rank)) 97 if axis < 0: 98 axis += x_rank 99 perm = index[1:1 + axis] + (0,) + index[1 + axis:] 100 return perm 101 102 103# pylint: disable=W0231 104class MySparseGatherV2(PrimitiveWithInfer): 105 """ 106 For test 107 """ 108 109 @prim_attr_register 110 def __init__(self): 111 """init index_select""" 112 self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) 113 self.add_prim_attr('bprop_return_sparse', True) 114 115 def __infer__(self, params, indices, axis): 116 validator.check_subclass("params", params['dtype'], mstype.tensor_type, self.name) 117 validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) 118 validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) 119 axis_v = axis['value'] 120 params_shp = params['shape'] 121 rank = len(params_shp) 122 validator.check_int_range(axis_v, -rank, rank, validator.INC_LEFT, "axis", self.name) 123 if axis_v < 0: 124 axis_v += rank 125 out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] 126 out = {'shape': out_shape, 127 'dtype': params['dtype'], 128 'value': None} 129 return out 130 131 132@bprop_getters.register(MySparseGatherV2) 133def get_bprop_sparse_gather_v2(self): 134 """Generate bprop for MySparseGatherV2""" 135 136 def bprop(x, indices, axis, out, dout): 137 x_shp = shape_op(x) 138 if axis == 0: 139 indices_size = (size_op(indices),) 140 x_tail_shp = x_shp[1:] 141 values_shape = indices_size + x_tail_shp 142 values = reshape(dout, values_shape) 143 indices = reshape(indices, indices_size) 144 return RowTensorInner(indices, values, x_shp), zeros_like(indices), zeros_like(axis) 145 if F.rank(dout) == 0: 146 dout = P.ExpandDims()(dout, -1) 147 if F.rank(indices) == 0: 148 indices = P.ExpandDims()(indices, -1) 149 out_shp = shape_op(dout) 150 ind_shp = shape_op(indices) 151 # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) 152 perm_1 = generate_shape_index(out_shp, ind_shp, axis) 153 values_transpose = transpose(dout, perm_1) 154 params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) 155 # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) 156 perm_2 = _generate_inverse_index(x_shp, axis) 157 params_grad = transpose(params_grad, perm_2) 158 return params_grad, zeros_like(indices), zeros_like(axis) 159 160 return bprop 161 162 163adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") 164 165 166@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", 167 "Tensor", "Tensor", "Tensor", "RowTensor", "Bool") 168def _update_run_op_for_map_row_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param, 169 m, v, gradient, decay_flag): 170 return gradient.values 171 172 173@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", 174 "Tensor", "Tensor", "Tensor", "Tensor", "Bool") 175def _update_run_op_for_map_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param, 176 m, v, gradient, decay_flag): 177 op_mul = P.Mul() 178 op_square = P.Square() 179 op_sqrt = P.Sqrt() 180 op_cast = P.Cast() 181 op_reshape = P.Reshape() 182 op_shape = P.Shape() 183 184 param_fp32 = op_cast(param, mstype.float32) 185 m_fp32 = op_cast(m, mstype.float32) 186 v_fp32 = op_cast(v, mstype.float32) 187 gradient_fp32 = op_cast(gradient, mstype.float32) 188 189 next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) 190 191 next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) 192 - beta2, op_square(gradient_fp32)) 193 194 update = next_m / (op_sqrt(next_v) + eps) 195 if decay_flag: 196 update = update + op_mul(weight_decay_tensor, param_fp32) 197 198 update_with_lr = op_mul(lr, update) 199 next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) 200 201 next_v = F.depend(next_v, F.assign(param, next_param)) 202 next_v = F.depend(next_v, F.assign(m, next_m)) 203 next_v = F.depend(next_v, F.assign(v, next_v)) 204 return next_v 205 206 207def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): 208 """Check the type of inputs.""" 209 validator.check_value_type("beta1", beta1, [float], prim_name) 210 validator.check_value_type("beta2", beta2, [float], prim_name) 211 validator.check_value_type("eps", eps, [float], prim_name) 212 validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) 213 validator.check_float_range(beta1, 0.0, 1.0, validator.INC_NEITHER, "beta1", prim_name) 214 validator.check_float_range(beta2, 0.0, 1.0, validator.INC_NEITHER, "beta2", prim_name) 215 validator.check_positive_float(eps, "eps", prim_name) 216 validator.check_non_negative_float(weight_decay, "weight_decay", prim_name) 217 218 219class AdamWeightDecaySparse(Optimizer): 220 def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, 221 decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): 222 super(AdamWeightDecaySparse, self).__init__(learning_rate, params) 223 if self.is_group: 224 raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") 225 _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) 226 self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) 227 self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) 228 self.eps = Tensor(np.array([eps]).astype(np.float32)) 229 self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32)) 230 231 self.params = self.parameters 232 self.moments1 = self.params.clone(prefix="adam_m", init='zeros') 233 self.moments2 = self.params.clone(prefix="adam_v", init='zeros') 234 self.decay_flag = tuple(decay_filter(x) for x in self.params) 235 self.map = C.Map() 236 237 def construct(self, gradients): 238 lr = self.get_lr() 239 updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr, 240 self.weight_decay_tensor), 241 self.params, self.moments1, self.moments2, gradients, self.decay_flag) 242 return updated_velocity 243 244 245def test_row_tensor_make_row_tensor(): 246 class MakeRowTensor(nn.Cell): 247 def __init__(self): 248 super(MakeRowTensor, self).__init__() 249 self.dense_shape = (3, 2) 250 251 def construct(self, indices, values): 252 ret = (RowTensorInner(indices, values, self.dense_shape),) 253 return ret[0] 254 255 indices = Tensor([1, 2]) 256 values = Tensor([[0, 0], [1, 2]], dtype=ms.float32) 257 MakeRowTensor()(indices, values) 258 259 260class RowTensorGetAttr(nn.Cell): 261 def __init__(self, dense_shape): 262 super(RowTensorGetAttr, self).__init__() 263 self.dense_shape = dense_shape 264 265 def construct(self, indices, values): 266 x = RowTensorInner(indices, values, self.dense_shape) 267 return x.values, x.indices, x.dense_shape 268 269 270def test_row_tensor_attr(): 271 indices = Tensor([0]) 272 values = Tensor([[1, 2]], dtype=ms.float32) 273 RowTensorGetAttr((3, 2))(indices, values) 274 275 276def test_row_tensor_sparse_gatherv2_grad_all(): 277 grad_all = C.GradOperation(get_all=True) 278 279 class GradWrap(nn.Cell): 280 def __init__(self, network): 281 super(GradWrap, self).__init__() 282 self.network = network 283 284 def construct(self, x, y): 285 grad = grad_all(self.network)(x, y) 286 return grad[0].indices, grad[0].values, grad[0].dense_shape 287 288 class SparseGatherV2(nn.Cell): 289 def __init__(self): 290 super(SparseGatherV2, self).__init__() 291 self.sparse_gatherv2 = MySparseGatherV2() 292 self.axis = 0 293 294 def construct(self, params, indices): 295 return self.sparse_gatherv2(params, indices, self.axis) 296 297 params = Tensor(np.ones([3, 1, 2]).astype(np.int32)) 298 indices = Tensor(np.array([0, 1]).astype(np.int32)) 299 GradWrap(SparseGatherV2())(params, indices) 300 301 302def test_row_tensor_sparse_gatherv2_grad_with_pram(): 303 grad_by_list = C.GradOperation(get_by_list=True) 304 305 class GradWrap(nn.Cell): 306 def __init__(self, network): 307 super(GradWrap, self).__init__() 308 self.network = network 309 self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) 310 311 def construct(self, x): 312 weights = self.weights 313 grad = grad_by_list(self.network, weights)(x) 314 x = grad[0] 315 return x.values, x.indices, x.dense_shape 316 317 class SparseGatherV2(nn.Cell): 318 def __init__(self): 319 super(SparseGatherV2, self).__init__() 320 self.sparse_gatherv2 = MySparseGatherV2() 321 self.axis = 0 322 self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), name="params") 323 324 def construct(self, indices): 325 return self.sparse_gatherv2(self.params, indices, self.axis) 326 327 indices = Tensor(np.array([0, 1]).astype(np.int32)) 328 network = GradWrap(SparseGatherV2()) 329 network(indices) 330 331 332def test_row_tensor_env_get(): 333 class Loss(nn.Cell): 334 def __init__(self, empty=None): 335 super(Loss, self).__init__() 336 self.empty = empty 337 338 def construct(self, base, target): 339 return base 340 341 class NetWithSparseGatherV2(nn.Cell): 342 def __init__(self): 343 super(NetWithSparseGatherV2, self).__init__() 344 self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1") 345 self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") 346 self.gatherv2 = MySparseGatherV2() 347 self.axis = 0 348 349 def construct(self, indices): 350 return self.gatherv2(self.w1, indices, self.axis) * self.w2 351 352 inputs = Tensor(np.array([0, 1]).astype(np.int32)) 353 label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) 354 net = NetWithSparseGatherV2() 355 net.set_train() 356 loss = Loss() 357 optimizer = AdamWeightDecaySparse(net.trainable_params()) 358 359 net_with_loss = WithLossCell(net, loss) 360 train_network = TrainOneStepCell(net_with_loss, optimizer) 361 train_network(inputs, label) 362 363 364def test_row_tensor_model_train(): 365 class Net(nn.Cell): 366 def __init__(self, in_features, out_features): 367 super(Net, self).__init__() 368 self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight") 369 self.add = P.Add() 370 self.cast = P.Cast() 371 self.flag = True 372 373 def construct(self, inputs, label): 374 x = self.add(inputs, self.weight) 375 if self.flag: 376 x = self.cast(x, mstype.float32) 377 return x 378 379 dataset_types = (np.float32, np.float32) 380 dataset_shapes = ((16, 16), (16, 16)) 381 dataset = MindDataSet(dataset_types, dataset_shapes) 382 net = Net(16, 16) 383 net.set_train() 384 385 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 386 model = Model(net, optimizer=optimizer) 387 model.train(2, dataset, dataset_sink_mode=False) 388 389 390def test_row_tensor_values_dim_greater_than_dense_shape_dim(): 391 indices = Tensor(np.array([0, 1], dtype=np.int32)) 392 values = Tensor(np.random.randn(2, 4, 5).astype(np.float32)) 393 dense_shape = (3, 4) 394 with pytest.raises(TypeError): 395 RowTensorGetAttr(dense_shape)(indices, values) 396 397 398def test_row_tensor_values_dim_less_than_dense_shape_dim(): 399 indices = Tensor(np.array([0, 1], dtype=np.int32)) 400 values = Tensor(np.random.randn(2, 4).astype(np.float32)) 401 dense_shape = (3, 4, 5) 402 with pytest.raises(TypeError): 403 RowTensorGetAttr(dense_shape)(indices, values) 404 405 406def test_row_tensor_value_and_dense_shape_illegal(): 407 indices = Tensor(np.array([0, 1], dtype=np.int32)) 408 values = Tensor(np.random.randn(2, 4).astype(np.float32)) 409 dense_shape = (3, 5) 410 with pytest.raises(TypeError): 411 RowTensorGetAttr(dense_shape)(indices, values) 412 413 414class RowTensorValuesDouble(nn.Cell): 415 def __init__(self, empty=None): 416 super(RowTensorValuesDouble, self).__init__() 417 self.empty = empty 418 419 def construct(self, x): 420 indices = x.indices 421 values = x.values * 2 422 dense_shape = x.dense_shape 423 return RowTensorInner(indices, values, dense_shape) 424 425 426class RowTensorValuesAdd2(nn.Cell): 427 def __init__(self, empty=None): 428 super(RowTensorValuesAdd2, self).__init__() 429 self.empty = empty 430 431 def construct(self, x): 432 indices = x.indices 433 values = x.values + 2 434 dense_shape = x.dense_shape 435 return RowTensorInner(indices, values, dense_shape) 436 437 438class RowTensorWithControlIf(nn.Cell): 439 def __init__(self, dense_shape): 440 super().__init__() 441 self.op1 = RowTensorValuesDouble() 442 self.op2 = RowTensorValuesAdd2() 443 self.dense_shape = dense_shape 444 445 def construct(self, a, b, indices, values): 446 x = RowTensorInner(indices, values, self.dense_shape) 447 if a > b: 448 x = self.op1(x) 449 else: 450 x = self.op2(x) 451 return x.indices, x.values 452 453 454def test_row_tensor_with_control_flow_if(): 455 a = Tensor(np.array(0).astype(np.int32)) 456 b = Tensor(np.array(2).astype(np.int32)) 457 indices = Tensor(np.array([0, 2]).astype(np.int32)) 458 values = Tensor(np.ones([2, 2]).astype(np.float32)) 459 dense_shape = (5, 2) 460 461 net = RowTensorWithControlIf(dense_shape) 462 net(a, b, indices, values) 463 464 465class EmbeddingLookUpBnNet(nn.Cell): 466 def __init__(self, vocab_size, embedding_size, target='CPU'): 467 super().__init__() 468 self.embedding_lookup = nn.EmbeddingLookup(vocab_size, embedding_size, param_init='ones', target=target) 469 self.bn = nn.BatchNorm2d(num_features=3) 470 self.mul = P.Mul() 471 self.reshape = P.Reshape() 472 self.relu = nn.PReLU() 473 474 def construct(self, indices): 475 x = self.embedding_lookup(indices) 476 x = self.reshape(x, (2, 3, 2, 2)) 477 x = self.relu(x) 478 x = self.bn(x) 479 x = x[0, 0, :, :] 480 return x 481 482 483def test_embedding_lookup_with_mix_precision(): 484 data = Tensor(np.array([0, 1, 2]).astype(np.int32)) 485 label = Tensor(np.random.randn(*(2, 2)).astype(np.float32)) 486 net = EmbeddingLookUpBnNet(8, 8, target='CPU') 487 488 criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') 489 optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1) 490 optimizer.target = 'CPU' 491 train_network = ms.amp.build_train_network(net, optimizer, criterion, level="O2") 492 train_network.set_train() 493 for _ in range(2): 494 train_network(data, label) 495