Home
last modified time | relevance | path

Searched refs:mean_flag (Results 1 – 3 of 3) sorted by relevance

/third_party/mindspore/mindspore/ops/operations/
Dcomm_ops.py260 …def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None): argument
270 self.mean_flag = mean_flag
293 def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, mean_flag=None): argument
301 self.mean_flag = mean_flag
725 def __init__(self, group=None, dev_num=None, mean_flag=None): argument
729 self.mean_flag = mean_flag
755 def __init__(self, group=None, dev_num=None, mean_flag=None, grad_accumulation_step=None): argument
759 self.mean_flag = mean_flag
883 def __init__(self, group=None, dev_num=None, mean_flag=None): argument
887 self.mean_flag = mean_flag
/third_party/mindspore/mindspore/ops/_grad/
Dgrad_comm_ops.py162 mean_flag = self.mean_flag
187 if mean_flag:
221 mean_flag = self.get_attr_dict()["mean_flag"]
226 if mean_flag:
237 mean_flag = self.get_attr_dict()["mean_flag"]
250 if mean_flag:
270 mean_flag = self.get_attr_dict()["mean_flag"]
287 if mean_flag:
400 mean_flag = self.mean_flag
418 if mean_flag:
[all …]
/third_party/mindspore/mindspore/ccsrc/frontend/parallel/ops_info/
Doperator_info.cc378 bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); in AddCommOpMeanFlag() local
379 attrs[MEAN_FLAG] = MakeValue<bool>(mean_flag); in AddCommOpMeanFlag()
416 bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); in CreateMiniStepAllGatherOp() local
423 ValuePtr attr2_value = MakeValue(mean_flag); // mean_flag in CreateMiniStepAllGatherOp()
439 bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); in CreateMicroStepAllGatherOp() local
444 ValuePtr attr1_value = MakeValue(mean_flag); // mean_flag in CreateMicroStepAllGatherOp()
483 bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); in CreateMirrorOps() local
489 ValuePtr attr2_value = MakeValue(mean_flag); in CreateMirrorOps()
520 << mean_flag; in CreateMirrorOps()