• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <vector>
18 #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
19 #include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
20 #include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h"
21 #include "minddata/dataset/engine/ir/datasetops/cache_node.h"
22 #ifdef ENABLE_PYTHON
23 #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
24 #endif
25 #ifndef ENABLE_ANDROID
26 #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
27 #endif
28 #include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
29 
30 namespace mindspore {
31 namespace dataset {
32 
33 // Constructor
CachePass()34 CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_node_(nullptr), sampler_(nullptr) {}
35 
36 // Identifies the subtree below this node as a cached descendant tree.
37 // Note that this function will only get called on non-leaf nodes.
38 // For leaf nodes, the other Visit with NonMappableSourceNode or MappableSourceNode argument will be called instead.
Visit(std::shared_ptr<DatasetNode> node,bool * const modified)39 Status CacheTransformPass::CachePass::Visit(std::shared_ptr<DatasetNode> node, bool *const modified) {
40   *modified = false;
41   if (node->IsCached()) {
42     MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
43     is_caching_ = true;
44   }
45   return Status::OK();
46 }
47 
48 // Resets the tracking of the cache within the tree and assigns the nodes that will be involved in a cache
49 // transformation
VisitAfter(std::shared_ptr<DatasetNode> node,bool * const modified)50 Status CacheTransformPass::CachePass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) {
51   *modified = false;
52   if (node->IsCached()) {
53     is_caching_ = false;  // We a no longer in a cache subtree.  clear the flag.
54     if (leaf_node_) {
55       MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache.";
56       // Assign the leaf node into the transform pass, using move to null our copy of it,
57       // and also assign the cached node, using base class pointers.
58       // In the cases where cache is directly injected after the leaf node, these two nodes might be the same.
59       cache_pairs_.push_back(std::make_pair(std::move(leaf_node_), node));
60     } else {
61       // If there was no leaf_node_ set, then this is a non-mappable scenario.
62       // We only assign the cached node in this case.
63       cached_nodes_.push_back(node);
64     }
65   }
66 
67   return Status::OK();
68 }
69 
70 #ifndef ENABLE_ANDROID
71 // Perform leaf node cache transform identification
Visit(std::shared_ptr<NonMappableSourceNode> node,bool * const modified)72 Status CacheTransformPass::CachePass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) {
73   if (node->IsCached()) {
74     MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
75     is_caching_ = true;
76   }
77   // Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true
78   // by the other Visit() with DatasetNode argument
79   if (is_caching_) {
80     MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
81     // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
82     if (leaf_node_) {
83       return Status(StatusCode::kMDNotImplementedYet, __LINE__, __FILE__,
84                     "There is currently no support for multiple leaf nodes under cache.");
85     }
86     // Set up a sampler here to be used by cache if we are a non-mappable leaf in a caching tree.
87     // Node that sampler for non mappable dataset only works if there is a downstream cache.
88     RETURN_IF_NOT_OK(node->SetupSamplerForCache(&sampler_));
89     // If we are a non-mappable source node in a caching tree, then change our config so that it becomes a basic
90     // source node that parses all files. Selection of data will come from the sampler on the cache instead.
91     RETURN_IF_NOT_OK(node->MakeSimpleProducer());
92   }
93   return Status::OK();
94 }
95 #endif
96 
97 // Almost the same with NonMappableSourceNode's Visit, only this one is not guarded by the compiler
98 // directive #ifndef ENABLE_ANDROID, also and there is no need to call MakeSimpleProducer() because
99 // RandomOp doesn't support sampling or sharding
Visit(std::shared_ptr<RandomNode> node,bool * const modified)100 Status CacheTransformPass::CachePass::Visit(std::shared_ptr<RandomNode> node, bool *const modified) {
101   if (node->IsCached()) {
102     MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
103     is_caching_ = true;
104   }
105   // Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true
106   // by the other Visit() with DatasetNode argument
107   if (is_caching_) {
108     MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
109     // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
110     if (leaf_node_) {
111       return Status(StatusCode::kMDNotImplementedYet, __LINE__, __FILE__,
112                     "There is currently no support for multiple leaf nodes under cache.");
113     }
114     // Set up a sampler here to be used by cache if we are a non-mappable leaf in a caching tree.
115     // Node that sampler for non mappable dataset only works if there is a downstream cache.
116     RETURN_IF_NOT_OK(node->SetupSamplerForCache(&sampler_));
117   }
118   return Status::OK();
119 }
120 
121 // Perform leaf node cache transform identification
Visit(std::shared_ptr<MappableSourceNode> node,bool * const modified)122 Status CacheTransformPass::CachePass::Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) {
123   if (node->IsCached()) {
124     MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
125     is_caching_ = true;
126   }
127   // Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true
128   // by the other Visit() with DatasetNode argument
129   if (is_caching_) {
130     MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected";
131     // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
132     if (leaf_node_) {
133       return Status(StatusCode::kMDNotImplementedYet, __LINE__, __FILE__,
134                     "There is currently no support for multiple leaf nodes under cache.");
135     }
136     // If we are a leaf in the caching path, then save this leaf
137     leaf_node_ = node;
138   }
139   return Status::OK();
140 }
141 
142 #ifndef ENABLE_ANDROID
143 // Almost the same with MappableSourceNode's Visit, only in this one we also marked this node's descendant_of_cache_
144 // field to true. Later when building, MindDataNode will take different actions based on this information.
Visit(std::shared_ptr<MindDataNode> node,bool * const modified)145 Status CacheTransformPass::CachePass::Visit(std::shared_ptr<MindDataNode> node, bool *const modified) {
146   if (node->IsCached()) {
147     MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
148     is_caching_ = true;
149   }
150   // Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true
151   // by the other Visit() with DatasetNode argument
152   if (is_caching_) {
153     node->HasCacheAbove();
154     MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected";
155     // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
156     if (leaf_node_) {
157       return Status(StatusCode::kMDNotImplementedYet, __LINE__, __FILE__,
158                     "There is currently no support for multiple leaf nodes under cache.");
159     }
160     // If we are a leaf in the caching path, then save this leaf
161     leaf_node_ = node;
162   }
163   return Status::OK();
164 }
165 #endif
166 
167 #ifdef ENABLE_PYTHON
168 // Perform leaf node cache transform identification
Visit(std::shared_ptr<GeneratorNode> node,bool * const modified)169 Status CacheTransformPass::CachePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) {
170   if (node->IsCached() || is_caching_) {
171     return Status(StatusCode::kMDNotImplementedYet, __LINE__, __FILE__,
172                   "There is currently no support for GeneratorOp under cache.");
173   }
174   return Status::OK();
175 }
176 #endif
177 
178 // constructor
CacheTransformPass()179 CacheTransformPass::CacheTransformPass() {}
180 
181 // Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations
RunOnTree(std::shared_ptr<DatasetNode> root_ir,bool * const modified)182 Status CacheTransformPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
183   MS_LOG(INFO) << "Pre pass: Cache transform pass started.";
184   // Create the cache pass and run it.  The cache pass identifies and creates the leaf/cache pairs that we will
185   // use to execute a transform.
186   CachePass cache_pass = CachePass();
187   RETURN_IF_NOT_OK(cache_pass.Run(root_ir, modified));
188 
189   // Execute the transform for non-mappable cache
190   for (auto cached_node : cache_pass.cached_nodes()) {
191     MS_LOG(DEBUG) << "Cache transform pass: Injecting a non-mappable cache node.";
192     RETURN_IF_NOT_OK(InjectNonMappableCacheNode(cached_node, cache_pass.sampler()));
193   }
194 
195   // Execute the transform for mappable cache
196   for (auto cache_pair : cache_pass.cache_pairs()) {
197     MS_LOG(DEBUG) << "Cache transform pass: Injecting a mappable cache node.";
198     RETURN_IF_NOT_OK(InjectMappableCacheNode(cache_pair.first, cache_pair.second));
199   }
200   MS_LOG(INFO) << "Pre pass: Cache transform pass complete.";
201   return Status::OK();
202 }
203 
204 // Helper function to execute mappable cache transformation.
205 // Input tree:
206 //   Sampler
207 //     |
208 //   LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
209 // Transformed tree:
210 //   Sampler --> CacheLookupNode ------------------------->
211 //                       |                                |
212 //                       |                           CacheMergeNode
213 //                       |                                |
214 //                       LeafNode --> OtherNodes --> CachedNode
InjectMappableCacheNode(std::shared_ptr<MappableSourceNode> leaf_node,std::shared_ptr<DatasetNode> cached_node)215 Status CacheTransformPass::InjectMappableCacheNode(std::shared_ptr<MappableSourceNode> leaf_node,
216                                                    std::shared_ptr<DatasetNode> cached_node) {
217   // Create a cache merge node with defaults
218   auto cache_merge_node = std::make_shared<CacheMergeNode>(nullptr, cached_node->GetDatasetCache());
219   // Insert the cache merge node to become the cached_node's parent
220   RETURN_IF_NOT_OK(cached_node->InsertAbove(cache_merge_node));
221 
222   // Extract the sampler from the leaf.  We will overwrite this sampler with the lookup op later.
223   std::shared_ptr<SamplerObj> leaf_sampler = leaf_node->Sampler();
224   // Create a cache lookup node with leaf_node's sampler
225   auto cache_lookup_node = std::make_shared<CacheLookupNode>(nullptr, leaf_sampler, cached_node->GetDatasetCache());
226   // Insert the cache lookup node as the first child of cache merge node
227   RETURN_IF_NOT_OK(cache_merge_node->InsertChildAt(0, cache_lookup_node));
228   // Overwrite the old sampler in this leaf node to become the cache lookup node
229   leaf_node->SetSampler(std::static_pointer_cast<SamplerObj>(cache_lookup_node));
230   return Status::OK();
231 }
232 
233 // Helper function to execute non-mappable cache transformation.
234 // Input tree:
235 //   LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
236 // Transformed tree:
237 //                                               Sampler
238 //                                                  |
239 //   LeafNode --> OtherNodes --> CachedNode --> CacheNode
InjectNonMappableCacheNode(std::shared_ptr<DatasetNode> cached_node,std::shared_ptr<SamplerObj> sampler)240 Status CacheTransformPass::InjectNonMappableCacheNode(std::shared_ptr<DatasetNode> cached_node,
241                                                       std::shared_ptr<SamplerObj> sampler) {
242   // Create a cache node using the sampler we saved from the leaf
243   auto cache_node = std::make_shared<CacheNode>(nullptr, sampler, cached_node->GetDatasetCache());
244   // Insert the cache node to become the cached_node's parent
245   RETURN_IF_NOT_OK(cached_node->InsertAbove(cache_node));
246   return Status::OK();
247 }
248 }  // namespace dataset
249 }  // namespace mindspore
250