• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import os
17import argparse
18import numpy as np
19import mindspore.context as context
20import mindspore.dataset as ds
21import mindspore.dataset.transforms.c_transforms as C
22import mindspore.dataset.vision.c_transforms as CV
23from mindspore.common import dtype as mstype
24from mindspore.dataset.vision import Inter
25from mindspore.common.tensor import Tensor
26from mindspore.nn import Cell
27from mindspore.nn import Flatten
28from mindspore.nn import Conv2d
29from mindspore.nn import BatchNorm2d
30from mindspore.nn import SoftmaxCrossEntropyWithLogits
31from mindspore.nn import Adam
32from mindspore.nn import EmbeddingLookup
33from mindspore.nn import ReLU
34import mindspore
35import mindspore.ops.operations as op
36from mindspore.common.parameter import Parameter
37from mindspore.train import Model
38from mindspore.common import set_seed
39
40parser = argparse.ArgumentParser(description='test_ps_lenet')
41parser.add_argument("--device_target", type=str, default="Ascend")
42parser.add_argument("--dataset_path", type=str, default="/home/workspace/mindspore_dataset/mnist")
43args, _ = parser.parse_known_args()
44device_target = args.device_target
45dataset_path = args.dataset_path
46context.set_context(mode=context.GRAPH_MODE, device_target=device_target, enable_sparse=True)
47context.set_ps_context(enable_ps=True)
48
49
50class Menet(Cell):
51    def __init__(self, in_channels, out_channels, kernel_size, vocab_size, embedding_size,
52                 output_channels, target, sparse):
53        super().__init__()
54        set_seed(5)
55        self.relu = ReLU()
56        self.conv = Conv2d(in_channels=in_channels, out_channels=out_channels,
57                           kernel_size=kernel_size, has_bias=True, weight_init='normal')
58        self.batchnorm = BatchNorm2d(num_features=out_channels)
59        self.embedding_lookup = EmbeddingLookup(vocab_size=vocab_size,
60                                                embedding_size=embedding_size,
61                                                param_init='normal', target=target, sparse=sparse)
62        self.flatten = Flatten()
63        self.cast = op.Cast()
64        self.bias = Parameter(Tensor(np.ones([output_channels]).astype(np.float32)), name='bias')
65        self.biasadd = op.BiasAdd()
66        self.type = mindspore.int32
67
68    def construct(self, x):
69        x = self.conv(x)
70        x = self.batchnorm(x)
71        x = self.flatten(x)
72        x = self.relu(x)
73        x = self.cast(x, self.type)
74        x = self.embedding_lookup(x)
75        x = self.flatten(x)
76        x = self.biasadd(x, self.bias)
77        x = self.biasadd(x, self.bias)
78        return x
79
80
81def create_dataset(data_path, batch_size=32, repeat_size=1,
82                   num_parallel_workers=1):
83    """
84    create dataset for train or test
85    """
86    # define dataset
87    mnist_ds = ds.MnistDataset(data_path)
88
89    resize_height, resize_width = 32, 32
90    rescale = 1.0 / 255.0
91    shift = 0.0
92    rescale_nml = 1 / 0.3081
93    shift_nml = -1 * 0.1307 / 0.3081
94
95    # define map operations
96    resize_op = CV.Resize((resize_height, resize_width),
97                          interpolation=Inter.LINEAR)  # Bilinear mode
98    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
99    rescale_op = CV.Rescale(rescale, shift)
100    hwc2chw_op = CV.HWC2CHW()
101    type_cast_op = C.TypeCast(mstype.int32)
102
103    # apply map operations on images
104    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label",
105                            num_parallel_workers=num_parallel_workers)
106    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image",
107                            num_parallel_workers=num_parallel_workers)
108    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image",
109                            num_parallel_workers=num_parallel_workers)
110    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image",
111                            num_parallel_workers=num_parallel_workers)
112    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image",
113                            num_parallel_workers=num_parallel_workers)
114
115    # apply DatasetOps
116    buffer_size = 10000
117    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)  # 10000 as in LeNet train script
118    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
119    mnist_ds = mnist_ds.repeat(repeat_size)
120
121    return mnist_ds
122
123
124class NetFactory:
125    def __init__(self, input_shape=(2, 1, 32, 32), in_channels=1, out_channels=3,
126                 kernel_size=5, vocab_size=5, embedding_size=1, output_channels=3072,
127                 epoch_size=1, target='CPU', sparse=True):
128        self.in_channels = in_channels
129        self.out_channels = out_channels
130        self.kernel_size = kernel_size
131        self.vocab_size = vocab_size
132        self.embedding_size = embedding_size
133        self.output_channels = output_channels
134        self.epoch_size = epoch_size
135        self.target = target
136        self.sparse = sparse
137        self.input_np = np.random.randn(*input_shape).astype(np.float32)
138
139    def no_ps_impl(self, dataset):
140        context.set_ps_context(enable_ps=False)
141        net = Menet(self.in_channels, self.out_channels, self.kernel_size, self.vocab_size,
142                    self.embedding_size, self.output_channels, self.target, self.sparse)
143        net.conv.conv2d.add_prim_attr('primitive_target', 'CPU')
144        net.conv.bias_add.add_prim_attr('primitive_target', 'CPU')
145        net.set_train()
146        loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
147        opt = Adam(params=filter(lambda x: x.requires_grad, net.get_parameters()))
148        opt.target = 'CPU'
149        model = Model(net, loss, opt)
150        model.train(self.epoch_size, dataset, dataset_sink_mode=False)
151        input_me = Tensor(self.input_np)
152        out_me = model.predict(input_me)
153        context.set_ps_context(enable_ps=True)
154        return out_me.asnumpy()
155
156    def part_ps_impl(self, dataset):
157        net = Menet(self.in_channels, self.out_channels, self.kernel_size, self.vocab_size,
158                    self.embedding_size, self.output_channels, self.target, self.sparse)
159        net.embedding_lookup.set_param_ps()
160        net.conv.conv2d.add_prim_attr('primitive_target', 'CPU')
161        net.conv.bias_add.add_prim_attr('primitive_target', 'CPU')
162        net.set_train()
163        loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
164        opt = Adam(params=filter(lambda x: x.requires_grad, net.get_parameters()))
165        opt.target = 'CPU'
166        model = Model(net, loss, opt)
167        model.train(self.epoch_size, dataset, dataset_sink_mode=False)
168        input_me = Tensor(self.input_np)
169        out_me = model.predict(input_me)
170        return out_me.asnumpy()
171
172    def part_cmp(self):
173        ds1 = create_dataset(os.path.join(dataset_path, "train"), 32, 1)
174        ds2 = create_dataset(os.path.join(dataset_path, "train"), 32, 1)
175        part_ps = self.part_ps_impl(ds1)
176        no_ps = self.no_ps_impl(ds2)
177        print(part_ps)
178        print(no_ps)
179        assert np.allclose(no_ps, part_ps, rtol=1.0e-4, atol=1.0e-4)
180
181
182if __name__ == "__main__":
183    fact = NetFactory()
184    fact.part_cmp()
185