1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test control ops """ 16import functools 17import numpy as np 18 19from mindspore import Tensor 20from mindspore import context 21from mindspore import nn 22from mindspore.common import dtype as mstype 23from mindspore.ops import operations as P 24from ....mindspore_test_framework.mindspore_test import mindspore_test 25from ....mindspore_test_framework.pipeline.forward.compile_forward \ 26 import pipeline_for_compile_forward_ge_graph_for_case_by_case_config 27 28context.set_context(mode=context.GRAPH_MODE) 29 30 31class ComparisonOpsNet(nn.Cell): 32 def __init__(self): 33 super(ComparisonOpsNet, self).__init__() 34 35 def construct(self, x, y): 36 a = x <= y 37 b = x <= 1.0 38 c = y >= 1.0 39 d = y >= x 40 e = x < y 41 f = x < 1.0 42 g = y < 1.0 43 h = y > x 44 i = y == 3.0 45 j = x != 4 46 k = + x 47 l = + 1.0 48 m = k != l 49 return a or b or c or d or e or f or g or h or i or j or m 50 51 52class MathOpsNet(nn.Cell): 53 def __init__(self): 54 super(MathOpsNet, self).__init__() 55 self.relu = P.ReLU() 56 57 def construct(self, x, y): 58 x = x - (-1) 59 return self.relu(x) 60 61 62class ScalarCompareNet(nn.Cell): 63 def __init__(self): 64 super(ScalarCompareNet, self).__init__() 65 self.relu = P.ReLU() 66 67 def construct(self, x, y): 68 t = 0 69 if 3 > 3.2: 70 t = x + y 71 else: 72 t = x - y 73 if 3.1 <= 5: 74 t = t - x 75 else: 76 t = t + x 77 a = 32.0 * 12 78 b = 12 / 3.0 79 if a > b: 80 t = t * x 81 else: 82 t = t / x 83 return t 84 85 86class LogicalNumberOpsNet(nn.Cell): 87 def __init__(self): 88 super(LogicalNumberOpsNet, self).__init__() 89 self.cond = True 90 self.one = 0 91 self.zero = 0.0 92 93 def construct(self, x, y): 94 if self.cond and self.one or self.zero and not self.one: 95 return x + y 96 return x - y 97 98 99class LogicalTensorOpsNet(nn.Cell): 100 def __init__(self): 101 """""" 102 super(LogicalTensorOpsNet, self).__init__() 103 self.const_true = Tensor(True, dtype=mstype.bool_) 104 105 def construct(self, x, y): 106 ret = x and y and (y or self.const_true) and (not y) 107 return ret 108 109 110test_case_ops = [ 111 ('CompareOpsNet', { 112 'block': ComparisonOpsNet(), 113 'desc_inputs': [Tensor(1.0, dtype=mstype.float32), 114 Tensor(1.0, dtype=mstype.float32)]}), 115 ('MathOpsNet', { 116 'block': MathOpsNet(), 117 'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32), 118 Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}), 119 ('ScalarCompareNet', { 120 'block': ScalarCompareNet(), 121 'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32), 122 Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}), 123 ('LogicalNumberOps', { 124 'block': LogicalNumberOpsNet(), 125 'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32), 126 Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}), 127 ('LogicalTensorOps', { 128 'block': LogicalTensorOpsNet(), 129 'desc_inputs': [Tensor(True, dtype=mstype.bool_), 130 Tensor(False, dtype=mstype.bool_)]}), 131] 132 133test_case_lists = [test_case_ops] 134test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists) 135 136 137@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) 138def test_compile(): 139 return test_exec_case 140