• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_cont_break """
16import numpy as np
17
18import mindspore as ms
19from mindspore import Tensor, context, nn, ms_function
20from mindspore.nn import Cell
21from mindspore.ops import operations as P
22
23
24class WhileSubGraphParam(Cell):
25    def __init__(self):
26        super().__init__()
27        self.update = ms.Parameter(Tensor(1, ms.float32), "update")
28
29    def construct(self, x, y, z):
30        out1 = z
31        while x < y:
32            self.update = self.update + 1
33            out1 = out1 + 1
34            x = x + 1
35        return out1, self.update
36
37
38def test_while_loop_phi():
39    context.set_context(mode=context.GRAPH_MODE)
40    x = Tensor(0, ms.float32)
41    y = Tensor(10, ms.float32)
42    z = Tensor(100, ms.float32)
43
44    net = WhileSubGraphParam()
45    net(x, y, z)
46
47class WhileSubGraphParam2(Cell):
48    def __init__(self):
49        super().__init__()
50        self.update = ms.Parameter(Tensor(1, ms.float32), "update")
51
52    def construct(self, x, y, z):
53        out1 = z
54        i = self.update
55        while x < y:
56            i = i + 1
57            out1 = out1 + 1
58            x = x + 1
59        return out1, self.update
60
61
62def test_while_loop_phi_2():
63    context.set_context(mode=context.GRAPH_MODE)
64    x = Tensor(0, ms.float32)
65    y = Tensor(10, ms.float32)
66    z = Tensor(100, ms.float32)
67
68    net = WhileSubGraphParam2()
69    net(x, y, z)
70
71
72class WhileSubGraphParam3(Cell):
73    def __init__(self, initial_input_x):
74        super().__init__()
75        self.initial_input_x = initial_input_x
76        self.X = ms.Parameter(initial_input_x, name="parameter_x")
77        self.Y = ms.Parameter(self.initial_input_x, name="parameter_y")
78
79    def construct(self):
80        a = 0
81        while a < 3:
82            self.X = self.X + self.Y
83            a += 1
84        return self.X
85
86
87def test_while_loop_phi_3():
88    context.set_context(mode=context.GRAPH_MODE)
89    x = Tensor(0, ms.float32)
90
91    net = WhileSubGraphParam3(x)
92    net()
93
94class ControlMixedWhileIf(nn.Cell):
95    def __init__(self):
96        super().__init__()
97        self.assign = P.Assign()
98        self.var = ms.Parameter(ms.Tensor([1], ms.float32), name="var")
99
100    @ms_function
101    def construct(self, x, y, z, c2, c4):
102        out = self.assign(self.var, c4)
103        while x < c2:
104            y = self.assign(self.var, c4)
105            while y < c2 and x < c2:
106                if 2 * y < c2:
107                    y = y + 2
108                else:
109                    y = y + 1
110            out = out + y
111            z = self.assign(self.var, c4)
112            while z < c2:
113                z = z + 1
114            out = out + z
115            x = x + 1
116        out = out + x
117        while x < 2 * c2:
118            y = self.assign(self.var, c4)
119            x = x + 1
120            while y < c2:
121                z = self.assign(self.var, c4)
122                while z < c2:
123                    z = z + 1
124                if x < c2:
125                    y = y - 1
126                else:
127                    y = y + 1
128                out = out + z
129            out = out + y
130        out = out + x
131        return out
132
133def test_mixed_while_if():
134    context.set_context(mode=context.PYNATIVE_MODE)
135    x = np.array(2).astype(np.int32)
136    y = np.array(14).astype(np.int32)
137    z = np.array(1).astype(np.int32)
138    c2 = Tensor([14], ms.int32)
139    c4 = Tensor([0], ms.int32)
140    net = ControlMixedWhileIf()
141    output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4)
142    expect = np.array(3318).astype(np.int32)
143    assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
144    context.set_context(mode=context.GRAPH_MODE)
145