• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test came """
16import os
17import sys
18import numpy as np
19import mindspore as ms
20import mindspore.nn as nn
21from mindspore import Tensor, Parameter, context, build_searched_strategy
22from mindspore.common.api import _cell_graph_executor
23from mindspore.nn import TrainOneStepCell, WithLossCell
24from mindspore.ops import operations as P
25from mindspore.train import Callback
26from mindspore.train import Model
27import mindspore.dataset as ds
28from mindspore.communication import init
29from came import Came
30
31sys.path.append(os.path.dirname(os.path.realpath(__file__)))
32context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
33context.set_context(device_id=1)
34
35class Net(nn.Cell):
36    """ Net definition """
37    def __init__(self):
38        super(Net, self).__init__()
39        self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight")
40        self.bias = Parameter(Tensor(np.ones([10]).astype((np.float32))), name="bias")
41        self.matmul = P.MatMul()
42        self.biasAdd = P.BiasAdd()
43
44    def construct(self, x):
45        x = self.biasAdd(self.matmul(x, self.weight), self.bias)
46        return x
47
48class ModelCallback(Callback):
49    def __init__(self):
50        super(ModelCallback, self).__init__()
51        self.loss_list = []
52
53    def on_train_step_end(self, run_context):
54        cb_params = run_context.original_args()
55        result = cb_params.net_outputs
56        self.loss_list.append(result.asnumpy().mean())
57
58class MyDataset:
59    def __init__(self, n, in_dim, out_dim):
60        self.input_data = []
61        self.label_data = []
62        for _ in range(n):
63            self.input_data.append(np.arange(0.0, in_dim, dtype=np.float32) * 0.1)
64            label_data = np.zeros(out_dim, dtype=np.float32)
65            label_data[0] = 1.0
66            self.label_data.append(label_data)
67
68    def __getitem__(self, index):
69        return self.input_data[index], self.label_data[index]
70
71    def __len__(self):
72        return len(self.input_data)
73
74def came_compile():
75    """ test came compile"""
76    inputs = Tensor(np.ones([1, 64]).astype(np.float32))
77    label = Tensor(np.zeros([1, 10]).astype(np.float32))
78    net = Net()
79    net.set_train()
80
81    loss = nn.SoftmaxCrossEntropyWithLogits()
82    optimizer = Came(net.trainable_params(), learning_rate=0.1)
83
84    net_with_loss = WithLossCell(net, loss)
85    train_network = TrainOneStepCell(net_with_loss, optimizer)
86    _cell_graph_executor.compile(train_network, inputs, label)
87
88def came_loss():
89    """ test came with loss decrease"""
90    net = Net()
91    net.set_train()
92    parallel_callback = ModelCallback()
93
94    loss = nn.SoftmaxCrossEntropyWithLogits()
95    optimizer = Came(net.trainable_params(), learning_rate=0.1)
96    fake_dataset = MyDataset(8, 64, 10)
97    dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"]).batch(1)
98    model = Model(net, loss_fn=loss, optimizer=optimizer)
99    model.train(1, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
100    loss_values = np.array(parallel_callback.loss_list)
101    assert abs(loss_values[-1]) < abs(loss_values[0])
102
103def came_parallel():
104    "test came optimizer shard with two cards with loss decrease"
105    strategy_file_path = "./strategy_stage1.ckpt"
106    ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
107    ms.set_auto_parallel_context(enable_parallel_optimizer=True)
108    ms.set_auto_parallel_context(strategy_ckpt_config={"save_file": strategy_file_path})
109    ms.set_auto_parallel_context(parallel_optimizer_config={"parallel_optimizer_threshold": 1})
110    context.set_context(device_id=int(os.getenv('DEVICE_ID')))
111    init()
112    net = Net()
113    net.set_train()
114    parallel_callback = ModelCallback()
115    loss = nn.SoftmaxCrossEntropyWithLogits()
116    optimizer = Came(net.trainable_params(), learning_rate=0.1)
117    fake_dataset = MyDataset(8, 64, 10)
118    dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"]).batch(2)
119    model = Model(net, loss_fn=loss, optimizer=optimizer)
120    model.train(1, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
121    assert optimizer.exp_avg_sq_row[0].shape == (32,)
122    assert optimizer.exp_avg_sq_row[1].shape == (1,)
123    assert optimizer.exp_avg_sq_col[0].shape == (10,)
124    assert optimizer.exp_avg_sq_col[1].shape == (1,)
125    assert optimizer.exp_avg_insta_row[0].shape == (32,)
126    assert optimizer.exp_avg_insta_row[1].shape == (1,)
127    assert optimizer.exp_avg_insta_col[0].shape == (10,)
128    assert optimizer.exp_avg_insta_col[1].shape == (1,)
129    assert optimizer.exp_avg_sq[0].shape == (1,)
130    assert optimizer.exp_avg_sq[1].shape == (10,)
131    loss_values = np.array(parallel_callback.loss_list)
132    assert abs(loss_values[-1]) < abs(loss_values[0])
133    strategy = build_searched_strategy(strategy_file_path)
134    matched_count = 0
135    for key, _ in strategy.items():
136        if 'avg' in key:
137            matched_count += 1
138    assert matched_count == 7
139