• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15from tests.st.control.cases_register import case_register
16import mindspore.context as context
17from mindspore import Tensor, jit
18from mindspore.common import dtype as mstype
19from mindspore.ops import composite as C
20from mindspore import mutable
21import numpy as np
22
23grad_by_list = C.GradOperation(get_by_list=True)
24grad_all = C.GradOperation(get_all=True)
25
26
27def no_inline(func):
28    func.no_inline = True
29    return func
30
31
32def while_loop(cond_fun, body_fun, init_val):
33    val = init_val
34    val = mutable(val)
35    while cond_fun(val):
36        val = body_fun(val)
37    return val
38
39
40def while_cond(val):
41    return val < 10
42
43
44def while_body_fun(val):
45    val = val * 3 - 1
46    return val
47
48
49@jit
50def call_while_loop(x):
51    val = while_loop(while_cond, while_body_fun, x)
52    return val
53
54
55@jit
56def grad_while_loop(x):
57    x = grad_all(call_while_loop)(x)
58    return x
59
60
61def fori_loop(lower, upper, body_fun, init_val):
62    val = init_val
63    upper = mutable(upper)
64    for i in range(lower, upper):
65        val = body_fun(i, val)
66    return val
67
68
69def scan(f, init, xs, length=None):
70    if xs is None:
71        xs = [None] * length
72    carry = init
73    ys = []
74    for x in xs:
75        x = mutable(x)
76        carry, y = f(carry, x)  # carry is the carryover
77        ys.append(y)  # the `y`s get accumulated into a stacked array
78    return carry, ys
79
80
81def cumsum(res, el):
82    res = res + el
83    return res, res  # ("carryover", "accumulated")
84
85
86@jit
87def call_scan(a):
88    result_init = 0
89    return scan(cumsum, result_init, a)
90
91
92def for_body_fun(i, val):
93    x = i * 3
94    x = x * val * val
95    return x
96
97
98@jit
99def call_fori_loop(x):
100    x = fori_loop(1, 100, for_body_fun, x)
101    return x
102
103
104@jit
105def grad_for_loop(x):
106    x = grad_all(call_fori_loop)(x)
107    return x
108
109
110@case_register.level1
111@case_register.target_ascend
112def test_grad_for_loop():
113    """
114    Feature: control flow function.
115    Description: test gad of for_loop.
116    Expectation: Null.
117    """
118    context.set_context(mode=context.GRAPH_MODE, save_graphs=0, save_graphs_path="./ir")
119    x = Tensor([1], mstype.int32)
120    x = grad_for_loop(x)
121    print(x)
122
123
124@case_register.level1
125@case_register.target_ascend
126def test_fori_loop():
127    """
128    Feature: control flow function.
129    Description: test fori_loop.
130    Expectation: Null.
131    """
132    context.set_context(mode=context.GRAPH_MODE, save_graphs=0, save_graphs_path="./ir")
133    x = Tensor([1], mstype.int32)
134    x = call_fori_loop(x)
135    print(x)
136
137
138@case_register.level1
139@case_register.target_ascend
140def test_scan():
141    """
142    Feature: control flow function.
143    Description: test scap.
144    Expectation: Null.
145    """
146    context.set_context(mode=context.GRAPH_MODE, save_graphs=0, save_graphs_path="./ir")
147    x = np.array([1, 2, 3, 5, 7, 11, 13, 17])
148    x, _ = call_scan(x)
149    print(x)
150
151
152@case_register.level1
153@case_register.target_ascend
154def test_while_loop():
155    """
156    Feature: control flow function.
157    Description: test while_loop.
158    Expectation: Null.
159    """
160    context.set_context(mode=context.GRAPH_MODE, save_graphs=0, save_graphs_path="./ir")
161    x = Tensor([1], mstype.int32)
162    x = call_while_loop(x)
163    print(x)
164
165
166@case_register.level1
167@case_register.target_ascend
168def test_grad_while_loop():
169    """
170    Feature: control flow function.
171    Description: test grad of while_loop.
172    Expectation: Null.
173    """
174
175    context.set_context(mode=context.GRAPH_MODE, save_graphs=0, save_graphs_path="./ir")
176    x = Tensor([1], mstype.int32)
177    x = grad_while_loop(x)
178    print(x)
179
180
181@no_inline
182def no_inline_fun(val):
183    x = val * 3 + 2
184    return x
185
186
187@jit
188def call_no_inline_fun(val):
189    for _ in range(100):
190        val = no_inline_fun(val)
191    return val
192
193
194@case_register.level1
195@case_register.target_ascend
196def test_no_inline_fun():
197    """
198    Feature: control flow function.
199    Description: test no inline function.
200    Expectation: Null.
201    """
202    context.set_context(mode=context.GRAPH_MODE, save_graphs=0, save_graphs_path="./ir")
203    x = Tensor([1], mstype.int32)
204    x = call_no_inline_fun(x)
205    print(x)
206