• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_OFFLOAD_PASS_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_OFFLOAD_PASS_H_
19 
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <vector>
25 #include "minddata/dataset/engine/opt/pass.h"
26 
27 namespace mindspore {
28 namespace dataset {
29 class DatasetOp;
30 
31 /// \class NodeOffloadPass
32 /// \brief This is a tree pass that will offload nodes.  It uses offload_nodes to first identify which
33 ///     nodes should be offloaded, adds the nodes' namea to the offload list, then removes the nodes from the ir tree.
34 class NodeOffloadPass : public IRTreePass {
35   /// \class OffloadNodes
36   /// \brief This is a NodePass whose job is to identify which nodes should be offloaded.
37   class OffloadNodes : public IRNodePass {
38    public:
39     /// \brief Constructor
40     OffloadNodes();
41     /// \brief Destructor
42     ~OffloadNodes() = default;
43 
44     /// \brief Perform MapNode offload check
45     /// \param[in] node The node being visited
46     /// \param[in, out] modified Indicator if the node was changed at all
47     /// \return Status The status code returned
48     Status Visit(std::shared_ptr<MapNode> node, bool *const modified) override;
49 
50     /// \brief Access selected offload nodes for removal.
51     /// \return All the nodes to be removed by offload.
nodes_to_offload()52     std::vector<std::shared_ptr<DatasetNode>> nodes_to_offload() { return nodes_to_offload_; }
53 
54    private:
55     /// \brief Vector of nodes to offload
56     std::vector<std::shared_ptr<DatasetNode>> nodes_to_offload_;
57     /// \brief Vector of supported offload operations
58     const std::set<std::string> supported_ops_{
59       "HwcToChw",        "Normalize",          "RandomColorAdjust", "RandomHorizontalFlip",
60       "RandomSharpness", "RandomVerticalFlip", "Rescale",           "TypeCast"};
61     /// \brief std::map indicating if the map op for the input column is at the end of the pipeline
62     std::map<std::string, bool> end_of_pipeline_;
63     /// \brief bool indicating whether the auto_offload config option is enabled
64     bool auto_offload_;
65   };
66 
67  public:
68   /// \brief Constructor
69   NodeOffloadPass();
70 
71   /// \brief Destructor
72   ~NodeOffloadPass() = default;
73 
74   /// \brief Runs an offload_nodes pass first to find out which nodes to offload, then offloads them.
75   /// \param[in, out] root_ir The tree to operate on.
76   /// \param[in, out] modified Indicates if the tree was modified.
77   /// \return Status The status code returned
78   Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) override;
79   /// \brief Getter
80   /// \return JSON of offload
GetOffloadJson()81   nlohmann::json GetOffloadJson() { return offload_json_list_; }
82 
83  private:
84   /// \brief JSON instance containing single offload op.
85   nlohmann::json offload_json_;
86 
87   /// \brief JSON instance containing all offload ops.
88   nlohmann::json offload_json_list_;
89 };
90 }  // namespace dataset
91 }  // namespace mindspore
92 
93 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_OFFLOAD_PASS_H_
94