• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/batch_node.h"
18 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
19 #include "minddata/dataset/engine/ir/datasetops/map_node.h"
20 #include "minddata/dataset/engine/opt/post/auto_worker_pass.h"
21 
22 namespace mindspore {
23 namespace dataset {
24 
25 // this will become the RootNode:DatasetNode when it is turned on
RunOnTree(std::shared_ptr<DatasetNode> root_ir,bool * const modified)26 Status AutoWorkerPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
27   RETURN_UNEXPECTED_IF_NULL(root_ir);
28   RETURN_UNEXPECTED_IF_NULL(modified);
29   uint8_t config = GlobalContext::config_manager()->get_auto_worker_config();
30 
31   OpWeightPass pass(kOpWeightConfigs[config < kOpWeightConfigs.size() ? config : 0]);
32 
33   std::string weight_str;
34   for (const auto &p : pass.weight_profile_) weight_str += ("(" + p.first + "=" + std::to_string(p.second) + ")");
35   int32_t num_shards = GlobalContext::config_manager()->get_num_shards_for_auto_num_workers();
36   num_shards = std::min(std::max(1, num_shards), thread_cnt_);
37 
38   MS_LOG(INFO) << "AutoWorkerPass is enabled; this could override existing num_workers set in each parallel op."
39                << "total number of threads on this CPU: " << thread_cnt_ << ", "
40                << "min num_workers to override:" << min_num_workers_ << ", "
41                << "max num_workers to override:" << max_num_workers_ << ", "
42                << "adjusted num_shards (between 1 and total thread cnt): " << num_shards
43                << ", weight profile:" << weight_str << ".";
44 
45   // get the maximum weight of all the ops, this value is used to ensure the ratio of num_workers between ops
46   float max_weight = 0;
47   for (const auto &p : pass.weight_profile_) max_weight = std::max(max_weight, p.second);
48 
49   CHECK_FAIL_RETURN_UNEXPECTED(max_weight != 0, "Internal error, doesn't allow divide zero.");
50   RETURN_IF_NOT_OK(pass.Run(root_ir, modified));
51   constexpr size_t max_num_ops = 3;
52   if (pass.parallel_ops_.size() > max_num_ops) {
53     MS_LOG(WARNING) << "AutoNumWorker right now is only suitable for simple dataset pipelines that has at most, 1 leaf "
54                     << "1 batch and 1 map. AutoNumWorker may not be optimal for usage on complex pipelines.";
55   }
56 
57   CHECK_FAIL_RETURN_UNEXPECTED(pass.weight_sum_ != 0, "Internal error, doesn't allow divide zero.");
58   for (auto &p : pass.parallel_ops_) {
59     // get the num worker via the weight ratio
60     int32_t num_workers = std::ceil((thread_cnt_ * p.second) / (pass.weight_sum_ * num_shards));
61     // this is to ensure when thread_cnt_ is very large let's say 192, the num_worker ratio is still kept
62     // e.g. the optional 2:1 ratio between minddataset and batch
63     int32_t cur_node_max = std::ceil(p.second * max_num_workers_ / max_weight);
64     // this will ensure that num_workers will fall with the range of [1,cur_node_max]
65     int32_t cur_node_num_worker = std::max(std::min(num_workers, cur_node_max), min_num_workers_);
66 
67     // if the num_worker to set is same as original, skip setting and printing the logs
68     if (cur_node_num_worker == p.first->NumWorkers()) continue;
69     // log the change via warning msg so user can see what the num_worker is being set for which op
70     MS_LOG(WARNING) << "AutoNumWorker enabled, num_workers in " << p.first->Name() << " is auto-adjusted from "
71                     << std::to_string(p.first->NumWorkers()) + " to " + std::to_string(cur_node_num_worker);
72     p.first->SetNumWorkers(cur_node_num_worker);
73   }
74   return Status::OK();
75 }
76 
Visit(std::shared_ptr<MapNode> node,bool * const modified)77 Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
78   auto itr = weight_profile_.find(node->Name());
79   CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(), node->Name() + "'s weight doesn't exist.");
80   int32_t weight = itr->second;
81   weight_sum_ += weight;
82   parallel_ops_.emplace_back(std::make_pair(std::static_pointer_cast<DatasetNode>(node), weight));
83   return Status::OK();
84 }
85 
Visit(std::shared_ptr<BatchNode> node,bool * const modified)86 Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<BatchNode> node, bool *const modified) {
87   auto itr = weight_profile_.find(node->Name());
88   CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(), node->Name() + "'s weight doesn't exist.");
89   int32_t weight = itr->second;
90   weight_sum_ += weight;
91   parallel_ops_.emplace_back(std::make_pair(std::static_pointer_cast<DatasetNode>(node), weight));
92   return Status::OK();
93 }
94 
Visit(std::shared_ptr<MappableSourceNode> node,bool * const modified)95 Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) {
96   RETURN_OK_IF_TRUE(node->Name() == kGeneratorNode);  // generator is pipeline op, skip this
97   auto itr = weight_profile_.find("MappableSource");
98   CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(),
99                                "LeafSourceNode::" + node->Name() + "'s weight doesn't exist.");
100   int32_t weight = itr->second;
101   weight_sum_ += weight;
102   parallel_ops_.emplace_back(std::make_pair(std::static_pointer_cast<DatasetNode>(node), weight));
103   return Status::OK();
104 }
105 
Visit(std::shared_ptr<NonMappableSourceNode> node,bool * const modified)106 Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) {
107   auto itr = weight_profile_.find("NonMappableSource");
108   CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(),
109                                "NonLeafSource::" + node->Name() + "'s weight doesn't exist.");
110   int32_t weight = itr->second;
111   weight_sum_ += weight;
112   parallel_ops_.emplace_back(std::make_pair(std::static_pointer_cast<DatasetNode>(node), weight));
113   return Status::OK();
114 }
115 
Visit(std::shared_ptr<DatasetNode> node,bool * const modified)116 Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<DatasetNode> node, bool *const modified) {
117   weight_sum_ += GetNodeWeightFromProfile(node);
118   return Status::OK();
119 }
120 
GetNodeWeightFromProfile(std::shared_ptr<DatasetNode> node)121 float AutoWorkerPass::OpWeightPass::GetNodeWeightFromProfile(std::shared_ptr<DatasetNode> node) {
122   auto itr = weight_profile_.find(node->Name());
123   // returns 0 if name doesn't exist in the weight profile
124   return itr == weight_profile_.end() ? 0 : itr->second;
125 }
126 
127 }  // namespace dataset
128 }  // namespace mindspore
129