1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore.nn as nn 18from mindspore import Tensor 19from mindspore import context 20from mindspore.ops import operations as P 21from mindspore.train.model import Model 22 23 24class CrossEntropyLoss(nn.Cell): 25 def __init__(self, reduction='mean'): 26 super(CrossEntropyLoss, self).__init__() 27 28 self.reduce_mean = P.ReduceMean() 29 self.cross_entropy = nn.SoftmaxCrossEntropyWithLogits() 30 self.reduction = reduction 31 32 def construct(self, logits, label): 33 loss = self.cross_entropy(logits, label) 34 if self.reduction == 'mean': 35 loss = self.reduce_mean(loss, (-1,)) 36 return loss 37 38 39class DatasetLenet(): 40 def __init__(self, predict, label, length=3): 41 self.predict = predict 42 self.label = label 43 self.index = 0 44 self.length = length 45 46 def __iter__(self): 47 return self 48 49 def __next__(self): 50 if self.index >= self.length: 51 raise StopIteration 52 self.index += 1 53 return self.predict, self.label 54 55 def reset(self): 56 self.index = 0 57 58 def get_dataset_size(self): 59 return 32 60 61 def get_repeat_count(self): 62 return 1 63 64 def create_tuple_iterator(self, num_epochs=-1, do_copy=True): 65 return self 66 67 68class Net(nn.Cell): 69 def __init__(self): 70 super().__init__() 71 self.conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=1, stride=1, pad_mode='valid', 72 has_bias=True, weight_init='ones', bias_init='ones') 73 self.conv.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1))) 74 self.reduce_mean = P.ReduceMean(keep_dims=False).shard(((1, 1, 1, 8),)) 75 self.flat = nn.Flatten() 76 77 def construct(self, inputs): 78 x = self.conv(inputs) 79 x = self.reduce_mean(x, -1) 80 x = self.flat(x) 81 return x 82 83 84def test_bias_add(): 85 context.set_context(mode=context.GRAPH_MODE) 86 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8) 87 input_np = np.ones([16, 3, 32, 32]).astype(np.float32) 88 label_np = np.zeros([16, 2048]).astype(np.float32) 89 dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1) 90 net = Net() 91 loss = CrossEntropyLoss() 92 opt = nn.Momentum(learning_rate=0.01, momentum=0.9, params=net.get_parameters()) 93 model = Model(network=net, loss_fn=loss, optimizer=opt) 94 model.train(epoch=1, train_dataset=dataset, dataset_sink_mode=False) 95