• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test control ops """
16import functools
17import numpy as np
18
19from mindspore import Tensor
20from mindspore import context
21from mindspore import nn
22from mindspore.common import dtype as mstype
23from mindspore.ops import operations as P
24from ....mindspore_test_framework.mindspore_test import mindspore_test
25from ....mindspore_test_framework.pipeline.forward.compile_forward \
26    import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
27
28context.set_context(mode=context.GRAPH_MODE)
29
30
31class ComparisonOpsNet(nn.Cell):
32    def __init__(self):
33        super(ComparisonOpsNet, self).__init__()
34
35    def construct(self, x, y):
36        a = x <= y
37        b = x <= 1.0
38        c = y >= 1.0
39        d = y >= x
40        e = x < y
41        f = x < 1.0
42        g = y < 1.0
43        h = y > x
44        i = y == 3.0
45        j = x != 4
46        k = + x
47        l = + 1.0
48        m = k != l
49        return a or b or c or d or e or f or g or h or i or j or m
50
51
52class MathOpsNet(nn.Cell):
53    def __init__(self):
54        super(MathOpsNet, self).__init__()
55        self.relu = P.ReLU()
56
57    def construct(self, x, y):
58        x = x - (-1)
59        return self.relu(x)
60
61
62class ScalarCompareNet(nn.Cell):
63    def __init__(self):
64        super(ScalarCompareNet, self).__init__()
65        self.relu = P.ReLU()
66
67    def construct(self, x, y):
68        t = 0
69        if 3 > 3.2:
70            t = x + y
71        else:
72            t = x - y
73        if 3.1 <= 5:
74            t = t - x
75        else:
76            t = t + x
77        a = 32.0 * 12
78        b = 12 / 3.0
79        if a > b:
80            t = t * x
81        else:
82            t = t / x
83        return t
84
85
86class LogicalNumberOpsNet(nn.Cell):
87    def __init__(self):
88        super(LogicalNumberOpsNet, self).__init__()
89        self.cond = True
90        self.one = 0
91        self.zero = 0.0
92
93    def construct(self, x, y):
94        if self.cond and self.one or self.zero and not self.one:
95            return x + y
96        return x - y
97
98
99class LogicalTensorOpsNet(nn.Cell):
100    def __init__(self):
101        """"""
102        super(LogicalTensorOpsNet, self).__init__()
103        self.const_true = Tensor(True, dtype=mstype.bool_)
104
105    def construct(self, x, y):
106        ret = x and y and (y or self.const_true) and (not y)
107        return ret
108
109
110test_case_ops = [
111    ('CompareOpsNet', {
112        'block': ComparisonOpsNet(),
113        'desc_inputs': [Tensor(1.0, dtype=mstype.float32),
114                        Tensor(1.0, dtype=mstype.float32)]}),
115    ('MathOpsNet', {
116        'block': MathOpsNet(),
117        'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32),
118                        Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}),
119    ('ScalarCompareNet', {
120        'block': ScalarCompareNet(),
121        'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32),
122                        Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}),
123    ('LogicalNumberOps', {
124        'block': LogicalNumberOpsNet(),
125        'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32),
126                        Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}),
127    ('LogicalTensorOps', {
128        'block': LogicalTensorOpsNet(),
129        'desc_inputs': [Tensor(True, dtype=mstype.bool_),
130                        Tensor(False, dtype=mstype.bool_)]}),
131]
132
133test_case_lists = [test_case_ops]
134test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
135
136
137@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
138def test_compile():
139    return test_exec_case
140