1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test came """ 16import os 17import sys 18import numpy as np 19import mindspore as ms 20import mindspore.nn as nn 21from mindspore import Tensor, Parameter, context, build_searched_strategy 22from mindspore.common.api import _cell_graph_executor 23from mindspore.nn import TrainOneStepCell, WithLossCell 24from mindspore.ops import operations as P 25from mindspore.train import Callback 26from mindspore.train import Model 27import mindspore.dataset as ds 28from mindspore.communication import init 29from came import Came 30 31sys.path.append(os.path.dirname(os.path.realpath(__file__))) 32context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') 33context.set_context(device_id=1) 34 35class Net(nn.Cell): 36 """ Net definition """ 37 def __init__(self): 38 super(Net, self).__init__() 39 self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight") 40 self.bias = Parameter(Tensor(np.ones([10]).astype((np.float32))), name="bias") 41 self.matmul = P.MatMul() 42 self.biasAdd = P.BiasAdd() 43 44 def construct(self, x): 45 x = self.biasAdd(self.matmul(x, self.weight), self.bias) 46 return x 47 48class ModelCallback(Callback): 49 def __init__(self): 50 super(ModelCallback, self).__init__() 51 self.loss_list = [] 52 53 def on_train_step_end(self, run_context): 54 cb_params = run_context.original_args() 55 result = cb_params.net_outputs 56 self.loss_list.append(result.asnumpy().mean()) 57 58class MyDataset: 59 def __init__(self, n, in_dim, out_dim): 60 self.input_data = [] 61 self.label_data = [] 62 for _ in range(n): 63 self.input_data.append(np.arange(0.0, in_dim, dtype=np.float32) * 0.1) 64 label_data = np.zeros(out_dim, dtype=np.float32) 65 label_data[0] = 1.0 66 self.label_data.append(label_data) 67 68 def __getitem__(self, index): 69 return self.input_data[index], self.label_data[index] 70 71 def __len__(self): 72 return len(self.input_data) 73 74def came_compile(): 75 """ test came compile""" 76 inputs = Tensor(np.ones([1, 64]).astype(np.float32)) 77 label = Tensor(np.zeros([1, 10]).astype(np.float32)) 78 net = Net() 79 net.set_train() 80 81 loss = nn.SoftmaxCrossEntropyWithLogits() 82 optimizer = Came(net.trainable_params(), learning_rate=0.1) 83 84 net_with_loss = WithLossCell(net, loss) 85 train_network = TrainOneStepCell(net_with_loss, optimizer) 86 _cell_graph_executor.compile(train_network, inputs, label) 87 88def came_loss(): 89 """ test came with loss decrease""" 90 net = Net() 91 net.set_train() 92 parallel_callback = ModelCallback() 93 94 loss = nn.SoftmaxCrossEntropyWithLogits() 95 optimizer = Came(net.trainable_params(), learning_rate=0.1) 96 fake_dataset = MyDataset(8, 64, 10) 97 dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"]).batch(1) 98 model = Model(net, loss_fn=loss, optimizer=optimizer) 99 model.train(1, dataset, dataset_sink_mode=False, callbacks=parallel_callback) 100 loss_values = np.array(parallel_callback.loss_list) 101 assert abs(loss_values[-1]) < abs(loss_values[0]) 102 103def came_parallel(): 104 "test came optimizer shard with two cards with loss decrease" 105 strategy_file_path = "./strategy_stage1.ckpt" 106 ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL) 107 ms.set_auto_parallel_context(enable_parallel_optimizer=True) 108 ms.set_auto_parallel_context(strategy_ckpt_config={"save_file": strategy_file_path}) 109 ms.set_auto_parallel_context(parallel_optimizer_config={"parallel_optimizer_threshold": 1}) 110 context.set_context(device_id=int(os.getenv('DEVICE_ID'))) 111 init() 112 net = Net() 113 net.set_train() 114 parallel_callback = ModelCallback() 115 loss = nn.SoftmaxCrossEntropyWithLogits() 116 optimizer = Came(net.trainable_params(), learning_rate=0.1) 117 fake_dataset = MyDataset(8, 64, 10) 118 dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"]).batch(2) 119 model = Model(net, loss_fn=loss, optimizer=optimizer) 120 model.train(1, dataset, dataset_sink_mode=False, callbacks=parallel_callback) 121 assert optimizer.exp_avg_sq_row[0].shape == (32,) 122 assert optimizer.exp_avg_sq_row[1].shape == (1,) 123 assert optimizer.exp_avg_sq_col[0].shape == (10,) 124 assert optimizer.exp_avg_sq_col[1].shape == (1,) 125 assert optimizer.exp_avg_insta_row[0].shape == (32,) 126 assert optimizer.exp_avg_insta_row[1].shape == (1,) 127 assert optimizer.exp_avg_insta_col[0].shape == (10,) 128 assert optimizer.exp_avg_insta_col[1].shape == (1,) 129 assert optimizer.exp_avg_sq[0].shape == (1,) 130 assert optimizer.exp_avg_sq[1].shape == (10,) 131 loss_values = np.array(parallel_callback.loss_list) 132 assert abs(loss_values[-1]) < abs(loss_values[0]) 133 strategy = build_searched_strategy(strategy_file_path) 134 matched_count = 0 135 for key, _ in strategy.items(): 136 if 'avg' in key: 137 matched_count += 1 138 assert matched_count == 7 139