1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
18 #include <ctime>
19 #include <string>
20 #include "frontend/optimizer/optimizer.h"
21 #include "frontend/parallel/allreduce_fusion/allreduce_fusion.h"
22 #include "frontend/parallel/context.h"
23 #include "frontend/parallel/graph_util/graph_info.h"
24 #include "frontend/parallel/status.h"
25 #include "utils/log_adapter.h"
26
27 namespace mindspore {
28 namespace parallel {
StepAllreduceFusion(const FuncGraphPtr & root,const opt::OptimizerPtr & optimizer)29 bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
30 MS_EXCEPTION_IF_NULL(root);
31 MS_EXCEPTION_IF_NULL(optimizer);
32 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
33 std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
34 bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion();
35 // assume no change to graph
36 bool changes = false;
37 // control whether use model_parallel mode
38 if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
39 (!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) {
40 return changes;
41 }
42 #if defined(_WIN32) || defined(_WIN64)
43 auto start_time = std::chrono::steady_clock::now();
44 #else
45 struct timeval start_time {
46 0
47 }, end_time{0};
48 (void)gettimeofday(&start_time, nullptr);
49 #endif
50 MS_LOG(INFO) << "Now entering allreduce fusion";
51 DumpGraph(root, std::string(ALLREDUCE_FUSION_BEGIN));
52
53 pipeline::ResourceBasePtr res = optimizer->resource();
54 MS_EXCEPTION_IF_NULL(res);
55
56 FuncGraphManagerPtr manager = res->manager();
57 MS_EXCEPTION_IF_NULL(manager);
58 CNodePtr ret = root->get_return();
59 MS_EXCEPTION_IF_NULL(ret);
60
61 AllreduceFusion allreduce_fusion;
62 if (allreduce_fusion.ProcessAllreduceFusion(ret) != SUCCESS) {
63 MS_LOG(EXCEPTION) << "ProcessAllreduceFusion failed";
64 }
65
66 DumpGraph(root, std::string(ALLREDUCE_FUSION_END));
67
68 // allreduce fusion only run once
69 root->set_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY, true);
70 res->results()[pipeline::kStepParallelGraph] = root;
71 #if defined(_WIN32) || defined(_WIN64)
72 auto end_time = std::chrono::steady_clock::now();
73 std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
74 MS_LOG(INFO) << "Now leaving allreduce fusion, used time: " << cost.count() << " us";
75 #else
76 (void)gettimeofday(&end_time, nullptr);
77 uint64_t time = 1000000 * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
78 time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
79 MS_LOG(INFO) << "Now leaving allreduce fusion, used time: " << time << " us";
80 #endif
81 return changes;
82 }
83 } // namespace parallel
84 } // namespace mindspore
85