1# Copyright 2019-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore import context 21from mindspore.common.api import _cell_graph_executor 22from mindspore.ops import composite as C 23from mindspore.ops import operations as P 24from tests.ut.python.ops.test_math_ops import VirtualLoss 25 26 27def setup_function(): 28 context.set_auto_parallel_context(dataset_strategy="full_batch") 29 30 31grad_all = C.GradOperation(get_all=True) 32 33 34class NetWithLoss(nn.Cell): 35 def __init__(self, network): 36 super(NetWithLoss, self).__init__() 37 self.loss = VirtualLoss() 38 self.network = network 39 40 def construct(self, x, w1, w2): 41 predict = self.network(x, w1, w2) 42 return self.loss(predict) 43 44 45class GradWrap(nn.Cell): 46 def __init__(self, network): 47 super(GradWrap, self).__init__() 48 self.network = network 49 50 def construct(self, x, w1, w2): 51 return grad_all(self.network)(x, w1, w2) 52 53 54class NetConv(nn.Cell): 55 def __init__(self, 56 cin, 57 cout, 58 kernel_size, 59 stride=1, 60 pad_mode='pad', 61 padding=0, 62 dilation=1, 63 group=1, 64 has_bias=False, 65 weight_init='normal', 66 bias_init='zeros', 67 strategy=None): 68 super(NetConv, self).__init__() 69 self.conv = nn.Conv2d(cin, 70 cout, 71 kernel_size, 72 stride, 73 pad_mode, 74 padding, 75 dilation, 76 group, 77 has_bias, 78 weight_init, 79 bias_init) 80 self.conv.conv2d.shard(strategy) 81 82 def construct(self, input_x): 83 return self.conv(input_x) 84 85 86class Net(nn.Cell): 87 def __init__(self, strategy1, strategy2, strategy3): 88 super().__init__() 89 self.conv1 = NetConv(16, 8, (3, 3), bias_init='zeros', strategy=strategy1) 90 self.mul1 = P.Mul().shard(strategy2) 91 self.conv2 = NetConv(8, 64, (9, 9), bias_init='zeros', strategy=strategy1) 92 self.mul2 = P.Mul().shard(strategy3) 93 94 def construct(self, x, w1, w2): 95 out1 = self.conv1(x) 96 out2 = self.mul1(out1, w1) 97 out3 = self.conv2(out2) 98 out4 = self.mul2(out3, w2) 99 return out4 100 101 102def test_batch(): 103 """ 104 Feature: Batch parallel 105 Description: test batch parallel 106 Expectation: compile ok 107 """ 108 context.set_auto_parallel_context(device_num=8, global_rank=0) 109 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 110 strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) 111 strategy2 = ((1, 1, 1, 8), (1, 1, 1, 8)) 112 strategy3 = ((4, 1, 1, 2), (4, 1, 1, 2)) 113 114 net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3))) 115 116 x = Tensor(np.ones([128, 16, 34, 34]), dtype=ms.float32) 117 w1 = Tensor(np.ones([128, 8, 32, 32]), dtype=ms.float32) 118 w2 = Tensor(np.ones([128, 64, 24, 24]), dtype=ms.float32) 119 net.set_train() 120 _cell_graph_executor.compile(net, x, w1, w2) 121 122 123def test_batch_shape_less_than_devices(): 124 """ 125 Feature: Batch parallel 126 Description: test batch parallel, shapes less than device nums. 127 Expectation: compile ok 128 """ 129 context.set_auto_parallel_context(device_num=512, global_rank=0) 130 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 131 strategy1 = None 132 strategy2 = None 133 strategy3 = None 134 135 net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3))) 136 137 x = Tensor(np.ones([128, 16, 34, 34]), dtype=ms.float32) 138 w1 = Tensor(np.ones([128, 8, 32, 32]), dtype=ms.float32) 139 w2 = Tensor(np.ones([128, 64, 24, 24]), dtype=ms.float32) 140 net.set_train() 141 _cell_graph_executor.compile(net, x, w1, w2) 142