• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import inspect
17import numpy as np
18import pytest
19from mindspore import context, ops, Tensor
20from mindspore.common import dtype as mstype
21from mindspore.nn import Cell
22
23
24class UserDefined(ops.PrimitiveWithInfer):
25    def __init__(self, func, shape, dtype, func_type=None):
26        ops.PrimitiveWithInfer.__init__(self, "UserDefined")
27        self.add_prim_attr('akg', True)
28
29        if "__wrapped__" in func.__dict__:
30            func = func.__dict__["__wrapped__"]
31        func_name = func.__name__
32        self.add_prim_attr('func_name', func_name)
33        func_source_str = inspect.getsource(func)
34
35        if func_type is None:
36            if "ir_builder" in func_source_str:
37                func_type = "ir_builder"
38            elif "compute" in func_source_str:
39                func_type = "tvm_compute"
40            else:
41                func_type = "hybrid"
42
43        self.add_prim_attr('func_source_str', func_source_str)
44        self.add_prim_attr('func_type', func_type)
45
46        self._shape = shape
47        self._dtype = dtype
48
49    def infer_shape(self, *args):
50        if callable(self._shape):
51            return self._shape(*args)
52        return self._shape
53
54    def infer_dtype(self, *args):
55        if callable(self._dtype):
56            return self._dtype(*args)
57        return self._dtype
58
59
60def outer_product(a, b):
61    c = output_tensor((a.shape[0], b.shape[1]), 'float32')
62
63    for i0 in range(a.shape[0]):
64        for i1 in range(b.shape[1]):
65            c[i0, i1] = 0.0
66            for i2 in range(a.shape[1]):
67                c[i0, i1] = c[i0, i1] + (a[i0, i2] * b[i2, i1])
68    return c
69
70
71class TestHybrid(Cell):
72    def __init__(self):
73        super(TestHybrid, self).__init__()
74
75        def infer_func(x, y):
76            return x
77
78        self.program = UserDefined(
79            outer_product, shape=infer_func, dtype=infer_func)
80
81    def construct(self, x, y):
82        return self.program(x, y)
83
84
85def v_add(inputs, attrs):
86    def vadd_func(dst, data_1, data_2):
87        ib = tvm.ir_builder.create()
88        with ib.for_range_n(data_1.shape, "i") as i:
89            ib.store(dst, i, ib.load(data_1, i) + ib.load(data_2, i))
90        return ib.get()
91    data_1, data_2 = inputs[0], inputs[1]
92    return tvm.extern(data_1.shape, [data_1, data_2],
93                      lambda ins, outs: vadd_func(outs[0], ins[0], ins[1]),
94                      name="v_add", dtype=data_1.dtype)
95
96
97class TestIRbuilder(Cell):
98    def __init__(self, shape):
99        super(TestIRbuilder, self).__init__()
100        self.program = UserDefined(
101            v_add, shape=shape, dtype=mstype.float16)
102
103    def construct(self, x, y):
104        return self.program(x, y)
105
106
107def test_user_defined_hybrid():
108
109    input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
110    input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)
111
112    test = TestHybrid()
113    output = test(Tensor(input_x), Tensor(input_y))
114    expect = np.matmul(input_x, input_y)
115    assert np.allclose(expect, output.asnumpy(), 0.001, 0.001)
116
117
118def test_user_defined_irbuider():
119
120    shape = (4, 5)
121    input_x = np.random.normal(0, 1, shape).astype(np.float16)
122    input_y = np.random.normal(0, 1, shape).astype(np.float16)
123
124    test = TestIRbuilder(shape)
125    output = test(Tensor(input_x), Tensor(input_y))
126    assert np.allclose(input_x + input_y, output.asnumpy(), 0.001, 0.001)
127
128
129@pytest.mark.level0
130@pytest.mark.platform_x86_gpu_training
131@pytest.mark.env_onecard
132def test_user_defined_gpu():
133    context.set_context(mode=0, enable_graph_kernel=True)
134    test_user_defined_hybrid()
135    test_user_defined_irbuider()
136