1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/dataset_utils.h"
17 #include "tensorflow/core/data/name_utils.h"
18 #include "tensorflow/core/data/serialization_utils.h"
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/tensor_util.h"
21 #include "tensorflow/core/grappler/graph_view.h"
22 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
23 #include "tensorflow/core/platform/stringprintf.h"
24
25 namespace tensorflow {
26 namespace data {
27 namespace experimental {
28 namespace {
29
30 using grappler::graph_utils::GetScalarConstNodeValue;
31
32 constexpr char kMapAndBatchOp[] = "MapAndBatchDataset";
33 constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset";
34
35 constexpr std::array<const char*, 4> kBatchDatasetOps = {
36 "BatchDataset",
37 "PaddedBatchDataset",
38 kMapAndBatchOp,
39 kExperimentalMapAndBatchOp,
40 };
41
42 constexpr std::array<const char*, 2> kMultipleInputDatasetOps = {
43 "ConcatenateDataset",
44 "ZipDataset",
45 };
46
47 constexpr std::array<const char*, 16> kPassThroughOps = {
48 "AssertCardinalityDataset",
49 "CacheDataset",
50 "FilterDataset",
51 "FinalizeDataset",
52 "Identity",
53 "ModelDataset",
54 "OptimizeDataset",
55 "OptionsDataset",
56 "ParseExampleDataset",
57 "PrefetchDataset",
58 "RepeatDataset",
59 "ShardDataset",
60 "ShuffleAndRepeatDataset",
61 "ShuffleDataset",
62 "SkipDataset",
63 "TakeDataset",
64 };
65
66 template <std::size_t SIZE>
IsDatasetNodeOfType(const NodeDef & node,const std::array<const char *,SIZE> & arr)67 bool IsDatasetNodeOfType(const NodeDef& node,
68 const std::array<const char*, SIZE>& arr) {
69 for (const auto& dataset_op : arr) {
70 if (MatchesAnyVersion(dataset_op, node.op())) return true;
71 }
72 return false;
73 }
74
GetInputNode(const NodeDef & node,const grappler::GraphView & graph,int64_t input_index)75 const NodeDef* GetInputNode(const NodeDef& node,
76 const grappler::GraphView& graph,
77 int64_t input_index) {
78 if (node.input_size() == 0) return nullptr;
79 grappler::GraphView::InputPort input_port =
80 graph.GetInputPort(node.name(), input_index);
81 return graph.GetRegularFanin(input_port).node;
82 }
83
84 // TODO(rachelim): This op traverses the dataset graph using a allowlist-based
85 // approach. As an alternative, we could instead rewrite all batching datasets'
86 // drop_remainder parameter to True, then rerun the dataset graph to derive
87 // new output shapes using C++ shape inference. This is more robust in cases
88 // where datasets have shape inference implemented in C++. If this allowlist-
89 // based approach proves hard to maintain, consider doing the alternative.
90 class ComputeBatchSizeOp : public OpKernel {
91 public:
ComputeBatchSizeOp(OpKernelConstruction * ctx)92 explicit ComputeBatchSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
93
Compute(OpKernelContext * ctx)94 void Compute(OpKernelContext* ctx) override {
95 DatasetBase* dataset;
96 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
97
98 std::vector<std::pair<string, Tensor>> input_list;
99 GraphDef graph_def;
100 string dataset_node_name;
101 OP_REQUIRES_OK(ctx, AsGraphDefForRewrite(ctx, dataset, &input_list,
102 &graph_def, &dataset_node_name));
103
104 // Create GraphView for easier traversal of graph.
105 grappler::GraphView graph_view(&graph_def);
106
107 const NodeDef* node = graph_view.GetNode(dataset_node_name);
108 OP_REQUIRES(ctx, node != nullptr,
109 errors::InvalidArgument("Node does not exist in graph"));
110 int64_t batch_size = GetBatchSize(*node, graph_view);
111 Tensor* result;
112 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
113 result->scalar<int64_t>()() = batch_size;
114 }
115
116 private:
GetBatchSizeFromBatchNode(const NodeDef & node,const grappler::GraphView & graph)117 int64_t GetBatchSizeFromBatchNode(const NodeDef& node,
118 const grappler::GraphView& graph) {
119 int64_t arg_index;
120 if (node.op() == kMapAndBatchOp ||
121 node.op() == kExperimentalMapAndBatchOp) {
122 arg_index = node.input_size() - 3;
123 } else {
124 arg_index = 1;
125 }
126
127 auto batch_size_node = GetInputNode(node, graph, arg_index);
128 int64_t batch_size;
129 auto s = GetScalarConstNodeValue(*batch_size_node, &batch_size);
130 if (!s.ok()) {
131 VLOG(1) << "Could not compute static batch size. Found batching dataset ("
132 << node.name() << "), but failed to get its input batch size: "
133 << s.error_message();
134 return -1;
135 }
136 return batch_size;
137 }
138
139 // Helper function that returns the static 0th dimension of a given dataset
140 // node in the graph. It starts from a node in the graph and recursively
141 // traverses its inputs until it finds a valid BatchDataset operation,
142 // and returns its batch size. If the batch size cannot be determined,
143 // returns -1.
144 //
145 // During recursion, it handles four kinds of cases:
146 // 1. BatchDataset type ops: Returns the value from its batch_size input node.
147 // 2. Zip / Concatenate dataset ops: Recurses into all inputs to these ops,
148 // which are themselves all datasets, and returns the batch sizes computed
149 // by the inputs if they are all the same.
150 // 3. Core dataset ops which cannot change the size of the 0th dimension of
151 // dataset output elements: Recurses into the first input parameter.
152 // 4. All other ops: Fail, returning -1 for unknown.
153 // TODO(rachelim): For FlatMap type mapping dataset ops, recurse into the
154 // function definition.
GetBatchSize(const NodeDef & node,const grappler::GraphView & graph)155 int64_t GetBatchSize(const NodeDef& node, const grappler::GraphView& graph) {
156 if (IsDatasetNodeOfType(node, kBatchDatasetOps)) {
157 return GetBatchSizeFromBatchNode(node, graph);
158 }
159 if (IsDatasetNodeOfType(node, kMultipleInputDatasetOps)) {
160 const NodeDef* input_0 = GetInputNode(node, graph, 0);
161 int64_t batch_size_0 = GetBatchSize(*input_0, graph);
162 for (int i = 1; i < node.input_size(); ++i) {
163 const NodeDef* input = GetInputNode(node, graph, i);
164 auto batch_size_i = GetBatchSize(*input, graph);
165 if (batch_size_i != batch_size_0) {
166 VLOG(1) << "Could not compute batch size: inputs to " << node.name()
167 << " (" << node.op() << ") had different batch sizes."
168 << " Namely, input 0 had batch size " << batch_size_0
169 << " while input " << i << " had batch size " << batch_size_i
170 << ".";
171 return -1;
172 }
173 }
174 return batch_size_0;
175 }
176 if (IsDatasetNodeOfType(node, kPassThroughOps)) {
177 const NodeDef* input = GetInputNode(node, graph, 0);
178 return GetBatchSize(*input, graph);
179 }
180 VLOG(1) << "Encountered dataset node " << node.name() << " (" << node.op()
181 << ") that prevented further static batch size analysis.";
182
183 return -1;
184 }
185 };
186
187 REGISTER_KERNEL_BUILDER(Name("ComputeBatchSize").Device(DEVICE_CPU),
188 ComputeBatchSizeOp);
189
190 } // anonymous namespace
191 } // namespace experimental
192 } // namespace data
193 } // namespace tensorflow
194