• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/dataset_utils.h"
17 #include "tensorflow/core/data/name_utils.h"
18 #include "tensorflow/core/data/serialization_utils.h"
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/tensor_util.h"
21 #include "tensorflow/core/grappler/graph_view.h"
22 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
23 #include "tensorflow/core/platform/stringprintf.h"
24 
25 namespace tensorflow {
26 namespace data {
27 namespace experimental {
28 namespace {
29 
30 using grappler::graph_utils::GetScalarConstNodeValue;
31 
32 constexpr char kMapAndBatchOp[] = "MapAndBatchDataset";
33 constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset";
34 
35 constexpr std::array<const char*, 4> kBatchDatasetOps = {
36     "BatchDataset",
37     "PaddedBatchDataset",
38     kMapAndBatchOp,
39     kExperimentalMapAndBatchOp,
40 };
41 
42 constexpr std::array<const char*, 2> kMultipleInputDatasetOps = {
43     "ConcatenateDataset",
44     "ZipDataset",
45 };
46 
47 constexpr std::array<const char*, 16> kPassThroughOps = {
48     "AssertCardinalityDataset",
49     "CacheDataset",
50     "FilterDataset",
51     "FinalizeDataset",
52     "Identity",
53     "ModelDataset",
54     "OptimizeDataset",
55     "OptionsDataset",
56     "ParseExampleDataset",
57     "PrefetchDataset",
58     "RepeatDataset",
59     "ShardDataset",
60     "ShuffleAndRepeatDataset",
61     "ShuffleDataset",
62     "SkipDataset",
63     "TakeDataset",
64 };
65 
66 template <std::size_t SIZE>
IsDatasetNodeOfType(const NodeDef & node,const std::array<const char *,SIZE> & arr)67 bool IsDatasetNodeOfType(const NodeDef& node,
68                          const std::array<const char*, SIZE>& arr) {
69   for (const auto& dataset_op : arr) {
70     if (MatchesAnyVersion(dataset_op, node.op())) return true;
71   }
72   return false;
73 }
74 
GetInputNode(const NodeDef & node,const grappler::GraphView & graph,int64_t input_index)75 const NodeDef* GetInputNode(const NodeDef& node,
76                             const grappler::GraphView& graph,
77                             int64_t input_index) {
78   if (node.input_size() == 0) return nullptr;
79   grappler::GraphView::InputPort input_port =
80       graph.GetInputPort(node.name(), input_index);
81   return graph.GetRegularFanin(input_port).node;
82 }
83 
84 // TODO(rachelim): This op traverses the dataset graph using a allowlist-based
85 // approach. As an alternative, we could instead rewrite all batching datasets'
86 // drop_remainder parameter to True, then rerun the dataset graph to derive
87 // new output shapes using C++ shape inference. This is more robust in cases
88 // where datasets have shape inference implemented in C++. If this allowlist-
89 // based approach proves hard to maintain, consider doing the alternative.
90 class ComputeBatchSizeOp : public OpKernel {
91  public:
ComputeBatchSizeOp(OpKernelConstruction * ctx)92   explicit ComputeBatchSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
93 
Compute(OpKernelContext * ctx)94   void Compute(OpKernelContext* ctx) override {
95     DatasetBase* dataset;
96     OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
97 
98     std::vector<std::pair<string, Tensor>> input_list;
99     GraphDef graph_def;
100     string dataset_node_name;
101     OP_REQUIRES_OK(ctx, AsGraphDefForRewrite(ctx, dataset, &input_list,
102                                              &graph_def, &dataset_node_name));
103 
104     // Create GraphView for easier traversal of graph.
105     grappler::GraphView graph_view(&graph_def);
106 
107     const NodeDef* node = graph_view.GetNode(dataset_node_name);
108     OP_REQUIRES(ctx, node != nullptr,
109                 errors::InvalidArgument("Node does not exist in graph"));
110     int64_t batch_size = GetBatchSize(*node, graph_view);
111     Tensor* result;
112     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
113     result->scalar<int64_t>()() = batch_size;
114   }
115 
116  private:
GetBatchSizeFromBatchNode(const NodeDef & node,const grappler::GraphView & graph)117   int64_t GetBatchSizeFromBatchNode(const NodeDef& node,
118                                     const grappler::GraphView& graph) {
119     int64_t arg_index;
120     if (node.op() == kMapAndBatchOp ||
121         node.op() == kExperimentalMapAndBatchOp) {
122       arg_index = node.input_size() - 3;
123     } else {
124       arg_index = 1;
125     }
126 
127     auto batch_size_node = GetInputNode(node, graph, arg_index);
128     int64_t batch_size;
129     auto s = GetScalarConstNodeValue(*batch_size_node, &batch_size);
130     if (!s.ok()) {
131       VLOG(1) << "Could not compute static batch size. Found batching dataset ("
132               << node.name() << "), but failed to get its input batch size: "
133               << s.error_message();
134       return -1;
135     }
136     return batch_size;
137   }
138 
139   // Helper function that returns the static 0th dimension of a given dataset
140   // node in the graph. It starts from a node in the graph and recursively
141   // traverses its inputs until it finds a valid BatchDataset operation,
142   // and returns its batch size. If the batch size cannot be determined,
143   // returns -1.
144   //
145   // During recursion, it handles four kinds of cases:
146   // 1. BatchDataset type ops: Returns the value from its batch_size input node.
147   // 2. Zip / Concatenate dataset ops: Recurses into all inputs to these ops,
148   //    which are themselves all datasets, and returns the batch sizes computed
149   //    by the inputs if they are all the same.
150   // 3. Core dataset ops which cannot change the size of the 0th dimension of
151   //    dataset output elements: Recurses into the first input parameter.
152   // 4. All other ops: Fail, returning -1 for unknown.
153   // TODO(rachelim): For FlatMap type mapping dataset ops, recurse into the
154   // function definition.
GetBatchSize(const NodeDef & node,const grappler::GraphView & graph)155   int64_t GetBatchSize(const NodeDef& node, const grappler::GraphView& graph) {
156     if (IsDatasetNodeOfType(node, kBatchDatasetOps)) {
157       return GetBatchSizeFromBatchNode(node, graph);
158     }
159     if (IsDatasetNodeOfType(node, kMultipleInputDatasetOps)) {
160       const NodeDef* input_0 = GetInputNode(node, graph, 0);
161       int64_t batch_size_0 = GetBatchSize(*input_0, graph);
162       for (int i = 1; i < node.input_size(); ++i) {
163         const NodeDef* input = GetInputNode(node, graph, i);
164         auto batch_size_i = GetBatchSize(*input, graph);
165         if (batch_size_i != batch_size_0) {
166           VLOG(1) << "Could not compute batch size: inputs to " << node.name()
167                   << " (" << node.op() << ") had different batch sizes."
168                   << " Namely, input 0 had batch size " << batch_size_0
169                   << " while input " << i << " had batch size " << batch_size_i
170                   << ".";
171           return -1;
172         }
173       }
174       return batch_size_0;
175     }
176     if (IsDatasetNodeOfType(node, kPassThroughOps)) {
177       const NodeDef* input = GetInputNode(node, graph, 0);
178       return GetBatchSize(*input, graph);
179     }
180     VLOG(1) << "Encountered dataset node " << node.name() << " (" << node.op()
181             << ") that prevented further static batch size analysis.";
182 
183     return -1;
184   }
185 };
186 
187 REGISTER_KERNEL_BUILDER(Name("ComputeBatchSize").Device(DEVICE_CPU),
188                         ComputeBatchSizeOp);
189 
190 }  // anonymous namespace
191 }  // namespace experimental
192 }  // namespace data
193 }  // namespace tensorflow
194