Home
last modified time | relevance | path

Searched refs:train_cfg (Results 1 – 12 of 12) sorted by relevance

/third_party/mindspore/mindspore/lite/test/ut/src/runtime/kernel/arm/cxx_api/
Dmodel_test.cc126 auto train_cfg = std::make_shared<TrainCfg>(); in TEST_F() local
127 train_cfg->accumulate_gradients_ = true; in TEST_F()
130 ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); in TEST_F()
162 auto train_cfg = std::make_shared<TrainCfg>(); in TEST_F() local
163 train_cfg->accumulate_gradients_ = true; in TEST_F()
166 ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); in TEST_F()
194 auto train_cfg = std::make_shared<TrainCfg>(); in TEST_F() local
195 train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true; in TEST_F()
198 ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); in TEST_F()
200 train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = false; in TEST_F()
[all …]
/third_party/mindspore/tests/st/fl/albert/
Dcloud_train.py25 from src.config import train_cfg, server_net_cfg
112 train_cfg.max_global_epoch = fl_iteration_num
171 filter(train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter, server_params)
174 filter(lambda x: not train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter(x), server_params)
177 …{'params': server_decay_params, 'weight_decay': train_cfg.optimizer_cfg.AdamWeightDecay.weight_dec…
182 learning_rate=train_cfg.server_cfg.learning_rate,
183 eps=train_cfg.optimizer_cfg.AdamWeightDecay.eps)
194 input_ids = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32))
195 … attention_mask = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32))
196 … token_type_ids = Tensor(np.zeros((train_cfg.batch_size, server_net_cfg.seq_length), np.int32))
[all …]
/third_party/mindspore/mindspore/lite/src/cxx_api/train/
Dtrain_support.cc61 lite::TrainCfg train_cfg; in CreateTrainSession() local
63 auto status = A2L_ConvertConfig(cfg.get(), &train_cfg); in CreateTrainSession()
70 auto ret = session->Init(context, &train_cfg); in CreateTrainSession()
/third_party/mindspore/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/
Dnetwork_test.cc171 lite::TrainCfg train_cfg; in TEST_F() local
172 train_cfg.loss_name_ = "nhwc"; in TEST_F()
174 …to session = mindspore::session::TrainSession::CreateTrainSession(net, &context, true, &train_cfg); in TEST_F()
/third_party/mindspore/mindspore/lite/tools/benchmark_train/
Dnet_train.cc337 … const TrainCfg &train_cfg, int epochs) { in CreateAndRunNetworkForTrain() argument
344 … session::TrainSession::CreateTransferSession(bb_filename, filename, &context, true, &train_cfg)); in CreateAndRunNetworkForTrain()
353 …std::cout << "Is raw mix precision model: " << train_cfg.mix_precision_cfg_.is_raw_mix_precision_ … in CreateAndRunNetworkForTrain()
355 session::TrainSession::CreateTrainSession(filename, &context, true, &train_cfg)); in CreateAndRunNetworkForTrain()
412 TrainCfg train_cfg; in CreateAndRunNetwork() local
414 train_cfg.loss_name_ = flags_->loss_name_; in CreateAndRunNetwork()
416 train_cfg.mix_precision_cfg_.is_raw_mix_precision_ = flags_->is_raw_mix_precision_; in CreateAndRunNetwork()
419 session = CreateAndRunNetworkForTrain(filename, bb_filename, context, train_cfg, epochs); in CreateAndRunNetwork()
Dnet_train.h141 … const Context &context, const TrainCfg &train_cfg,
/third_party/mindspore/mindspore/lite/examples/train_lenet/src/
Dnet_runner.cc153 mindspore::lite::TrainCfg train_cfg; in InitAndFigureInputs() local
154 train_cfg.mix_precision_cfg_.is_raw_mix_precision_ = is_raw_mix_precision_; in InitAndFigureInputs()
155 …sion_ = mindspore::session::TrainSession::CreateTrainSession(ms_file_, &context, true, &train_cfg); in InitAndFigureInputs()
/third_party/mindspore/tests/st/fl/albert/src/
Dconfig.py29 train_cfg = edict({ variable
/third_party/mindspore/mindspore/lite/src/cxx_api/model/
Dmodel.cc68 const std::shared_ptr<TrainCfg> &train_cfg) { in Build() argument
91 impl_->SetConfig(train_cfg); in Build()
/third_party/mindspore/include/api/
Dmodel.h57 const std::shared_ptr<TrainCfg> &train_cfg = nullptr);
/third_party/mindspore/mindspore/lite/src/train/
Dtrain_session.cc58 int TrainSession::Init(InnerContext *context, const TrainCfg *train_cfg) { in Init() argument
59 if (train_cfg != nullptr) { in Init()
60 if (train_cfg->mix_precision_cfg_.loss_scale_ <= 0) { in Init()
64 cfg_ = *train_cfg; in Init()
Dtrain_session.h57 virtual int Init(InnerContext *context, const TrainCfg *train_cfg);