1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16@File : test_row_tensor.py 17@Author: 18@Date : 2020-06-08 19@Desc : test mindspore row_tensor's operation 20""" 21import numpy as np 22import pytest 23 24import mindspore as ms 25import mindspore.nn as nn 26from mindspore.ops import composite as C 27from mindspore.ops import functional as F 28from mindspore.ops import operations as P 29from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like 30from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register 31from mindspore.ops._grad.grad_base import bprop_getters 32from mindspore.ops._utils.utils import generate_shape_index 33from mindspore import Tensor, RowTensor, context 34from mindspore.common.parameter import Parameter, ParameterTuple 35from mindspore.common import dtype as mstype 36from mindspore._checkparam import Validator as validator 37from mindspore._checkparam import Rel 38from mindspore.nn import Optimizer 39from mindspore.nn import TrainOneStepCell, WithLossCell 40from mindspore.nn.optim import Momentum 41from mindspore.train import Model 42from ....dataset_mock import MindData 43 44@pytest.fixture(scope="module", autouse=True) 45def setup_teardown(): 46 context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) 47 yield 48 context.set_context(enable_sparse=False) 49 50reduce_sum = P.ReduceSum() 51unsorted_segment_sum = P.UnsortedSegmentSum() 52transpose = P.Transpose() 53shape_op = P.Shape() 54reshape = P.Reshape() 55size_op = P.Size() 56invert_permutation = P.InvertPermutation() 57logical_and = P.LogicalAnd() 58 59def get_axis(x): 60 shape = shape_op(x) 61 length = F.tuple_len(shape) 62 perm = F.make_range(0, length) 63 return perm 64 65class MSELoss(nn.Cell): 66 def __init__(self): 67 super(MSELoss, self).__init__() 68 self.reduce_sum = P.ReduceSum() 69 self.square = P.Square() 70 self.reduce_mean = P.ReduceMean() 71 72 def construct(self, data, label): 73 diff = data - label 74 return self.reduce_mean(self.square(diff), get_axis(diff)) 75 76 77class MindDataSet(MindData): 78 def __init__(self, dataset_types, dataset_shapes): 79 super(MindDataSet, self).__init__(size=2, batch_size=32, 80 np_types=dataset_types, 81 output_shapes=dataset_shapes, 82 input_indexs=(0, 1)) 83 def __next__(self): 84 if self._size < self._iter_num: 85 raise StopIteration 86 self._iter_num += 1 87 lst = [] 88 for shape_, type_ in zip(self._output_shapes, self._np_types): 89 lst.append(Tensor(np.ones(shape_).astype(type_))) 90 return tuple(lst) 91 92 93@constexpr 94def _generate_inverse_index(x_shape, axis): 95 x_rank = len(x_shape) 96 index = tuple(range(x_rank)) 97 if axis < 0: 98 axis += x_rank 99 perm = index[1:1 + axis] + (0,) + index[1 + axis:] 100 return perm 101 102# pylint: disable=W0231 103class MySparseGatherV2(PrimitiveWithInfer): 104 """ 105 For test 106 """ 107 @prim_attr_register 108 def __init__(self): 109 """init index_select""" 110 self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) 111 112 def __infer__(self, params, indices, axis): 113 validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) 114 validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) 115 validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) 116 axis_v = axis['value'] 117 params_shp = params['shape'] 118 rank = len(params_shp) 119 validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) 120 if axis_v < 0: 121 axis_v += rank 122 out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] 123 out = {'shape': out_shape, 124 'dtype': params['dtype'], 125 'value': None} 126 return out 127 128@bprop_getters.register(MySparseGatherV2) 129def get_bprop_sparse_gather_v2(self): 130 """Generate bprop for MySparseGatherV2""" 131 132 def bprop(x, indices, axis, out, dout): 133 x_shp = shape_op(x) 134 if axis == 0: 135 indices_size = (size_op(indices),) 136 x_tail_shp = x_shp[1:] 137 values_shape = indices_size + x_tail_shp 138 values = reshape(dout, values_shape) 139 indices = reshape(indices, indices_size) 140 return RowTensor(indices, values, x_shp), zeros_like(indices), zeros_like(axis) 141 if F.rank(dout) == 0: 142 dout = P.ExpandDims()(dout, -1) 143 if F.rank(indices) == 0: 144 indices = P.ExpandDims()(indices, -1) 145 out_shp = shape_op(dout) 146 ind_shp = shape_op(indices) 147 # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) 148 perm_1 = generate_shape_index(out_shp, ind_shp, axis) 149 values_transpose = transpose(dout, perm_1) 150 params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) 151 # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) 152 perm_2 = _generate_inverse_index(x_shp, axis) 153 params_grad = transpose(params_grad, perm_2) 154 return params_grad, zeros_like(indices), zeros_like(axis) 155 156 return bprop 157 158adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") 159@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", 160 "Tensor", "Tensor", "Tensor", "RowTensor", "Bool") 161def _update_run_op_for_map_row_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param, 162 m, v, gradient, decay_flag): 163 return gradient.values 164 165@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", 166 "Tensor", "Tensor", "Tensor", "Tensor", "Bool") 167def _update_run_op_for_map_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param, 168 m, v, gradient, decay_flag): 169 op_mul = P.Mul() 170 op_square = P.Square() 171 op_sqrt = P.Sqrt() 172 op_cast = P.Cast() 173 op_reshape = P.Reshape() 174 op_shape = P.Shape() 175 176 param_fp32 = op_cast(param, mstype.float32) 177 m_fp32 = op_cast(m, mstype.float32) 178 v_fp32 = op_cast(v, mstype.float32) 179 gradient_fp32 = op_cast(gradient, mstype.float32) 180 181 next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) 182 183 next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) 184 - beta2, op_square(gradient_fp32)) 185 186 update = next_m / (op_sqrt(next_v) + eps) 187 if decay_flag: 188 update = update + op_mul(weight_decay_tensor, param_fp32) 189 190 update_with_lr = op_mul(lr, update) 191 next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) 192 193 next_v = F.depend(next_v, F.assign(param, next_param)) 194 next_v = F.depend(next_v, F.assign(m, next_m)) 195 next_v = F.depend(next_v, F.assign(v, next_v)) 196 return next_v 197 198 199def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): 200 """Check the type of inputs.""" 201 validator.check_value_type("beta1", beta1, [float], prim_name) 202 validator.check_value_type("beta2", beta2, [float], prim_name) 203 validator.check_value_type("eps", eps, [float], prim_name) 204 validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) 205 validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) 206 validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) 207 validator.check_positive_float(eps, "eps", prim_name) 208 validator.check_non_negative_float(weight_decay, "weight_decay", prim_name) 209 210 211class AdamWeightDecaySparse(Optimizer): 212 def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, 213 decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): 214 super(AdamWeightDecaySparse, self).__init__(learning_rate, params) 215 if self.is_group: 216 raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") 217 _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) 218 self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) 219 self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) 220 self.eps = Tensor(np.array([eps]).astype(np.float32)) 221 self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32)) 222 223 self.params = self.parameters 224 self.moments1 = self.params.clone(prefix="adam_m", init='zeros') 225 self.moments2 = self.params.clone(prefix="adam_v", init='zeros') 226 self.decay_flag = tuple(decay_filter(x) for x in self.params) 227 self.map = C.Map() 228 229 def construct(self, gradients): 230 lr = self.get_lr() 231 updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr, 232 self.weight_decay_tensor), 233 self.params, self.moments1, self.moments2, gradients, self.decay_flag) 234 return updated_velocity 235 236 237def test_row_tensor_make_row_tensor(): 238 class MakeRowTensor(nn.Cell): 239 def __init__(self): 240 super(MakeRowTensor, self).__init__() 241 self.dense_shape = (3, 2) 242 def construct(self, indices, values): 243 ret = (RowTensor(indices, values, self.dense_shape),) 244 return ret[0] 245 indices = Tensor([1, 2]) 246 values = Tensor([[0, 0], [1, 2]], dtype=ms.float32) 247 MakeRowTensor()(indices, values) 248 249 250class RowTensorGetAttr(nn.Cell): 251 def __init__(self, dense_shape): 252 super(RowTensorGetAttr, self).__init__() 253 self.dense_shape = dense_shape 254 def construct(self, indices, values): 255 x = RowTensor(indices, values, self.dense_shape) 256 return x.values, x.indices, x.dense_shape 257 258 259def test_row_tensor_attr(): 260 indices = Tensor([0]) 261 values = Tensor([[1, 2]], dtype=ms.float32) 262 RowTensorGetAttr((3, 2))(indices, values) 263 264 265def test_row_tensor_sparse_gatherv2_grad_all(): 266 grad_all = C.GradOperation(get_all=True) 267 class GradWrap(nn.Cell): 268 def __init__(self, network): 269 super(GradWrap, self).__init__() 270 self.network = network 271 def construct(self, x, y): 272 grad = grad_all(self.network)(x, y) 273 return grad[0].indices, grad[0].values, grad[0].dense_shape 274 class SparseGatherV2(nn.Cell): 275 def __init__(self): 276 super(SparseGatherV2, self).__init__() 277 self.sparse_gatherv2 = MySparseGatherV2() 278 self.axis = 0 279 def construct(self, params, indices): 280 return self.sparse_gatherv2(params, indices, self.axis) 281 params = Tensor(np.ones([3, 1, 2]).astype(np.int32)) 282 indices = Tensor(np.array([0, 1]).astype(np.int32)) 283 GradWrap(SparseGatherV2())(params, indices) 284 285 286def test_row_tensor_sparse_gatherv2_grad_with_pram(): 287 grad_by_list = C.GradOperation(get_by_list=True) 288 class GradWrap(nn.Cell): 289 def __init__(self, network): 290 super(GradWrap, self).__init__() 291 self.network = network 292 self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) 293 def construct(self, x): 294 weights = self.weights 295 grad = grad_by_list(self.network, weights)(x) 296 x = grad[0] 297 return x.values, x.indices, x.dense_shape 298 class SparseGatherV2(nn.Cell): 299 def __init__(self): 300 super(SparseGatherV2, self).__init__() 301 self.sparse_gatherv2 = MySparseGatherV2() 302 self.axis = 0 303 self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), name="params") 304 def construct(self, indices): 305 return self.sparse_gatherv2(self.params, indices, self.axis) 306 indices = Tensor(np.array([0, 1]).astype(np.int32)) 307 network = GradWrap(SparseGatherV2()) 308 network(indices) 309 310 311def test_row_tensor_env_get(): 312 class Loss(nn.Cell): 313 def __init__(self): 314 super(Loss, self).__init__() 315 def construct(self, base, target): 316 return base 317 class NetWithSparseGatherV2(nn.Cell): 318 def __init__(self): 319 super(NetWithSparseGatherV2, self).__init__() 320 self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1") 321 self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") 322 self.gatherv2 = MySparseGatherV2() 323 self.axis = 0 324 def construct(self, indices): 325 return self.gatherv2(self.w1, indices, self.axis) * self.w2 326 327 inputs = Tensor(np.array([0, 1]).astype(np.int32)) 328 label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) 329 net = NetWithSparseGatherV2() 330 net.set_train() 331 loss = Loss() 332 optimizer = AdamWeightDecaySparse(net.trainable_params()) 333 334 net_with_loss = WithLossCell(net, loss) 335 train_network = TrainOneStepCell(net_with_loss, optimizer) 336 train_network(inputs, label) 337 338 339def test_row_tensor_model_train(): 340 class Net(nn.Cell): 341 def __init__(self, in_features, out_features): 342 super(Net, self).__init__() 343 self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight") 344 self.add = P.Add() 345 self.cast = P.Cast() 346 self.flag = True 347 348 def construct(self, inputs, label): 349 x = self.add(inputs, self.weight) 350 if self.flag: 351 x = self.cast(x, mstype.float32) 352 return x 353 354 dataset_types = (np.float32, np.float32) 355 dataset_shapes = ((16, 16), (16, 16)) 356 dataset = MindDataSet(dataset_types, dataset_shapes) 357 net = Net(16, 16) 358 net.set_train() 359 360 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 361 model = Model(net, optimizer=optimizer) 362 model.train(2, dataset, dataset_sink_mode=False) 363 364 365def test_row_tensor_values_dim_greater_than_dense_shape_dim(): 366 indices = Tensor(np.array([0, 1], dtype=np.int32)) 367 values = Tensor(np.random.randn(2, 4, 5).astype(np.float32)) 368 dense_shape = (3, 4) 369 with pytest.raises(TypeError): 370 RowTensorGetAttr(dense_shape)(indices, values) 371 372 373def test_row_tensor_values_dim_less_than_dense_shape_dim(): 374 indices = Tensor(np.array([0, 1], dtype=np.int32)) 375 values = Tensor(np.random.randn(2, 4).astype(np.float32)) 376 dense_shape = (3, 4, 5) 377 with pytest.raises(TypeError): 378 RowTensorGetAttr(dense_shape)(indices, values) 379 380 381def test_row_tensor_value_and_dense_shape_illegal(): 382 indices = Tensor(np.array([0, 1], dtype=np.int32)) 383 values = Tensor(np.random.randn(2, 4).astype(np.float32)) 384 dense_shape = (3, 5) 385 with pytest.raises(TypeError): 386 RowTensorGetAttr(dense_shape)(indices, values) 387 388 389class RowTensorValuesDouble(nn.Cell): 390 def __init__(self): 391 super().__init__() 392 393 def construct(self, x): 394 indices = x.indices 395 values = x.values * 2 396 dense_shape = x.dense_shape 397 return RowTensor(indices, values, dense_shape) 398 399 400class RowTensorValuesAdd2(nn.Cell): 401 def __init__(self): 402 super().__init__() 403 404 def construct(self, x): 405 indices = x.indices 406 values = x.values + 2 407 dense_shape = x.dense_shape 408 return RowTensor(indices, values, dense_shape) 409 410 411class RowTensorWithControlIf(nn.Cell): 412 def __init__(self, dense_shape): 413 super().__init__() 414 self.op1 = RowTensorValuesDouble() 415 self.op2 = RowTensorValuesAdd2() 416 self.dense_shape = dense_shape 417 418 def construct(self, a, b, indices, values): 419 x = RowTensor(indices, values, self.dense_shape) 420 if a > b: 421 x = self.op1(x) 422 else: 423 x = self.op2(x) 424 return x.indices, x.values 425 426 427def test_row_tensor_with_control_flow_if(): 428 a = Tensor(np.array(0).astype(np.int32)) 429 b = Tensor(np.array(2).astype(np.int32)) 430 indices = Tensor(np.array([0, 2]).astype(np.int32)) 431 values = Tensor(np.ones([2, 2]).astype(np.float32)) 432 dense_shape = (5, 2) 433 434 net = RowTensorWithControlIf(dense_shape) 435 net(a, b, indices, values) 436 437 438class EmbeddingLookUpBnNet(nn.Cell): 439 def __init__(self, vocab_size, embedding_size, target='CPU'): 440 super().__init__() 441 self.embedding_lookup = nn.EmbeddingLookup(vocab_size, embedding_size, param_init='ones', target=target) 442 self.bn = nn.BatchNorm2d(num_features=3) 443 self.mul = P.Mul() 444 self.reshape = P.Reshape() 445 self.relu = nn.PReLU() 446 447 def construct(self, indices): 448 x = self.embedding_lookup(indices) 449 x = self.reshape(x, (2, 3, 2, 2)) 450 x = self.relu(x) 451 x = self.bn(x) 452 return x 453 454 455def test_embedding_lookup_with_mix_precision(): 456 data = Tensor(np.array([0, 1, 2]).astype(np.int32)) 457 label = Tensor(np.random.randn(*(2, 3, 2, 2)).astype(np.float32)) 458 net = EmbeddingLookUpBnNet(8, 8, target='CPU') 459 460 criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') 461 optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1) 462 optimizer.target = 'CPU' 463 train_network = ms.amp.build_train_network(net, optimizer, criterion, level="O2") 464 train_network.set_train() 465 for _ in range(2): 466 train_network(data, label) 467