1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15from tests.st.control.cases_register import case_register 16import mindspore.context as context 17from mindspore import Tensor, jit 18from mindspore.common import dtype as mstype 19from mindspore.ops import composite as C 20from mindspore import mutable 21import numpy as np 22 23grad_by_list = C.GradOperation(get_by_list=True) 24grad_all = C.GradOperation(get_all=True) 25 26 27def no_inline(func): 28 func.no_inline = True 29 return func 30 31 32def while_loop(cond_fun, body_fun, init_val): 33 val = init_val 34 val = mutable(val) 35 while cond_fun(val): 36 val = body_fun(val) 37 return val 38 39 40def while_cond(val): 41 return val < 10 42 43 44def while_body_fun(val): 45 val = val * 3 - 1 46 return val 47 48 49@jit 50def call_while_loop(x): 51 val = while_loop(while_cond, while_body_fun, x) 52 return val 53 54 55@jit 56def grad_while_loop(x): 57 x = grad_all(call_while_loop)(x) 58 return x 59 60 61def fori_loop(lower, upper, body_fun, init_val): 62 val = init_val 63 upper = mutable(upper) 64 for i in range(lower, upper): 65 val = body_fun(i, val) 66 return val 67 68 69def scan(f, init, xs, length=None): 70 if xs is None: 71 xs = [None] * length 72 carry = init 73 ys = [] 74 for x in xs: 75 x = mutable(x) 76 carry, y = f(carry, x) # carry is the carryover 77 ys.append(y) # the `y`s get accumulated into a stacked array 78 return carry, ys 79 80 81def cumsum(res, el): 82 res = res + el 83 return res, res # ("carryover", "accumulated") 84 85 86@jit 87def call_scan(a): 88 result_init = 0 89 return scan(cumsum, result_init, a) 90 91 92def for_body_fun(i, val): 93 x = i * 3 94 x = x * val * val 95 return x 96 97 98@jit 99def call_fori_loop(x): 100 x = fori_loop(1, 100, for_body_fun, x) 101 return x 102 103 104@jit 105def grad_for_loop(x): 106 x = grad_all(call_fori_loop)(x) 107 return x 108 109 110@case_register.level1 111@case_register.target_ascend 112def test_grad_for_loop(): 113 """ 114 Feature: control flow function. 115 Description: test gad of for_loop. 116 Expectation: Null. 117 """ 118 context.set_context(mode=context.GRAPH_MODE, save_graphs=0, save_graphs_path="./ir") 119 x = Tensor([1], mstype.int32) 120 x = grad_for_loop(x) 121 print(x) 122 123 124@case_register.level1 125@case_register.target_ascend 126def test_fori_loop(): 127 """ 128 Feature: control flow function. 129 Description: test fori_loop. 130 Expectation: Null. 131 """ 132 context.set_context(mode=context.GRAPH_MODE, save_graphs=0, save_graphs_path="./ir") 133 x = Tensor([1], mstype.int32) 134 x = call_fori_loop(x) 135 print(x) 136 137 138@case_register.level1 139@case_register.target_ascend 140def test_scan(): 141 """ 142 Feature: control flow function. 143 Description: test scap. 144 Expectation: Null. 145 """ 146 context.set_context(mode=context.GRAPH_MODE, save_graphs=0, save_graphs_path="./ir") 147 x = np.array([1, 2, 3, 5, 7, 11, 13, 17]) 148 x, _ = call_scan(x) 149 print(x) 150 151 152@case_register.level1 153@case_register.target_ascend 154def test_while_loop(): 155 """ 156 Feature: control flow function. 157 Description: test while_loop. 158 Expectation: Null. 159 """ 160 context.set_context(mode=context.GRAPH_MODE, save_graphs=0, save_graphs_path="./ir") 161 x = Tensor([1], mstype.int32) 162 x = call_while_loop(x) 163 print(x) 164 165 166@case_register.level1 167@case_register.target_ascend 168def test_grad_while_loop(): 169 """ 170 Feature: control flow function. 171 Description: test grad of while_loop. 172 Expectation: Null. 173 """ 174 175 context.set_context(mode=context.GRAPH_MODE, save_graphs=0, save_graphs_path="./ir") 176 x = Tensor([1], mstype.int32) 177 x = grad_while_loop(x) 178 print(x) 179 180 181@no_inline 182def no_inline_fun(val): 183 x = val * 3 + 2 184 return x 185 186 187@jit 188def call_no_inline_fun(val): 189 for _ in range(100): 190 val = no_inline_fun(val) 191 return val 192 193 194@case_register.level1 195@case_register.target_ascend 196def test_no_inline_fun(): 197 """ 198 Feature: control flow function. 199 Description: test no inline function. 200 Expectation: Null. 201 """ 202 context.set_context(mode=context.GRAPH_MODE, save_graphs=0, save_graphs_path="./ir") 203 x = Tensor([1], mstype.int32) 204 x = call_no_inline_fun(x) 205 print(x) 206