• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
18 #include <string>
19 #include <vector>
20 #include "frontend/optimizer/optimizer.h"
21 #include "frontend/parallel/allreduce_fusion/allreduce_fusion.h"
22 #include "include/common/utils/parallel_context.h"
23 #include "frontend/parallel/step_parallel_utils.h"
24 #include "frontend/parallel/graph_util/graph_info.h"
25 #include "frontend/parallel/status.h"
26 #include "frontend/parallel/step_parallel.h"
27 #include "utils/log_adapter.h"
28 
29 namespace mindspore {
30 namespace parallel {
StepAllreduceFusion(const FuncGraphPtr & root,const opt::OptimizerPtr & optimizer)31 bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
32   MS_EXCEPTION_IF_NULL(root);
33   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
34   bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion();
35   bool enable_all_gather_fusion = ParallelContext::GetInstance()->enable_all_gather_fusion();
36   bool enable_reduce_scatter_fusion = ParallelContext::GetInstance()->enable_reduce_scatter_fusion();
37   auto graph_set = ForwardGraph(root);
38   // assume no change to graph
39   bool changes = false;
40   // control whether use model_parallel mode
41   if (!IsAutoParallelCareGraph(root) ||
42       ((!enable_all_reduce_fusion) && (!enable_all_gather_fusion) && (!enable_reduce_scatter_fusion)) ||
43       (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY)) || graph_set.size() < 1) {
44     return changes;
45   }
46 
47 #if defined(_WIN32) || defined(_WIN64)
48   auto start_time = std::chrono::steady_clock::now();
49 #else
50   struct timeval start_time {
51     0
52   };
53   struct timeval end_time {
54     0
55   };
56   (void)gettimeofday(&start_time, nullptr);
57 #endif
58   MS_LOG(INFO) << "Now entering comm ops (allreduce, allgather, reducescatter) fusion by size, and fusion before will "
59                   "be overlapped!";
60   DumpGraph(root, std::string(ALLREDUCE_FUSION_BEGIN));
61   FuncGraphManagerPtr manager;
62   pipeline::ResourceBasePtr res;
63   if (optimizer == nullptr) {
64     manager = root->manager();
65     res = std::make_shared<pipeline::Resource>();
66     res->set_manager(manager);
67   } else {
68     res = optimizer->resource();
69     MS_EXCEPTION_IF_NULL(res);
70     manager = res->manager();
71   }
72 
73   MS_EXCEPTION_IF_NULL(manager);
74   CNodePtr ret = root->get_return();
75   MS_EXCEPTION_IF_NULL(ret);
76 
77   AllCommFusion allcomm_fusion;
78   std::vector<std::string> comm_ops = {ALL_REDUCE, ALL_GATHER, REDUCE_SCATTER};
79   std::vector<bool> fusionlist = {enable_all_reduce_fusion, enable_all_gather_fusion, enable_reduce_scatter_fusion};
80   for (size_t i = 0; i < comm_ops.size(); i++) {
81     if (fusionlist[i]) {
82       if (allcomm_fusion.ProcessCommOpsFusion(ret, comm_ops[i]) != SUCCESS) {
83         MS_LOG(EXCEPTION) << "Process" << comm_ops[i] << "Fusion failed";
84       }
85     }
86   }
87 
88   DumpGraph(root, std::string(ALLREDUCE_FUSION_END));
89 
90   // allreduce fusion only run once
91   root->set_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY, true);
92   // Keep all func graph for parallel before save result.
93   SetReserved(root);
94   res->SetResult(pipeline::kStepParallelGraph, root);
95 #if defined(_WIN32) || defined(_WIN64)
96   auto end_time = std::chrono::steady_clock::now();
97   std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
98   MS_LOG(INFO) << "Now leaving allreduce fusion, used time: " << cost.count() << " us";
99 #else
100   (void)gettimeofday(&end_time, nullptr);
101   uint64_t time = 1000000 * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
102   time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
103   MS_LOG(INFO) << "Now leaving allreduce fusion, used time: " << time << " us";
104 #endif
105   return changes;
106 }
107 }  // namespace parallel
108 }  // namespace mindspore
109