1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
18 #include <string>
19 #include <vector>
20 #include "frontend/optimizer/optimizer.h"
21 #include "frontend/parallel/allreduce_fusion/allreduce_fusion.h"
22 #include "include/common/utils/parallel_context.h"
23 #include "frontend/parallel/step_parallel_utils.h"
24 #include "frontend/parallel/graph_util/graph_info.h"
25 #include "frontend/parallel/status.h"
26 #include "frontend/parallel/step_parallel.h"
27 #include "utils/log_adapter.h"
28
29 namespace mindspore {
30 namespace parallel {
StepAllreduceFusion(const FuncGraphPtr & root,const opt::OptimizerPtr & optimizer)31 bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
32 MS_EXCEPTION_IF_NULL(root);
33 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
34 bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion();
35 bool enable_all_gather_fusion = ParallelContext::GetInstance()->enable_all_gather_fusion();
36 bool enable_reduce_scatter_fusion = ParallelContext::GetInstance()->enable_reduce_scatter_fusion();
37 auto graph_set = ForwardGraph(root);
38 // assume no change to graph
39 bool changes = false;
40 // control whether use model_parallel mode
41 if (!IsAutoParallelCareGraph(root) ||
42 ((!enable_all_reduce_fusion) && (!enable_all_gather_fusion) && (!enable_reduce_scatter_fusion)) ||
43 (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY)) || graph_set.size() < 1) {
44 return changes;
45 }
46
47 #if defined(_WIN32) || defined(_WIN64)
48 auto start_time = std::chrono::steady_clock::now();
49 #else
50 struct timeval start_time {
51 0
52 };
53 struct timeval end_time {
54 0
55 };
56 (void)gettimeofday(&start_time, nullptr);
57 #endif
58 MS_LOG(INFO) << "Now entering comm ops (allreduce, allgather, reducescatter) fusion by size, and fusion before will "
59 "be overlapped!";
60 DumpGraph(root, std::string(ALLREDUCE_FUSION_BEGIN));
61 FuncGraphManagerPtr manager;
62 pipeline::ResourceBasePtr res;
63 if (optimizer == nullptr) {
64 manager = root->manager();
65 res = std::make_shared<pipeline::Resource>();
66 res->set_manager(manager);
67 } else {
68 res = optimizer->resource();
69 MS_EXCEPTION_IF_NULL(res);
70 manager = res->manager();
71 }
72
73 MS_EXCEPTION_IF_NULL(manager);
74 CNodePtr ret = root->get_return();
75 MS_EXCEPTION_IF_NULL(ret);
76
77 AllCommFusion allcomm_fusion;
78 std::vector<std::string> comm_ops = {ALL_REDUCE, ALL_GATHER, REDUCE_SCATTER};
79 std::vector<bool> fusionlist = {enable_all_reduce_fusion, enable_all_gather_fusion, enable_reduce_scatter_fusion};
80 for (size_t i = 0; i < comm_ops.size(); i++) {
81 if (fusionlist[i]) {
82 if (allcomm_fusion.ProcessCommOpsFusion(ret, comm_ops[i]) != SUCCESS) {
83 MS_LOG(EXCEPTION) << "Process" << comm_ops[i] << "Fusion failed";
84 }
85 }
86 }
87
88 DumpGraph(root, std::string(ALLREDUCE_FUSION_END));
89
90 // allreduce fusion only run once
91 root->set_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY, true);
92 // Keep all func graph for parallel before save result.
93 SetReserved(root);
94 res->SetResult(pipeline::kStepParallelGraph, root);
95 #if defined(_WIN32) || defined(_WIN64)
96 auto end_time = std::chrono::steady_clock::now();
97 std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
98 MS_LOG(INFO) << "Now leaving allreduce fusion, used time: " << cost.count() << " us";
99 #else
100 (void)gettimeofday(&end_time, nullptr);
101 uint64_t time = 1000000 * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
102 time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
103 MS_LOG(INFO) << "Now leaving allreduce fusion, used time: " << time << " us";
104 #endif
105 return changes;
106 }
107 } // namespace parallel
108 } // namespace mindspore
109