Searched refs:WeightUpdateMode (Results 1 – 5 of 5) sorted by relevance
37 enum class WeightUpdateMode { NORMAL, VIRTUAL_BATCH, ACCUMULATE_GRADS }; enum47 WeightUpdateMode get_optimizer_mode() { return weight_update_mod_; } in get_optimizer_mode()124 int SetOptimizerMode(WeightUpdateMode mod) { in SetOptimizerMode()125 if (mod == WeightUpdateMode::VIRTUAL_BATCH || mod == WeightUpdateMode::ACCUMULATE_GRADS) { in SetOptimizerMode()171 if (weight_update_mod_ != WeightUpdateMode::ACCUMULATE_GRADS) { in Eval()221 WeightUpdateMode weight_update_mod_ = WeightUpdateMode::NORMAL;
653 optimizer->SetOptimizerMode(kernel::WeightUpdateMode::ACCUMULATE_GRADS); in CompileOptimizedKernels()814 …(virtual_batch_multiplier <= 1) ? kernel::WeightUpdateMode::NORMAL : kernel::WeightUpdateMode::VIR… in AdminSetupVirtualBatch()821 if (optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::NORMAL && in AdminSetupVirtualBatch()822 optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::VIRTUAL_BATCH) { in AdminSetupVirtualBatch()831 if (mod == kernel::WeightUpdateMode::VIRTUAL_BATCH) { in AdminSetupVirtualBatch()846 if (mod == kernel::WeightUpdateMode::VIRTUAL_BATCH) { in AdminSetupVirtualBatch()
125 if (sgd_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) { in SgdRun()127 } else if (sgd_kernel->get_optimizer_mode() == WeightUpdateMode::ACCUMULATE_GRADS) { in SgdRun()143 if (sgd_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) { in SgdRunInit()
77 if (applyMomentum_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) { in ApplyMomentumRun()79 } else if (applyMomentum_kernel->get_optimizer_mode() == WeightUpdateMode::ACCUMULATE_GRADS) { in ApplyMomentumRun()
96 if (adam_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) { in AdamRun()98 } else if (adam_kernel->get_optimizer_mode() == WeightUpdateMode::ACCUMULATE_GRADS) { in AdamRun()