• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/opt/pre/cache_validation_pass.h"
18 
19 #include "minddata/dataset/engine/ir/datasetops/batch_node.h"
20 #include "minddata/dataset/engine/ir/datasetops/concat_node.h"
21 #include "minddata/dataset/engine/ir/datasetops/filter_node.h"
22 #include "minddata/dataset/engine/ir/datasetops/map_node.h"
23 #include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
24 #include "minddata/dataset/engine/ir/datasetops/skip_node.h"
25 #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
26 #include "minddata/dataset/engine/ir/datasetops/take_node.h"
27 #include "minddata/dataset/engine/ir/datasetops/zip_node.h"
28 #include "minddata/dataset/kernels/ir/tensor_operation.h"
29 
30 namespace mindspore {
31 namespace dataset {
32 
33 // Constructor
CacheValidationPass()34 CacheValidationPass::CacheValidationPass() : is_cached_(false), is_mappable_(false) {}
35 
36 // Returns an error if BatchNode exists under a cache
Visit(std::shared_ptr<BatchNode> node,bool * const modified)37 Status CacheValidationPass::Visit(std::shared_ptr<BatchNode> node, bool *const modified) {
38   MS_LOG(DEBUG) << "CacheValidationPass::Visit(<BatchNode>): visiting " << node->Name() << ".";
39   if (is_cached_) {
40     RETURN_STATUS_UNEXPECTED("BatchNode is not supported as a descendant operator under a cache.");
41   }
42   if (node->IsCached()) {
43     RETURN_STATUS_UNEXPECTED("BatchNode cannot be cached.");
44   }
45   return Status::OK();
46 }
47 
48 // Returns an error if ConcatNode exists under a cache
Visit(std::shared_ptr<ConcatNode> node,bool * const modified)49 Status CacheValidationPass::Visit(std::shared_ptr<ConcatNode> node, bool *const modified) {
50   MS_LOG(DEBUG) << "CacheValidationPass::Visit(<ConcatNode>): visiting " << node->Name() << ".";
51   if (is_cached_) {
52     RETURN_STATUS_UNEXPECTED("ConcatNode is not supported as a descendant operator under a cache.");
53   }
54   if (node->IsCached()) {
55     RETURN_STATUS_UNEXPECTED("ConcatNode cannot be cached.");
56   }
57   return Status::OK();
58 }
59 
60 // Returns an error if FilterNode exists under a cache
Visit(std::shared_ptr<FilterNode> node,bool * const modified)61 Status CacheValidationPass::Visit(std::shared_ptr<FilterNode> node, bool *const modified) {
62   MS_LOG(DEBUG) << "CacheValidationPass::Visit(<FilterNode>): visiting " << node->Name() << ".";
63   if (is_cached_) {
64     RETURN_STATUS_UNEXPECTED("FilterNode is not supported as a descendant operator under a cache.");
65   }
66   if (node->IsCached()) {
67     RETURN_STATUS_UNEXPECTED("FilterNode cannot be cached.");
68   }
69   return Status::OK();
70 }
71 
72 // Returns an error if SkipNode exists under a cache
Visit(std::shared_ptr<SkipNode> node,bool * const modified)73 Status CacheValidationPass::Visit(std::shared_ptr<SkipNode> node, bool *const modified) {
74   MS_LOG(DEBUG) << "CacheValidationPass::Visit(<SkipNode>): visiting " << node->Name() << ".";
75   if (is_cached_) {
76     RETURN_STATUS_UNEXPECTED("SkipNode is not supported as a descendant operator under a cache.");
77   }
78   if (node->IsCached()) {
79     RETURN_STATUS_UNEXPECTED("SkipNode cannot be cached.");
80   }
81   return Status::OK();
82 }
83 
84 // Returns an error if TakeNode exists under a cache
Visit(std::shared_ptr<TakeNode> node,bool * const modified)85 Status CacheValidationPass::Visit(std::shared_ptr<TakeNode> node, bool *const modified) {
86   MS_LOG(DEBUG) << "CacheValidationPass::Visit(<TakeNode>): visiting " << node->Name() << ".";
87   if (is_cached_) {
88     RETURN_STATUS_UNEXPECTED("TakeNode (possibly from Split) is not supported as a descendant operator under a cache.");
89   }
90   if (node->IsCached()) {
91     RETURN_STATUS_UNEXPECTED("TakeNode cannot be cached.");
92   }
93   return Status::OK();
94 }
95 
96 // Returns an error if ZipNode exists under a cache
Visit(std::shared_ptr<ZipNode> node,bool * const modified)97 Status CacheValidationPass::Visit(std::shared_ptr<ZipNode> node, bool *const modified) {
98   MS_LOG(DEBUG) << "CacheValidationPass::Visit(<ZipNode>): visiting " << node->Name() << ".";
99   if (is_cached_) {
100     RETURN_STATUS_UNEXPECTED("ZipNode is not supported as a descendant operator under a cache.");
101   }
102   if (node->IsCached()) {
103     RETURN_STATUS_UNEXPECTED("ZipNode cannot be cached.");
104   }
105   return Status::OK();
106 }
107 
108 // Returns an error if MapNode with non-deterministic tensor operations exists under a cache
Visit(std::shared_ptr<MapNode> node,bool * const modified)109 Status CacheValidationPass::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
110   MS_LOG(DEBUG) << "CacheValidationPass::Visit(<MapNode>): visiting " << node->Name() << ".";
111   if (node->IsCached()) {
112     if (is_cached_) {
113       RETURN_STATUS_UNEXPECTED("Nested cache operations over MapNode is not supported.");
114     }
115     // If Map is created to be cached, set the flag indicating we found an operation with a cache.
116     is_cached_ = true;
117 
118     // This is temporary code.
119     // Because the randomness of its tensor operations is not known in TensorOperation form until we convert them
120     // to TensorOp, we need to check the randomness in MapNode::Build().
121     // By setting this MapNode is under a cache, we will check the randomness of its tensor operations without the need
122     // to walk the IR tree again.
123     node->HasCacheAbove();
124 
125     auto tfuncs = node->TensorOperations();
126     for (size_t i = 0; i < tfuncs.size(); i++) {
127       if (tfuncs[i]->IsRandomOp()) {
128         RETURN_STATUS_UNEXPECTED("MapNode containing random operation is not supported as a descendant of cache.");
129       }
130     }
131   }
132   return Status::OK();
133 }
134 
135 // Flag an error if we have a cache over another cache
Visit(std::shared_ptr<DatasetNode> node,bool * const modified)136 Status CacheValidationPass::Visit(std::shared_ptr<DatasetNode> node, bool *const modified) {
137   MS_LOG(DEBUG) << "CacheValidationPass::Visit(<DatasetNode>): visiting " << node->Name() << ".";
138   if (node->IsCached()) {
139     if (is_cached_) {
140       RETURN_STATUS_UNEXPECTED("Nested cache operations over " + node->Name() + " is not supported.");
141     }
142     // If this node is created to be cached, set the flag.
143     is_cached_ = true;
144   }
145   if (node->IsLeaf() && node->IsMappableDataSource()) {
146     is_mappable_ = true;
147   }
148   return Status::OK();
149 }
150 
151 // Returns an error if MappableSource <- Repeat <- Node with a cache
152 // Because there is no operator in the cache hit stream to consume EoEs, caching above repeat causes problem.
VisitAfter(std::shared_ptr<RepeatNode> node,bool * const modified)153 Status CacheValidationPass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) {
154   MS_LOG(DEBUG) << "CacheValidationPass::VisitAfter(<RepeatNode>): visiting " << node->Name() << ".";
155   if (is_cached_ && is_mappable_) {
156     RETURN_STATUS_UNEXPECTED("A cache over a RepeatNode of a mappable dataset is not supported.");
157   }
158   return Status::OK();
159 }
160 
VisitAfter(std::shared_ptr<TFRecordNode> node,bool * const modified)161 Status CacheValidationPass::VisitAfter(std::shared_ptr<TFRecordNode> node, bool *const modified) {
162   MS_LOG(DEBUG) << "CacheValidationPass::VisitAfter(<TFRecordNode>): visiting " << node->Name() << ".";
163   if (!is_cached_) {
164     // If we are not in a cache path, then we must validate the file-based sharding config.
165     // If we are in a cache path, there is no file-based sharding so the check is not required.
166     if (!node->shard_equal_rows_ && node->dataset_files_.size() < static_cast<uint32_t>(node->num_shards_)) {
167       RETURN_STATUS_UNEXPECTED("Invalid file, not enough tfrecord files provided.\n");
168     }
169   }
170   // Reset the flag when this node is cached and is already visited
171   if (node->IsCached()) {
172     is_cached_ = false;
173   }
174   return Status::OK();
175 }
176 
VisitAfter(std::shared_ptr<DatasetNode> node,bool * const modified)177 Status CacheValidationPass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) {
178   MS_LOG(DEBUG) << "CacheValidationPass::VisitAfter(<DatasetNode>): visiting " << node->Name() << ".";
179   // Reset the flag when all descendants are visited
180   if (node->IsCached()) {
181     is_cached_ = false;
182   }
183   return Status::OK();
184 }
185 }  // namespace dataset
186 }  // namespace mindspore
187