1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import inspect 17import numpy as np 18import pytest 19from mindspore import context, ops, Tensor 20from mindspore.common import dtype as mstype 21from mindspore.nn import Cell 22 23 24class UserDefined(ops.PrimitiveWithInfer): 25 def __init__(self, func, shape, dtype, func_type=None): 26 ops.PrimitiveWithInfer.__init__(self, "UserDefined") 27 self.add_prim_attr('akg', True) 28 29 if "__wrapped__" in func.__dict__: 30 func = func.__dict__["__wrapped__"] 31 func_name = func.__name__ 32 self.add_prim_attr('func_name', func_name) 33 func_source_str = inspect.getsource(func) 34 35 if func_type is None: 36 if "ir_builder" in func_source_str: 37 func_type = "ir_builder" 38 elif "compute" in func_source_str: 39 func_type = "tvm_compute" 40 else: 41 func_type = "hybrid" 42 43 self.add_prim_attr('func_source_str', func_source_str) 44 self.add_prim_attr('func_type', func_type) 45 46 self._shape = shape 47 self._dtype = dtype 48 49 def infer_shape(self, *args): 50 if callable(self._shape): 51 return self._shape(*args) 52 return self._shape 53 54 def infer_dtype(self, *args): 55 if callable(self._dtype): 56 return self._dtype(*args) 57 return self._dtype 58 59 60def outer_product(a, b): 61 c = output_tensor((a.shape[0], b.shape[1]), 'float32') 62 63 for i0 in range(a.shape[0]): 64 for i1 in range(b.shape[1]): 65 c[i0, i1] = 0.0 66 for i2 in range(a.shape[1]): 67 c[i0, i1] = c[i0, i1] + (a[i0, i2] * b[i2, i1]) 68 return c 69 70 71class TestHybrid(Cell): 72 def __init__(self): 73 super(TestHybrid, self).__init__() 74 75 def infer_func(x, y): 76 return x 77 78 self.program = UserDefined( 79 outer_product, shape=infer_func, dtype=infer_func) 80 81 def construct(self, x, y): 82 return self.program(x, y) 83 84 85def v_add(inputs, attrs): 86 def vadd_func(dst, data_1, data_2): 87 ib = tvm.ir_builder.create() 88 with ib.for_range_n(data_1.shape, "i") as i: 89 ib.store(dst, i, ib.load(data_1, i) + ib.load(data_2, i)) 90 return ib.get() 91 data_1, data_2 = inputs[0], inputs[1] 92 return tvm.extern(data_1.shape, [data_1, data_2], 93 lambda ins, outs: vadd_func(outs[0], ins[0], ins[1]), 94 name="v_add", dtype=data_1.dtype) 95 96 97class TestIRbuilder(Cell): 98 def __init__(self, shape): 99 super(TestIRbuilder, self).__init__() 100 self.program = UserDefined( 101 v_add, shape=shape, dtype=mstype.float16) 102 103 def construct(self, x, y): 104 return self.program(x, y) 105 106 107def test_user_defined_hybrid(): 108 109 input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32) 110 input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32) 111 112 test = TestHybrid() 113 output = test(Tensor(input_x), Tensor(input_y)) 114 expect = np.matmul(input_x, input_y) 115 assert np.allclose(expect, output.asnumpy(), 0.001, 0.001) 116 117 118def test_user_defined_irbuider(): 119 120 shape = (4, 5) 121 input_x = np.random.normal(0, 1, shape).astype(np.float16) 122 input_y = np.random.normal(0, 1, shape).astype(np.float16) 123 124 test = TestIRbuilder(shape) 125 output = test(Tensor(input_x), Tensor(input_y)) 126 assert np.allclose(input_x + input_y, output.asnumpy(), 0.001, 0.001) 127 128 129@pytest.mark.level0 130@pytest.mark.platform_x86_gpu_training 131@pytest.mark.env_onecard 132def test_user_defined_gpu(): 133 context.set_context(mode=0, enable_graph_kernel=True) 134 test_user_defined_hybrid() 135 test_user_defined_irbuider() 136