/third_party/mindspore/mindspore/lite/test/ut/src/runtime/kernel/arm/cxx_api/ |
D | model_test.cc | 126 auto train_cfg = std::make_shared<TrainCfg>(); in TEST_F() local 127 train_cfg->accumulate_gradients_ = true; in TEST_F() 130 ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); in TEST_F() 162 auto train_cfg = std::make_shared<TrainCfg>(); in TEST_F() local 163 train_cfg->accumulate_gradients_ = true; in TEST_F() 166 ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); in TEST_F() 194 auto train_cfg = std::make_shared<TrainCfg>(); in TEST_F() local 195 train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true; in TEST_F() 198 ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); in TEST_F() 200 train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = false; in TEST_F() [all …]
|
/third_party/mindspore/tests/st/fl/albert/ |
D | cloud_train.py | 25 from src.config import train_cfg, server_net_cfg 112 train_cfg.max_global_epoch = fl_iteration_num 171 filter(train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter, server_params) 174 filter(lambda x: not train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter(x), server_params) 177 …{'params': server_decay_params, 'weight_decay': train_cfg.optimizer_cfg.AdamWeightDecay.weight_dec… 182 learning_rate=train_cfg.server_cfg.learning_rate, 183 eps=train_cfg.optimizer_cfg.AdamWeightDecay.eps) 194 input_ids = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32)) 195 … attention_mask = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32)) 196 … token_type_ids = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32)) [all …]
|
/third_party/mindspore/mindspore/lite/src/cxx_api/train/ |
D | train_support.cc | 61 lite::TrainCfg train_cfg; in CreateTrainSession() local 63 auto status = A2L_ConvertConfig(cfg.get(), &train_cfg); in CreateTrainSession() 70 auto ret = session->Init(context, &train_cfg); in CreateTrainSession()
|
/third_party/mindspore/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/ |
D | network_test.cc | 171 lite::TrainCfg train_cfg; in TEST_F() local 172 train_cfg.loss_name_ = "nhwc"; in TEST_F() 174 …to session = mindspore::session::TrainSession::CreateTrainSession(net, &context, true, &train_cfg); in TEST_F()
|
/third_party/mindspore/mindspore/lite/tools/benchmark_train/ |
D | net_train.cc | 337 … const TrainCfg &train_cfg, int epochs) { in CreateAndRunNetworkForTrain() argument 344 … session::TrainSession::CreateTransferSession(bb_filename, filename, &context, true, &train_cfg)); in CreateAndRunNetworkForTrain() 353 …std::cout << "Is raw mix precision model: " << train_cfg.mix_precision_cfg_.is_raw_mix_precision_ … in CreateAndRunNetworkForTrain() 355 session::TrainSession::CreateTrainSession(filename, &context, true, &train_cfg)); in CreateAndRunNetworkForTrain() 412 TrainCfg train_cfg; in CreateAndRunNetwork() local 414 train_cfg.loss_name_ = flags_->loss_name_; in CreateAndRunNetwork() 416 train_cfg.mix_precision_cfg_.is_raw_mix_precision_ = flags_->is_raw_mix_precision_; in CreateAndRunNetwork() 419 session = CreateAndRunNetworkForTrain(filename, bb_filename, context, train_cfg, epochs); in CreateAndRunNetwork()
|
D | net_train.h | 141 … const Context &context, const TrainCfg &train_cfg,
|
/third_party/mindspore/mindspore/lite/examples/train_lenet/src/ |
D | net_runner.cc | 153 mindspore::lite::TrainCfg train_cfg; in InitAndFigureInputs() local 154 train_cfg.mix_precision_cfg_.is_raw_mix_precision_ = is_raw_mix_precision_; in InitAndFigureInputs() 155 …sion_ = mindspore::session::TrainSession::CreateTrainSession(ms_file_, &context, true, &train_cfg); in InitAndFigureInputs()
|
/third_party/mindspore/tests/st/fl/albert/src/ |
D | config.py | 29 train_cfg = edict({ variable
|
/third_party/mindspore/mindspore/lite/src/cxx_api/model/ |
D | model.cc | 68 const std::shared_ptr<TrainCfg> &train_cfg) { in Build() argument 91 impl_->SetConfig(train_cfg); in Build()
|
/third_party/mindspore/include/api/ |
D | model.h | 57 const std::shared_ptr<TrainCfg> &train_cfg = nullptr);
|
/third_party/mindspore/mindspore/lite/src/train/ |
D | train_session.cc | 58 int TrainSession::Init(InnerContext *context, const TrainCfg *train_cfg) { in Init() argument 59 if (train_cfg != nullptr) { in Init() 60 if (train_cfg->mix_precision_cfg_.loss_scale_ <= 0) { in Init() 64 cfg_ = *train_cfg; in Init()
|
D | train_session.h | 57 virtual int Init(InnerContext *context, const TrainCfg *train_cfg);
|