1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_OFFLOAD_PASS_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_OFFLOAD_PASS_H_ 19 20 #include <map> 21 #include <memory> 22 #include <set> 23 #include <string> 24 #include <vector> 25 #include "minddata/dataset/engine/opt/pass.h" 26 27 namespace mindspore { 28 namespace dataset { 29 class DatasetOp; 30 31 /// \class NodeOffloadPass 32 /// \brief This is a tree pass that will offload nodes. It uses offload_nodes to first identify which 33 /// nodes should be offloaded, adds the nodes' namea to the offload list, then removes the nodes from the ir tree. 34 class NodeOffloadPass : public IRTreePass { 35 /// \class OffloadNodes 36 /// \brief This is a NodePass whose job is to identify which nodes should be offloaded. 37 class OffloadNodes : public IRNodePass { 38 public: 39 /// \brief Constructor 40 OffloadNodes(); 41 /// \brief Destructor 42 ~OffloadNodes() = default; 43 44 /// \brief Perform MapNode offload check 45 /// \param[in] node The node being visited 46 /// \param[in, out] modified Indicator if the node was changed at all 47 /// \return Status The status code returned 48 Status Visit(std::shared_ptr<MapNode> node, bool *const modified) override; 49 50 /// \brief Access selected offload nodes for removal. 51 /// \return All the nodes to be removed by offload. nodes_to_offload()52 std::vector<std::shared_ptr<DatasetNode>> nodes_to_offload() { return nodes_to_offload_; } 53 54 private: 55 /// \brief Vector of nodes to offload 56 std::vector<std::shared_ptr<DatasetNode>> nodes_to_offload_; 57 /// \brief Vector of supported offload operations 58 const std::set<std::string> supported_ops_{ 59 "HwcToChw", "Normalize", "RandomColorAdjust", "RandomHorizontalFlip", 60 "RandomSharpness", "RandomVerticalFlip", "Rescale", "TypeCast"}; 61 /// \brief std::map indicating if the map op for the input column is at the end of the pipeline 62 std::map<std::string, bool> end_of_pipeline_; 63 /// \brief bool indicating whether the auto_offload config option is enabled 64 bool auto_offload_; 65 }; 66 67 public: 68 /// \brief Constructor 69 NodeOffloadPass(); 70 71 /// \brief Destructor 72 ~NodeOffloadPass() = default; 73 74 /// \brief Runs an offload_nodes pass first to find out which nodes to offload, then offloads them. 75 /// \param[in, out] root_ir The tree to operate on. 76 /// \param[in, out] modified Indicates if the tree was modified. 77 /// \return Status The status code returned 78 Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) override; 79 /// \brief Getter 80 /// \return JSON of offload GetOffloadJson()81 nlohmann::json GetOffloadJson() { return offload_json_list_; } 82 83 private: 84 /// \brief JSON instance containing single offload op. 85 nlohmann::json offload_json_; 86 87 /// \brief JSON instance containing all offload ops. 88 nlohmann::json offload_json_list_; 89 }; 90 } // namespace dataset 91 } // namespace mindspore 92 93 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_OFFLOAD_PASS_H_ 94