• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_operator """
16import numpy as np
17
18from mindspore import Tensor, Model, context
19from mindspore.nn import Cell
20from mindspore.nn import ReLU
21from mindspore.ops import operations as P
22from ...ut_filter import non_graph_engine
23
24
25class arithmetic_Net(Cell):
26    """ arithmetic_Net definition """
27
28    def __init__(self, symbol, loop_count=(1, 3)):
29        super().__init__()
30        self.symbol = symbol
31        self.loop_count = loop_count
32        self.relu = ReLU()
33
34    def construct(self, x):
35        a, b = self.loop_count
36        y = self.symbol
37        if y == 1:
38            a += b
39            for _ in (b, a):
40                x = self.relu(x)
41        elif y == 2:
42            b -= a
43            for _ in (a, b):
44                x = self.relu(x)
45        elif y == 3:
46            z = a + b
47            for _ in (b, z):
48                x = self.relu(x)
49        elif y == 4:
50            z = b - a
51            for _ in (z, b):
52                x = self.relu(x)
53        elif y == 5:
54            z = a * b
55            for _ in (a, z):
56                x = self.relu(x)
57        elif y == 6:
58            z = b / a
59            for _ in (a, z):
60                x = self.relu(x)
61        elif y == 7:
62            z = b % a + 1
63            for _ in (a, z):
64                x = self.relu(x)
65        else:
66            if not a:
67                x = self.relu(x)
68        return x
69
70
71class logical_Net(Cell):
72    """ logical_Net definition """
73
74    def __init__(self, symbol, loop_count=(1, 3)):
75        super().__init__()
76        self.symbol = symbol
77        self.loop_count = loop_count
78        self.fla = P.Flatten()
79        self.relu = ReLU()
80
81    def construct(self, x):
82        a, b = self.loop_count
83        y = self.symbol
84        if y == 1:
85            if b and a:
86                x = self.relu(x)
87            else:
88                x = self.fla(x)
89        else:
90            if b or a:
91                x = self.relu(x)
92            else:
93                x = self.fla(x)
94        return x
95
96
97def arithmetic_operator_base(symbol):
98    """ arithmetic_operator_base """
99    input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
100    input_me = Tensor(input_np)
101    logical_operator = {"++": 1, "--": 2, "+": 3, "-": 4, "*": 5, "/": 6, "%": 7, "not": 8}
102    x = logical_operator[symbol]
103    net = arithmetic_Net(x)
104    context.set_context(mode=context.GRAPH_MODE)
105    model = Model(net)
106    model.predict(input_me)
107
108
109def logical_operator_base(symbol):
110    """ logical_operator_base """
111    input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
112    input_me = Tensor(input_np)
113    logical_operator = {"and": 1, "or": 2}
114    x = logical_operator[symbol]
115    net = logical_Net(x)
116    context.set_context(mode=context.GRAPH_MODE)
117    model = Model(net)
118    model.predict(input_me)
119
120
121@non_graph_engine
122def test_ME_arithmetic_operator_0080():
123    """ test_ME_arithmetic_operator_0080 """
124    arithmetic_operator_base('not')
125
126
127@non_graph_engine
128def test_ME_arithmetic_operator_0070():
129    """ test_ME_arithmetic_operator_0070 """
130    logical_operator_base('and')
131
132
133@non_graph_engine
134def test_ME_logical_operator_0020():
135    """ test_ME_logical_operator_0020 """
136    logical_operator_base('or')
137
138
139def test_ops():
140    class OpsNet(Cell):
141        """ OpsNet definition """
142
143        def __init__(self, x, y):
144            super(OpsNet, self).__init__()
145            self.x = x
146            self.y = y
147            self.int = 4
148            self.float = 3.2
149            self.str_a = "hello"
150            self.str_b = "world"
151
152        def construct(self, x, y):
153            h = x // y
154            m = x ** y
155            n = x % y
156            r = self.x // self.y
157            s = self.x ** self.y
158            t = self.x % self.y
159            p = h + m + n
160            q = r + s + t
161            ret_pow = p ** q + q ** p
162            ret_mod = p % q + q % p
163            ret_floor = p // q + q // p
164            ret = ret_pow + ret_mod + ret_floor
165            if self.int > self.float:
166                if [1, 2, 3] is not None:
167                    if self.str_a + self.str_b == "helloworld":
168                        if q == 86:
169                            return ret
170            return x
171
172    net = OpsNet(9, 2)
173    x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32))
174    y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32))
175    context.set_context(mode=context.GRAPH_MODE)
176    net(x, y)
177
178
179def test_in_dict():
180    class InDictNet(Cell):
181        """ InDictNet definition """
182
183        def __init__(self, key_in, key_not_in):
184            super(InDictNet, self).__init__()
185            self.key_in = key_in
186            self.key_not_in = key_not_in
187
188        def construct(self, x, y, z):
189            d = {"a": x, "b": y}
190            ret_in = 1
191            ret_not_in = 2
192            if self.key_in in d:
193                ret_in = d[self.key_in]
194            if self.key_not_in not in d:
195                ret_not_in = z
196            ret = ret_in + ret_not_in
197            return ret
198
199    net = InDictNet("a", "c")
200    x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32))
201    y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32))
202    z = Tensor(np.random.randint(low=20, high=30, size=(2, 3, 4), dtype=np.int32))
203    context.set_context(mode=context.GRAPH_MODE)
204    net(x, y, z)
205