• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_hypermap """
16import numpy as np
17
18from mindspore import Tensor
19from mindspore.common.api import ms_function
20from mindspore.ops import Primitive
21from mindspore.ops import _constants
22from mindspore.ops import composite as C
23from mindspore.ops import functional as F
24from mindspore.ops import operations as P
25from ...ut_filter import non_graph_engine
26
27# pylint: disable=W0613
28# W0613: unused-argument
29
30
31tensor_add = P.Add()
32scala_add = Primitive(_constants.kScalarAdd)
33add = C.MultitypeFuncGraph('add')
34
35
36@add.register("Number", "Number")
37def add_scala(x, y):
38    return scala_add(x, y)
39
40
41@add.register("Tensor", "Tensor")
42def add_tensor(x, y):
43    return tensor_add(x, y)
44
45
46hyper_add = C.HyperMap(add)
47
48
49@ms_function
50def mainf(x, y):
51    return hyper_add(x, y)
52
53
54@non_graph_engine
55def test_hypermap_tensor():
56    tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
57    tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
58    print("test_hypermap_tensor:", mainf(tensor1, tensor2))
59
60
61def test_hypermap_scalar():
62    print("test_hypermap_scalar", mainf(1, 2))
63
64
65def test_hypermap_tuple():
66    print("test_hypermap_tuple", mainf((1, 1), (2, 2)))
67
68
69@non_graph_engine
70def test_hypermap_tuple_tensor():
71    tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
72    tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
73    print("test_hypermap_tuple_tensor", mainf((tensor1, tensor1), (tensor2, tensor2)))
74
75
76@non_graph_engine
77def test_hypermap_tuple_mix():
78    tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
79    tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
80    print("test_hypermap_tuple_mix", mainf((tensor1, 1), (tensor2, 2)))
81
82
83hyper_map = C.HyperMap()
84
85
86@ms_function
87def main_noleaf(x, y):
88    return hyper_map(add, x, y)
89
90
91def test_hypermap_noleaf_scalar():
92    main_noleaf(1, 2)
93
94
95@non_graph_engine
96def test_hypermap_noleaf_tensor():
97    tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
98    tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
99    main_noleaf(tensor1, tensor2)
100
101
102def test_hypermap_noleaf_tuple():
103    main_noleaf((1, 1), (2, 2))
104
105
106@non_graph_engine
107def test_hypermap_noleaf_tuple_tensor():
108    tensor1 = Tensor(np.array([[1.1, 2.1], [2.1, 3.1]]).astype('float32'))
109    tensor2 = Tensor(np.array([[1.2, 2.2], [2.2, 3.2]]).astype('float32'))
110    tensor3 = Tensor(np.array([[2.2], [3.2]]).astype('float32'))
111    tensor4 = Tensor(np.array([[2.2], [3.2]]).astype('float32'))
112    main_noleaf((tensor1, tensor3), (tensor2, tensor4))
113
114
115def test_hypermap_noleaf_tuple_mix():
116    tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
117    tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
118    main_noleaf((tensor1, 1), (tensor2, 2))
119
120
121def add3_scalar(x, y, z):
122    return scala_add(scala_add(x, y), z)
123
124
125@ms_function
126def main_add3_easy(x, y):
127    add2 = F.partial(add3_scalar, 1)
128    return add2(x, y)
129
130
131def test_hypermap_add3_easy():
132    main_add3_easy(1, 2)
133
134
135add3 = C.MultitypeFuncGraph('add')
136partial = P.Partial()
137
138
139@add3.register("Number", "Number", "Number")
140def add3_scala(x, y, z):
141    return scala_add(scala_add(x, y), z)
142
143
144@add3.register("Number", "Tensor", "Tensor")
145def add3_tensor(x, y, z):
146    return tensor_add(y, z)
147
148
149@ms_function
150def main_add3_scala(x, y):
151    add2 = partial(add3_scala, 1)
152    return hyper_map(add2, x, y)
153
154
155@ms_function
156def main_add3(x, y):
157    add2 = partial(add3, 1)
158    return hyper_map(add2, x, y)
159
160
161@non_graph_engine
162def test_hypermap_add3_tensor():
163    tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
164    tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
165    main_add3(tensor1, tensor2)
166
167
168def test_hypermap_add3_tuple():
169    tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
170    tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
171
172    main_add3((tensor1, 1), (tensor2, 1))
173