1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <vector>
18 #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
19 #include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
20 #include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h"
21 #include "minddata/dataset/engine/ir/datasetops/cache_node.h"
22 #ifdef ENABLE_PYTHON
23 #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
24 #endif
25 #ifndef ENABLE_ANDROID
26 #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
27 #endif
28 #include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
29
30 namespace mindspore {
31 namespace dataset {
32
33 // Constructor
CachePass()34 CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_node_(nullptr), sampler_(nullptr) {}
35
36 // Identifies the subtree below this node as a cached descendant tree.
37 // Note that this function will only get called on non-leaf nodes.
38 // For leaf nodes, the other Visit with NonMappableSourceNode or MappableSourceNode argument will be called instead.
Visit(std::shared_ptr<DatasetNode> node,bool * const modified)39 Status CacheTransformPass::CachePass::Visit(std::shared_ptr<DatasetNode> node, bool *const modified) {
40 *modified = false;
41 if (node->IsCached()) {
42 MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
43 is_caching_ = true;
44 }
45 return Status::OK();
46 }
47
48 // Resets the tracking of the cache within the tree and assigns the nodes that will be involved in a cache
49 // transformation
VisitAfter(std::shared_ptr<DatasetNode> node,bool * const modified)50 Status CacheTransformPass::CachePass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) {
51 *modified = false;
52 if (node->IsCached()) {
53 is_caching_ = false; // We a no longer in a cache subtree. clear the flag.
54 if (leaf_node_) {
55 MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache.";
56 // Assign the leaf node into the transform pass, using move to null our copy of it,
57 // and also assign the cached node, using base class pointers.
58 // In the cases where cache is directly injected after the leaf node, these two nodes might be the same.
59 cache_pairs_.push_back(std::make_pair(std::move(leaf_node_), node));
60 } else {
61 // If there was no leaf_node_ set, then this is a non-mappable scenario.
62 // We only assign the cached node in this case.
63 cached_nodes_.push_back(node);
64 }
65 }
66
67 return Status::OK();
68 }
69
70 #ifndef ENABLE_ANDROID
71 // Perform leaf node cache transform identification
Visit(std::shared_ptr<NonMappableSourceNode> node,bool * const modified)72 Status CacheTransformPass::CachePass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) {
73 if (node->IsCached()) {
74 MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
75 is_caching_ = true;
76 }
77 // Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true
78 // by the other Visit() with DatasetNode argument
79 if (is_caching_) {
80 MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
81 // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
82 if (leaf_node_) {
83 return Status(StatusCode::kMDNotImplementedYet, __LINE__, __FILE__,
84 "There is currently no support for multiple leaf nodes under cache.");
85 }
86 // Set up a sampler here to be used by cache if we are a non-mappable leaf in a caching tree.
87 // Node that sampler for non mappable dataset only works if there is a downstream cache.
88 RETURN_IF_NOT_OK(node->SetupSamplerForCache(&sampler_));
89 // If we are a non-mappable source node in a caching tree, then change our config so that it becomes a basic
90 // source node that parses all files. Selection of data will come from the sampler on the cache instead.
91 RETURN_IF_NOT_OK(node->MakeSimpleProducer());
92 }
93 return Status::OK();
94 }
95 #endif
96
97 // Almost the same with NonMappableSourceNode's Visit, only this one is not guarded by the compiler
98 // directive #ifndef ENABLE_ANDROID, also and there is no need to call MakeSimpleProducer() because
99 // RandomOp doesn't support sampling or sharding
Visit(std::shared_ptr<RandomNode> node,bool * const modified)100 Status CacheTransformPass::CachePass::Visit(std::shared_ptr<RandomNode> node, bool *const modified) {
101 if (node->IsCached()) {
102 MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
103 is_caching_ = true;
104 }
105 // Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true
106 // by the other Visit() with DatasetNode argument
107 if (is_caching_) {
108 MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
109 // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
110 if (leaf_node_) {
111 return Status(StatusCode::kMDNotImplementedYet, __LINE__, __FILE__,
112 "There is currently no support for multiple leaf nodes under cache.");
113 }
114 // Set up a sampler here to be used by cache if we are a non-mappable leaf in a caching tree.
115 // Node that sampler for non mappable dataset only works if there is a downstream cache.
116 RETURN_IF_NOT_OK(node->SetupSamplerForCache(&sampler_));
117 }
118 return Status::OK();
119 }
120
121 // Perform leaf node cache transform identification
Visit(std::shared_ptr<MappableSourceNode> node,bool * const modified)122 Status CacheTransformPass::CachePass::Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) {
123 if (node->IsCached()) {
124 MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
125 is_caching_ = true;
126 }
127 // Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true
128 // by the other Visit() with DatasetNode argument
129 if (is_caching_) {
130 MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected";
131 // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
132 if (leaf_node_) {
133 return Status(StatusCode::kMDNotImplementedYet, __LINE__, __FILE__,
134 "There is currently no support for multiple leaf nodes under cache.");
135 }
136 // If we are a leaf in the caching path, then save this leaf
137 leaf_node_ = node;
138 }
139 return Status::OK();
140 }
141
142 #ifndef ENABLE_ANDROID
143 // Almost the same with MappableSourceNode's Visit, only in this one we also marked this node's descendant_of_cache_
144 // field to true. Later when building, MindDataNode will take different actions based on this information.
Visit(std::shared_ptr<MindDataNode> node,bool * const modified)145 Status CacheTransformPass::CachePass::Visit(std::shared_ptr<MindDataNode> node, bool *const modified) {
146 if (node->IsCached()) {
147 MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
148 is_caching_ = true;
149 }
150 // Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true
151 // by the other Visit() with DatasetNode argument
152 if (is_caching_) {
153 node->HasCacheAbove();
154 MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected";
155 // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
156 if (leaf_node_) {
157 return Status(StatusCode::kMDNotImplementedYet, __LINE__, __FILE__,
158 "There is currently no support for multiple leaf nodes under cache.");
159 }
160 // If we are a leaf in the caching path, then save this leaf
161 leaf_node_ = node;
162 }
163 return Status::OK();
164 }
165 #endif
166
167 #ifdef ENABLE_PYTHON
168 // Perform leaf node cache transform identification
Visit(std::shared_ptr<GeneratorNode> node,bool * const modified)169 Status CacheTransformPass::CachePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) {
170 if (node->IsCached() || is_caching_) {
171 return Status(StatusCode::kMDNotImplementedYet, __LINE__, __FILE__,
172 "There is currently no support for GeneratorOp under cache.");
173 }
174 return Status::OK();
175 }
176 #endif
177
178 // constructor
CacheTransformPass()179 CacheTransformPass::CacheTransformPass() {}
180
181 // Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations
RunOnTree(std::shared_ptr<DatasetNode> root_ir,bool * const modified)182 Status CacheTransformPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
183 MS_LOG(INFO) << "Pre pass: Cache transform pass started.";
184 // Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will
185 // use to execute a transform.
186 CachePass cache_pass = CachePass();
187 RETURN_IF_NOT_OK(cache_pass.Run(root_ir, modified));
188
189 // Execute the transform for non-mappable cache
190 for (auto cached_node : cache_pass.cached_nodes()) {
191 MS_LOG(DEBUG) << "Cache transform pass: Injecting a non-mappable cache node.";
192 RETURN_IF_NOT_OK(InjectNonMappableCacheNode(cached_node, cache_pass.sampler()));
193 }
194
195 // Execute the transform for mappable cache
196 for (auto cache_pair : cache_pass.cache_pairs()) {
197 MS_LOG(DEBUG) << "Cache transform pass: Injecting a mappable cache node.";
198 RETURN_IF_NOT_OK(InjectMappableCacheNode(cache_pair.first, cache_pair.second));
199 }
200 MS_LOG(INFO) << "Pre pass: Cache transform pass complete.";
201 return Status::OK();
202 }
203
204 // Helper function to execute mappable cache transformation.
205 // Input tree:
206 // Sampler
207 // |
208 // LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
209 // Transformed tree:
210 // Sampler --> CacheLookupNode ------------------------->
211 // | |
212 // | CacheMergeNode
213 // | |
214 // LeafNode --> OtherNodes --> CachedNode
InjectMappableCacheNode(std::shared_ptr<MappableSourceNode> leaf_node,std::shared_ptr<DatasetNode> cached_node)215 Status CacheTransformPass::InjectMappableCacheNode(std::shared_ptr<MappableSourceNode> leaf_node,
216 std::shared_ptr<DatasetNode> cached_node) {
217 // Create a cache merge node with defaults
218 auto cache_merge_node = std::make_shared<CacheMergeNode>(nullptr, cached_node->GetDatasetCache());
219 // Insert the cache merge node to become the cached_node's parent
220 RETURN_IF_NOT_OK(cached_node->InsertAbove(cache_merge_node));
221
222 // Extract the sampler from the leaf. We will overwrite this sampler with the lookup op later.
223 std::shared_ptr<SamplerObj> leaf_sampler = leaf_node->Sampler();
224 // Create a cache lookup node with leaf_node's sampler
225 auto cache_lookup_node = std::make_shared<CacheLookupNode>(nullptr, leaf_sampler, cached_node->GetDatasetCache());
226 // Insert the cache lookup node as the first child of cache merge node
227 RETURN_IF_NOT_OK(cache_merge_node->InsertChildAt(0, cache_lookup_node));
228 // Overwrite the old sampler in this leaf node to become the cache lookup node
229 leaf_node->SetSampler(std::static_pointer_cast<SamplerObj>(cache_lookup_node));
230 return Status::OK();
231 }
232
233 // Helper function to execute non-mappable cache transformation.
234 // Input tree:
235 // LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
236 // Transformed tree:
237 // Sampler
238 // |
239 // LeafNode --> OtherNodes --> CachedNode --> CacheNode
InjectNonMappableCacheNode(std::shared_ptr<DatasetNode> cached_node,std::shared_ptr<SamplerObj> sampler)240 Status CacheTransformPass::InjectNonMappableCacheNode(std::shared_ptr<DatasetNode> cached_node,
241 std::shared_ptr<SamplerObj> sampler) {
242 // Create a cache node using the sampler we saved from the leaf
243 auto cache_node = std::make_shared<CacheNode>(nullptr, sampler, cached_node->GetDatasetCache());
244 // Insert the cache node to become the cached_node's parent
245 RETURN_IF_NOT_OK(cached_node->InsertAbove(cache_node));
246 return Status::OK();
247 }
248 } // namespace dataset
249 } // namespace mindspore
250