• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_
19 
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 
24 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
25 #include "minddata/dataset/engine/opt/pass.h"
26 
27 namespace mindspore {
28 namespace dataset {
29 
30 class DatasetOp;
31 
32 class CacheClient;
33 
34 /// \class CacheTransformPass cache_transform_pass.h
35 /// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching
36 ///     operations
37 class CacheTransformPass : public IRTreePass {
38   /// \class CachePass
39   /// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache
40   ///     transformation. It works in conjunction with the CacheTransformPass
41   class CachePass : public IRNodePass {
42    public:
43     /// \brief Constructor
44     /// \param[in] transform_pass Raw pointer back to controlling tree pass
45     CachePass();
46 
47     /// \brief Destructor
48     ~CachePass() = default;
49 
50     /// \brief Identifies the subtree below this node as a cached descendant tree.
51     /// \param[in] node The node being visited
52     /// \param[in,out] modified Indicator if the node was changed at all
53     /// \return Status The status code returned
54     Status Visit(std::shared_ptr<DatasetNode> node, bool *const modified) override;
55 
56     /// \brief Resets the tracking of the cache within the tree and assigns the operators that
57     ///     will be involved in a cache transformation
58     /// \param[in] node The node being visited
59     /// \param[in,out] modified Indicator if the node was changed at all
60     /// \return Status The status code returned
61     Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) override;
62 
63 #ifndef ENABLE_ANDROID
64 
65     /// \brief Perform non-mappable leaf node cache transform identifications
66     /// \param[in] node The node being visited
67     /// \param[in,out] modified Indicator if the node was changed at all
68     /// \return Status The status code returned
69     Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) override;
70 #endif
71 
72     /// \brief Perform non-mappable leaf node cache transform identifications
73     /// \param[in] node The node being visited
74     /// \param[in,out] modified Indicator if the node was changed at all
75     /// \return Status The status code returned
76     Status Visit(std::shared_ptr<RandomNode> node, bool *const modified) override;
77 
78     /// \brief Perform mappable leaf node cache transform identifications
79     /// \param[in] node The node being visited
80     /// \param[in,out] modified Indicator if the node was changed at all
81     /// \return Status The status code returned
82     Status Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) override;
83 
84 #ifdef ENABLE_PYTHON
85     /// \brief Perform leaf node cache transform identifications
86     /// \param[in] node The node being visited
87     /// \param[in,out] modified Indicator if the node was changed at all
88     /// \return Status The status code returned
89     Status Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) override;
90 #endif
91 
92 #ifndef ENABLE_ANDROID
93     /// \brief Perform leaf node cache transform identifications
94     /// \param[in] node The node being visited
95     /// \param[in,out] modified Indicator if the node was changed at all
96     /// \return Status The status code returned
97     Status Visit(std::shared_ptr<MindDataNode> node, bool *const modified) override;
98 #endif
99 
100     /// \brief Getter
cache_pairs()101     std::vector<std::pair<std::shared_ptr<MappableSourceNode>, std::shared_ptr<DatasetNode>>> cache_pairs() {
102       return cache_pairs_;
103     }
104 
105     /// \brief Getter
cached_nodes()106     std::vector<std::shared_ptr<DatasetNode>> cached_nodes() { return cached_nodes_; }
107 
108     /// \brief Getter
sampler()109     std::shared_ptr<SamplerObj> sampler() { return sampler_; }
110 
111    private:
112     bool is_caching_;
113     std::shared_ptr<MappableSourceNode> leaf_node_;
114     std::shared_ptr<SamplerObj> sampler_;
115     // The two nodes that work together to establish the cache transform
116     std::vector<std::shared_ptr<DatasetNode>> cached_nodes_;
117     std::vector<std::pair<std::shared_ptr<MappableSourceNode>, std::shared_ptr<DatasetNode>>> cache_pairs_;
118   };
119 
120  public:
121   /// \brief Constructor
122   CacheTransformPass();
123 
124   /// \brief Destructor
125   ~CacheTransformPass() = default;
126 
127   /// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations
128   /// \param[in,out] tree The tree to operate on.
129   /// \param[in,out] Indicate of the tree was modified.
130   /// \return Status The status code returned
131   Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) override;
132 
133  private:
134   /// \brief Helper function to execute mappable cache transformation.
135   ///
136   ///     Input tree:
137   ///       Sampler
138   ///         |
139   ///       LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
140   ///
141   ///     Transformed tree:
142   ///       Sampler --> CacheLookupNode ------------------------->
143   ///                           |                                |
144   ///                           |                           CacheMergeNode
145   ///                           |                                |
146   ///                           LeafNode --> OtherNodes --> CachedNode
147   ///
148   /// \param[in] leaf_node The leaf node in the transform
149   /// \param[in] cached_node The node with cache attribute which is involved in the cache transform
150   /// \return Status The status code returned
151   Status InjectMappableCacheNode(std::shared_ptr<MappableSourceNode> leaf_node,
152                                  std::shared_ptr<DatasetNode> cached_node);
153 
154   /// \brief Helper function to execute non-mappable cache transformation.
155   ///
156   ///     Input tree:
157   ///       LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
158   ///
159   ///     Transformed tree:
160   ///                                                   Sampler
161   ///                                                      |
162   ///       LeafNode --> OtherNodes --> CachedNode --> CacheNode
163   ///
164   /// \param[in] cached_node The node with cache attribute which is involved in the cache transform
165   /// \param[in] sampler The sampler saved for non-mappable leaf nodes during the CachePass
166   /// \return Status The status code returned
167   Status InjectNonMappableCacheNode(std::shared_ptr<DatasetNode> cached_node, std::shared_ptr<SamplerObj> sampler);
168 };
169 }  // namespace dataset
170 }  // namespace mindspore
171 
172 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_
173