1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/graph_properties.h"
17
18 #include "tensorflow/core/framework/common_shape_fns.h"
19 #include "tensorflow/core/framework/function.pb.h"
20 #include "tensorflow/core/framework/node_def_util.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/framework/tensor_shape.pb.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/graph/graph_constructor.h"
27 #include "tensorflow/core/graph/tensor_id.h"
28 #include "tensorflow/core/grappler/costs/utils.h"
29 #include "tensorflow/core/grappler/mutable_graph_view.h"
30 #include "tensorflow/core/grappler/op_types.h"
31 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
32 #include "tensorflow/core/grappler/utils.h"
33 #include "tensorflow/core/grappler/utils/functions.h"
34 #include "tensorflow/core/grappler/utils/topological_sort.h"
35 #include "tensorflow/core/lib/gtl/cleanup.h"
36 #include "tensorflow/core/lib/gtl/flatset.h"
37 #include "tensorflow/core/lib/strings/str_util.h"
38
39 namespace tensorflow {
40 namespace grappler {
41
42 namespace {
43
44 using shape_inference::DimensionHandle;
45 using shape_inference::InferenceContext;
46 using shape_inference::ShapeAndType;
47 using shape_inference::ShapeHandle;
48 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
49
50 template <typename Handle>
51 struct HashHandle {
operator ()tensorflow::grappler::__anon949966ef0111::HashHandle52 std::size_t operator()(const Handle& h) const { return h.Handle(); }
53 };
54 template <typename Handle>
55 struct CompareHandle {
operator ()tensorflow::grappler::__anon949966ef0111::CompareHandle56 bool operator()(const Handle& h1, const Handle& h2) const {
57 return h1.SameHandle(h2);
58 }
59 };
60
61 template <typename Handle>
62 struct HandleToObject {};
63 template <>
64 struct HandleToObject<ShapeHandle> {
65 typedef ShapeHandle Object;
66
Unknowntensorflow::grappler::__anon949966ef0111::HandleToObject67 static ShapeHandle Unknown() { return ShapeHandle(); }
68 };
69
70 template <>
71 struct HandleToObject<DimensionHandle> {
72 typedef int64 Object;
73
Unknowntensorflow::grappler::__anon949966ef0111::HandleToObject74 static int64 Unknown() { return -1; }
75 };
76
77 template <typename Handle>
78 struct Processor {};
79
80 template <>
81 struct Processor<ShapeHandle> {
82 // Extract the shape or dim denoted by the handle.
ExtractValuetensorflow::grappler::__anon949966ef0111::Processor83 void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; }
84 // Merge the shapes or dims.
Mergetensorflow::grappler::__anon949966ef0111::Processor85 Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) {
86 if (InferenceContext::RankKnown(*result)) {
87 // The result was initialized in a previous merge to a shape of known
88 // rank, make sure we preserve that information.
89 return Status::OK();
90 }
91 if (InferenceContext::RankKnown(h1)) {
92 *result = h1;
93 } else {
94 *result = h2;
95 }
96 return Status::OK();
97 }
98 };
99
100 template <>
101 struct Processor<DimensionHandle> {
102 // Assign a negative id to unknown dimensions, starting at -2 (the -1 id
103 // reserved by TensorFlow).
ExtractValuetensorflow::grappler::__anon949966ef0111::Processor104 void ExtractValue(DimensionHandle d, int64* result) {
105 if (!InferenceContext::ValueKnown(d)) {
106 *result = -counter;
107 counter++;
108 } else {
109 int64 val = InferenceContext::Value(d);
110 if (val >= 0) {
111 *result = val;
112 } else {
113 // A shape inference function generated an invalid dimension handle.
114 // Use a symbolic dimension to encode this.
115 *result = -counter;
116 counter++;
117 }
118 }
119 }
120
121 // Merge the dimensions d1 and d2. Return the known shape if there is one,
122 // otherwise look for a symbolic shape. If there is no symbolic shape and no
123 // known shape, the shape if fully unknown so return -1.
Mergetensorflow::grappler::__anon949966ef0111::Processor124 Status Merge(DimensionHandle d1, DimensionHandle d2, int64* result) {
125 const int64 dim1 = InferenceContext::Value(d1);
126 const int64 dim2 = InferenceContext::Value(d2);
127
128 if (dim1 >= 0 && dim2 >= 0) {
129 CHECK_EQ(dim1, dim2);
130 return RefineDim(dim1, result);
131 } else if (dim1 >= 0 && dim2 < 0) {
132 return RefineDim(dim1, result);
133 } else if (dim1 < 0 && dim2 >= 0) {
134 return RefineDim(dim2, result);
135 } else if (dim1 < -1) {
136 return RefineDim(dim1, result);
137 } else if (dim2 < -1) {
138 return RefineDim(dim2, result);
139 } else {
140 CHECK_EQ(dim1, dim2);
141 CHECK_EQ(-1, dim1);
142 return RefineDim(-1, result);
143 }
144 return Status::OK();
145 }
146
147 private:
RefineDimtensorflow::grappler::__anon949966ef0111::Processor148 Status RefineDim(int64 dim, int64* result) {
149 if (*result >= 0) {
150 if (!(*result == dim || dim < 0)) {
151 return errors::InvalidArgument("Inconsistent dimensions detected");
152 }
153 } else if (dim >= 0) {
154 *result = dim;
155 } else if (dim < *result) {
156 *result = dim;
157 }
158 return Status::OK();
159 }
160
161 int64 counter = 2;
162 };
163
164 // Traditional Disjoint-Set datastructure with path compression.
165 // (https://en.wikipedia.org/wiki/Disjoint-set_data_structure)
166 template <typename Handle>
167 class DisjointSet {
168 public:
DisjointSet()169 DisjointSet() {}
~DisjointSet()170 ~DisjointSet() {
171 for (auto rep : nodes_) {
172 delete rep.second;
173 }
174 }
175
176 Status Merge(Handle x, Handle y);
177 const typename HandleToObject<Handle>::Object GetMergedValue(Handle value);
178
179 private:
180 // All the handles that belong to the same set are part of the same tree, and
181 // utimately represented by the root of that tree.
182 struct Rep {
183 // Parent in the tree used to encode the set.
184 Rep* parent;
185 // Rank in the tree, used to figure out how to compress the path to the root
186 // of the tree.
187 int rank;
188 // The handle.
189 typename HandleToObject<Handle>::Object value;
190 };
191
192 // Create a new set for the value if none exists, or return its representative
193 // node otherwise.
194 Rep* Find(Handle value);
195
196 private:
197 Processor<Handle> processor_;
198 std::unordered_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>>
199 nodes_;
200 };
201
202 template <typename Handle>
203 const typename HandleToObject<Handle>::Object
GetMergedValue(Handle value)204 DisjointSet<Handle>::GetMergedValue(Handle value) {
205 Rep* rep = Find(value);
206 if (!rep) {
207 // We don't know anything about this handle.
208 return HandleToObject<Handle>::Unknown();
209 }
210 return rep->value;
211 }
212
213 template <typename Handle>
Merge(Handle x,Handle y)214 Status DisjointSet<Handle>::Merge(Handle x, Handle y) {
215 Rep* x_root = Find(x);
216 Rep* y_root = Find(y);
217
218 // x and y are already in the same set
219 if (x_root == y_root) {
220 return Status::OK();
221 }
222 // x and y are not in same set, so we merge them
223 // Use the occasion to strengthen what we know about the handle by merging the
224 // information about the 2 subsets.
225 if (x_root->rank < y_root->rank) {
226 TF_RETURN_IF_ERROR(processor_.Merge(y, x, &y_root->value));
227 x_root->parent = y_root;
228 } else if (x_root->rank > y_root->rank) {
229 TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
230 y_root->parent = x_root;
231 } else {
232 TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
233 // Arbitrarily make one root the new parent
234 y_root->parent = x_root;
235 x_root->rank = x_root->rank + 1;
236 }
237 return Status::OK();
238 }
239
240 template <typename Handle>
Find(Handle value)241 typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
242 auto it = nodes_.find(value);
243 if (it == nodes_.end()) {
244 // This is the first time we process this handle, create an entry for it.
245 Rep* node = new Rep;
246 node->parent = node;
247 node->rank = 0;
248 processor_.ExtractValue(value, &node->value);
249 nodes_[value] = node;
250 return node;
251 }
252 // Return the representative for the set, which is the root of the tree. Apply
253 // path compression to speedup future queries.
254 Rep* node = it->second;
255 Rep* root = node->parent;
256 while (root != root->parent) {
257 root = root->parent;
258 }
259 while (node->parent != root) {
260 Rep* next = node->parent;
261 node->parent = root;
262 node = next;
263 }
264 return root;
265 }
266
267 // TODO(dyoon): Move many helper functions in this file (including those within
268 // SymbolicShapeRefiner class) to shared utils.
IsEnqueue(const NodeDef & n)269 bool IsEnqueue(const NodeDef& n) {
270 return (n.op().find("Enqueue") != string::npos &&
271 n.op().find("EnqueueMany") == string::npos);
272 }
273
IsDequeue(const NodeDef & n)274 bool IsDequeue(const NodeDef& n) {
275 return (n.op().find("Dequeue") != string::npos &&
276 n.op().find("DequeueMany") == string::npos);
277 }
278
HasAnyUnknownDimensions(const TensorShapeProto & proto)279 bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
280 if (proto.unknown_rank()) {
281 return true;
282 }
283 for (const auto& dim : proto.dim()) {
284 if (dim.size() < 0) {
285 return true;
286 }
287 }
288 return false;
289 }
290
291 // This really should be done in an external debugging tool
VerboseLogUnknownDimensionSources(const GraphDef & graph,const std::unordered_map<string,std::vector<OpInfo::TensorProperties>> & input_properties_map,const std::unordered_map<string,std::vector<OpInfo::TensorProperties>> & output_properties_map)292 void VerboseLogUnknownDimensionSources(
293 const GraphDef& graph,
294 const std::unordered_map<string, std::vector<OpInfo::TensorProperties>>&
295 input_properties_map,
296 const std::unordered_map<string, std::vector<OpInfo::TensorProperties>>&
297 output_properties_map) {
298 if (!VLOG_IS_ON(2)) {
299 return;
300 }
301
302 VLOG(2) << "Nodes with known inputs, but with unknown output dimensions:";
303
304 // Find all nodes in the graph for which we
305 // do not have any unknown dimensions in their inputs, but
306 // we have some unknown dimensions in their outputs.
307 std::map<string, int> op_to_count;
308 for (const NodeDef& node : graph.node()) {
309 const auto& input_properties = input_properties_map.at(node.name());
310 const auto& output_properties = output_properties_map.at(node.name());
311
312 bool has_unknown_inputs = false;
313 for (const auto& input_prop : input_properties) {
314 if (HasAnyUnknownDimensions(input_prop.shape())) {
315 has_unknown_inputs = true;
316 break;
317 }
318 }
319
320 if (has_unknown_inputs) {
321 continue;
322 }
323
324 for (const auto& output_prop : output_properties) {
325 if (HasAnyUnknownDimensions(output_prop.shape())) {
326 string inputs = "input_shapes=[";
327 for (const auto& input_prop : input_properties) {
328 inputs += PartialTensorShape::DebugString(input_prop.shape());
329 }
330 inputs += "]";
331
332 string outputs = "output_shapes=[";
333 for (const auto& output_prop : output_properties) {
334 outputs += PartialTensorShape::DebugString(output_prop.shape());
335 }
336 outputs += "]";
337
338 VLOG(2) << "Node: " << node.name() << ", Op: " << node.op() << ", "
339 << inputs << ", " << outputs;
340
341 op_to_count[node.op()]++;
342
343 // don't log again for this node
344 break;
345 }
346 }
347 }
348 VLOG(2) << "Op types with known inputs, but with unknown output dimensions "
349 << "(format: <op_type> (<count>)):";
350 for (const auto& p : op_to_count) {
351 VLOG(2) << p.first << " (" << p.second << ")";
352 }
353 }
354
IsShapeFullyDefinedIntegerVectorOrScalar(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)355 bool IsShapeFullyDefinedIntegerVectorOrScalar(
356 InferenceContext* ic, const ShapeHandle& shape,
357 const ShapeHandle& tensor_as_shape, const DataType& dtype) {
358 if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 ||
359 !ic->FullyDefined(tensor_as_shape) ||
360 (dtype != DT_INT32 && dtype != DT_INT64)) {
361 return false;
362 }
363 return true;
364 }
365
366 // Returned tensor's shape is like `shape`, and its values and dtype are from
367 // `tensor_as_shape` and `dtype`.
MakeTensorProtoFromShape(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)368 TensorProto MakeTensorProtoFromShape(InferenceContext* ic,
369 const ShapeHandle& shape,
370 const ShapeHandle& tensor_as_shape,
371 const DataType& dtype) {
372 TensorProto tensor_proto;
373 tensor_proto.set_dtype(dtype);
374 auto* shape_proto = tensor_proto.mutable_tensor_shape();
375 if (ic->Rank(shape) == 1) {
376 shape_proto->add_dim()->set_size(ic->Rank(tensor_as_shape));
377 }
378 // For a scalar tensor, tensor_shape field will be left empty; no dim.
379 for (int i = 0; i < ic->Rank(tensor_as_shape); i++) {
380 int64 value = ic->Value(ic->Dim(tensor_as_shape, i));
381 if (dtype == DT_INT32) {
382 tensor_proto.add_int_val(value);
383 } else {
384 tensor_proto.add_int64_val(value);
385 }
386 }
387 return tensor_proto;
388 }
389
390 // Returns a Const NodeDef with tensor `tensor_proto` and dtype = `dtype`.
MakeConstNodeDefFromTensorProto(InferenceContext * ic,const TensorProto & tensor_proto,const DataType & dtype)391 NodeDef MakeConstNodeDefFromTensorProto(InferenceContext* ic,
392 const TensorProto& tensor_proto,
393 const DataType& dtype) {
394 NodeDef const_node;
395 const_node.set_name("const_from_shape");
396 const_node.set_op("Const");
397 auto* attr = const_node.mutable_attr();
398 (*attr)["dtype"].set_type(dtype);
399 auto* tensor = (*attr)["value"].mutable_tensor();
400 *tensor = tensor_proto;
401 return const_node;
402 }
403
404 // Returns a Const NodeDef with shape = `shape`, values = `tensor_as_shape`,
405 // and dtype = `dtype`.
MakeConstNodeDefFromShape(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)406 NodeDef MakeConstNodeDefFromShape(InferenceContext* ic,
407 const ShapeHandle& shape,
408 const ShapeHandle& tensor_as_shape,
409 const DataType& dtype) {
410 return MakeConstNodeDefFromTensorProto(
411 ic, MakeTensorProtoFromShape(ic, shape, tensor_as_shape, dtype), dtype);
412 }
413
414 } // namespace
415
416 // Queue of nodes to process. Nodes can be enqueued in any order, but will be
417 // dequeued in (roughly) topological order. Propagating shapes following a
418 // topological ordering isn't required for correctness but helps speed things up
419 // since it avoids processing the same node multiple times as its inputs
420 // information is refined.
421 class TopoQueue {
422 public:
TopoQueue(const std::vector<const NodeDef * > & topo_order)423 explicit TopoQueue(const std::vector<const NodeDef*>& topo_order)
424 : topo_order_(TopoOrder(topo_order)) {}
425
push(const NodeDef * n)426 void push(const NodeDef* n) { queue_.emplace(n, topo_order_.at(n)); }
427
pop()428 const NodeDef* pop() {
429 CHECK(!empty());
430 auto it = queue_.begin();
431 const NodeDef* n = it->first;
432 queue_.erase(it);
433 return n;
434 }
435
empty() const436 bool empty() const { return queue_.empty(); }
size() const437 std::size_t size() const { return queue_.size(); }
438
439 private:
440 using NodeAndId = std::pair<const NodeDef*, int>;
441 // Graph nodes are created in (roughly) topological order. Therefore we can
442 // use their id to ensure they're sorted topologically.
443 struct OrderByIdAscending {
operator ()tensorflow::grappler::TopoQueue::OrderByIdAscending444 bool operator()(const NodeAndId& lhs, const NodeAndId& rhs) const {
445 return lhs.second < rhs.second;
446 }
447 };
448
TopoOrder(const std::vector<const NodeDef * > & topo_order) const449 const std::unordered_map<const NodeDef*, int> TopoOrder(
450 const std::vector<const NodeDef*>& topo_order) const {
451 std::unordered_map<const NodeDef*, int> map;
452 map.reserve(topo_order.size());
453 for (int i = 0; i < topo_order.size(); ++i) {
454 map.emplace(topo_order[i], i);
455 }
456 return map;
457 }
458
459 const std::unordered_map<const NodeDef*, int> topo_order_;
460 std::set<NodeAndId, OrderByIdAscending> queue_;
461 };
462
IsNumericType(const DataType dtype)463 bool IsNumericType(const DataType dtype) {
464 static const gtl::FlatSet<DataType>* const kRealNumberTypes =
465 CHECK_NOTNULL((new gtl::FlatSet<DataType>{
466 // Floating point.
467 DT_BFLOAT16,
468 DT_HALF,
469 DT_FLOAT,
470 DT_DOUBLE,
471 // Int / UInt.
472 DT_INT8,
473 DT_INT16,
474 DT_INT32,
475 DT_INT64,
476 DT_UINT8,
477 DT_UINT16,
478 DT_UINT32,
479 DT_UINT64,
480 // Quantized Int.
481 DT_QINT8,
482 DT_QUINT8,
483 DT_QINT16,
484 DT_QUINT16,
485 DT_QINT32,
486 // Bool.
487 DT_BOOL,
488 }));
489 return kRealNumberTypes->find(dtype) != kRealNumberTypes->end();
490 }
491
IsWhiteListedOpTypeForEvaluateNode(const string & op_type)492 bool IsWhiteListedOpTypeForEvaluateNode(const string& op_type) {
493 static const gtl::FlatSet<string>* const kOpTpeWhitelist =
494 CHECK_NOTNULL((new gtl::FlatSet<string>{
495 // Unary arithmetic ops
496 "Floor",
497 "Round",
498 "Sqrt",
499 "Square",
500 "Sign",
501 // Binary arithmetic ops
502 "Add",
503 "Div",
504 "FloorDiv",
505 "FloorMod",
506 "Greater",
507 "GreaterEqual",
508 "Less",
509 "LessEqual",
510 "LogicalAnd",
511 "LogicalNot",
512 "LogicalOr",
513 "Maximum",
514 "Minimum",
515 "Mod",
516 "Mul",
517 "NotEqual",
518 "QuantizedAdd",
519 "QuantizedMul",
520 "SquareDifference",
521 "Sub",
522 "TruncateDiv",
523 "TruncateMod",
524 "RealDiv",
525 // N-ary arithemtic ops
526 "AddN",
527 // Others
528 "StridedSlice",
529 "OnesLike",
530 "ZerosLike",
531 "Concat",
532 "ConcatV2",
533 "Split",
534 "Range",
535 "Fill",
536 "Cast",
537 }));
538 return kOpTpeWhitelist->find(op_type) != kOpTpeWhitelist->end();
539 }
540
541 // Processes symbolic shapes.
542 // Each symbolic shape or dimension is represented by a handle. Unlike the TF
543 // shape refiner which creates new handles every time it processes an unknown
544 // shape/dimension, the symbolic shape refiner assigns a specific handle to each
545 // unknown shape/dimension of a given node.
546 class SymbolicShapeRefiner {
547 public:
SymbolicShapeRefiner(const GraphView & graph,const std::unordered_map<string,std::unordered_set<int>> & fed_ports,const bool aggressive_shape_inference)548 explicit SymbolicShapeRefiner(
549 const GraphView& graph,
550 const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
551 const bool aggressive_shape_inference)
552 : graph_(graph),
553 function_library_(OpRegistry::Global(), graph.graph()->library()),
554 fed_ports_(fed_ports),
555 aggressive_shape_inference_(aggressive_shape_inference) {
556 graph_def_version_ = graph.graph()->versions().producer();
557 node_to_context_.reserve(graph.graph()->node_size());
558 }
559
graph() const560 const GraphView& graph() const { return graph_; }
561
562 struct NodeContext {
563 const OpRegistrationData* op_data;
564 DataTypeVector input_types;
565 DataTypeVector output_types;
566 std::unique_ptr<InferenceContext> inference_context;
567 // Additional info for propagating tensor values and tensor shapes.
568 std::vector<const TensorProto*> input_tensor_protos;
569 std::vector<const TensorProto*> output_tensor_protos;
570 std::vector<ShapeHandle> output_tensors_as_shapes;
571 };
572
GetNodeContext(const NodeDef * node)573 NodeContext* GetNodeContext(const NodeDef* node) {
574 auto it = node_to_context_.find(node);
575 if (it == node_to_context_.end()) {
576 return nullptr;
577 }
578 return &it->second;
579 }
580
GetContext(const NodeDef * node)581 InferenceContext* GetContext(const NodeDef* node) {
582 auto it = node_to_context_.find(node);
583 if (it == node_to_context_.end()) {
584 return nullptr;
585 }
586 return it->second.inference_context.get();
587 }
588
589 // Forward the shapes from the function input nodes to
590 // the argument nodes (which are Placeholder nodes), then
591 // perform shape inference on the function body.
592 //
593 // Propagate shape information of final function body node
594 // to function node `function_node`.
595 //
596 // In the event of an error, UpdateNode will simply set `function_node`'s
597 // output shape to be Unknown.
UpdateFunction(const NodeDef * function_node)598 Status UpdateFunction(const NodeDef* function_node) {
599 auto it = fun_to_grappler_function_item_.find(function_node->op());
600 if (it == fun_to_grappler_function_item_.end()) {
601 return errors::InvalidArgument(
602 function_node->op(),
603 " was not previously added to SymbolicShapeRefiner.");
604 }
605
606 // Copy (not reference) so that changes we make here (e.g., replacing
607 // Placeholder with Const) don't affect one in
608 // fun_to_grappler_function_item_.
609 GrapplerFunctionItem grappler_function_item = it->second;
610 MutableGraphView gv(&grappler_function_item.graph);
611
612 // Forward shapes from function input nodes to argument nodes.
613 for (int i = 0; i < grappler_function_item.inputs().size(); ++i) {
614 auto& fun_input = grappler_function_item.input(i);
615 if (fun_input.placeholders.size() > 1) {
616 // TODO(jmdecker): Handle case with multiple input placeholders
617 return errors::Unimplemented(
618 "Input arguments with multiple placeholders are not yet "
619 "supported.");
620 }
621 NodeDef* fun_node = gv.GetNode(fun_input.input_name);
622 const TensorId input_tensor = ParseTensorName(function_node->input(i));
623
624 if (IsControlInput(input_tensor)) {
625 return errors::FailedPrecondition(
626 "Function inputs should not contain control nodes.");
627 }
628
629 const NodeDef* input_node = graph_.GetNode(input_tensor.node());
630 if (input_node == nullptr) {
631 return errors::FailedPrecondition(input_tensor.node(),
632 " was not found in the graph.");
633 }
634
635 InferenceContext* input_ic = GetContext(input_node);
636 if (input_ic == nullptr) {
637 return errors::FailedPrecondition(
638 "Inference context has not been created for ", input_tensor.node());
639 }
640
641 int output_port_num = input_tensor.index();
642 AttrValue attr_output_shape;
643 TensorShapeProto proto;
644 const auto& handle = input_ic->output(output_port_num);
645 input_ic->ShapeHandleToProto(handle, &proto);
646 // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
647 for (int i = 0; i < proto.dim_size(); i++) {
648 if (proto.dim(i).size() < -1) {
649 proto.mutable_dim(i)->set_size(-1);
650 }
651 }
652 *attr_output_shape.mutable_shape() = proto;
653 (*fun_node->mutable_attr())["shape"] = attr_output_shape;
654 }
655
656 // Replace input Placeholders with Consts, if values are known. Note that
657 // we don't check exceptions here as it's done in the above loop.
658 auto* ctx = GetNodeContext(function_node);
659 auto* ic = ctx->inference_context.get();
660 for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
661 const string& input = function_node->input(i);
662 const string& node_name = NodeName(input);
663 const NodeDef* input_node = graph_.GetNode(node_name);
664 if (IsConstant(*input_node)) {
665 TF_CHECK_OK(
666 ReplaceInputWithConst(*input_node, i, &grappler_function_item));
667 } else if (ctx->input_tensor_protos.size() > i &&
668 ctx->input_tensor_protos[i] != nullptr) {
669 NodeDef const_input_node = MakeConstNodeDefFromTensorProto(
670 ic, *ctx->input_tensor_protos[i], ctx->input_types[i]);
671 TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
672 &grappler_function_item));
673 } else if (ic->input_tensors_as_shapes().size() > i &&
674 IsShapeFullyDefinedIntegerVectorOrScalar(
675 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
676 ctx->input_types[i])) {
677 // We have fully defined input_tensors_as_shapes for this input; use it
678 // as a const input to the function node.
679 NodeDef const_input_node = MakeConstNodeDefFromShape(
680 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
681 ctx->input_types[i]);
682 TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
683 &grappler_function_item));
684 }
685 }
686
687 // Perform inference on function body.
688 GraphProperties gp(grappler_function_item);
689 TF_RETURN_IF_ERROR(gp.InferStatically(true, aggressive_shape_inference_));
690
691 // Add return nodes for output shapes.
692 int output = 0;
693 ctx->output_tensors_as_shapes.resize(grappler_function_item.output_size());
694 ctx->output_tensor_protos.resize(grappler_function_item.output_size(),
695 nullptr);
696 for (auto const& out_arg : grappler_function_item.outputs()) {
697 if (out_arg.output_nodes.size() > 1) {
698 // TODO(jmdecker): Handle case of multiple output tensors
699 return errors::Unimplemented(
700 "Output arguments with multiple output tensors are not yet "
701 "supported.");
702 }
703
704 // It is guaranteed that output_tensors does not contain any control
705 // inputs, so port_id >= 0.
706 TensorId out_tensor = ParseTensorName(out_arg.output_nodes[0]);
707
708 const NodeDef* retnode = gv.GetNode(out_tensor.node());
709 if (retnode == nullptr) {
710 return errors::FailedPrecondition(
711 "Unable to find return function_node ", out_tensor.node(), " for ",
712 function_node->name());
713 }
714
715 auto output_properties = gp.GetOutputProperties(retnode->name());
716 if (out_tensor.index() >= output_properties.size()) {
717 return errors::InvalidArgument(
718 out_tensor.ToString(), " has invalid position ", out_tensor.index(),
719 " (output_properties.size() = ", output_properties.size(), ").");
720 }
721 auto const& outprop = output_properties[out_tensor.index()];
722 const TensorShapeProto& shape = outprop.shape();
723 ShapeHandle out;
724 TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
725 ic->set_output(output, out);
726 if (outprop.has_value()) {
727 // Forward tensor value to output_tensors_as_shape.
728 Tensor tensor;
729 if (tensor.FromProto(outprop.value())) {
730 MaybeTensorValueToShape(ic, tensor,
731 &ctx->output_tensors_as_shapes[output]);
732 const_tensors_to_propagate_.push_back(outprop.value());
733 ctx->output_tensor_protos[output] =
734 &const_tensors_to_propagate_.back();
735 }
736 }
737 output++;
738 }
739
740 return Status::OK();
741 }
742
743 // Prepares input shapes/values/handles, then runs shape inference, and
744 // finally sets output shapes/values/handles.
UpdateNode(const NodeDef * node,bool * refined)745 Status UpdateNode(const NodeDef* node, bool* refined) {
746 NodeContext* ctx = GetNodeContext(node);
747 if (ctx == nullptr) {
748 TF_RETURN_IF_ERROR(AddNode(node));
749 ctx = CHECK_NOTNULL(GetNodeContext(node));
750 *refined = true;
751 }
752
753 // Check if the shapes of the nodes in the fan-in of this node have changed,
754 // and if they have, update the node input shapes.
755 InferenceContext* ic = ctx->inference_context.get();
756 std::vector<Tensor> const_values(ic->num_inputs());
757 std::vector<const Tensor*> input_tensors(ic->num_inputs(), nullptr);
758 std::vector<ShapeHandle> input_tensors_as_shapes(ic->num_inputs());
759 ctx->input_tensor_protos.resize(ic->num_inputs(), nullptr);
760
761 for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) {
762 const GraphView::InputPort port(node, dst_input);
763 const GraphView::OutputPort fanin = graph_.GetRegularFanin(port);
764 int src_output = fanin.port_id;
765 const NodeDef* src = fanin.node;
766 NodeContext* src_ctx = GetNodeContext(src);
767 InferenceContext* src_ic = src_ctx->inference_context.get();
768 if (src_ctx == nullptr) {
769 return errors::FailedPrecondition(
770 "Input ", dst_input, " ('", src->name(), "') for '", node->name(),
771 "' was not previously added to SymbolicShapeRefiner.");
772 }
773
774 if (src_output >= src_ic->num_outputs()) {
775 return errors::OutOfRange("src_output = ", src_output,
776 ", but num_outputs is only ",
777 src_ic->num_outputs());
778 }
779
780 // Propagate input node's NodeContext info to the current node's
781 // NodeContext:
782 // output_tensor_protos to input_tensor_protos and input_tensors, and
783 // output_tensors_as_shapes to input_tensors_as_shapes.
784
785 if (src_ctx->output_tensors_as_shapes.size() > src_output) {
786 input_tensors_as_shapes[dst_input] =
787 src_ctx->output_tensors_as_shapes[src_output];
788 }
789
790 if (src_ctx->output_tensor_protos.size() > src_output) {
791 auto* tensor_proto = src_ctx->output_tensor_protos[src_output];
792 if (tensor_proto != nullptr &&
793 const_values[dst_input].FromProto(*tensor_proto)) {
794 input_tensors[dst_input] = &const_values[dst_input];
795 ctx->input_tensor_protos[dst_input] = tensor_proto;
796
797 if (!ic->FullyDefined(input_tensors_as_shapes[dst_input])) {
798 // Shape from a Const is not fully defined when the Const has
799 // value -1 (e.g., Reshape(x, Const(-1)) to reshape an arbitrary
800 // tensor x to a vector).
801 // It's possible that the same Const with -1 is used in many
802 // places, but that doesn't mean the resultant shapes are
803 // identical. e.g., x1 = Reshape(x, c) and y1 = Reshape(y, c),
804 // where c is -1. In this case, shape inference yields both x1 and
805 // y1 as rank 1, size unknown, but still the shapes of x1 and y1
806 // can be different. (even if we use different Const(-1) for x1
807 // and x2, graph optimzier may merge them to single Const through
808 // duplicate removal.)
809 // If we reuse output_tensors_as_shapes to input_tensors_as_shapes
810 // by copying ShapeHandle, they share the same Shape object, and
811 // SymbolicShapeManager, later in InferStatically(), assigns the
812 // same symbolic dim value (unique value < -1); in the above
813 // Reshape example, the shapes of x1 and y1 become, for example,
814 // [-278] and graph optimizer may yield incorrect output 'cause it
815 // assumes x1 and y1 have the same shape.
816 // To prevent this, we re-create a ShapeHandle from the Const
817 // tensor, instead of reusing output_tensors_as_shapes (so that
818 // ShapeHandles of the const fanouts have the same values,
819 // but different Shape objects -- SymbolicShapeManager assigns
820 // different symbol id to each fanout shape).
821 // TODO(dyoon): clean up the way values are propagated.
822 MaybeTensorValueToShape(ic, const_values[dst_input],
823 &input_tensors_as_shapes[dst_input]);
824 }
825 }
826 }
827
828 // NOTE: we check only shape is refined; we do not (yet) check whether
829 // tensor value is refined.
830 if (!*refined &&
831 !ic->input(dst_input).SameHandle(src_ic->output(src_output))) {
832 *refined = true;
833 }
834 ic->SetInput(dst_input, src_ic->output(src_output));
835
836 if (!*refined && ic->requested_input_tensor_as_partial_shape(dst_input)) {
837 // The input value may have changed. Since we have no way to know if
838 // that's indeed the case, err on the safe side.
839 *refined = true;
840 }
841
842 // Also propagate handle shape and dtype of edges which are carrying
843 // resource handles.
844 if (ctx->input_types[dst_input] == DT_RESOURCE) {
845 auto* outputs = src_ic->output_handle_shapes_and_types(src_output);
846 if (!outputs) continue;
847 auto* inputs = ic->input_handle_shapes_and_types(dst_input);
848
849 if (!inputs || !EquivalentShapesAndTypes(*outputs, *inputs))
850 *refined = true;
851 ic->set_input_handle_shapes_and_types(dst_input, *outputs);
852 }
853 }
854
855 // Make sure we schedule the fanout of resources (which have no input)
856 // whenever the resources are updated.
857 *refined |= ic->num_inputs() == 0;
858
859 if (!*refined) {
860 // No input shape has changed, we're done.
861 return Status::OK();
862 }
863
864 ic->set_input_tensors(input_tensors);
865 ic->set_input_tensors_as_shapes(input_tensors_as_shapes);
866
867 // Properly handle function nodes.
868 if (ctx->op_data && ctx->op_data->is_function_op) {
869 // TODO(jmdecker): Detect if the input shapes have changed for this
870 // function. Note that when we hit a function call node, refined will be
871 // true, as the updates to the call node will have changed, even if it's
872 // the same function being called twice with the same input shapes.
873 // Example: simple_function.pbtxt
874 auto s = UpdateFunction(node);
875 if (s.ok()) {
876 return Status::OK();
877 } else {
878 VLOG(1) << "UpdateFunction failed for " << node->op()
879 << ". Defaulting to ShapeUnknown.\n"
880 << s.ToString();
881 }
882 }
883
884 // Update the shapes of the outputs.
885 return InferShapes(*node, ctx);
886 }
887
SetUnknownShape(const NodeDef * node,int output_port)888 Status SetUnknownShape(const NodeDef* node, int output_port) {
889 shape_inference::ShapeHandle shape =
890 GetUnknownOutputShape(node, output_port);
891 InferenceContext* ctx = GetContext(node);
892 if (ctx == nullptr) {
893 return errors::InvalidArgument("Missing context");
894 }
895 ctx->set_output(output_port, shape);
896 return Status::OK();
897 }
898
899 struct ShapeId {
900 const NodeDef* node;
901 int port_id;
operator ==tensorflow::grappler::SymbolicShapeRefiner::ShapeId902 bool operator==(const ShapeId& other) const {
903 return node == other.node && port_id == other.port_id;
904 }
905 };
906 struct HashShapeId {
operator ()tensorflow::grappler::SymbolicShapeRefiner::HashShapeId907 std::size_t operator()(const ShapeId& shp) const {
908 return std::hash<const NodeDef*>{}(shp.node) + shp.port_id;
909 }
910 };
911
912 struct DimId {
913 const NodeDef* node;
914 int port_id;
915 int dim_index;
operator ==tensorflow::grappler::SymbolicShapeRefiner::DimId916 bool operator==(const DimId& other) const {
917 return node == other.node && port_id == other.port_id &&
918 dim_index == other.dim_index;
919 }
920 };
921
922 struct HashDimId {
operator ()tensorflow::grappler::SymbolicShapeRefiner::HashDimId923 std::size_t operator()(const DimId& dim) const {
924 return std::hash<const NodeDef*>{}(dim.node) + dim.port_id +
925 dim.dim_index;
926 }
927 };
928
929 // 'port_index' as the union of shape1 and shape2.
OutputAsUnion(const NodeDef * node,int port_index,ShapeHandle shape1,ShapeHandle shape2)930 ShapeHandle OutputAsUnion(const NodeDef* node, int port_index,
931 ShapeHandle shape1, ShapeHandle shape2) {
932 if (shape1.SameHandle(shape2)) {
933 return shape1;
934 }
935 InferenceContext* ctx = GetContext(node);
936 ShapeHandle relaxed = shape1;
937 const int rank = ctx->Rank(shape1);
938 if (!ctx->RankKnown(shape2) || ctx->Rank(shape2) != rank) {
939 relaxed = GetUnknownOutputShape(node, port_index);
940 } else {
941 for (int d = 0; d < rank; ++d) {
942 if (!ctx->Dim(shape1, d).SameHandle(ctx->Dim(shape2, d))) {
943 int64 val1 = ctx->Value(ctx->Dim(shape1, d));
944 int64 val2 = ctx->Value(ctx->Dim(shape2, d));
945 if (val1 != val2 || (val1 < 0 && val2 < 0)) {
946 DimensionHandle new_dim = GetUnknownOutputDim(node, port_index, d);
947 TF_CHECK_OK(ctx->ReplaceDim(relaxed, d, new_dim, &relaxed));
948 }
949 }
950 }
951 }
952 return relaxed;
953 }
954
EquivalentShapes(ShapeHandle s1,ShapeHandle s2) const955 bool EquivalentShapes(ShapeHandle s1, ShapeHandle s2) const {
956 if (s1.SameHandle(s2)) {
957 return true;
958 }
959 if (InferenceContext::Rank(s1) != InferenceContext::Rank(s2)) {
960 return false;
961 }
962 if (!InferenceContext::RankKnown(s1) && !InferenceContext::RankKnown(s2)) {
963 return true;
964 }
965 const int rank = InferenceContext::Rank(s1);
966 for (int i = 0; i < rank; ++i) {
967 if (!InferenceContext::DimKnownRank(s1, i).SameHandle(
968 InferenceContext::DimKnownRank(s2, i))) {
969 int64 val1 =
970 InferenceContext::Value(InferenceContext::DimKnownRank(s1, i));
971 int64 val2 =
972 InferenceContext::Value(InferenceContext::DimKnownRank(s2, i));
973 if (val1 >= 0 && val2 >= 0 && val1 == val2) {
974 continue;
975 }
976 return false;
977 }
978 }
979 return true;
980 }
981
982 // Return true if the annotated shape is compatible with shape inference
983 // result. Examples:
984 // Inferred shape: ?, annotated shape: [10, 10] -> true;
985 // Inferred shape: [-1, 10], annotated shape: [10, 10] -> true;
986 // Inferred shape: [-1, 100], annotated shape: [10, 10] -> false;
987 // Inferred shape: [-1, 10, 10], annotated shape: [10, 10] -> false.
CompatibleShapes(ShapeHandle inferred_shape,ShapeHandle annotated_shape) const988 bool CompatibleShapes(ShapeHandle inferred_shape,
989 ShapeHandle annotated_shape) const {
990 if (inferred_shape.SameHandle(annotated_shape)) {
991 return true;
992 }
993 if (!InferenceContext::RankKnown(inferred_shape)) {
994 return true;
995 }
996 if (InferenceContext::Rank(inferred_shape) !=
997 InferenceContext::Rank(annotated_shape)) {
998 return false;
999 }
1000 const int rank = InferenceContext::Rank(inferred_shape);
1001 for (int i = 0; i < rank; ++i) {
1002 if (!InferenceContext::DimKnownRank(inferred_shape, i)
1003 .SameHandle(
1004 InferenceContext::DimKnownRank(annotated_shape, i))) {
1005 int64 val1 = InferenceContext::Value(
1006 InferenceContext::DimKnownRank(inferred_shape, i));
1007 int64 val2 = InferenceContext::Value(
1008 InferenceContext::DimKnownRank(annotated_shape, i));
1009 if (val1 >= 0 && val1 != val2) {
1010 return false;
1011 }
1012 }
1013 }
1014 return true;
1015 }
1016
EquivalentShapesAndTypes(const std::vector<ShapeAndType> & st1,const std::vector<ShapeAndType> & st2) const1017 bool EquivalentShapesAndTypes(const std::vector<ShapeAndType>& st1,
1018 const std::vector<ShapeAndType>& st2) const {
1019 if (st1.size() != st2.size()) {
1020 return false;
1021 }
1022 for (int i = 0; i < st1.size(); ++i) {
1023 const ShapeAndType& s1 = st1[i];
1024 const ShapeAndType& s2 = st2[i];
1025 if (s1.dtype != s2.dtype) {
1026 return false;
1027 }
1028 if (!EquivalentShapes(s1.shape, s2.shape)) {
1029 return false;
1030 }
1031 }
1032 return true;
1033 }
1034
AddFunction(const NodeDef * function_node)1035 Status AddFunction(const NodeDef* function_node) {
1036 auto it = fun_to_grappler_function_item_.find(function_node->op());
1037 if (it != fun_to_grappler_function_item_.end()) {
1038 return Status::OK();
1039 }
1040
1041 const FunctionDef* function_def =
1042 CHECK_NOTNULL(function_library_.Find(function_node->op()));
1043
1044 GrapplerFunctionItem grappler_function_item;
1045 TF_RETURN_IF_ERROR(
1046 MakeGrapplerFunctionItem(*function_def, function_library_,
1047 graph_def_version_, &grappler_function_item));
1048
1049 if (grappler_function_item.inputs().size() > function_node->input_size()) {
1050 return errors::FailedPrecondition(
1051 "Function input size should be smaller than node input size.");
1052 }
1053
1054 for (int i = grappler_function_item.inputs().size();
1055 i < function_node->input_size(); ++i) {
1056 const string& input = function_node->input(i);
1057 if (!IsControlInput(input)) {
1058 return errors::FailedPrecondition(
1059 "Found regular input (", input,
1060 ") instead of control nodes for node ", function_node->name());
1061 }
1062 }
1063
1064 fun_to_grappler_function_item_[function_def->signature().name()] =
1065 grappler_function_item;
1066
1067 return Status::OK();
1068 }
1069
AddNode(const NodeDef * node)1070 Status AddNode(const NodeDef* node) {
1071 NodeContext& node_ctx = node_to_context_[node];
1072 TF_RETURN_IF_ERROR(function_library_.LookUp(node->op(), &node_ctx.op_data));
1073
1074 if (node_ctx.op_data->is_function_op) {
1075 TF_RETURN_IF_ERROR(AddFunction(node));
1076 }
1077
1078 TF_RETURN_IF_ERROR(InOutTypesForNode(*node, node_ctx.op_data->op_def,
1079 &node_ctx.input_types,
1080 &node_ctx.output_types));
1081
1082 // Create the inference context for this node.
1083 const int num_inputs = node_ctx.input_types.size();
1084 std::vector<ShapeHandle> input_shapes(num_inputs);
1085 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
1086 input_handle_shapes_and_types(num_inputs);
1087 std::vector<const Tensor*> input_tensors(num_inputs, nullptr);
1088 std::vector<ShapeHandle> input_tensors_as_shapes;
1089
1090 node_ctx.inference_context.reset(new InferenceContext(
1091 graph_def_version_, node, node_ctx.op_data->op_def, input_shapes,
1092 input_tensors, input_tensors_as_shapes,
1093 std::move(input_handle_shapes_and_types)));
1094 const Status s = node_ctx.inference_context->construction_status();
1095 if (!s.ok()) {
1096 node_ctx.inference_context.reset(nullptr);
1097 }
1098 return s;
1099 }
1100
1101 private:
1102 // Return the one ShapeHandle used to denote a fully unknown shape for a node
1103 // output.
GetUnknownOutputShape(const NodeDef * node,int index)1104 ShapeHandle GetUnknownOutputShape(const NodeDef* node, int index) {
1105 ShapeId id{node, index};
1106 auto it = unknown_shapes_.find(id);
1107 if (it != unknown_shapes_.end()) {
1108 return it->second;
1109 }
1110 InferenceContext* c = GetContext(node);
1111 ShapeHandle shp = c->UnknownShape();
1112 unknown_shapes_[id] = shp;
1113 return shp;
1114 }
1115 // Return the one ShapeHandle used to denote a fully unknown dimension for a
1116 // node output.
GetUnknownOutputDim(const NodeDef * node,int index,int dim_id)1117 DimensionHandle GetUnknownOutputDim(const NodeDef* node, int index,
1118 int dim_id) {
1119 DimId id{node, index, dim_id};
1120 auto it = unknown_dims_.find(id);
1121 if (it != unknown_dims_.end()) {
1122 return it->second;
1123 }
1124 InferenceContext* c = GetContext(node);
1125 DimensionHandle dim = c->UnknownDim();
1126 unknown_dims_[id] = dim;
1127 return dim;
1128 }
1129
1130 // Returns true if all the output tensors have known values.
AllOutputValuesKnown(NodeContext * c)1131 bool AllOutputValuesKnown(NodeContext* c) {
1132 InferenceContext* ic = c->inference_context.get();
1133 if (c->output_tensors_as_shapes.size() < ic->num_outputs() &&
1134 c->output_tensor_protos.size() < ic->num_outputs()) {
1135 return false;
1136 } else {
1137 // Checks if we can get output value via either output_tensor_proto or
1138 // output_tensors_as_shapes.
1139 for (int i = 0; i < ic->num_outputs(); i++) {
1140 if (c->output_tensor_protos.size() > i &&
1141 c->output_tensor_protos[i] != nullptr) {
1142 continue;
1143 }
1144 if (c->output_tensors_as_shapes.size() > i &&
1145 ic->FullyDefined(c->output_tensors_as_shapes[i])) {
1146 continue;
1147 }
1148
1149 // Unknown for output[i].
1150 return false;
1151 }
1152 }
1153 return true;
1154 }
1155
1156 // Returns true if we can infer output tensors' values -- we know values of
1157 // all the input tensors.
AllInputValuesKnown(NodeContext * c)1158 bool AllInputValuesKnown(NodeContext* c) {
1159 InferenceContext* ic = c->inference_context.get();
1160
1161 // Check inputs are fully defined and values are known.
1162 for (int i = 0; i < ic->num_inputs(); i++) {
1163 const Tensor* tensor = ic->input_tensor(i);
1164 // Note that we don't check c->input_tensor_protos[i], as UpdateNode()
1165 // already converted it to ic->input_tensor(i);
1166 const ShapeHandle& input_tensors_as_shape =
1167 ic->input_tensors_as_shapes()[i];
1168 // Either input_tensor is valid or input_tensors_as_shape, which has
1169 // value of input tensors as shape format, should be fully defined.
1170 if (tensor == nullptr && !ic->FullyDefined(input_tensors_as_shape)) {
1171 return false;
1172 }
1173 }
1174 return true;
1175 }
1176
1177 // Returns true if we want to update output shapes and values with running
1178 // EvaluateNode() for this op, based on op type, data type, and size.
ShouldUpdateOutputShapesAndValues(NodeContext * c,int64 max_size)1179 bool ShouldUpdateOutputShapesAndValues(NodeContext* c, int64 max_size) {
1180 InferenceContext* ic = c->inference_context.get();
1181
1182 // Due to the cost of running EvaluateNode(), we limit only to white listed
1183 // op types.
1184 if (!IsWhiteListedOpTypeForEvaluateNode(c->op_data->op_def.name())) {
1185 return false;
1186 }
1187
1188 // Check input dtypes are number types.
1189 for (const auto& input_type : c->input_types) {
1190 if (!IsNumericType(input_type)) {
1191 return false;
1192 }
1193 }
1194
1195 // Check output dtypes are number types.
1196 for (const auto& output_type : c->output_types) {
1197 if (!IsNumericType(output_type)) {
1198 return false;
1199 }
1200 }
1201
1202 // Check if the number of elements of each of input tensor is no larger than
1203 // the given max size.
1204 for (int i = 0; i < ic->num_inputs(); i++) {
1205 const Tensor* tensor = ic->input_tensor(i);
1206 const ShapeHandle& input_shape_handle = ic->input(i);
1207 if (tensor != nullptr) {
1208 if (tensor->NumElements() > max_size) {
1209 return false;
1210 }
1211 } else if (ic->Value(ic->NumElements(input_shape_handle)) > max_size) {
1212 return false;
1213 }
1214 }
1215
1216 // Check if we know the shape of each output tensor, and the number of
1217 // elements is larger than the given max size.
1218 for (int i = 0; i < ic->num_outputs(); i++) {
1219 const ShapeHandle& shape_handle = ic->output(i);
1220 if (!ic->FullyDefined(shape_handle) ||
1221 ic->Value(ic->NumElements(shape_handle)) > max_size) {
1222 return false;
1223 }
1224 }
1225 return true;
1226 }
1227
1228 // Create input tensors from the NodeConext.
CreateInputTensors(NodeContext * c,std::vector<Tensor> * input_tensor_vector,TensorVector * inputs)1229 void CreateInputTensors(NodeContext* c,
1230 std::vector<Tensor>* input_tensor_vector,
1231 TensorVector* inputs) {
1232 InferenceContext* ic = c->inference_context.get();
1233 for (int i = 0; i < ic->num_inputs(); i++) {
1234 if (ic->input_tensor(i)) {
1235 input_tensor_vector->at(i) = *ic->input_tensor(i);
1236 inputs->emplace_back(&input_tensor_vector->at(i));
1237 // Note that we don't check c->input_tensor_protos[i], as UpdateNode()
1238 // already converted it to ic->input_tensor(i);
1239 } else {
1240 // Create Tensor from input_tensors_as_shapes, and then emplace it
1241 // back to inputs.
1242 // Note that input_tensors_as_shapes is scalar or vector.
1243 const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
1244 const DataType& data_type = c->input_types[i];
1245 int32 rank = ic->Rank(shape_handle);
1246 if (rank < 1) {
1247 input_tensor_vector->at(i) = Tensor(data_type, {});
1248 } else {
1249 input_tensor_vector->at(i) = Tensor(data_type, {rank});
1250 }
1251 auto* tensor = &input_tensor_vector->at(i);
1252 if (data_type == DT_INT32) {
1253 auto flat = tensor->flat<int32>();
1254 for (int j = 0; j < rank; j++) {
1255 int32 dim = ic->Value(ic->Dim(shape_handle, j));
1256 flat(j) = dim;
1257 }
1258 } else {
1259 auto flat = tensor->flat<int64>();
1260 for (int j = 0; j < rank; j++) {
1261 int64 dim = ic->Value(ic->Dim(shape_handle, j));
1262 flat(j) = dim;
1263 }
1264 }
1265 inputs->emplace_back(tensor);
1266 }
1267 }
1268 }
1269
1270 // Run a node to infer output shapes and values, and add it to the
1271 // NodeContext.
UpdateOutputShapesAndValues(const NodeDef & node,NodeContext * c)1272 Status UpdateOutputShapesAndValues(const NodeDef& node, NodeContext* c) {
1273 InferenceContext* ic = c->inference_context.get();
1274
1275 // Input to EvaluateNode()
1276 TensorVector inputs;
1277 // Container for temporaily created tensor object.
1278 std::vector<Tensor> input_tensor_vector(ic->num_inputs());
1279 CreateInputTensors(c, &input_tensor_vector, &inputs);
1280
1281 // Output for EvaluateNode() and output tensor clean up object.
1282 TensorVector outputs;
1283 auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
1284 for (const auto& output : outputs) {
1285 if (output.tensor) {
1286 delete output.tensor;
1287 }
1288 }
1289 });
1290
1291 TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, /*cpu_device=*/nullptr,
1292 &resource_mgr_, &outputs));
1293 c->output_tensors_as_shapes.resize(outputs.size());
1294 c->output_tensor_protos.resize(outputs.size(), nullptr);
1295 for (int k = 0; k < outputs.size(); k++) {
1296 const auto& t = outputs[k];
1297 // Override output shape.
1298 ShapeHandle output_shape;
1299 TF_RETURN_IF_ERROR(
1300 ic->MakeShapeFromTensorShape(t->shape(), &output_shape));
1301 if (ic->FullyDefined(ic->output(k)) &&
1302 !EquivalentShapes(ic->output(k), output_shape)) {
1303 LOG(WARNING) << "UpdateOutputShapesAndValues() -- node: " << node.name()
1304 << ", inferred output shape "
1305 << "doesn't match for k=" << k << ": "
1306 << "ic->output(k): " << ic->DebugString(ic->output(k))
1307 << ", output_shape: " << ic->DebugString(output_shape)
1308 << " -- " << node.DebugString();
1309 }
1310 ic->set_output(k, output_shape);
1311 // Set output_tensors_as_shape.
1312 MaybeTensorValueToShape(ic, *t.tensor, &c->output_tensors_as_shapes[k]);
1313
1314 // Set output_tensor_protos.
1315 TensorProto tensor_proto;
1316 t->AsProtoTensorContent(&tensor_proto);
1317 const_tensors_to_propagate_.push_back(tensor_proto);
1318 c->output_tensor_protos[k] = &const_tensors_to_propagate_.back();
1319 }
1320 return Status::OK();
1321 }
1322
1323 // Update output shapes with annotated information.
1324 // Currently only handle nodes with static shapes, i.e. shapes do not change
1325 // during execution.
1326 // TODO(andiryxu): Use annotated shapes in Enter/Merge etc as well.
UpdateOutputShapesUsingAnnotatedInformation(const NodeDef & node,NodeContext * c) const1327 Status UpdateOutputShapesUsingAnnotatedInformation(const NodeDef& node,
1328 NodeContext* c) const {
1329 const auto& attr = node.attr();
1330 if (attr.count(kOutputSame) == 0 || !attr.at(kOutputSame).b() ||
1331 attr.count(kOutputShapes) == 0)
1332 return Status::OK();
1333
1334 InferenceContext* ic = c->inference_context.get();
1335 int output_size = attr.at(kOutputShapes).list().shape_size();
1336
1337 for (int i = 0; i < ic->num_outputs(); i++) {
1338 // Annotated Switch node has only one output. Propagate the shape to all
1339 // the outputs.
1340 int shape_index = IsSwitch(node) ? 0 : i;
1341 if (shape_index >= output_size) {
1342 LOG(WARNING)
1343 << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1344 << node.name() << ", inferred output shape size "
1345 << ic->num_outputs() << ", annotated output shape size "
1346 << output_size;
1347 break;
1348 }
1349
1350 const TensorShapeProto& shape =
1351 attr.at(kOutputShapes).list().shape(shape_index);
1352 ShapeHandle output_shape;
1353 TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &output_shape));
1354
1355 // Only use annotated shapes if the inference shape is unknown and
1356 // compatible with annotated shapes.
1357 if (!ic->FullyDefined(ic->output(i)) &&
1358 CompatibleShapes(ic->output(i), output_shape)) {
1359 VLOG(3) << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1360 << node.name() << ", inferred output shape " << i << ": "
1361 << "ic->output(i): " << ic->DebugString(ic->output(i))
1362 << ", annotated output shape: " << ic->DebugString(output_shape)
1363 << " -- " << node.ShortDebugString();
1364 ic->set_output(i, output_shape);
1365 }
1366 }
1367
1368 return Status::OK();
1369 }
1370
MaybeUpdateNodeContextOutput(const NodeDef & node,const bool is_fed,NodeContext * c)1371 Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed,
1372 NodeContext* c) {
1373 // Propagate tensors and shape tensors unless the node is fed.
1374 // TODO(bsteiner) We should still propagate the shapes to the ports that
1375 // aren't fed in the case of a ShapeN node.
1376
1377 InferenceContext* ic = c->inference_context.get();
1378 if (!is_fed) {
1379 if (IsConstant(node)) {
1380 c->output_tensor_protos.resize(1);
1381 const TensorProto& tensor_proto = node.attr().at("value").tensor();
1382 c->output_tensor_protos[0] = &tensor_proto;
1383 c->output_tensors_as_shapes.resize(1);
1384 MaybeTensorProtoToShape(ic, tensor_proto,
1385 &c->output_tensors_as_shapes[0]);
1386 } else if (IsRank(node)) {
1387 if (ic->RankKnown(ic->input(0))) {
1388 // Propagate rank value.
1389 int32 rank = ic->Rank(ic->input(0));
1390 const_tensors_to_propagate_.push_back(
1391 MakeIntegerScalarTensorProto(DT_INT32, rank));
1392 c->output_tensor_protos.resize(1);
1393 c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
1394 }
1395 } else if (IsSize(node)) {
1396 DimensionHandle size = ic->NumElements(ic->input(0));
1397 if (ic->ValueKnown(size)) {
1398 // Propagate size value.
1399 int64 sz = ic->Value(size);
1400 bool valid = false;
1401 if (node.attr().at("out_type").type() == DT_INT32) {
1402 if (sz < std::numeric_limits<int32>::max()) {
1403 const_tensors_to_propagate_.push_back(
1404 MakeIntegerScalarTensorProto(DT_INT32, sz));
1405 valid = true;
1406 }
1407 } else {
1408 const_tensors_to_propagate_.push_back(
1409 MakeIntegerScalarTensorProto(DT_INT64, sz));
1410 valid = true;
1411 }
1412 if (valid) {
1413 c->output_tensor_protos.resize(1);
1414 c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
1415 }
1416 }
1417 } else if (IsShape(node)) {
1418 c->output_tensors_as_shapes.resize(1);
1419 c->output_tensors_as_shapes[0] = c->inference_context->input(0);
1420 } else if (IsShapeN(node)) {
1421 c->output_tensors_as_shapes.resize(c->inference_context->num_inputs());
1422 for (int i = 0; i < c->inference_context->num_inputs(); ++i) {
1423 c->output_tensors_as_shapes[i] = c->inference_context->input(i);
1424 }
1425 } else if (node.op() == "ConcatV2") {
1426 bool valid = true;
1427 ShapeHandle result;
1428 for (int i = 0; i < ic->num_inputs() - 1; ++i) {
1429 ShapeHandle input = ic->input_tensors_as_shapes()[i];
1430 if (!ic->RankKnown(input)) {
1431 valid = false;
1432 break;
1433 } else if (i == 0) {
1434 result = input;
1435 } else {
1436 TF_RETURN_IF_ERROR(ic->Concatenate(result, input, &result));
1437 }
1438 }
1439 if (valid) {
1440 c->output_tensors_as_shapes.resize(1);
1441 c->output_tensors_as_shapes[0] = result;
1442 }
1443 } else if (IsPack(node)) {
1444 // A Pack node concatenating scalars is often used to generate a shape.
1445 std::vector<DimensionHandle> dims;
1446 bool valid = true;
1447 for (int i = 0; i < ic->num_inputs(); ++i) {
1448 const Tensor* t = ic->input_tensor(i);
1449 if (t) {
1450 if (t->dims() != 0 ||
1451 (t->dtype() != DT_INT32 && t->dtype() != DT_INT64)) {
1452 valid = false;
1453 break;
1454 }
1455 int64 size = t->dtype() == DT_INT32 ? t->scalar<int32>()()
1456 : t->scalar<int64>()();
1457 dims.push_back(size < 0 ? ic->UnknownDim() : ic->MakeDim(size));
1458 } else {
1459 // Don't have tensor value, but use input_tensors_as_shapes, if
1460 // possible.
1461 const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
1462 if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 &&
1463 ic->ValueKnown(ic->Dim(shape_handle, 0))) {
1464 dims.push_back(ic->Dim(shape_handle, 0));
1465 } else {
1466 dims.push_back(ic->UnknownDim());
1467 }
1468 }
1469 }
1470 if (valid) {
1471 c->output_tensors_as_shapes.resize(1);
1472 c->output_tensors_as_shapes[0] = ic->MakeShape(dims);
1473 }
1474 } else if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
1475 c->output_tensors_as_shapes.resize(1);
1476 c->output_tensors_as_shapes[0] = ic->input_tensors_as_shapes()[0];
1477 if (c->input_tensor_protos[0] != nullptr) {
1478 c->output_tensor_protos.resize(1);
1479 c->output_tensor_protos[0] = c->input_tensor_protos[0];
1480 }
1481 } else if (IsSlice(node)) {
1482 ShapeHandle input = ic->input_tensors_as_shapes()[0];
1483 bool valid = ic->RankKnown(input);
1484 const Tensor* slice_offset = ic->input_tensor(1);
1485 valid &= slice_offset != nullptr && slice_offset->NumElements() == 1;
1486 const Tensor* slice_size = ic->input_tensor(2);
1487 valid &= slice_size != nullptr && slice_size->NumElements() == 1;
1488 if (valid) {
1489 int64 start = slice_offset->dtype() == DT_INT32
1490 ? slice_offset->flat<int32>()(0)
1491 : slice_offset->flat<int64>()(0);
1492 int64 size =
1493 (slice_size->dtype() == DT_INT32 ? slice_size->flat<int32>()(0)
1494 : slice_size->flat<int64>()(0));
1495 ShapeHandle result;
1496 if (size == -1) {
1497 TF_RETURN_IF_ERROR(ic->Subshape(input, start, &result));
1498 } else {
1499 int64 end = start + size;
1500 TF_RETURN_IF_ERROR(ic->Subshape(input, start, end, &result));
1501 }
1502 c->output_tensors_as_shapes.resize(1);
1503 c->output_tensors_as_shapes[0] = result;
1504 }
1505 } else if (IsStridedSlice(node)) {
1506 ShapeHandle input = ic->input_tensors_as_shapes()[0];
1507 bool valid = ic->RankKnown(input);
1508 const Tensor* slice_begin = ic->input_tensor(1);
1509 valid &= slice_begin != nullptr && slice_begin->NumElements() == 1;
1510 const Tensor* slice_end = ic->input_tensor(2);
1511 valid &= slice_end != nullptr && slice_end->NumElements() == 1;
1512 const Tensor* slice_stride = ic->input_tensor(3);
1513 valid &= slice_stride != nullptr && slice_stride->NumElements() == 1;
1514
1515 if (node.attr().count("ellipsis_mask") > 0 &&
1516 node.attr().at("ellipsis_mask").i() != 0) {
1517 valid = false;
1518 }
1519 if (node.attr().count("new_axis_mask") > 0 &&
1520 node.attr().at("new_axis_mask").i() != 0) {
1521 valid = false;
1522 }
1523 if (node.attr().count("shrink_axis_mask") > 0 &&
1524 node.attr().at("shrink_axis_mask").i() != 0) {
1525 valid = false;
1526 }
1527 int begin_mask = 0;
1528 if (node.attr().count("begin_mask") > 0) {
1529 begin_mask = node.attr().at("begin_mask").i();
1530 }
1531 int end_mask = 0;
1532 if (node.attr().count("end_mask") > 0) {
1533 end_mask = node.attr().at("end_mask").i();
1534 }
1535 if (begin_mask < 0 || begin_mask > 1 || end_mask < 0 || end_mask > 1) {
1536 valid = false;
1537 }
1538 if (valid) {
1539 int64 begin = 0;
1540 if (begin_mask == 0) {
1541 begin = slice_begin->dtype() == DT_INT32
1542 ? slice_begin->flat<int32>()(0)
1543 : slice_begin->flat<int64>()(0);
1544 }
1545 int64 end = std::numeric_limits<int64>::max();
1546 if (end_mask == 0) {
1547 end =
1548 (slice_end->dtype() == DT_INT32 ? slice_end->flat<int32>()(0)
1549 : slice_end->flat<int64>()(0));
1550 }
1551 int64 stride = slice_stride->dtype() == DT_INT32
1552 ? slice_stride->flat<int32>()(0)
1553 : slice_stride->flat<int64>()(0);
1554 ShapeHandle result;
1555 TF_RETURN_IF_ERROR(ic->Subshape(input, begin, end, stride, &result));
1556 c->output_tensors_as_shapes.resize(1);
1557 c->output_tensors_as_shapes[0] = result;
1558 }
1559 }
1560 }
1561
1562 if (aggressive_shape_inference_) {
1563 // Update output shapes with annotated information. This is optional.
1564 UpdateOutputShapesUsingAnnotatedInformation(node, c).IgnoreError();
1565
1566 // Update output tensor values using EvaluateNode() if we can.
1567 // Due to the cost of EvaluateNode(), we run it only for certain op types
1568 // (white listed) and small integer tensors.
1569
1570 const int max_element_size = 17; // Max up to 4x4 matrix or similar.
1571 if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) ||
1572 !ShouldUpdateOutputShapesAndValues(c, max_element_size)) {
1573 return Status::OK();
1574 }
1575 UpdateOutputShapesAndValues(node, c).IgnoreError(); // This is optional.
1576 }
1577 return Status::OK();
1578 }
1579
InferShapes(const NodeDef & node,NodeContext * c)1580 Status InferShapes(const NodeDef& node, NodeContext* c) {
1581 // Infer the shapes of output tensors.
1582 if (!c->op_data || c->op_data->shape_inference_fn == nullptr) {
1583 // There is nothing more we can infer, annotate outputs with unknown
1584 // shapes
1585 return c->inference_context->Run(shape_inference::UnknownShape);
1586 }
1587
1588 TF_RETURN_IF_ERROR(
1589 c->inference_context->Run(c->op_data->shape_inference_fn));
1590
1591 Status status = Status::OK();
1592 auto it = fed_ports_.find(node.name());
1593 const bool is_fed = it != fed_ports_.end();
1594 if (is_fed) {
1595 // It is possible to feed node output ports with tensors of any shape: as
1596 // a result, the shape of a fed port is completely unknown.
1597 for (const int output_port : it->second) {
1598 status.Update(SetUnknownShape(&node, output_port));
1599 }
1600 }
1601
1602 // Update NodeContext output fields after shape inference function runs.
1603 status.Update(MaybeUpdateNodeContextOutput(node, is_fed, c));
1604
1605 return status;
1606 }
1607
1608 private:
IsIntegerVector(const Tensor & tensor)1609 bool IsIntegerVector(const Tensor& tensor) {
1610 if (tensor.dims() == 1 &&
1611 (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64)) {
1612 return true;
1613 }
1614 return false;
1615 }
1616
IsIntegerScalar(const Tensor & tensor)1617 bool IsIntegerScalar(const Tensor& tensor) {
1618 if (tensor.dims() == 0 &&
1619 (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64) &&
1620 tensor.NumElements() == 1) {
1621 return true;
1622 }
1623 return false;
1624 }
1625
MakeIntegerScalarTensorProto(const DataType dtype,const int64 val)1626 TensorProto MakeIntegerScalarTensorProto(const DataType dtype,
1627 const int64 val) {
1628 TensorProto tensor_proto;
1629 tensor_proto.set_dtype(dtype);
1630 // Scalar TensorProto has an empty tensor_shape; no dim, no dim.size.
1631 tensor_proto.mutable_tensor_shape();
1632 if (dtype == DT_INT32) {
1633 tensor_proto.add_int_val(val);
1634 } else if (dtype == DT_INT64) {
1635 tensor_proto.add_int64_val(val);
1636 }
1637 return tensor_proto;
1638 }
1639
MaybeTensorProtoToShape(InferenceContext * ic,const TensorProto & tensor_proto,ShapeHandle * tensors_as_shapes)1640 bool MaybeTensorProtoToShape(InferenceContext* ic,
1641 const TensorProto& tensor_proto,
1642 ShapeHandle* tensors_as_shapes) {
1643 // Skip if dtype is not integer.
1644 if (tensor_proto.dtype() != DT_INT32 && tensor_proto.dtype() != DT_INT64) {
1645 return false;
1646 }
1647 // Skip if shape is neither scalar nor vector.
1648 if (tensor_proto.tensor_shape().unknown_rank() ||
1649 tensor_proto.tensor_shape().dim_size() > 1) {
1650 return false;
1651 }
1652 Tensor tensor;
1653 if (!tensor.FromProto(tensor_proto)) {
1654 return false;
1655 }
1656 return MaybeTensorValueToShape(ic, tensor, tensors_as_shapes);
1657 }
1658
MaybeTensorValueToShape(InferenceContext * ic,const Tensor & tensor,ShapeHandle * tensors_as_shapes)1659 bool MaybeTensorValueToShape(InferenceContext* ic, const Tensor& tensor,
1660 ShapeHandle* tensors_as_shapes) {
1661 // Integer tensors of rank one can also be interpreted as a shape
1662 // provided all their values are >= -1.
1663 if (IsIntegerVector(tensor)) {
1664 bool has_values_smaller_than_minus_1 = false;
1665 std::vector<DimensionHandle> dims;
1666 for (int i = 0; i < tensor.NumElements(); i++) {
1667 int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(i)
1668 : tensor.flat<int64>()(i);
1669 has_values_smaller_than_minus_1 |= (value < -1);
1670 dims.push_back(value < 0 ? ic->UnknownDim() : ic->MakeDim(value));
1671 }
1672 if (!has_values_smaller_than_minus_1) {
1673 *tensors_as_shapes = ic->MakeShape(dims);
1674 }
1675 } else if (IsIntegerScalar(tensor)) {
1676 // Scalar constant.
1677 int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0)
1678 : tensor.flat<int64>()(0);
1679 // Ideally, values can be < -1, but MakeDim() fails with a value < -1.
1680 // It's a limitation as we use ShapeHandle as a means to pass values.
1681 if (value >= -1) {
1682 *tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)});
1683 return true;
1684 }
1685 }
1686 return false;
1687 }
1688
1689 const GraphView& graph_;
1690 int graph_def_version_;
1691 std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
1692 std::unordered_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
1693 std::unordered_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
1694 std::unordered_map<string, GrapplerFunctionItem>
1695 fun_to_grappler_function_item_;
1696 FunctionLibraryDefinition function_library_;
1697 const std::unordered_map<string, std::unordered_set<int>>& fed_ports_;
1698 // Store TensorProtos for tensor value propagation. Note that we use list, not
1699 // vector, as we use pointers to the TensorProtos in this container. Vector
1700 // may resize and copy the objects into a new buffer, then the existing
1701 // pointers become dangling pointers.
1702 std::list<TensorProto> const_tensors_to_propagate_;
1703
1704 // For more aggressive shape and value inference.
1705 bool aggressive_shape_inference_;
1706 ResourceMgr resource_mgr_;
1707 };
1708
1709 // Keep track of shapes and dimensions in a graph.
1710 // In particular, use disjoint sets to track equivalence between shapes and
1711 // dims, and consolidate the information globally.
1712 class SymbolicShapeManager {
1713 public:
SymbolicShapeManager()1714 SymbolicShapeManager() {}
1715
Merge(ShapeHandle s1,ShapeHandle s2)1716 Status Merge(ShapeHandle s1, ShapeHandle s2) {
1717 if (!s1.IsSet() || !s2.IsSet()) {
1718 return Status::OK();
1719 }
1720 TF_RETURN_IF_ERROR(shapes_.Merge(s1, s2));
1721 if (InferenceContext::Rank(s1) > 0 && InferenceContext::Rank(s2) > 0) {
1722 CHECK_EQ(InferenceContext::Rank(s1), InferenceContext::Rank(s2));
1723 for (int i = 0; i < InferenceContext::Rank(s1); ++i) {
1724 TF_RETURN_IF_ERROR(dims_.Merge(InferenceContext::DimKnownRank(s1, i),
1725 InferenceContext::DimKnownRank(s2, i)));
1726 }
1727 }
1728 return Status::OK();
1729 }
Merge(DimensionHandle d1,DimensionHandle d2)1730 Status Merge(DimensionHandle d1, DimensionHandle d2) {
1731 if (!d1.IsSet() || !d2.IsSet()) {
1732 return Status::OK();
1733 }
1734 return dims_.Merge(d1, d2);
1735 }
1736
AsTensorProperties(const ShapeHandle & shape,const DataType & type,OpInfo::TensorProperties * properties)1737 void AsTensorProperties(const ShapeHandle& shape, const DataType& type,
1738 OpInfo::TensorProperties* properties) {
1739 properties->set_dtype(type);
1740 ShapeHandle actual_shape = shapes_.GetMergedValue(shape);
1741 if (!InferenceContext::RankKnown(actual_shape)) {
1742 properties->mutable_shape()->set_unknown_rank(true);
1743 } else {
1744 for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
1745 shape_inference::DimensionHandle dim =
1746 InferenceContext::DimKnownRank(actual_shape, j);
1747 int64 d = dims_.GetMergedValue(dim);
1748 properties->mutable_shape()->add_dim()->set_size(d);
1749 }
1750 }
1751 }
1752
1753 private:
1754 DisjointSet<shape_inference::ShapeHandle> shapes_;
1755 DisjointSet<shape_inference::DimensionHandle> dims_;
1756 };
1757
RelaxEnqueueShapesAndMergeTypes(SymbolicShapeRefiner * shape_refiner,const NodeDef * qnode,const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * queue_shapes_and_types)1758 Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
1759 SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
1760 const std::vector<ShapeAndType>& shapes_and_types,
1761 std::vector<ShapeAndType>* queue_shapes_and_types) {
1762 if (shapes_and_types.size() != queue_shapes_and_types->size()) {
1763 return errors::InvalidArgument(
1764 "Enqueue nodes mixed number of tensors: ", shapes_and_types.size(),
1765 " vs ", queue_shapes_and_types->size());
1766 }
1767 for (size_t i = 0; i < shapes_and_types.size(); ++i) {
1768 const ShapeAndType& a = shapes_and_types[i];
1769 ShapeAndType& b = (*queue_shapes_and_types)[i];
1770 if (a.dtype != b.dtype) {
1771 return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ",
1772 i, ": ", DataTypeString(a.dtype), " vs ",
1773 DataTypeString(b.dtype));
1774 }
1775
1776 b.shape = shape_refiner->OutputAsUnion(qnode, i, a.shape, b.shape);
1777 }
1778 return Status::OK();
1779 }
1780
1781 // Compute the output shape of the merge node as the union of the available
1782 // input shapes.
UpdateMerge(SymbolicShapeRefiner * shape_refiner,const NodeDef * node,bool * new_shapes) const1783 Status GraphProperties::UpdateMerge(SymbolicShapeRefiner* shape_refiner,
1784 const NodeDef* node,
1785 bool* new_shapes) const {
1786 InferenceContext* ic = shape_refiner->GetContext(node);
1787 if (!ic) {
1788 // Now we can run shape inference
1789 TF_RETURN_IF_ERROR(shape_refiner->AddNode(node));
1790 ic = CHECK_NOTNULL(shape_refiner->GetContext(node));
1791 *new_shapes = true;
1792
1793 // Infer the shape of the second output once and for all since it never
1794 // changes.
1795 ShapeHandle out1 = ic->Scalar();
1796 ic->set_output(1, out1);
1797 }
1798
1799 ShapeHandle out;
1800 const std::vector<ShapeAndType>* out_handle = nullptr;
1801 bool out_initialized = false;
1802 for (const GraphView::Edge fanin : shape_refiner->graph().GetFaninEdges(
1803 *node, /*include_controlling_edges=*/false)) {
1804 InferenceContext* src_ic = shape_refiner->GetContext(fanin.src.node);
1805 if (!src_ic) {
1806 // Handling a loop for the first time, the back edge won't have any shape
1807 // info.
1808 continue;
1809 }
1810 ShapeHandle input = src_ic->output(fanin.src.port_id);
1811 ic->SetInput(fanin.dst.port_id, input);
1812 auto* input_handle =
1813 src_ic->output_handle_shapes_and_types(fanin.src.port_id);
1814 if (input_handle)
1815 ic->set_input_handle_shapes_and_types(fanin.dst.port_id, *input_handle);
1816 if (!out_initialized) {
1817 out_initialized = true;
1818 out = input;
1819 out_handle = input_handle;
1820 } else {
1821 // Note here only out, not out_handle, is modified.
1822 out = shape_refiner->OutputAsUnion(node, 0, input, out);
1823 }
1824 }
1825
1826 if (*new_shapes || !shape_refiner->EquivalentShapes(out, ic->output(0))) {
1827 ic->set_output(0, out);
1828 if (out_handle) ic->set_output_handle_shapes_and_types(0, *out_handle);
1829 *new_shapes = true;
1830 }
1831
1832 return Status::OK();
1833 }
1834
1835 // Manually propagate the input shape for Enter nodes.
UpdateEnter(SymbolicShapeRefiner * shape_refiner,const NodeDef * node,bool * new_shapes)1836 Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
1837 const NodeDef* node, bool* new_shapes) {
1838 InferenceContext* ic = shape_refiner->GetContext(node);
1839 if (!ic) {
1840 TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(node, new_shapes));
1841 ic = shape_refiner->GetContext(node);
1842 }
1843
1844 GraphView::InputPort port(node, 0);
1845 GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(port);
1846
1847 InferenceContext* src_ic = shape_refiner->GetContext(fanin.node);
1848 ShapeHandle input = src_ic->output(fanin.port_id);
1849 if (!ic->output(0).SameHandle(input)) {
1850 ic->SetInput(0, input);
1851 ic->set_output(0, input);
1852 *new_shapes = true;
1853 }
1854 auto* outputs = src_ic->output_handle_shapes_and_types(fanin.port_id);
1855 if (outputs) {
1856 ic->set_input_handle_shapes_and_types(0, *outputs);
1857 ic->set_output_handle_shapes_and_types(0, *outputs);
1858 *new_shapes = true;
1859 }
1860 return Status::OK();
1861 }
1862
UpdateShapes(SymbolicShapeRefiner * shape_refiner,const std::unordered_map<const NodeDef *,const NodeDef * > & resource_handles,const NodeDef * n,bool * new_shapes) const1863 Status GraphProperties::UpdateShapes(
1864 SymbolicShapeRefiner* shape_refiner,
1865 const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
1866 const NodeDef* n, bool* new_shapes) const {
1867 if (IsEnter(*n)) {
1868 // The Enter shape function always forwards an UnknownShape, so do the right
1869 // thing here.
1870 TF_RETURN_IF_ERROR(UpdateEnter(shape_refiner, n, new_shapes));
1871 } else if (IsMerge(*n)) {
1872 // Properly handle merge nodes.
1873 TF_RETURN_IF_ERROR(UpdateMerge(shape_refiner, n, new_shapes));
1874 } else if (IsEnqueue(*n)) {
1875 // Make sure the shapes of enqueued tensors are propagated to the queue
1876 // itself.
1877 TF_RETURN_IF_ERROR(
1878 UpdateEnqueue(n, resource_handles, shape_refiner, new_shapes));
1879 } else if (IsQueue(*n)) {
1880 // Set shapes and types of Queue ops, if needed.
1881 TF_RETURN_IF_ERROR(UpdateQueue(n, shape_refiner, new_shapes));
1882 } else {
1883 // Rely on regular TF shape refinement for all the other nodes.
1884 // UpdateNode calls UpdateFunction if a function node is detected.
1885 TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
1886 }
1887
1888 return Status::OK();
1889 }
1890
1891 // Propagates the shapes in the transitive fan-out of <new_shapes>.
PropagateShapes(SymbolicShapeRefiner * shape_refiner,TopoQueue * new_shapes,const std::unordered_map<const NodeDef *,const NodeDef * > & resource_handles,int num_loops) const1892 Status GraphProperties::PropagateShapes(
1893 SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
1894 const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
1895 int num_loops) const {
1896 // Limit the number of iterations to prevent infinite loops in the presence of
1897 // incorrect shape functions. The algorithm should converge in at most
1898 // num_nested_loops^2 * max_rank. We approximate max_rank with the constant 4.
1899 // The same applies to resources.
1900 VLOG(1) << "Propagating " << new_shapes->size() << " new shapes through "
1901 << num_loops << " loops and " << resource_handles.size()
1902 << " resources" << std::endl;
1903
1904 const int64 max_loop_length = item_.graph.node_size();
1905 const int64 max_rank = 4;
1906 const int64 max_loop_iterations =
1907 max_rank * max_loop_length * std::max<int64>(1, num_loops * num_loops);
1908 const int64 num_queues = resource_handles.size();
1909 const int64 max_resource_iterations = num_queues * num_queues * max_rank;
1910
1911 int64 num_resource_iterations = 0;
1912 do {
1913 int64 num_loop_iterations = 0;
1914 while (!new_shapes->empty() &&
1915 num_loop_iterations++ < max_loop_iterations) {
1916 const NodeDef* n = new_shapes->pop();
1917 bool updated = false;
1918 TF_RETURN_IF_ERROR(
1919 UpdateShapes(shape_refiner, resource_handles, n, &updated));
1920 if (updated) {
1921 for (const auto& fanout : shape_refiner->graph().GetFanouts(
1922 *n, /*include_controlled_nodes=*/false)) {
1923 new_shapes->push(fanout.node);
1924 }
1925 // Make sure the corresponding queue nodes are (re)processed.
1926 if (IsEnqueue(*n)) {
1927 auto it = resource_handles.find(n);
1928 if (it != resource_handles.end()) {
1929 new_shapes->push(it->second);
1930 }
1931 }
1932 }
1933 }
1934 } while (!new_shapes->empty() &&
1935 num_resource_iterations++ < max_resource_iterations);
1936
1937 if (!new_shapes->empty()) {
1938 return errors::Internal("Shape inference failed to converge");
1939 }
1940
1941 return Status::OK();
1942 }
1943
UpdateQueue(const NodeDef * queue_node,SymbolicShapeRefiner * shape_refiner,bool * new_shapes)1944 Status GraphProperties::UpdateQueue(const NodeDef* queue_node,
1945 SymbolicShapeRefiner* shape_refiner,
1946 bool* new_shapes) {
1947 auto* ctx = shape_refiner->GetNodeContext(queue_node);
1948 if (!ctx) {
1949 TF_RETURN_IF_ERROR(shape_refiner->AddNode(queue_node));
1950 ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(queue_node));
1951 }
1952 auto* ic = ctx->inference_context.get();
1953
1954 auto* outputs = ic->output_handle_shapes_and_types(0);
1955 if (outputs) {
1956 // Shapes and types are already set, presumably by Enqueue ops.
1957 return shape_refiner->UpdateNode(queue_node, new_shapes);
1958 }
1959
1960 if (queue_node->attr().count("shapes") <= 0 ||
1961 queue_node->attr().count("component_types") <= 0 ||
1962 queue_node->attr().at("shapes").list().shape_size() !=
1963 queue_node->attr().at("component_types").list().type_size()) {
1964 // Errors in shapes and component_types attr.
1965 return shape_refiner->UpdateNode(queue_node, new_shapes);
1966 }
1967
1968 // Extract types and shapes from Queue attr.
1969 const auto& shapes = queue_node->attr().at("shapes").list().shape();
1970 const auto& types = queue_node->attr().at("component_types").list().type();
1971 std::vector<ShapeAndType> shapes_and_types;
1972 for (int i = 0; i < types.size(); i++) {
1973 const auto& shape = shapes[i];
1974 ShapeHandle shape_handle;
1975 TF_RETURN_IF_ERROR(
1976 ic->MakeShapeFromPartialTensorShape(shape, &shape_handle));
1977 DataType data_type =
1978 queue_node->attr().at("component_types").list().type(i);
1979 ShapeAndType shape_and_type(shape_handle, data_type);
1980 shapes_and_types.push_back(shape_and_type);
1981 }
1982 ic->set_output_handle_shapes_and_types(0, shapes_and_types);
1983
1984 // Queue node is updated with output_handle_shapes_and_types, so set
1985 // new_shapes and ignore it from UpdateNoe().
1986 *new_shapes = true;
1987 bool dummy_new_shapes = false;
1988 return shape_refiner->UpdateNode(queue_node, &dummy_new_shapes);
1989 }
1990
UpdateEnqueue(const NodeDef * enqueue_node,const std::unordered_map<const NodeDef *,const NodeDef * > & resource_handles,SymbolicShapeRefiner * shape_refiner,bool * new_shapes)1991 Status GraphProperties::UpdateEnqueue(
1992 const NodeDef* enqueue_node,
1993 const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
1994 SymbolicShapeRefiner* shape_refiner, bool* new_shapes) {
1995 auto ctx = shape_refiner->GetNodeContext(enqueue_node);
1996 if (!ctx) {
1997 TF_RETURN_IF_ERROR(shape_refiner->AddNode(enqueue_node));
1998 ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(enqueue_node));
1999 }
2000
2001 auto it = resource_handles.find(enqueue_node);
2002 if (it == resource_handles.end()) {
2003 // The corresponding queue was not found, there isn't much we can do.
2004 return Status::OK();
2005 }
2006 const NodeDef* qnode = it->second;
2007 auto qctx = shape_refiner->GetContext(qnode);
2008 if (!qctx) {
2009 return Status::OK();
2010 }
2011 auto* queue_handle_data = qctx->output_handle_shapes_and_types(0);
2012
2013 // TODO(bsteiner): handle EnqueueMany as well.
2014 std::vector<ShapeAndType> shapes_and_types;
2015 for (int i = 1; i < ctx->input_types.size(); ++i) {
2016 GraphView::InputPort inp(enqueue_node, i);
2017 GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(inp);
2018 InferenceContext* in = shape_refiner->GetContext(fanin.node);
2019 ShapeHandle input = in->output(fanin.port_id);
2020 ctx->inference_context->SetInput(i, input);
2021 shapes_and_types.push_back({input, ctx->input_types[i]});
2022 }
2023
2024 if (queue_handle_data == nullptr) {
2025 qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
2026 *new_shapes = true;
2027 } else {
2028 TF_RETURN_IF_ERROR(RelaxEnqueueShapesAndMergeTypes(
2029 shape_refiner, qnode, *queue_handle_data, &shapes_and_types));
2030 *new_shapes |= !shape_refiner->EquivalentShapesAndTypes(*queue_handle_data,
2031 shapes_and_types);
2032 qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
2033 }
2034
2035 return Status::OK();
2036 }
2037
InferStatically(bool assume_valid_feeds,bool aggressive_shape_inference)2038 Status GraphProperties::InferStatically(bool assume_valid_feeds,
2039 bool aggressive_shape_inference) {
2040 FunctionLibraryDefinition function_library(OpRegistry::Global(),
2041 item_.graph.library());
2042 std::unordered_map<string, std::unordered_set<int>> fed_ports;
2043 if (!assume_valid_feeds) {
2044 for (const auto& feed : item_.feed) {
2045 SafeTensorId tensor_id = ParseTensorName(feed.first);
2046 fed_ports[tensor_id.node()].insert(tensor_id.index());
2047 }
2048 }
2049
2050 GraphView graph_view(&item_.graph);
2051
2052 // List the resources and the nodes using them. Also collect the Merge nodes,
2053 // fed nodes, and primary inputs.
2054 std::unordered_map<const NodeDef*,
2055 std::pair<std::unordered_set<const NodeDef*>,
2056 std::unordered_set<const NodeDef*>>>
2057 resources;
2058 std::unordered_set<const NodeDef*> merge_nodes;
2059 std::unordered_set<const NodeDef*> fed_nodes;
2060 std::unordered_set<const NodeDef*> primary_inputs;
2061 int num_loops = 0;
2062 for (const NodeDef& node : item_.graph.node()) {
2063 if (IsQueue(node)) {
2064 for (const GraphView::InputPort& fanout :
2065 graph_view.GetFanouts(node, false)) {
2066 if (IsEnter(*fanout.node)) {
2067 const NodeDef& enter = *fanout.node;
2068 for (const GraphView::InputPort& fanout :
2069 graph_view.GetFanouts(enter, false)) {
2070 if (IsEnqueue(*fanout.node)) {
2071 resources[&node].first.insert(fanout.node);
2072 } else if (IsDequeue(*fanout.node)) {
2073 resources[&node].second.insert(fanout.node);
2074 }
2075 }
2076 } else {
2077 if (IsEnqueue(*fanout.node)) {
2078 resources[&node].first.insert(fanout.node);
2079 } else if (IsDequeue(*fanout.node)) {
2080 resources[&node].second.insert(fanout.node);
2081 }
2082 }
2083 }
2084 }
2085 if (NumNonControlInputs(node) == 0) {
2086 primary_inputs.insert(&node);
2087 } else if (IsMerge(node)) {
2088 merge_nodes.insert(&node);
2089 } else if (IsNextIteration(node)) {
2090 ++num_loops;
2091 }
2092 if (fed_ports.find(node.name()) != fed_ports.end()) {
2093 fed_nodes.insert(&node);
2094 }
2095 }
2096
2097 std::unordered_map<const NodeDef*, const NodeDef*> resource_handles;
2098 std::vector<TopologicalDependency> extra_deps;
2099 for (const auto& resource : resources) {
2100 for (const NodeDef* src : resource.second.first) {
2101 resource_handles[src] = resource.first;
2102 for (const NodeDef* dst : resource.second.second) {
2103 // Add control edges from enqueue to dequeue nodes to ensure they are
2104 // processed in their logical order.
2105 extra_deps.emplace_back(src, dst);
2106 }
2107 }
2108 }
2109
2110 std::vector<const NodeDef*> topo_order;
2111 Status s = ComputeTopologicalOrder(item_.graph, extra_deps, &topo_order);
2112 if (!s.ok()) {
2113 if (extra_deps.empty()) {
2114 return s;
2115 } else {
2116 // There is a loop between queues: we'll just use the graph topological
2117 // order. This will make the shape inference less precise but since this
2118 // isn't common it's not worth to figure out where to break the loop and
2119 // do a proper relaxation.
2120 TF_RETURN_IF_ERROR(ComputeTopologicalOrder(item_.graph, &topo_order));
2121 }
2122 }
2123
2124 SymbolicShapeRefiner refiner(graph_view, fed_ports,
2125 aggressive_shape_inference);
2126
2127 TopoQueue new_shapes(topo_order);
2128 // Also seed the propagation of shapes in the fanout of primary inputs.
2129 for (const NodeDef* node : primary_inputs) {
2130 new_shapes.push(node);
2131 }
2132 // Also seed the propagation of shapes in the fanout of fed nodes.
2133 for (const NodeDef* node : fed_nodes) {
2134 new_shapes.push(node);
2135 }
2136 // Propagate shapes normally.
2137 TF_RETURN_IF_ERROR(
2138 PropagateShapes(&refiner, &new_shapes, resource_handles, num_loops));
2139
2140 // Track shapes globally across the graph.
2141 std::unique_ptr<SymbolicShapeManager> shape_manager =
2142 absl::make_unique<SymbolicShapeManager>();
2143 bool found_error = false;
2144 for (const NodeDef& node : item_.graph.node()) {
2145 auto node_ctx = refiner.GetContext(&node);
2146 if (!node_ctx) {
2147 continue;
2148 }
2149 // Skip any information that comes from fed nodes.
2150 if (fed_ports.find(node.name()) != fed_ports.end()) {
2151 VLOG(2) << "Skipping feed node shape: " << node.name();
2152 continue;
2153 }
2154 for (const auto& merged_shapes : node_ctx->MergedShapes()) {
2155 if (!shape_manager->Merge(merged_shapes.first, merged_shapes.second)
2156 .ok()) {
2157 found_error = true;
2158 break;
2159 }
2160 }
2161 for (const auto& merged_dims : node_ctx->MergedDims()) {
2162 if (!shape_manager->Merge(merged_dims.first, merged_dims.second).ok()) {
2163 found_error = true;
2164 break;
2165 }
2166 }
2167 if (found_error) {
2168 // The shapes aren't consistent, we can't infer safely: discard all the
2169 // information discovered so far.
2170 shape_manager = absl::make_unique<SymbolicShapeManager>();
2171 break;
2172 }
2173 }
2174
2175 for (const NodeDef& node : item_.graph.node()) {
2176 VLOG(3) << "Filling in graph properties for node: " << node.name();
2177 auto ctx = refiner.GetNodeContext(&node);
2178 if (!ctx) {
2179 continue;
2180 }
2181
2182 auto* ic = ctx->inference_context.get();
2183
2184 // Fill input properties.
2185 {
2186 auto& input_properties = input_properties_[node.name()];
2187
2188 // Should always be empty, node names in graph are supposed to be unique.
2189 CHECK_EQ(input_properties.size(), 0);
2190
2191 input_properties.resize(ic->num_inputs());
2192 GraphView::InputPort input(&node, -1);
2193 for (int i = 0; i < ic->num_inputs(); ++i) {
2194 shape_manager->AsTensorProperties(ic->input(i), ctx->input_types[i],
2195 &input_properties[i]);
2196 input.port_id = i;
2197 GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
2198 // Export tensor value to input_properties.value.
2199 if (IsConstant(*fanin.node)) {
2200 const TensorProto& raw_val = fanin.node->attr().at("value").tensor();
2201 *input_properties[i].mutable_value() = raw_val;
2202 } else if (ctx->input_tensor_protos.size() > i &&
2203 ctx->input_tensor_protos[i] != nullptr) {
2204 *input_properties[i].mutable_value() = *ctx->input_tensor_protos[i];
2205 } else if (ic->input_tensors_as_shapes().size() > i &&
2206 IsShapeFullyDefinedIntegerVectorOrScalar(
2207 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
2208 ctx->input_types[i])) {
2209 *input_properties[i].mutable_value() = MakeTensorProtoFromShape(
2210 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
2211 ctx->input_types[i]);
2212 }
2213 }
2214 }
2215
2216 // Fill output properties.
2217 {
2218 auto& output_properties = output_properties_[node.name()];
2219
2220 // Should always be empty, node names in graph are supposed to be unique.
2221 CHECK_EQ(output_properties.size(), 0);
2222
2223 output_properties.resize(ic->num_outputs());
2224 for (int i = 0; i < ic->num_outputs(); ++i) {
2225 shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i],
2226 &output_properties[i]);
2227 // Export tensor value to output_properties.value.
2228 if (IsConstant(node)) {
2229 const TensorProto& raw_val = node.attr().at("value").tensor();
2230 *output_properties[i].mutable_value() = raw_val;
2231 } else if (ctx->output_tensor_protos.size() > i &&
2232 ctx->output_tensor_protos[i] != nullptr) {
2233 *output_properties[i].mutable_value() = *ctx->output_tensor_protos[i];
2234 } else if (ctx->output_tensors_as_shapes.size() > i &&
2235 IsShapeFullyDefinedIntegerVectorOrScalar(
2236 ic, ic->output(i), ctx->output_tensors_as_shapes[i],
2237 ctx->output_types[i])) {
2238 *output_properties[i].mutable_value() = MakeTensorProtoFromShape(
2239 ic, ic->output(i), ctx->output_tensors_as_shapes[i],
2240 ctx->output_types[i]);
2241 }
2242 }
2243 }
2244 }
2245
2246 // Help trace the unknown dimensions to their origins.
2247 VerboseLogUnknownDimensionSources(item_.graph, input_properties_,
2248 output_properties_);
2249
2250 return Status::OK();
2251 }
2252
InferDynamically(Cluster * cluster)2253 Status GraphProperties::InferDynamically(Cluster* cluster) {
2254 TF_RETURN_IF_ERROR(cluster->Initialize(item_));
2255
2256 // Runs the model once to collect the shapes in the cost model.
2257 RunMetadata metadata;
2258 TF_RETURN_IF_ERROR(
2259 cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata));
2260
2261 return InferFromCostGraph(metadata.cost_graph());
2262 }
2263
AnnotateOutputShapes(GraphDef * output_graph_def) const2264 Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def) const {
2265 *output_graph_def = item_.graph;
2266 for (int i = 0; i < output_graph_def->node_size(); i++) {
2267 auto node = output_graph_def->mutable_node(i);
2268 AttrValue attr_output_shape;
2269 auto tensor_properties = GetOutputProperties(node->name());
2270 for (const auto& tensor_property : tensor_properties) {
2271 *attr_output_shape.mutable_list()->add_shape() = tensor_property.shape();
2272 }
2273 (*node->mutable_attr())["_output_shapes"] = attr_output_shape;
2274 }
2275 return Status::OK();
2276 }
2277
InferFromCostGraph(const CostGraphDef & cost_graph)2278 Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) {
2279 if (cost_graph.node_size() == 0) {
2280 LOG(WARNING) << "cost_graph is empty: nothing can be inferred!";
2281 }
2282 std::unordered_map<string, const CostGraphDef::Node*> name_to_cost;
2283 std::unordered_map<string, const NodeDef*> name_to_node; // Empty
2284 for (auto& node : cost_graph.node()) {
2285 name_to_cost[node.name()] = &node;
2286
2287 std::vector<OpInfo::TensorProperties> output_properties;
2288 for (const auto& out : node.output_info()) {
2289 OpInfo::TensorProperties properties;
2290 properties.set_dtype(out.dtype());
2291 *properties.mutable_shape() = out.shape();
2292 output_properties.push_back(properties);
2293 }
2294 output_properties_[node.name()] = output_properties;
2295 }
2296
2297 for (const auto& node : item_.graph.node()) {
2298 // Skip the nodes that are not in the cost graph: these are nodes that
2299 // aren't run, because they aren't in the intersection of transitive fan-in
2300 // of a fetch node and the transitive fan-out of an input, or nodes that
2301 // were optimized away by the optimizer.
2302 auto it = name_to_cost.find(node.name());
2303 if (it == name_to_cost.end()) {
2304 continue;
2305 }
2306 std::vector<OpInfo::TensorProperties> inputs =
2307 FindInputFeatures(node, name_to_cost, name_to_node);
2308
2309 input_properties_[node.name()] = inputs;
2310 }
2311 return Status::OK();
2312 }
2313
HasInputProperties(const string & node_name) const2314 bool GraphProperties::HasInputProperties(const string& node_name) const {
2315 return input_properties_.find(node_name) != input_properties_.end();
2316 }
2317
HasOutputProperties(const string & node_name) const2318 bool GraphProperties::HasOutputProperties(const string& node_name) const {
2319 return output_properties_.find(node_name) != output_properties_.end();
2320 }
2321
2322 const std::vector<OpInfo::TensorProperties>&
GetInputProperties(const string & node_name) const2323 GraphProperties::GetInputProperties(const string& node_name) const {
2324 auto it = input_properties_.find(node_name);
2325 if (it != input_properties_.end()) {
2326 return it->second;
2327 }
2328 return missing_properties_;
2329 }
2330
2331 const std::vector<OpInfo::TensorProperties>&
GetOutputProperties(const string & node_name) const2332 GraphProperties::GetOutputProperties(const string& node_name) const {
2333 auto it = output_properties_.find(node_name);
2334 if (it != output_properties_.end()) {
2335 return it->second;
2336 }
2337 return missing_properties_;
2338 }
2339
ClearInputProperties(const string & node_name)2340 void GraphProperties::ClearInputProperties(const string& node_name) {
2341 input_properties_.erase(node_name);
2342 }
ClearOutputProperties(const string & node_name)2343 void GraphProperties::ClearOutputProperties(const string& node_name) {
2344 output_properties_.erase(node_name);
2345 }
2346
2347 } // end namespace grappler
2348 } // end namespace tensorflow
2349