1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import os 17import argparse 18import numpy as np 19import mindspore.context as context 20import mindspore.dataset as ds 21import mindspore.dataset.transforms.c_transforms as C 22import mindspore.dataset.vision.c_transforms as CV 23from mindspore.common import dtype as mstype 24from mindspore.dataset.vision import Inter 25from mindspore.common.tensor import Tensor 26from mindspore.nn import Cell 27from mindspore.nn import Flatten 28from mindspore.nn import Conv2d 29from mindspore.nn import BatchNorm2d 30from mindspore.nn import SoftmaxCrossEntropyWithLogits 31from mindspore.nn import Adam 32from mindspore.nn import EmbeddingLookup 33from mindspore.nn import ReLU 34import mindspore 35import mindspore.ops.operations as op 36from mindspore.common.parameter import Parameter 37from mindspore.train import Model 38from mindspore.common import set_seed 39 40parser = argparse.ArgumentParser(description='test_ps_lenet') 41parser.add_argument("--device_target", type=str, default="Ascend") 42parser.add_argument("--dataset_path", type=str, default="/home/workspace/mindspore_dataset/mnist") 43args, _ = parser.parse_known_args() 44device_target = args.device_target 45dataset_path = args.dataset_path 46context.set_context(mode=context.GRAPH_MODE, device_target=device_target, enable_sparse=True) 47context.set_ps_context(enable_ps=True) 48 49 50class Menet(Cell): 51 def __init__(self, in_channels, out_channels, kernel_size, vocab_size, embedding_size, 52 output_channels, target, sparse): 53 super().__init__() 54 set_seed(5) 55 self.relu = ReLU() 56 self.conv = Conv2d(in_channels=in_channels, out_channels=out_channels, 57 kernel_size=kernel_size, has_bias=True, weight_init='normal') 58 self.batchnorm = BatchNorm2d(num_features=out_channels) 59 self.embedding_lookup = EmbeddingLookup(vocab_size=vocab_size, 60 embedding_size=embedding_size, 61 param_init='normal', target=target, sparse=sparse) 62 self.flatten = Flatten() 63 self.cast = op.Cast() 64 self.bias = Parameter(Tensor(np.ones([output_channels]).astype(np.float32)), name='bias') 65 self.biasadd = op.BiasAdd() 66 self.type = mindspore.int32 67 68 def construct(self, x): 69 x = self.conv(x) 70 x = self.batchnorm(x) 71 x = self.flatten(x) 72 x = self.relu(x) 73 x = self.cast(x, self.type) 74 x = self.embedding_lookup(x) 75 x = self.flatten(x) 76 x = self.biasadd(x, self.bias) 77 x = self.biasadd(x, self.bias) 78 return x 79 80 81def create_dataset(data_path, batch_size=32, repeat_size=1, 82 num_parallel_workers=1): 83 """ 84 create dataset for train or test 85 """ 86 # define dataset 87 mnist_ds = ds.MnistDataset(data_path) 88 89 resize_height, resize_width = 32, 32 90 rescale = 1.0 / 255.0 91 shift = 0.0 92 rescale_nml = 1 / 0.3081 93 shift_nml = -1 * 0.1307 / 0.3081 94 95 # define map operations 96 resize_op = CV.Resize((resize_height, resize_width), 97 interpolation=Inter.LINEAR) # Bilinear mode 98 rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) 99 rescale_op = CV.Rescale(rescale, shift) 100 hwc2chw_op = CV.HWC2CHW() 101 type_cast_op = C.TypeCast(mstype.int32) 102 103 # apply map operations on images 104 mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", 105 num_parallel_workers=num_parallel_workers) 106 mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", 107 num_parallel_workers=num_parallel_workers) 108 mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", 109 num_parallel_workers=num_parallel_workers) 110 mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", 111 num_parallel_workers=num_parallel_workers) 112 mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", 113 num_parallel_workers=num_parallel_workers) 114 115 # apply DatasetOps 116 buffer_size = 10000 117 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script 118 mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) 119 mnist_ds = mnist_ds.repeat(repeat_size) 120 121 return mnist_ds 122 123 124class NetFactory: 125 def __init__(self, input_shape=(2, 1, 32, 32), in_channels=1, out_channels=3, 126 kernel_size=5, vocab_size=5, embedding_size=1, output_channels=3072, 127 epoch_size=1, target='CPU', sparse=True): 128 self.in_channels = in_channels 129 self.out_channels = out_channels 130 self.kernel_size = kernel_size 131 self.vocab_size = vocab_size 132 self.embedding_size = embedding_size 133 self.output_channels = output_channels 134 self.epoch_size = epoch_size 135 self.target = target 136 self.sparse = sparse 137 self.input_np = np.random.randn(*input_shape).astype(np.float32) 138 139 def no_ps_impl(self, dataset): 140 context.set_ps_context(enable_ps=False) 141 net = Menet(self.in_channels, self.out_channels, self.kernel_size, self.vocab_size, 142 self.embedding_size, self.output_channels, self.target, self.sparse) 143 net.conv.conv2d.add_prim_attr('primitive_target', 'CPU') 144 net.conv.bias_add.add_prim_attr('primitive_target', 'CPU') 145 net.set_train() 146 loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 147 opt = Adam(params=filter(lambda x: x.requires_grad, net.get_parameters())) 148 opt.target = 'CPU' 149 model = Model(net, loss, opt) 150 model.train(self.epoch_size, dataset, dataset_sink_mode=False) 151 input_me = Tensor(self.input_np) 152 out_me = model.predict(input_me) 153 context.set_ps_context(enable_ps=True) 154 return out_me.asnumpy() 155 156 def part_ps_impl(self, dataset): 157 net = Menet(self.in_channels, self.out_channels, self.kernel_size, self.vocab_size, 158 self.embedding_size, self.output_channels, self.target, self.sparse) 159 net.embedding_lookup.set_param_ps() 160 net.conv.conv2d.add_prim_attr('primitive_target', 'CPU') 161 net.conv.bias_add.add_prim_attr('primitive_target', 'CPU') 162 net.set_train() 163 loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 164 opt = Adam(params=filter(lambda x: x.requires_grad, net.get_parameters())) 165 opt.target = 'CPU' 166 model = Model(net, loss, opt) 167 model.train(self.epoch_size, dataset, dataset_sink_mode=False) 168 input_me = Tensor(self.input_np) 169 out_me = model.predict(input_me) 170 return out_me.asnumpy() 171 172 def part_cmp(self): 173 ds1 = create_dataset(os.path.join(dataset_path, "train"), 32, 1) 174 ds2 = create_dataset(os.path.join(dataset_path, "train"), 32, 1) 175 part_ps = self.part_ps_impl(ds1) 176 no_ps = self.no_ps_impl(ds2) 177 print(part_ps) 178 print(no_ps) 179 assert np.allclose(no_ps, part_ps, rtol=1.0e-4, atol=1.0e-4) 180 181 182if __name__ == "__main__": 183 fact = NetFactory() 184 fact.part_cmp() 185