Home
last modified time | relevance | path

Searched refs:WeightUpdateMode (Results 1 – 5 of 5) sorted by relevance

/third_party/mindspore/mindspore/lite/src/train/
Doptimizer_kernel.h37 enum class WeightUpdateMode { NORMAL, VIRTUAL_BATCH, ACCUMULATE_GRADS }; enum
47 WeightUpdateMode get_optimizer_mode() { return weight_update_mod_; } in get_optimizer_mode()
124 int SetOptimizerMode(WeightUpdateMode mod) { in SetOptimizerMode()
125 if (mod == WeightUpdateMode::VIRTUAL_BATCH || mod == WeightUpdateMode::ACCUMULATE_GRADS) { in SetOptimizerMode()
171 if (weight_update_mod_ != WeightUpdateMode::ACCUMULATE_GRADS) { in Eval()
221 WeightUpdateMode weight_update_mod_ = WeightUpdateMode::NORMAL;
Dtrain_session.cc653 optimizer->SetOptimizerMode(kernel::WeightUpdateMode::ACCUMULATE_GRADS); in CompileOptimizedKernels()
814 …(virtual_batch_multiplier <= 1) ? kernel::WeightUpdateMode::NORMAL : kernel::WeightUpdateMode::VIR… in AdminSetupVirtualBatch()
821 if (optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::NORMAL && in AdminSetupVirtualBatch()
822 optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::VIRTUAL_BATCH) { in AdminSetupVirtualBatch()
831 if (mod == kernel::WeightUpdateMode::VIRTUAL_BATCH) { in AdminSetupVirtualBatch()
846 if (mod == kernel::WeightUpdateMode::VIRTUAL_BATCH) { in AdminSetupVirtualBatch()
/third_party/mindspore/mindspore/lite/src/runtime/kernel/arm/fp32_grad/
Dsgd.cc125 if (sgd_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) { in SgdRun()
127 } else if (sgd_kernel->get_optimizer_mode() == WeightUpdateMode::ACCUMULATE_GRADS) { in SgdRun()
143 if (sgd_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) { in SgdRunInit()
Dapply_momentum.cc77 if (applyMomentum_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) { in ApplyMomentumRun()
79 } else if (applyMomentum_kernel->get_optimizer_mode() == WeightUpdateMode::ACCUMULATE_GRADS) { in ApplyMomentumRun()
Dadam.cc96 if (adam_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) { in AdamRun()
98 } else if (adam_kernel->get_optimizer_mode() == WeightUpdateMode::ACCUMULATE_GRADS) { in AdamRun()