Home
last modified time | relevance | path

Searched refs:stage_num (Results 1 – 9 of 9) sorted by relevance

/third_party/mindspore/mindspore/ccsrc/pipeline/jit/
Dpipeline_split.cc58 static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num) { in InferStage() argument
59 if (stage_num == 0) { in InferStage()
62 if (device_num % stage_num != 0) { in InferStage()
64 << "stage_num: " << stage_num; in InferStage()
66 auto per_stage_rank_num = device_num / stage_num; in InferStage()
78 auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num(); in PipelineSplit() local
79 if (stage_num <= 1) { in PipelineSplit()
80 MS_LOG(INFO) << "stage num is: " << stage_num << ". No need Pipeline split."; in PipelineSplit()
104 auto stage = InferStage(global_rank, stage_num, device_num); in PipelineSplit()
105 auto per_stage_rank_num = device_num / stage_num; in PipelineSplit()
/third_party/mindspore/mindspore/ccsrc/frontend/parallel/graph_util/
Dpipeline_split_utils.cc80 auto stage_num = g_device_manager->stage_num(); in IsLastStage() local
82 return ((stage_num - 1) == stage_id); in IsLastStage()
293 auto stage_num = g_device_manager->stage_num(); in ReorderForForward() local
295 for (size_t i = 1; i < LongToSize(stage_num - stage_id); ++i) { in ReorderForForward()
309 auto stage_num = g_device_manager->stage_num(); in ReorderForBackward() local
311 for (size_t i = LongToSize(stage_num - stage_id); i < (forward_start_pair.first.size()); ++i) { in ReorderForBackward()
313 … auto post_node1 = backward_start_pair.first[LongToSize(SizeToLong(i) - stage_num + stage_id + 1)]; in ReorderForBackward()
315 auto prior_node2 = backward_end_pair.second[LongToSize(SizeToLong(i) - stage_num + stage_id)]; in ReorderForBackward()
319 …for (size_t i = LongToSize(stage_num - stage_id); i < (forward_start_pair.first.size() + 1); ++i) { in ReorderForBackward()
321 … auto prior_node3 = backward_start_pair.second[LongToSize(SizeToLong(i) - stage_num + stage_id)]; in ReorderForBackward()
[all …]
/third_party/mindspore/mindspore/ccsrc/frontend/parallel/pipeline_transformer/
Dpipeline_transformer.cc182 auto stage_num = g_device_manager->stage_num(); in CreateForwardGroup() local
183 for (int64_t i = 0; i < stage_num; ++i) { in CreateForwardGroup()
228 auto stage_num = g_device_manager->stage_num(); in Coloring() local
229 if (SizeToLong(stage_set.size()) != stage_num) { in Coloring()
230 …MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.… in Coloring()
821 auto stage_num = g_device_manager->stage_num(); in CutBorder() local
822 if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) { in CutBorder()
823 …OG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num; in CutBorder()
/third_party/mindspore/tests/ut/cpp/parallel/
Ddevice_manager_test.cc86 ASSERT_EQ(dm_.stage_num(), (int32_t)(2)); in TEST_F()
136 ASSERT_EQ(dm_.stage_num(), 2); in TEST_F()
/third_party/mindspore/mindspore/ccsrc/frontend/optimizer/irpass/
Dinline.h52 auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num(); in operator() local
53 if (fg->stage() != -1 && stage_num > 1) { in operator()
119 auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num(); in operator() local
120 if (fg->stage() != -1 && stage_num > 1) { in operator()
/third_party/mindspore/mindspore/ccsrc/frontend/parallel/
Dcontext.cc106 …ontext::set_pipeline_stage_split_num(const int64_t stage_num) { pipeline_stage_split_num_ = stage_… in set_pipeline_stage_split_num() argument
Ddevice_manager.h77 int64_t stage_num() const { return stage_num_; } in stage_num() function
/third_party/mindspore/mindspore/profiler/parser/
Dintegrator.py1386 stage_num = get_auto_parallel_context("pipeline_stages")
1391 stage_num = 1
1392 if stage_num > 1:
1401 … self._cluster_analyse_filename.format(parallel_mode, stage_num, self._rank_size, self._rank_id)
/third_party/mindspore/mindspore/nn/
Dcell.py1452 stage_num = context.get_auto_parallel_context("pipeline_stages")
1455 per_stage_devices = device_num // stage_num