Searched refs:mean_flag (Results 1 – 3 of 3) sorted by relevance
/third_party/mindspore/mindspore/ops/operations/ |
D | comm_ops.py | 260 …def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None): argument 270 self.mean_flag = mean_flag 293 def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, mean_flag=None): argument 301 self.mean_flag = mean_flag 725 def __init__(self, group=None, dev_num=None, mean_flag=None): argument 729 self.mean_flag = mean_flag 755 def __init__(self, group=None, dev_num=None, mean_flag=None, grad_accumulation_step=None): argument 759 self.mean_flag = mean_flag 883 def __init__(self, group=None, dev_num=None, mean_flag=None): argument 887 self.mean_flag = mean_flag
|
/third_party/mindspore/mindspore/ops/_grad/ |
D | grad_comm_ops.py | 162 mean_flag = self.mean_flag 187 if mean_flag: 221 mean_flag = self.get_attr_dict()["mean_flag"] 226 if mean_flag: 237 mean_flag = self.get_attr_dict()["mean_flag"] 250 if mean_flag: 270 mean_flag = self.get_attr_dict()["mean_flag"] 287 if mean_flag: 400 mean_flag = self.mean_flag 418 if mean_flag: [all …]
|
/third_party/mindspore/mindspore/ccsrc/frontend/parallel/ops_info/ |
D | operator_info.cc | 378 bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); in AddCommOpMeanFlag() local 379 attrs[MEAN_FLAG] = MakeValue<bool>(mean_flag); in AddCommOpMeanFlag() 416 bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); in CreateMiniStepAllGatherOp() local 423 ValuePtr attr2_value = MakeValue(mean_flag); // mean_flag in CreateMiniStepAllGatherOp() 439 bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); in CreateMicroStepAllGatherOp() local 444 ValuePtr attr1_value = MakeValue(mean_flag); // mean_flag in CreateMicroStepAllGatherOp() 483 bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); in CreateMirrorOps() local 489 ValuePtr attr2_value = MakeValue(mean_flag); in CreateMirrorOps() 520 << mean_flag; in CreateMirrorOps()
|