• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16@File  : test_row_tensor.py
17@Author:
18@Date  : 2020-06-08
19@Desc  : test mindspore row_tensor's operation
20"""
21import numpy as np
22import pytest
23
24import mindspore as ms
25import mindspore.nn as nn
26from mindspore.ops import composite as C
27from mindspore.ops import functional as F
28from mindspore.ops import operations as P
29from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
30from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
31from mindspore.ops._grad_experimental.grad_base import bprop_getters
32from mindspore.ops._utils.utils import generate_shape_index
33from mindspore import Tensor, context
34from mindspore.common.parameter import Parameter, ParameterTuple
35from mindspore.common.sparse_tensor import RowTensorInner
36from mindspore.common import dtype as mstype
37from mindspore import _checkparam as validator
38from mindspore.nn import Optimizer
39from mindspore.nn import TrainOneStepCell, WithLossCell
40from mindspore.nn.optim import Momentum
41from mindspore.train import Model
42from ....dataset_mock import MindData
43
44
45context.set_context(mode=context.GRAPH_MODE)
46
47reduce_sum = P.ReduceSum()
48unsorted_segment_sum = P.UnsortedSegmentSum()
49transpose = P.Transpose()
50shape_op = P.Shape()
51reshape = P.Reshape()
52size_op = P.Size()
53invert_permutation = P.InvertPermutation()
54logical_and = P.LogicalAnd()
55
56
57def get_axis(x):
58    shape = shape_op(x)
59    length = F.tuple_len(shape)
60    perm = F.make_range(0, length)
61    return perm
62
63
64class MSELoss(nn.Cell):
65    def __init__(self):
66        super(MSELoss, self).__init__()
67        self.reduce_sum = P.ReduceSum()
68        self.square = P.Square()
69        self.reduce_mean = P.ReduceMean()
70
71    def construct(self, data, label):
72        diff = data - label
73        return self.reduce_mean(self.square(diff), get_axis(diff))
74
75
76class MindDataSet(MindData):
77    def __init__(self, dataset_types, dataset_shapes):
78        super(MindDataSet, self).__init__(size=2, batch_size=32,
79                                          np_types=dataset_types,
80                                          output_shapes=dataset_shapes,
81                                          input_indexs=(0, 1))
82
83    def __next__(self):
84        if self._size < self._iter_num:
85            raise StopIteration
86        self._iter_num += 1
87        lst = []
88        for shape_, type_ in zip(self._output_shapes, self._np_types):
89            lst.append(Tensor(np.ones(shape_).astype(type_)))
90        return tuple(lst)
91
92
93@constexpr
94def _generate_inverse_index(x_shape, axis):
95    x_rank = len(x_shape)
96    index = tuple(range(x_rank))
97    if axis < 0:
98        axis += x_rank
99    perm = index[1:1 + axis] + (0,) + index[1 + axis:]
100    return perm
101
102
103# pylint: disable=W0231
104class MySparseGatherV2(PrimitiveWithInfer):
105    """
106    For test
107    """
108
109    @prim_attr_register
110    def __init__(self):
111        """init index_select"""
112        self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
113        self.add_prim_attr('bprop_return_sparse', True)
114
115    def __infer__(self, params, indices, axis):
116        validator.check_subclass("params", params['dtype'], mstype.tensor_type, self.name)
117        validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
118        validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
119        axis_v = axis['value']
120        params_shp = params['shape']
121        rank = len(params_shp)
122        validator.check_int_range(axis_v, -rank, rank, validator.INC_LEFT, "axis", self.name)
123        if axis_v < 0:
124            axis_v += rank
125        out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
126        out = {'shape': out_shape,
127               'dtype': params['dtype'],
128               'value': None}
129        return out
130
131
132@bprop_getters.register(MySparseGatherV2)
133def get_bprop_sparse_gather_v2(self):
134    """Generate bprop for MySparseGatherV2"""
135
136    def bprop(x, indices, axis, out, dout):
137        x_shp = shape_op(x)
138        if axis == 0:
139            indices_size = (size_op(indices),)
140            x_tail_shp = x_shp[1:]
141            values_shape = indices_size + x_tail_shp
142            values = reshape(dout, values_shape)
143            indices = reshape(indices, indices_size)
144            return RowTensorInner(indices, values, x_shp), zeros_like(indices), zeros_like(axis)
145        if F.rank(dout) == 0:
146            dout = P.ExpandDims()(dout, -1)
147        if F.rank(indices) == 0:
148            indices = P.ExpandDims()(indices, -1)
149        out_shp = shape_op(dout)
150        ind_shp = shape_op(indices)
151        # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
152        perm_1 = generate_shape_index(out_shp, ind_shp, axis)
153        values_transpose = transpose(dout, perm_1)
154        params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
155        # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
156        perm_2 = _generate_inverse_index(x_shp, axis)
157        params_grad = transpose(params_grad, perm_2)
158        return params_grad, zeros_like(indices), zeros_like(axis)
159
160    return bprop
161
162
163adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
164
165
166@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
167                           "Tensor", "Tensor", "Tensor", "RowTensor", "Bool")
168def _update_run_op_for_map_row_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param,
169                                      m, v, gradient, decay_flag):
170    return gradient.values
171
172
173@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
174                           "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
175def _update_run_op_for_map_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param,
176                                  m, v, gradient, decay_flag):
177    op_mul = P.Mul()
178    op_square = P.Square()
179    op_sqrt = P.Sqrt()
180    op_cast = P.Cast()
181    op_reshape = P.Reshape()
182    op_shape = P.Shape()
183
184    param_fp32 = op_cast(param, mstype.float32)
185    m_fp32 = op_cast(m, mstype.float32)
186    v_fp32 = op_cast(v, mstype.float32)
187    gradient_fp32 = op_cast(gradient, mstype.float32)
188
189    next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
190
191    next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
192                                            - beta2, op_square(gradient_fp32))
193
194    update = next_m / (op_sqrt(next_v) + eps)
195    if decay_flag:
196        update = update + op_mul(weight_decay_tensor, param_fp32)
197
198    update_with_lr = op_mul(lr, update)
199    next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
200
201    next_v = F.depend(next_v, F.assign(param, next_param))
202    next_v = F.depend(next_v, F.assign(m, next_m))
203    next_v = F.depend(next_v, F.assign(v, next_v))
204    return next_v
205
206
207def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
208    """Check the type of inputs."""
209    validator.check_value_type("beta1", beta1, [float], prim_name)
210    validator.check_value_type("beta2", beta2, [float], prim_name)
211    validator.check_value_type("eps", eps, [float], prim_name)
212    validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
213    validator.check_float_range(beta1, 0.0, 1.0, validator.INC_NEITHER, "beta1", prim_name)
214    validator.check_float_range(beta2, 0.0, 1.0, validator.INC_NEITHER, "beta2", prim_name)
215    validator.check_positive_float(eps, "eps", prim_name)
216    validator.check_non_negative_float(weight_decay, "weight_decay", prim_name)
217
218
219class AdamWeightDecaySparse(Optimizer):
220    def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
221                 decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
222        super(AdamWeightDecaySparse, self).__init__(learning_rate, params)
223        if self.is_group:
224            raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
225        _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
226        self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
227        self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
228        self.eps = Tensor(np.array([eps]).astype(np.float32))
229        self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
230
231        self.params = self.parameters
232        self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
233        self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
234        self.decay_flag = tuple(decay_filter(x) for x in self.params)
235        self.map = C.Map()
236
237    def construct(self, gradients):
238        lr = self.get_lr()
239        updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr,
240                                              self.weight_decay_tensor),
241                                    self.params, self.moments1, self.moments2, gradients, self.decay_flag)
242        return updated_velocity
243
244
245def test_row_tensor_make_row_tensor():
246    class MakeRowTensor(nn.Cell):
247        def __init__(self):
248            super(MakeRowTensor, self).__init__()
249            self.dense_shape = (3, 2)
250
251        def construct(self, indices, values):
252            ret = (RowTensorInner(indices, values, self.dense_shape),)
253            return ret[0]
254
255    indices = Tensor([1, 2])
256    values = Tensor([[0, 0], [1, 2]], dtype=ms.float32)
257    MakeRowTensor()(indices, values)
258
259
260class RowTensorGetAttr(nn.Cell):
261    def __init__(self, dense_shape):
262        super(RowTensorGetAttr, self).__init__()
263        self.dense_shape = dense_shape
264
265    def construct(self, indices, values):
266        x = RowTensorInner(indices, values, self.dense_shape)
267        return x.values, x.indices, x.dense_shape
268
269
270def test_row_tensor_attr():
271    indices = Tensor([0])
272    values = Tensor([[1, 2]], dtype=ms.float32)
273    RowTensorGetAttr((3, 2))(indices, values)
274
275
276def test_row_tensor_sparse_gatherv2_grad_all():
277    grad_all = C.GradOperation(get_all=True)
278
279    class GradWrap(nn.Cell):
280        def __init__(self, network):
281            super(GradWrap, self).__init__()
282            self.network = network
283
284        def construct(self, x, y):
285            grad = grad_all(self.network)(x, y)
286            return grad[0].indices, grad[0].values, grad[0].dense_shape
287
288    class SparseGatherV2(nn.Cell):
289        def __init__(self):
290            super(SparseGatherV2, self).__init__()
291            self.sparse_gatherv2 = MySparseGatherV2()
292            self.axis = 0
293
294        def construct(self, params, indices):
295            return self.sparse_gatherv2(params, indices, self.axis)
296
297    params = Tensor(np.ones([3, 1, 2]).astype(np.int32))
298    indices = Tensor(np.array([0, 1]).astype(np.int32))
299    GradWrap(SparseGatherV2())(params, indices)
300
301
302def test_row_tensor_sparse_gatherv2_grad_with_pram():
303    grad_by_list = C.GradOperation(get_by_list=True)
304
305    class GradWrap(nn.Cell):
306        def __init__(self, network):
307            super(GradWrap, self).__init__()
308            self.network = network
309            self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
310
311        def construct(self, x):
312            weights = self.weights
313            grad = grad_by_list(self.network, weights)(x)
314            x = grad[0]
315            return x.values, x.indices, x.dense_shape
316
317    class SparseGatherV2(nn.Cell):
318        def __init__(self):
319            super(SparseGatherV2, self).__init__()
320            self.sparse_gatherv2 = MySparseGatherV2()
321            self.axis = 0
322            self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), name="params")
323
324        def construct(self, indices):
325            return self.sparse_gatherv2(self.params, indices, self.axis)
326
327    indices = Tensor(np.array([0, 1]).astype(np.int32))
328    network = GradWrap(SparseGatherV2())
329    network(indices)
330
331
332def test_row_tensor_env_get():
333    class Loss(nn.Cell):
334        def __init__(self, empty=None):
335            super(Loss, self).__init__()
336            self.empty = empty
337
338        def construct(self, base, target):
339            return base
340
341    class NetWithSparseGatherV2(nn.Cell):
342        def __init__(self):
343            super(NetWithSparseGatherV2, self).__init__()
344            self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1")
345            self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
346            self.gatherv2 = MySparseGatherV2()
347            self.axis = 0
348
349        def construct(self, indices):
350            return self.gatherv2(self.w1, indices, self.axis) * self.w2
351
352    inputs = Tensor(np.array([0, 1]).astype(np.int32))
353    label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
354    net = NetWithSparseGatherV2()
355    net.set_train()
356    loss = Loss()
357    optimizer = AdamWeightDecaySparse(net.trainable_params())
358
359    net_with_loss = WithLossCell(net, loss)
360    train_network = TrainOneStepCell(net_with_loss, optimizer)
361    train_network(inputs, label)
362
363
364def test_row_tensor_model_train():
365    class Net(nn.Cell):
366        def __init__(self, in_features, out_features):
367            super(Net, self).__init__()
368            self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
369            self.add = P.Add()
370            self.cast = P.Cast()
371            self.flag = True
372
373        def construct(self, inputs, label):
374            x = self.add(inputs, self.weight)
375            if self.flag:
376                x = self.cast(x, mstype.float32)
377            return x
378
379    dataset_types = (np.float32, np.float32)
380    dataset_shapes = ((16, 16), (16, 16))
381    dataset = MindDataSet(dataset_types, dataset_shapes)
382    net = Net(16, 16)
383    net.set_train()
384
385    optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
386    model = Model(net, optimizer=optimizer)
387    model.train(2, dataset, dataset_sink_mode=False)
388
389
390def test_row_tensor_values_dim_greater_than_dense_shape_dim():
391    indices = Tensor(np.array([0, 1], dtype=np.int32))
392    values = Tensor(np.random.randn(2, 4, 5).astype(np.float32))
393    dense_shape = (3, 4)
394    with pytest.raises(TypeError):
395        RowTensorGetAttr(dense_shape)(indices, values)
396
397
398def test_row_tensor_values_dim_less_than_dense_shape_dim():
399    indices = Tensor(np.array([0, 1], dtype=np.int32))
400    values = Tensor(np.random.randn(2, 4).astype(np.float32))
401    dense_shape = (3, 4, 5)
402    with pytest.raises(TypeError):
403        RowTensorGetAttr(dense_shape)(indices, values)
404
405
406def test_row_tensor_value_and_dense_shape_illegal():
407    indices = Tensor(np.array([0, 1], dtype=np.int32))
408    values = Tensor(np.random.randn(2, 4).astype(np.float32))
409    dense_shape = (3, 5)
410    with pytest.raises(TypeError):
411        RowTensorGetAttr(dense_shape)(indices, values)
412
413
414class RowTensorValuesDouble(nn.Cell):
415    def __init__(self, empty=None):
416        super(RowTensorValuesDouble, self).__init__()
417        self.empty = empty
418
419    def construct(self, x):
420        indices = x.indices
421        values = x.values * 2
422        dense_shape = x.dense_shape
423        return RowTensorInner(indices, values, dense_shape)
424
425
426class RowTensorValuesAdd2(nn.Cell):
427    def __init__(self, empty=None):
428        super(RowTensorValuesAdd2, self).__init__()
429        self.empty = empty
430
431    def construct(self, x):
432        indices = x.indices
433        values = x.values + 2
434        dense_shape = x.dense_shape
435        return RowTensorInner(indices, values, dense_shape)
436
437
438class RowTensorWithControlIf(nn.Cell):
439    def __init__(self, dense_shape):
440        super().__init__()
441        self.op1 = RowTensorValuesDouble()
442        self.op2 = RowTensorValuesAdd2()
443        self.dense_shape = dense_shape
444
445    def construct(self, a, b, indices, values):
446        x = RowTensorInner(indices, values, self.dense_shape)
447        if a > b:
448            x = self.op1(x)
449        else:
450            x = self.op2(x)
451        return x.indices, x.values
452
453
454def test_row_tensor_with_control_flow_if():
455    a = Tensor(np.array(0).astype(np.int32))
456    b = Tensor(np.array(2).astype(np.int32))
457    indices = Tensor(np.array([0, 2]).astype(np.int32))
458    values = Tensor(np.ones([2, 2]).astype(np.float32))
459    dense_shape = (5, 2)
460
461    net = RowTensorWithControlIf(dense_shape)
462    net(a, b, indices, values)
463
464
465class EmbeddingLookUpBnNet(nn.Cell):
466    def __init__(self, vocab_size, embedding_size, target='CPU'):
467        super().__init__()
468        self.embedding_lookup = nn.EmbeddingLookup(vocab_size, embedding_size, param_init='ones', target=target)
469        self.bn = nn.BatchNorm2d(num_features=3)
470        self.mul = P.Mul()
471        self.reshape = P.Reshape()
472        self.relu = nn.PReLU()
473
474    def construct(self, indices):
475        x = self.embedding_lookup(indices)
476        x = self.reshape(x, (2, 3, 2, 2))
477        x = self.relu(x)
478        x = self.bn(x)
479        x = x[0, 0, :, :]
480        return x
481
482
483def test_embedding_lookup_with_mix_precision():
484    data = Tensor(np.array([0, 1, 2]).astype(np.int32))
485    label = Tensor(np.random.randn(*(2, 2)).astype(np.float32))
486    net = EmbeddingLookUpBnNet(8, 8, target='CPU')
487
488    criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
489    optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1)
490    optimizer.target = 'CPU'
491    train_network = ms.amp.build_train_network(net, optimizer, criterion, level="O2")
492    train_network.set_train()
493    for _ in range(2):
494        train_network(data, label)
495