Searched refs:stage_num (Results 1 – 9 of 9) sorted by relevance
58 static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num) { in InferStage() argument59 if (stage_num == 0) { in InferStage()62 if (device_num % stage_num != 0) { in InferStage()64 << "stage_num: " << stage_num; in InferStage()66 auto per_stage_rank_num = device_num / stage_num; in InferStage()78 auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num(); in PipelineSplit() local79 if (stage_num <= 1) { in PipelineSplit()80 MS_LOG(INFO) << "stage num is: " << stage_num << ". No need Pipeline split."; in PipelineSplit()104 auto stage = InferStage(global_rank, stage_num, device_num); in PipelineSplit()105 auto per_stage_rank_num = device_num / stage_num; in PipelineSplit()
80 auto stage_num = g_device_manager->stage_num(); in IsLastStage() local82 return ((stage_num - 1) == stage_id); in IsLastStage()293 auto stage_num = g_device_manager->stage_num(); in ReorderForForward() local295 for (size_t i = 1; i < LongToSize(stage_num - stage_id); ++i) { in ReorderForForward()309 auto stage_num = g_device_manager->stage_num(); in ReorderForBackward() local311 for (size_t i = LongToSize(stage_num - stage_id); i < (forward_start_pair.first.size()); ++i) { in ReorderForBackward()313 … auto post_node1 = backward_start_pair.first[LongToSize(SizeToLong(i) - stage_num + stage_id + 1)]; in ReorderForBackward()315 auto prior_node2 = backward_end_pair.second[LongToSize(SizeToLong(i) - stage_num + stage_id)]; in ReorderForBackward()319 …for (size_t i = LongToSize(stage_num - stage_id); i < (forward_start_pair.first.size() + 1); ++i) { in ReorderForBackward()321 … auto prior_node3 = backward_start_pair.second[LongToSize(SizeToLong(i) - stage_num + stage_id)]; in ReorderForBackward()[all …]
182 auto stage_num = g_device_manager->stage_num(); in CreateForwardGroup() local183 for (int64_t i = 0; i < stage_num; ++i) { in CreateForwardGroup()228 auto stage_num = g_device_manager->stage_num(); in Coloring() local229 if (SizeToLong(stage_set.size()) != stage_num) { in Coloring()230 …MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.… in Coloring()821 auto stage_num = g_device_manager->stage_num(); in CutBorder() local822 if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) { in CutBorder()823 …OG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num; in CutBorder()
86 ASSERT_EQ(dm_.stage_num(), (int32_t)(2)); in TEST_F()136 ASSERT_EQ(dm_.stage_num(), 2); in TEST_F()
52 auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num(); in operator() local53 if (fg->stage() != -1 && stage_num > 1) { in operator()119 auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num(); in operator() local120 if (fg->stage() != -1 && stage_num > 1) { in operator()
106 …ontext::set_pipeline_stage_split_num(const int64_t stage_num) { pipeline_stage_split_num_ = stage_… in set_pipeline_stage_split_num() argument
77 int64_t stage_num() const { return stage_num_; } in stage_num() function
1386 stage_num = get_auto_parallel_context("pipeline_stages")1391 stage_num = 11392 if stage_num > 1:1401 … self._cluster_analyse_filename.format(parallel_mode, stage_num, self._rank_size, self._rank_id)
1452 stage_num = context.get_auto_parallel_context("pipeline_stages")1455 per_stage_devices = device_num // stage_num