1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef DATASET_ENGINE_OPT_POST_AUTO_WORKER_PASS_H_ 18 #define DATASET_ENGINE_OPT_POST_AUTO_WORKER_PASS_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <thread> 24 #include <utility> 25 #include <vector> 26 27 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" 28 #include "minddata/dataset/engine/opt/pass.h" 29 30 namespace mindspore { 31 namespace dataset { 32 33 class AutoWorkerPass : public IRTreePass { 34 public: 35 // this map will contain weight for the basic pipeline ops. Pipeline op takes up 1 thread but doesn't have workers 36 const std::vector<std::map<std::string, float>> kOpWeightConfigs = { 37 {{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 8}, {kMapNode, 8}}, // config1 leaf:batch:map=1:1:1 38 {{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 4}, {kMapNode, 4}}, // config2 leaf:batch:map=2:1:1 39 {{"MappableSource", 4}, {"NonMappableSource", 4}, {kBatchNode, 8}, {kMapNode, 4}}, // config3 leaf:batch:map=1:2:1 40 {{"MappableSource", 4}, {"NonMappableSource", 4}, {kBatchNode, 4}, {kMapNode, 8}}, // config4 leaf:batch:map=1:1:2 41 {{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 8}, {kMapNode, 4}}, // config5 leaf:batch:map=2:2:1 42 {{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 4}, {kMapNode, 8}}, // config6 leaf:batch:map=2:1:2 43 {{"MappableSource", 4}, {"NonMappableSource", 4}, {kBatchNode, 8}, {kMapNode, 8}}, // config7 leaf:batch:map=1:2:2 44 }; AutoWorkerPass()45 AutoWorkerPass() 46 : min_num_workers_(1), 47 max_num_workers_(8), 48 thread_cnt_(GlobalContext::Instance()->config_manager()->num_cpu_threads()) {} 49 50 /// \brief destructor, by doing "= default", compiler will automatically generate the correct destructor 51 ~AutoWorkerPass() = default; 52 53 Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *) override; 54 55 private: 56 class OpWeightPass : public IRNodePass { 57 public: OpWeightPass(const std::map<std::string,float> & weight_profile)58 explicit OpWeightPass(const std::map<std::string, float> &weight_profile) 59 : IRNodePass(), weight_sum_(0), weight_profile_(weight_profile) {} 60 61 /// \brief destructor, by doing "= default", compiler will automatically generate the correct destructor 62 ~OpWeightPass() = default; 63 64 // this is the base class function which contains the logic to handle most of the pipeline ops 65 // pipeline ops although can't config num_workers it still runs 1 thread they need to be factored into weight 66 Status Visit(std::shared_ptr<DatasetNode> node, bool *const modified) override; 67 // these functions calculate the weights of more complex Nodes which may depend on its input arg. these functions 68 // will also push these nodes to a vector whose num_workers will be set int the Tree Pass 69 Status Visit(std::shared_ptr<BatchNode> node, bool *const modified) override; 70 Status Visit(std::shared_ptr<MapNode> node, bool *const modified) override; 71 Status Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) override; 72 Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) override; 73 74 // helper function to look up weight according to the name of this Op. 75 float GetNodeWeightFromProfile(std::shared_ptr<DatasetNode> node); 76 77 int32_t weight_sum_; // sum of all weights in the pipeline 78 const std::map<std::string, float> weight_profile_; // key: name of ir node, val: weight of this node 79 std::vector<std::pair<std::shared_ptr<DatasetNode>, float>> parallel_ops_; // first: node second: weight 80 }; 81 82 const int32_t min_num_workers_; // minimum number of threads allowed for each op 83 const int32_t max_num_workers_; // maximum number of threads allowed for each op 84 const int32_t thread_cnt_; // thread cnt of current CPU, obtained through config manager 85 }; 86 } // namespace dataset 87 } // namespace mindspore 88 89 #endif // DATASET_ENGINE_OPT_POST_AUTO_WORKER_PASS_H_ 90