Home
last modified time | relevance | path

Searched refs:grad_accumulation_step (Results 1 – 9 of 9) sorted by relevance

/third_party/mindspore/mindspore/ops/operations/
Dcomm_ops.py260 …def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None): argument
269 self.grad_accumulation_step = grad_accumulation_step
755 def __init__(self, group=None, dev_num=None, mean_flag=None, grad_accumulation_step=None): argument
760 self.grad_accumulation_step = grad_accumulation_step
/third_party/mindspore/mindspore/ccsrc/frontend/parallel/
Dparameter_manager.cc346 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); in HandleNoUsedParameter() local
347 if (grad_accumulation_step > 1) { in HandleNoUsedParameter()
419 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); in HandleFullySplitParameters() local
420 if ((grad_accumulation_step <= 1) || root->has_flag(ACCUMULATION)) { in HandleFullySplitParameters()
Dcontext.cc98 void ParallelContext::set_grad_accumulation_step(int64_t grad_accumulation_step) { in set_grad_accumulation_step() argument
99 grad_accumulation_step_ = grad_accumulation_step; in set_grad_accumulation_step()
220 …if ((ParallelContext::GetInstance()->grad_accumulation_step() > 1) && !func_graph->has_flag(ACCUMU… in ParallelParameterContextInitShape()
Dcontext.h86 void set_grad_accumulation_step(int64_t grad_accumulation_step);
87 int64_t grad_accumulation_step() const { return grad_accumulation_step_; } in grad_accumulation_step() function
Dstep_parallel.cc133 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); in CreateMirrorInput() local
136 if (grad_accumulation_step > 1 || split_stage_num > 1) { in CreateMirrorInput()
191 if (grad_accumulation_step > 1) { in CreateMirrorInput()
1107 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); in MirrorOpName() local
1110 if (grad_accumulation_step > 1) { in MirrorOpName()
1563 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); in ApplyParallelOptOnParam() local
1566 if (grad_accumulation_step > 1) { in ApplyParallelOptOnParam()
/third_party/mindspore/mindspore/parallel/
D_auto_parallel_context.py301 def set_grad_accumulation_step(self, grad_accumulation_step): argument
309 Validator.check_positive_int(grad_accumulation_step)
310 self._context_handle.set_grad_accumulation_step(grad_accumulation_step)
684grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
/third_party/mindspore/mindspore/ccsrc/frontend/parallel/ops_info/
Doperator_info.cc415 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); in CreateMiniStepAllGatherOp() local
421 ValuePtr attr1_value = MakeValue(grad_accumulation_step); // grad_accumulation_step in CreateMiniStepAllGatherOp()
484 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); in CreateMirrorOps() local
501 if (grad_accumulation_step > 1) { in CreateMirrorOps()
503 ValuePtr attr3_value = MakeValue(grad_accumulation_step); in CreateMirrorOps()
506 …MS_LOG(INFO) << "The grad accumulation step is " << grad_accumulation_step << ", use mini step mir… in CreateMirrorOps()
/third_party/mindspore/mindspore/
Dcontext.py348 all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int)
/third_party/mindspore/mindspore/ccsrc/pipeline/jit/
Dinit.cc151 ….def("get_grad_accumulation_step", &ParallelContext::grad_accumulation_step, "Get grad accumulatio… in PYBIND11_MODULE()