1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ 19 20 #include <memory> 21 #include <utility> 22 #include <vector> 23 24 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" 25 #include "minddata/dataset/engine/opt/pass.h" 26 27 namespace mindspore { 28 namespace dataset { 29 30 class DatasetOp; 31 32 class CacheClient; 33 34 /// \class CacheTransformPass cache_transform_pass.h 35 /// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching 36 /// operations 37 class CacheTransformPass : public IRTreePass { 38 /// \class CachePass 39 /// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache 40 /// transformation. It works in conjunction with the CacheTransformPass 41 class CachePass : public IRNodePass { 42 public: 43 /// \brief Constructor 44 /// \param[in] transform_pass Raw pointer back to controlling tree pass 45 CachePass(); 46 47 /// \brief Destructor 48 ~CachePass() = default; 49 50 /// \brief Identifies the subtree below this node as a cached descendant tree. 51 /// \param[in] node The node being visited 52 /// \param[in,out] modified Indicator if the node was changed at all 53 /// \return Status The status code returned 54 Status Visit(std::shared_ptr<DatasetNode> node, bool *const modified) override; 55 56 /// \brief Resets the tracking of the cache within the tree and assigns the operators that 57 /// will be involved in a cache transformation 58 /// \param[in] node The node being visited 59 /// \param[in,out] modified Indicator if the node was changed at all 60 /// \return Status The status code returned 61 Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) override; 62 63 #ifndef ENABLE_ANDROID 64 65 /// \brief Perform non-mappable leaf node cache transform identifications 66 /// \param[in] node The node being visited 67 /// \param[in,out] modified Indicator if the node was changed at all 68 /// \return Status The status code returned 69 Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) override; 70 #endif 71 72 /// \brief Perform non-mappable leaf node cache transform identifications 73 /// \param[in] node The node being visited 74 /// \param[in,out] modified Indicator if the node was changed at all 75 /// \return Status The status code returned 76 Status Visit(std::shared_ptr<RandomNode> node, bool *const modified) override; 77 78 /// \brief Perform mappable leaf node cache transform identifications 79 /// \param[in] node The node being visited 80 /// \param[in,out] modified Indicator if the node was changed at all 81 /// \return Status The status code returned 82 Status Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) override; 83 84 #ifdef ENABLE_PYTHON 85 /// \brief Perform leaf node cache transform identifications 86 /// \param[in] node The node being visited 87 /// \param[in,out] modified Indicator if the node was changed at all 88 /// \return Status The status code returned 89 Status Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) override; 90 #endif 91 92 #ifndef ENABLE_ANDROID 93 /// \brief Perform leaf node cache transform identifications 94 /// \param[in] node The node being visited 95 /// \param[in,out] modified Indicator if the node was changed at all 96 /// \return Status The status code returned 97 Status Visit(std::shared_ptr<MindDataNode> node, bool *const modified) override; 98 #endif 99 100 /// \brief Getter cache_pairs()101 std::vector<std::pair<std::shared_ptr<MappableSourceNode>, std::shared_ptr<DatasetNode>>> cache_pairs() { 102 return cache_pairs_; 103 } 104 105 /// \brief Getter cached_nodes()106 std::vector<std::shared_ptr<DatasetNode>> cached_nodes() { return cached_nodes_; } 107 108 /// \brief Getter sampler()109 std::shared_ptr<SamplerObj> sampler() { return sampler_; } 110 111 private: 112 bool is_caching_; 113 std::shared_ptr<MappableSourceNode> leaf_node_; 114 std::shared_ptr<SamplerObj> sampler_; 115 // The two nodes that work together to establish the cache transform 116 std::vector<std::shared_ptr<DatasetNode>> cached_nodes_; 117 std::vector<std::pair<std::shared_ptr<MappableSourceNode>, std::shared_ptr<DatasetNode>>> cache_pairs_; 118 }; 119 120 public: 121 /// \brief Constructor 122 CacheTransformPass(); 123 124 /// \brief Destructor 125 ~CacheTransformPass() = default; 126 127 /// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations 128 /// \param[in,out] tree The tree to operate on. 129 /// \param[in,out] Indicate of the tree was modified. 130 /// \return Status The status code returned 131 Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) override; 132 133 private: 134 /// \brief Helper function to execute mappable cache transformation. 135 /// 136 /// Input tree: 137 /// Sampler 138 /// | 139 /// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache) 140 /// 141 /// Transformed tree: 142 /// Sampler --> CacheLookupNode -------------------------> 143 /// | | 144 /// | CacheMergeNode 145 /// | | 146 /// LeafNode --> OtherNodes --> CachedNode 147 /// 148 /// \param[in] leaf_node The leaf node in the transform 149 /// \param[in] cached_node The node with cache attribute which is involved in the cache transform 150 /// \return Status The status code returned 151 Status InjectMappableCacheNode(std::shared_ptr<MappableSourceNode> leaf_node, 152 std::shared_ptr<DatasetNode> cached_node); 153 154 /// \brief Helper function to execute non-mappable cache transformation. 155 /// 156 /// Input tree: 157 /// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache) 158 /// 159 /// Transformed tree: 160 /// Sampler 161 /// | 162 /// LeafNode --> OtherNodes --> CachedNode --> CacheNode 163 /// 164 /// \param[in] cached_node The node with cache attribute which is involved in the cache transform 165 /// \param[in] sampler The sampler saved for non-mappable leaf nodes during the CachePass 166 /// \return Status The status code returned 167 Status InjectNonMappableCacheNode(std::shared_ptr<DatasetNode> cached_node, std::shared_ptr<SamplerObj> sampler); 168 }; 169 } // namespace dataset 170 } // namespace mindspore 171 172 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ 173