• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Tensor, Parameter
20from mindspore import context
21from mindspore.common.api import _cell_graph_executor
22from mindspore.ops import composite as C
23from mindspore.ops import operations as P
24from tests.ut.python.ops.test_math_ops import VirtualLoss
25
26
27grad_all = C.GradOperation(get_all=True)
28
29
30class NetWithLoss(nn.Cell):
31    def __init__(self, network):
32        super(NetWithLoss, self).__init__()
33        self.loss = VirtualLoss()
34        self.network = network
35
36    def construct(self, x, y, b):
37        predict = self.network(x, y, b)
38        return self.loss(predict)
39
40
41class GradWrap(nn.Cell):
42    def __init__(self, network):
43        super(GradWrap, self).__init__()
44        self.network = network
45
46    def construct(self, x, y, b):
47        return grad_all(self.network)(x, y, b)
48
49
50def compile_net(net, x, y, b):
51    net.set_auto_parallel()
52    net.set_train()
53    _cell_graph_executor.compile(net, x, y, b)
54
55
56def test_rhombus1():
57    class Net(nn.Cell):
58        def __init__(self):
59            super().__init__()
60            self.matmul = P.MatMul()
61            self.tadd1 = P.Add()
62            self.tadd2 = P.Add()
63            self.weight = Parameter(Tensor(np.ones([128, 128]).astype(np.float32) * 0.01), "w", requires_grad=True)
64
65        def construct(self, x, y, z):
66            mm_out = self.matmul(x, self.weight)
67            ta1_out = self.tadd1(y, z)
68            out = self.tadd2(ta1_out, mm_out)
69            return out
70
71    size = 16
72    context.set_auto_parallel_context(device_num=size, global_rank=0)
73    x = Tensor(np.ones([128, 128]), dtype=ms.float32)
74    y = Tensor(np.ones([128, 128]), dtype=ms.float32)
75    b = Tensor(np.ones([128, 128]), dtype=ms.float32)
76
77    net = GradWrap(NetWithLoss(Net()))
78    context.set_auto_parallel_context(parallel_mode="auto_parallel")
79    compile_net(net, x, y, b)
80
81
82def test_rhombus2():
83    class Net(nn.Cell):
84        def __init__(self):
85            super().__init__()
86            self.matmul1 = P.MatMul()
87            self.matmul2 = P.MatMul()
88            self.tadd1 = P.Add()
89            self.tadd2 = P.Add()
90            self.tadd3 = P.Add()
91            self.weight1 = Parameter(Tensor(np.ones([128, 128]).astype(np.float32) * 0.01), "w", requires_grad=True)
92            self.weight2 = Parameter(Tensor(np.ones([128, 128]).astype(np.float32) * 0.01), "w", requires_grad=True)
93
94        def construct(self, x, y, z):
95            mm1_out = self.matmul1(x, self.weight1)
96            ta1_out = self.tadd1(y, z)
97            ta2_out = self.tadd2(mm1_out, ta1_out)
98            mm2_out = self.matmul2(ta1_out, self.weight2)
99            ta3_out = self.tadd3(ta2_out, mm2_out)
100            return ta3_out
101
102    size = 16
103    context.set_auto_parallel_context(device_num=size, global_rank=0)
104    x = Tensor(np.ones([128, 128]), dtype=ms.float32)
105    y = Tensor(np.ones([128, 128]), dtype=ms.float32)
106    b = Tensor(np.ones([128, 128]), dtype=ms.float32)
107
108    net = GradWrap(NetWithLoss(Net()))
109    context.set_auto_parallel_context(parallel_mode="auto_parallel")
110    compile_net(net, x, y, b)
111
112
113def test_rhombus3():
114    class Net(nn.Cell):
115        def __init__(self):
116            super().__init__()
117            self.matmul1 = P.MatMul()
118            self.tadd1 = P.Add()
119            self.tadd2 = P.Add()
120            self.tadd3 = P.Add()
121            self.tadd4 = P.Add()
122            self.weight1 = Parameter(Tensor(np.ones([128, 128]).astype(np.float32) * 0.01), "w", requires_grad=True)
123            self.t = Tensor(np.ones([128, 128]).astype(np.float32) * 0.01)
124
125        def construct(self, x, y, z):
126            mm1_out = self.matmul1(x, self.weight1)
127            ta1_out = self.tadd1(y, z)
128            ta2_out = self.tadd2(mm1_out, ta1_out)
129            ta3_out = self.tadd3(ta1_out, self.t)
130            ta4_out = self.tadd4(ta2_out, ta3_out)
131            return ta4_out
132
133    size = 16
134    context.set_auto_parallel_context(device_num=size, global_rank=0)
135    x = Tensor(np.ones([128, 128]), dtype=ms.float32)
136    y = Tensor(np.ones([128, 128]), dtype=ms.float32)
137    z = Tensor(np.ones([128, 128]), dtype=ms.float32)
138
139    net = GradWrap(NetWithLoss(Net()))
140    context.set_auto_parallel_context(parallel_mode="auto_parallel")
141    compile_net(net, x, y, z)
142