• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef DATASET_ENGINE_OPT_POST_AUTO_WORKER_PASS_H_
18 #define DATASET_ENGINE_OPT_POST_AUTO_WORKER_PASS_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <thread>
24 #include <utility>
25 #include <vector>
26 
27 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
28 #include "minddata/dataset/engine/opt/pass.h"
29 
30 namespace mindspore {
31 namespace dataset {
32 
33 class AutoWorkerPass : public IRTreePass {
34  public:
35   // this map will contain weight for the basic pipeline ops. Pipeline op takes up 1 thread but doesn't have workers
36   const std::vector<std::map<std::string, float>> kOpWeightConfigs = {
37     {{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 8}, {kMapNode, 8}},  // config1 leaf:batch:map=1:1:1
38     {{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 4}, {kMapNode, 4}},  // config2 leaf:batch:map=2:1:1
39     {{"MappableSource", 4}, {"NonMappableSource", 4}, {kBatchNode, 8}, {kMapNode, 4}},  // config3 leaf:batch:map=1:2:1
40     {{"MappableSource", 4}, {"NonMappableSource", 4}, {kBatchNode, 4}, {kMapNode, 8}},  // config4 leaf:batch:map=1:1:2
41     {{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 8}, {kMapNode, 4}},  // config5 leaf:batch:map=2:2:1
42     {{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 4}, {kMapNode, 8}},  // config6 leaf:batch:map=2:1:2
43     {{"MappableSource", 4}, {"NonMappableSource", 4}, {kBatchNode, 8}, {kMapNode, 8}},  // config7 leaf:batch:map=1:2:2
44   };
AutoWorkerPass()45   AutoWorkerPass()
46       : min_num_workers_(1),
47         max_num_workers_(8),
48         thread_cnt_(GlobalContext::Instance()->config_manager()->num_cpu_threads()) {}
49 
50   /// \brief destructor, by doing "= default", compiler will automatically generate the correct destructor
51   ~AutoWorkerPass() = default;
52 
53   Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *) override;
54 
55  private:
56   class OpWeightPass : public IRNodePass {
57    public:
OpWeightPass(const std::map<std::string,float> & weight_profile)58     explicit OpWeightPass(const std::map<std::string, float> &weight_profile)
59         : IRNodePass(), weight_sum_(0), weight_profile_(weight_profile) {}
60 
61     /// \brief destructor, by doing "= default", compiler will automatically generate the correct destructor
62     ~OpWeightPass() = default;
63 
64     // this is the base class function which contains the logic to handle most of the pipeline ops
65     // pipeline ops although can't config num_workers it still runs 1 thread they need to be factored into weight
66     Status Visit(std::shared_ptr<DatasetNode> node, bool *const modified) override;
67     // these functions calculate the weights of more complex Nodes which may depend on its input arg. these functions
68     // will also push these nodes to a vector whose num_workers will be set int the Tree Pass
69     Status Visit(std::shared_ptr<BatchNode> node, bool *const modified) override;
70     Status Visit(std::shared_ptr<MapNode> node, bool *const modified) override;
71     Status Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) override;
72     Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) override;
73 
74     // helper function to look up weight according to the name of this Op.
75     float GetNodeWeightFromProfile(std::shared_ptr<DatasetNode> node);
76 
77     int32_t weight_sum_;                                 // sum of all weights in the pipeline
78     const std::map<std::string, float> weight_profile_;  // key: name of ir node, val: weight of this node
79     std::vector<std::pair<std::shared_ptr<DatasetNode>, float>> parallel_ops_;  // first: node second: weight
80   };
81 
82   const int32_t min_num_workers_;  // minimum number of threads allowed for each op
83   const int32_t max_num_workers_;  // maximum number of threads allowed for each op
84   const int32_t thread_cnt_;       // thread cnt of current CPU, obtained through config manager
85 };
86 }  // namespace dataset
87 }  // namespace mindspore
88 
89 #endif  // DATASET_ENGINE_OPT_POST_AUTO_WORKER_PASS_H_
90