• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16@File  : test_row_tensor.py
17@Author:
18@Date  : 2020-06-08
19@Desc  : test mindspore row_tensor's operation
20"""
21import numpy as np
22import pytest
23
24import mindspore as ms
25import mindspore.nn as nn
26from mindspore.ops import composite as C
27from mindspore.ops import functional as F
28from mindspore.ops import operations as P
29from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
30from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
31from mindspore.ops._grad.grad_base import bprop_getters
32from mindspore.ops._utils.utils import generate_shape_index
33from mindspore import Tensor, RowTensor, context
34from mindspore.common.parameter import Parameter, ParameterTuple
35from mindspore.common import dtype as mstype
36from mindspore._checkparam import Validator as validator
37from mindspore._checkparam import Rel
38from mindspore.nn import Optimizer
39from mindspore.nn import TrainOneStepCell, WithLossCell
40from mindspore.nn.optim import Momentum
41from mindspore.train import Model
42from ....dataset_mock import MindData
43
44@pytest.fixture(scope="module", autouse=True)
45def setup_teardown():
46    context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
47    yield
48    context.set_context(enable_sparse=False)
49
50reduce_sum = P.ReduceSum()
51unsorted_segment_sum = P.UnsortedSegmentSum()
52transpose = P.Transpose()
53shape_op = P.Shape()
54reshape = P.Reshape()
55size_op = P.Size()
56invert_permutation = P.InvertPermutation()
57logical_and = P.LogicalAnd()
58
59def get_axis(x):
60    shape = shape_op(x)
61    length = F.tuple_len(shape)
62    perm = F.make_range(0, length)
63    return perm
64
65class MSELoss(nn.Cell):
66    def __init__(self):
67        super(MSELoss, self).__init__()
68        self.reduce_sum = P.ReduceSum()
69        self.square = P.Square()
70        self.reduce_mean = P.ReduceMean()
71
72    def construct(self, data, label):
73        diff = data - label
74        return self.reduce_mean(self.square(diff), get_axis(diff))
75
76
77class MindDataSet(MindData):
78    def __init__(self, dataset_types, dataset_shapes):
79        super(MindDataSet, self).__init__(size=2, batch_size=32,
80                                          np_types=dataset_types,
81                                          output_shapes=dataset_shapes,
82                                          input_indexs=(0, 1))
83    def __next__(self):
84        if self._size < self._iter_num:
85            raise StopIteration
86        self._iter_num += 1
87        lst = []
88        for shape_, type_ in zip(self._output_shapes, self._np_types):
89            lst.append(Tensor(np.ones(shape_).astype(type_)))
90        return tuple(lst)
91
92
93@constexpr
94def _generate_inverse_index(x_shape, axis):
95    x_rank = len(x_shape)
96    index = tuple(range(x_rank))
97    if axis < 0:
98        axis += x_rank
99    perm = index[1:1 + axis] + (0,) + index[1 + axis:]
100    return perm
101
102# pylint: disable=W0231
103class MySparseGatherV2(PrimitiveWithInfer):
104    """
105    For test
106    """
107    @prim_attr_register
108    def __init__(self):
109        """init index_select"""
110        self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
111
112    def __infer__(self, params, indices, axis):
113        validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
114        validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
115        validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
116        axis_v = axis['value']
117        params_shp = params['shape']
118        rank = len(params_shp)
119        validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
120        if axis_v < 0:
121            axis_v += rank
122        out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
123        out = {'shape': out_shape,
124               'dtype': params['dtype'],
125               'value': None}
126        return out
127
128@bprop_getters.register(MySparseGatherV2)
129def get_bprop_sparse_gather_v2(self):
130    """Generate bprop for MySparseGatherV2"""
131
132    def bprop(x, indices, axis, out, dout):
133        x_shp = shape_op(x)
134        if axis == 0:
135            indices_size = (size_op(indices),)
136            x_tail_shp = x_shp[1:]
137            values_shape = indices_size + x_tail_shp
138            values = reshape(dout, values_shape)
139            indices = reshape(indices, indices_size)
140            return RowTensor(indices, values, x_shp), zeros_like(indices), zeros_like(axis)
141        if F.rank(dout) == 0:
142            dout = P.ExpandDims()(dout, -1)
143        if F.rank(indices) == 0:
144            indices = P.ExpandDims()(indices, -1)
145        out_shp = shape_op(dout)
146        ind_shp = shape_op(indices)
147        # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
148        perm_1 = generate_shape_index(out_shp, ind_shp, axis)
149        values_transpose = transpose(dout, perm_1)
150        params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
151        # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
152        perm_2 = _generate_inverse_index(x_shp, axis)
153        params_grad = transpose(params_grad, perm_2)
154        return params_grad, zeros_like(indices), zeros_like(axis)
155
156    return bprop
157
158adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
159@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
160                           "Tensor", "Tensor", "Tensor", "RowTensor", "Bool")
161def _update_run_op_for_map_row_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param,
162                                      m, v, gradient, decay_flag):
163    return gradient.values
164
165@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
166                           "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
167def _update_run_op_for_map_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param,
168                                  m, v, gradient, decay_flag):
169    op_mul = P.Mul()
170    op_square = P.Square()
171    op_sqrt = P.Sqrt()
172    op_cast = P.Cast()
173    op_reshape = P.Reshape()
174    op_shape = P.Shape()
175
176    param_fp32 = op_cast(param, mstype.float32)
177    m_fp32 = op_cast(m, mstype.float32)
178    v_fp32 = op_cast(v, mstype.float32)
179    gradient_fp32 = op_cast(gradient, mstype.float32)
180
181    next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
182
183    next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
184                                            - beta2, op_square(gradient_fp32))
185
186    update = next_m / (op_sqrt(next_v) + eps)
187    if decay_flag:
188        update = update + op_mul(weight_decay_tensor, param_fp32)
189
190    update_with_lr = op_mul(lr, update)
191    next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
192
193    next_v = F.depend(next_v, F.assign(param, next_param))
194    next_v = F.depend(next_v, F.assign(m, next_m))
195    next_v = F.depend(next_v, F.assign(v, next_v))
196    return next_v
197
198
199def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
200    """Check the type of inputs."""
201    validator.check_value_type("beta1", beta1, [float], prim_name)
202    validator.check_value_type("beta2", beta2, [float], prim_name)
203    validator.check_value_type("eps", eps, [float], prim_name)
204    validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
205    validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
206    validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
207    validator.check_positive_float(eps, "eps", prim_name)
208    validator.check_non_negative_float(weight_decay, "weight_decay", prim_name)
209
210
211class AdamWeightDecaySparse(Optimizer):
212    def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
213                 decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
214        super(AdamWeightDecaySparse, self).__init__(learning_rate, params)
215        if self.is_group:
216            raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
217        _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
218        self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
219        self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
220        self.eps = Tensor(np.array([eps]).astype(np.float32))
221        self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
222
223        self.params = self.parameters
224        self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
225        self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
226        self.decay_flag = tuple(decay_filter(x) for x in self.params)
227        self.map = C.Map()
228
229    def construct(self, gradients):
230        lr = self.get_lr()
231        updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr,
232                                              self.weight_decay_tensor),
233                                    self.params, self.moments1, self.moments2, gradients, self.decay_flag)
234        return updated_velocity
235
236
237def test_row_tensor_make_row_tensor():
238    class MakeRowTensor(nn.Cell):
239        def __init__(self):
240            super(MakeRowTensor, self).__init__()
241            self.dense_shape = (3, 2)
242        def construct(self, indices, values):
243            ret = (RowTensor(indices, values, self.dense_shape),)
244            return ret[0]
245    indices = Tensor([1, 2])
246    values = Tensor([[0, 0], [1, 2]], dtype=ms.float32)
247    MakeRowTensor()(indices, values)
248
249
250class RowTensorGetAttr(nn.Cell):
251    def __init__(self, dense_shape):
252        super(RowTensorGetAttr, self).__init__()
253        self.dense_shape = dense_shape
254    def construct(self, indices, values):
255        x = RowTensor(indices, values, self.dense_shape)
256        return x.values, x.indices, x.dense_shape
257
258
259def test_row_tensor_attr():
260    indices = Tensor([0])
261    values = Tensor([[1, 2]], dtype=ms.float32)
262    RowTensorGetAttr((3, 2))(indices, values)
263
264
265def test_row_tensor_sparse_gatherv2_grad_all():
266    grad_all = C.GradOperation(get_all=True)
267    class GradWrap(nn.Cell):
268        def __init__(self, network):
269            super(GradWrap, self).__init__()
270            self.network = network
271        def construct(self, x, y):
272            grad = grad_all(self.network)(x, y)
273            return grad[0].indices, grad[0].values, grad[0].dense_shape
274    class SparseGatherV2(nn.Cell):
275        def __init__(self):
276            super(SparseGatherV2, self).__init__()
277            self.sparse_gatherv2 = MySparseGatherV2()
278            self.axis = 0
279        def construct(self, params, indices):
280            return self.sparse_gatherv2(params, indices, self.axis)
281    params = Tensor(np.ones([3, 1, 2]).astype(np.int32))
282    indices = Tensor(np.array([0, 1]).astype(np.int32))
283    GradWrap(SparseGatherV2())(params, indices)
284
285
286def test_row_tensor_sparse_gatherv2_grad_with_pram():
287    grad_by_list = C.GradOperation(get_by_list=True)
288    class GradWrap(nn.Cell):
289        def __init__(self, network):
290            super(GradWrap, self).__init__()
291            self.network = network
292            self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
293        def construct(self, x):
294            weights = self.weights
295            grad = grad_by_list(self.network, weights)(x)
296            x = grad[0]
297            return x.values, x.indices, x.dense_shape
298    class SparseGatherV2(nn.Cell):
299        def __init__(self):
300            super(SparseGatherV2, self).__init__()
301            self.sparse_gatherv2 = MySparseGatherV2()
302            self.axis = 0
303            self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), name="params")
304        def construct(self, indices):
305            return self.sparse_gatherv2(self.params, indices, self.axis)
306    indices = Tensor(np.array([0, 1]).astype(np.int32))
307    network = GradWrap(SparseGatherV2())
308    network(indices)
309
310
311def test_row_tensor_env_get():
312    class Loss(nn.Cell):
313        def __init__(self):
314            super(Loss, self).__init__()
315        def construct(self, base, target):
316            return base
317    class NetWithSparseGatherV2(nn.Cell):
318        def __init__(self):
319            super(NetWithSparseGatherV2, self).__init__()
320            self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1")
321            self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
322            self.gatherv2 = MySparseGatherV2()
323            self.axis = 0
324        def construct(self, indices):
325            return self.gatherv2(self.w1, indices, self.axis) * self.w2
326
327    inputs = Tensor(np.array([0, 1]).astype(np.int32))
328    label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
329    net = NetWithSparseGatherV2()
330    net.set_train()
331    loss = Loss()
332    optimizer = AdamWeightDecaySparse(net.trainable_params())
333
334    net_with_loss = WithLossCell(net, loss)
335    train_network = TrainOneStepCell(net_with_loss, optimizer)
336    train_network(inputs, label)
337
338
339def test_row_tensor_model_train():
340    class Net(nn.Cell):
341        def __init__(self, in_features, out_features):
342            super(Net, self).__init__()
343            self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
344            self.add = P.Add()
345            self.cast = P.Cast()
346            self.flag = True
347
348        def construct(self, inputs, label):
349            x = self.add(inputs, self.weight)
350            if self.flag:
351                x = self.cast(x, mstype.float32)
352            return x
353
354    dataset_types = (np.float32, np.float32)
355    dataset_shapes = ((16, 16), (16, 16))
356    dataset = MindDataSet(dataset_types, dataset_shapes)
357    net = Net(16, 16)
358    net.set_train()
359
360    optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
361    model = Model(net, optimizer=optimizer)
362    model.train(2, dataset, dataset_sink_mode=False)
363
364
365def test_row_tensor_values_dim_greater_than_dense_shape_dim():
366    indices = Tensor(np.array([0, 1], dtype=np.int32))
367    values = Tensor(np.random.randn(2, 4, 5).astype(np.float32))
368    dense_shape = (3, 4)
369    with pytest.raises(TypeError):
370        RowTensorGetAttr(dense_shape)(indices, values)
371
372
373def test_row_tensor_values_dim_less_than_dense_shape_dim():
374    indices = Tensor(np.array([0, 1], dtype=np.int32))
375    values = Tensor(np.random.randn(2, 4).astype(np.float32))
376    dense_shape = (3, 4, 5)
377    with pytest.raises(TypeError):
378        RowTensorGetAttr(dense_shape)(indices, values)
379
380
381def test_row_tensor_value_and_dense_shape_illegal():
382    indices = Tensor(np.array([0, 1], dtype=np.int32))
383    values = Tensor(np.random.randn(2, 4).astype(np.float32))
384    dense_shape = (3, 5)
385    with pytest.raises(TypeError):
386        RowTensorGetAttr(dense_shape)(indices, values)
387
388
389class RowTensorValuesDouble(nn.Cell):
390    def __init__(self):
391        super().__init__()
392
393    def construct(self, x):
394        indices = x.indices
395        values = x.values * 2
396        dense_shape = x.dense_shape
397        return RowTensor(indices, values, dense_shape)
398
399
400class RowTensorValuesAdd2(nn.Cell):
401    def __init__(self):
402        super().__init__()
403
404    def construct(self, x):
405        indices = x.indices
406        values = x.values + 2
407        dense_shape = x.dense_shape
408        return RowTensor(indices, values, dense_shape)
409
410
411class RowTensorWithControlIf(nn.Cell):
412    def __init__(self, dense_shape):
413        super().__init__()
414        self.op1 = RowTensorValuesDouble()
415        self.op2 = RowTensorValuesAdd2()
416        self.dense_shape = dense_shape
417
418    def construct(self, a, b, indices, values):
419        x = RowTensor(indices, values, self.dense_shape)
420        if a > b:
421            x = self.op1(x)
422        else:
423            x = self.op2(x)
424        return x.indices, x.values
425
426
427def test_row_tensor_with_control_flow_if():
428    a = Tensor(np.array(0).astype(np.int32))
429    b = Tensor(np.array(2).astype(np.int32))
430    indices = Tensor(np.array([0, 2]).astype(np.int32))
431    values = Tensor(np.ones([2, 2]).astype(np.float32))
432    dense_shape = (5, 2)
433
434    net = RowTensorWithControlIf(dense_shape)
435    net(a, b, indices, values)
436
437
438class EmbeddingLookUpBnNet(nn.Cell):
439    def __init__(self, vocab_size, embedding_size, target='CPU'):
440        super().__init__()
441        self.embedding_lookup = nn.EmbeddingLookup(vocab_size, embedding_size, param_init='ones', target=target)
442        self.bn = nn.BatchNorm2d(num_features=3)
443        self.mul = P.Mul()
444        self.reshape = P.Reshape()
445        self.relu = nn.PReLU()
446
447    def construct(self, indices):
448        x = self.embedding_lookup(indices)
449        x = self.reshape(x, (2, 3, 2, 2))
450        x = self.relu(x)
451        x = self.bn(x)
452        return x
453
454
455def test_embedding_lookup_with_mix_precision():
456    data = Tensor(np.array([0, 1, 2]).astype(np.int32))
457    label = Tensor(np.random.randn(*(2, 3, 2, 2)).astype(np.float32))
458    net = EmbeddingLookUpBnNet(8, 8, target='CPU')
459
460    criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
461    optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1)
462    optimizer.target = 'CPU'
463    train_network = ms.amp.build_train_network(net, optimizer, criterion, level="O2")
464    train_network.set_train()
465    for _ in range(2):
466        train_network(data, label)
467