• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14'''ResizeBilinear and ResizeNearestNeigbor ut'''
15import numpy as np
16
17import mindspore as ms
18from mindspore import context, Tensor, Parameter
19from mindspore.common.api import _cell_graph_executor
20from mindspore.nn import Cell, TrainOneStepCell, Momentum
21from mindspore.ops import operations as P
22
23
24class Net(Cell):
25    '''
26    create the test Net
27    '''
28    def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
29                 strategy1=None, strategy2=None):
30        super(Net, self).__init__()
31        self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
32                               pad_mode=pad_mode, stride=stride).shard(strategy1)
33        self.conv2d_weight = Parameter(conv2d_weight, "w1")
34        self.resize_bilinear = P.ResizeBilinear((16, 16)).shard(strategy2)
35
36    def construct(self, x):
37        out = self.conv2d(x, self.conv2d_weight)
38        out = self.resize_bilinear(out)
39        return out
40
41
42class Net2(Cell):
43    '''
44    create the test Net
45    '''
46    def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
47                 strategy1=None, strategy2=None):
48        super(Net2, self).__init__()
49        self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
50                               pad_mode=pad_mode, stride=stride).shard(strategy1)
51        self.conv2d_weight = Parameter(conv2d_weight, "w1")
52        self.resize_neighbor = P.ResizeNearestNeighbor((16, 16)).shard(strategy2)
53
54    def construct(self, x):
55        out = self.conv2d(x, self.conv2d_weight)
56        out = self.resize_neighbor(out)
57        return out
58
59class Net3(Cell):
60    '''
61    create the test Net
62    '''
63    def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
64                 strategy1=None):
65        super(Net3, self).__init__()
66        self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
67                               pad_mode=pad_mode, stride=stride).shard(strategy1)
68        self.conv2d_weight = Parameter(conv2d_weight, "w1")
69        self.resize_bilinear = P.ResizeBilinear((16, 16))
70
71    def construct(self, x):
72        out = self.conv2d(x, self.conv2d_weight)
73        out = self.resize_bilinear(out)
74        return out
75
76
77_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
78_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
79
80
81def compile_net(net, inputs=_x):
82    optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
83    train_net = TrainOneStepCell(net, optimizer)
84    train_net.set_auto_parallel()
85    train_net.set_train()
86    _cell_graph_executor.compile(train_net, inputs)
87    context.reset_auto_parallel_context()
88
89
90def test_bililear_data_parallel():
91    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
92    strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
93    strategy2 = ((8, 1, 1, 1),)
94    net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
95              strategy1=strategy1, strategy2=strategy2)
96    compile_net(net)
97
98
99def test_bilinear_model_parallel1():
100    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
101    strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
102    strategy2 = ((4, 2, 1, 1),)
103    net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
104              strategy1=strategy1, strategy2=strategy2)
105    compile_net(net)
106
107
108def test_bilinear_model_parallel2():
109    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
110    strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
111    strategy2 = ((2, 1, 1, 1),)
112    net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
113              strategy1=strategy1, strategy2=strategy2)
114    compile_net(net)
115
116
117def test_bilinear_auto_parallel():
118    context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
119    net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
120    compile_net(net)
121
122
123def test_bilinear_no_strategy():
124    context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
125    net = Net3(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
126    compile_net(net)
127
128
129def test_neighbor_data_parallel():
130    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
131    strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
132    strategy2 = ((8, 1, 1, 1),)
133    net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
134               strategy1=strategy1, strategy2=strategy2)
135    compile_net(net)
136
137
138def test_neighbor_model_parallel1():
139    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
140    strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
141    strategy2 = ((4, 2, 1, 1),)
142    net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
143               strategy1=strategy1, strategy2=strategy2)
144    compile_net(net)
145
146
147def test_neighbor_auto_parallel():
148    context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
149    net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
150    compile_net(net)
151