1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16import mindspore as ms 17import mindspore.context as context 18from mindspore import Tensor, Parameter 19import mindspore.nn as nn 20from mindspore.common.api import _cell_graph_executor 21from mindspore.nn import TrainOneStepCell, Momentum 22from mindspore.ops import operations as P 23 24class Net(nn.Cell): 25 def __init__(self, wi, stra1=None, stra2=None, stra3=None): 26 super(Net, self).__init__() 27 self.wi = Parameter(wi, "wi") 28 self.matmul = P.MatMul().shard(stra1) 29 self.onehot = P.OneHot(axis=-1).shard(stra2) 30 self.mul = P.Mul().shard(stra3) 31 self.on_value = Tensor(1.0, ms.float32) 32 self.off_value = Tensor(0.0, ms.float32) 33 self.cast = P.Cast() 34 self.depth = 48 35 36 def construct(self, x): 37 output = self.matmul(x, self.wi) 38 output = self.cast(output, ms.int32) 39 output = self.onehot(output, self.depth, self.on_value, self.off_value) 40 output = self.mul(output, output) 41 return output 42 43_x = Tensor(np.ones([16, 48]), dtype=ms.float32) 44_wi = Tensor(np.ones([48, 16]), dtype=ms.float32) 45 46 47def compile_net(net): 48 context.set_context(mode=context.GRAPH_MODE) 49 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 50 train_net = TrainOneStepCell(net, optimizer) 51 train_net.set_auto_parallel() 52 train_net.set_train() 53 _cell_graph_executor.compile(train_net, _x) 54 context.reset_auto_parallel_context() 55 56 57def test_onehot(): 58 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, enable_alltoall=True, 59 global_rank=0) 60 stra1 = ((8, 1), (1, 1)) 61 stra2 = ((8, 1, 1), (), ()) 62 stra3 = ((8, 1, 1), (8, 1, 1)) 63 net = Net(_wi, stra1=stra1, stra2=stra2, stra3=stra3) 64 compile_net(net) 65