1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/opt/pre/cache_validation_pass.h"
18
19 #include "minddata/dataset/engine/ir/datasetops/batch_node.h"
20 #include "minddata/dataset/engine/ir/datasetops/concat_node.h"
21 #include "minddata/dataset/engine/ir/datasetops/filter_node.h"
22 #include "minddata/dataset/engine/ir/datasetops/map_node.h"
23 #include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
24 #include "minddata/dataset/engine/ir/datasetops/skip_node.h"
25 #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
26 #include "minddata/dataset/engine/ir/datasetops/take_node.h"
27 #include "minddata/dataset/engine/ir/datasetops/zip_node.h"
28 #include "minddata/dataset/kernels/ir/tensor_operation.h"
29
30 namespace mindspore {
31 namespace dataset {
32
33 // Constructor
CacheValidationPass()34 CacheValidationPass::CacheValidationPass() : is_cached_(false), is_mappable_(false) {}
35
36 // Returns an error if BatchNode exists under a cache
Visit(std::shared_ptr<BatchNode> node,bool * const modified)37 Status CacheValidationPass::Visit(std::shared_ptr<BatchNode> node, bool *const modified) {
38 MS_LOG(DEBUG) << "CacheValidationPass::Visit(<BatchNode>): visiting " << node->Name() << ".";
39 if (is_cached_) {
40 RETURN_STATUS_UNEXPECTED("BatchNode is not supported as a descendant operator under a cache.");
41 }
42 if (node->IsCached()) {
43 RETURN_STATUS_UNEXPECTED("BatchNode cannot be cached.");
44 }
45 return Status::OK();
46 }
47
48 // Returns an error if ConcatNode exists under a cache
Visit(std::shared_ptr<ConcatNode> node,bool * const modified)49 Status CacheValidationPass::Visit(std::shared_ptr<ConcatNode> node, bool *const modified) {
50 MS_LOG(DEBUG) << "CacheValidationPass::Visit(<ConcatNode>): visiting " << node->Name() << ".";
51 if (is_cached_) {
52 RETURN_STATUS_UNEXPECTED("ConcatNode is not supported as a descendant operator under a cache.");
53 }
54 if (node->IsCached()) {
55 RETURN_STATUS_UNEXPECTED("ConcatNode cannot be cached.");
56 }
57 return Status::OK();
58 }
59
60 // Returns an error if FilterNode exists under a cache
Visit(std::shared_ptr<FilterNode> node,bool * const modified)61 Status CacheValidationPass::Visit(std::shared_ptr<FilterNode> node, bool *const modified) {
62 MS_LOG(DEBUG) << "CacheValidationPass::Visit(<FilterNode>): visiting " << node->Name() << ".";
63 if (is_cached_) {
64 RETURN_STATUS_UNEXPECTED("FilterNode is not supported as a descendant operator under a cache.");
65 }
66 if (node->IsCached()) {
67 RETURN_STATUS_UNEXPECTED("FilterNode cannot be cached.");
68 }
69 return Status::OK();
70 }
71
72 // Returns an error if SkipNode exists under a cache
Visit(std::shared_ptr<SkipNode> node,bool * const modified)73 Status CacheValidationPass::Visit(std::shared_ptr<SkipNode> node, bool *const modified) {
74 MS_LOG(DEBUG) << "CacheValidationPass::Visit(<SkipNode>): visiting " << node->Name() << ".";
75 if (is_cached_) {
76 RETURN_STATUS_UNEXPECTED("SkipNode is not supported as a descendant operator under a cache.");
77 }
78 if (node->IsCached()) {
79 RETURN_STATUS_UNEXPECTED("SkipNode cannot be cached.");
80 }
81 return Status::OK();
82 }
83
84 // Returns an error if TakeNode exists under a cache
Visit(std::shared_ptr<TakeNode> node,bool * const modified)85 Status CacheValidationPass::Visit(std::shared_ptr<TakeNode> node, bool *const modified) {
86 MS_LOG(DEBUG) << "CacheValidationPass::Visit(<TakeNode>): visiting " << node->Name() << ".";
87 if (is_cached_) {
88 RETURN_STATUS_UNEXPECTED("TakeNode (possibly from Split) is not supported as a descendant operator under a cache.");
89 }
90 if (node->IsCached()) {
91 RETURN_STATUS_UNEXPECTED("TakeNode cannot be cached.");
92 }
93 return Status::OK();
94 }
95
96 // Returns an error if ZipNode exists under a cache
Visit(std::shared_ptr<ZipNode> node,bool * const modified)97 Status CacheValidationPass::Visit(std::shared_ptr<ZipNode> node, bool *const modified) {
98 MS_LOG(DEBUG) << "CacheValidationPass::Visit(<ZipNode>): visiting " << node->Name() << ".";
99 if (is_cached_) {
100 RETURN_STATUS_UNEXPECTED("ZipNode is not supported as a descendant operator under a cache.");
101 }
102 if (node->IsCached()) {
103 RETURN_STATUS_UNEXPECTED("ZipNode cannot be cached.");
104 }
105 return Status::OK();
106 }
107
108 // Returns an error if MapNode with non-deterministic tensor operations exists under a cache
Visit(std::shared_ptr<MapNode> node,bool * const modified)109 Status CacheValidationPass::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
110 MS_LOG(DEBUG) << "CacheValidationPass::Visit(<MapNode>): visiting " << node->Name() << ".";
111 if (node->IsCached()) {
112 if (is_cached_) {
113 RETURN_STATUS_UNEXPECTED("Nested cache operations over MapNode is not supported.");
114 }
115 // If Map is created to be cached, set the flag indicating we found an operation with a cache.
116 is_cached_ = true;
117
118 // This is temporary code.
119 // Because the randomness of its tensor operations is not known in TensorOperation form until we convert them
120 // to TensorOp, we need to check the randomness in MapNode::Build().
121 // By setting this MapNode is under a cache, we will check the randomness of its tensor operations without the need
122 // to walk the IR tree again.
123 node->HasCacheAbove();
124
125 auto tfuncs = node->TensorOperations();
126 for (size_t i = 0; i < tfuncs.size(); i++) {
127 if (tfuncs[i]->IsRandomOp()) {
128 RETURN_STATUS_UNEXPECTED("MapNode containing random operation is not supported as a descendant of cache.");
129 }
130 }
131 }
132 return Status::OK();
133 }
134
135 // Flag an error if we have a cache over another cache
Visit(std::shared_ptr<DatasetNode> node,bool * const modified)136 Status CacheValidationPass::Visit(std::shared_ptr<DatasetNode> node, bool *const modified) {
137 MS_LOG(DEBUG) << "CacheValidationPass::Visit(<DatasetNode>): visiting " << node->Name() << ".";
138 if (node->IsCached()) {
139 if (is_cached_) {
140 RETURN_STATUS_UNEXPECTED("Nested cache operations over " + node->Name() + " is not supported.");
141 }
142 // If this node is created to be cached, set the flag.
143 is_cached_ = true;
144 }
145 if (node->IsLeaf() && node->IsMappableDataSource()) {
146 is_mappable_ = true;
147 }
148 return Status::OK();
149 }
150
151 // Returns an error if MappableSource <- Repeat <- Node with a cache
152 // Because there is no operator in the cache hit stream to consume EoEs, caching above repeat causes problem.
VisitAfter(std::shared_ptr<RepeatNode> node,bool * const modified)153 Status CacheValidationPass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) {
154 MS_LOG(DEBUG) << "CacheValidationPass::VisitAfter(<RepeatNode>): visiting " << node->Name() << ".";
155 if (is_cached_ && is_mappable_) {
156 RETURN_STATUS_UNEXPECTED("A cache over a RepeatNode of a mappable dataset is not supported.");
157 }
158 return Status::OK();
159 }
160
VisitAfter(std::shared_ptr<TFRecordNode> node,bool * const modified)161 Status CacheValidationPass::VisitAfter(std::shared_ptr<TFRecordNode> node, bool *const modified) {
162 MS_LOG(DEBUG) << "CacheValidationPass::VisitAfter(<TFRecordNode>): visiting " << node->Name() << ".";
163 if (!is_cached_) {
164 // If we are not in a cache path, then we must validate the file-based sharding config.
165 // If we are in a cache path, there is no file-based sharding so the check is not required.
166 if (!node->shard_equal_rows_ && node->dataset_files_.size() < static_cast<uint32_t>(node->num_shards_)) {
167 RETURN_STATUS_UNEXPECTED("Invalid file, not enough tfrecord files provided.\n");
168 }
169 }
170 // Reset the flag when this node is cached and is already visited
171 if (node->IsCached()) {
172 is_cached_ = false;
173 }
174 return Status::OK();
175 }
176
VisitAfter(std::shared_ptr<DatasetNode> node,bool * const modified)177 Status CacheValidationPass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) {
178 MS_LOG(DEBUG) << "CacheValidationPass::VisitAfter(<DatasetNode>): visiting " << node->Name() << ".";
179 // Reset the flag when all descendants are visited
180 if (node->IsCached()) {
181 is_cached_ = false;
182 }
183 return Status::OK();
184 }
185 } // namespace dataset
186 } // namespace mindspore
187