1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/graph_properties.h"
17
18 #include "absl/types/optional.h"
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/framework/common_shape_fns.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/framework/versions.pb.h"
29 #include "tensorflow/core/graph/tensor_id.h"
30 #include "tensorflow/core/grappler/costs/utils.h"
31 #include "tensorflow/core/grappler/mutable_graph_view.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
34 #include "tensorflow/core/grappler/utils.h"
35 #include "tensorflow/core/grappler/utils/functions.h"
36 #include "tensorflow/core/grappler/utils/topological_sort.h"
37 #include "tensorflow/core/lib/gtl/cleanup.h"
38 #include "tensorflow/core/lib/gtl/flatset.h"
39 #include "tensorflow/core/lib/strings/str_util.h"
40
41 namespace tensorflow {
42 namespace grappler {
43
44 namespace {
45
46 using shape_inference::DimensionHandle;
47 using shape_inference::InferenceContext;
48 using shape_inference::ShapeAndType;
49 using shape_inference::ShapeHandle;
50 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
51
52 // A large value for UnknownDim from Const used as a dim value in shape.
53 // Some ops treat "-1" specially, different from UnknownDim:
54 // e.g., shape input to Reshape op.
55 const int64 kUnknownDimFromConst = INT64_MAX;
56
57 template <typename Handle>
58 struct HashHandle {
operator ()tensorflow::grappler::__anonb793b30d0111::HashHandle59 std::size_t operator()(const Handle& h) const { return h.Handle(); }
60 };
61 template <typename Handle>
62 struct CompareHandle {
operator ()tensorflow::grappler::__anonb793b30d0111::CompareHandle63 bool operator()(const Handle& h1, const Handle& h2) const {
64 return h1.SameHandle(h2);
65 }
66 };
67
68 template <typename Handle>
69 struct HandleToObject {};
70 template <>
71 struct HandleToObject<ShapeHandle> {
72 typedef ShapeHandle Object;
73
Unknowntensorflow::grappler::__anonb793b30d0111::HandleToObject74 static ShapeHandle Unknown() { return ShapeHandle(); }
75 };
76
77 template <>
78 struct HandleToObject<DimensionHandle> {
79 typedef int64 Object;
80
Unknowntensorflow::grappler::__anonb793b30d0111::HandleToObject81 static int64 Unknown() { return -1; }
82 };
83
84 template <typename Handle>
85 struct Processor {};
86
87 template <>
88 struct Processor<ShapeHandle> {
89 // Extract the shape or dim denoted by the handle.
ExtractValuetensorflow::grappler::__anonb793b30d0111::Processor90 void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; }
91 // Merge the shapes or dims.
Mergetensorflow::grappler::__anonb793b30d0111::Processor92 Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) {
93 if (InferenceContext::RankKnown(*result)) {
94 // The result was initialized in a previous merge to a shape of known
95 // rank, make sure we preserve that information.
96 return Status::OK();
97 }
98 if (InferenceContext::RankKnown(h1)) {
99 *result = h1;
100 } else {
101 *result = h2;
102 }
103 return Status::OK();
104 }
105 };
106
107 template <>
108 struct Processor<DimensionHandle> {
109 // Assign a negative id to unknown dimensions, starting at -2 (the -1 id
110 // reserved by TensorFlow).
ExtractValuetensorflow::grappler::__anonb793b30d0111::Processor111 void ExtractValue(DimensionHandle d, int64* result) {
112 if (!InferenceContext::ValueKnown(d)) {
113 *result = -counter;
114 counter++;
115 } else {
116 int64 val = InferenceContext::Value(d);
117 if (val >= 0) {
118 *result = val;
119 } else {
120 // A shape inference function generated an invalid dimension handle.
121 // Use a symbolic dimension to encode this.
122 *result = -counter;
123 counter++;
124 }
125 }
126 }
127
128 // Merge the dimensions d1 and d2. Return the known shape if there is one,
129 // otherwise look for a symbolic shape. If there is no symbolic shape and no
130 // known shape, the shape if fully unknown so return -1.
Mergetensorflow::grappler::__anonb793b30d0111::Processor131 Status Merge(DimensionHandle d1, DimensionHandle d2, int64* result) {
132 const int64 dim1 = InferenceContext::Value(d1);
133 const int64 dim2 = InferenceContext::Value(d2);
134
135 if (dim1 >= 0 && dim2 >= 0) {
136 CHECK_EQ(dim1, dim2);
137 return RefineDim(dim1, result);
138 } else if (dim1 >= 0 && dim2 < 0) {
139 return RefineDim(dim1, result);
140 } else if (dim1 < 0 && dim2 >= 0) {
141 return RefineDim(dim2, result);
142 } else if (dim1 < -1) {
143 return RefineDim(dim1, result);
144 } else if (dim2 < -1) {
145 return RefineDim(dim2, result);
146 } else {
147 CHECK_EQ(dim1, dim2);
148 CHECK_EQ(-1, dim1);
149 return RefineDim(-1, result);
150 }
151 return Status::OK();
152 }
153
154 private:
RefineDimtensorflow::grappler::__anonb793b30d0111::Processor155 Status RefineDim(int64 dim, int64* result) {
156 if (*result >= 0) {
157 if (!(*result == dim || dim < 0)) {
158 return errors::InvalidArgument("Inconsistent dimensions detected");
159 }
160 } else if (dim >= 0) {
161 *result = dim;
162 } else if (dim < *result) {
163 *result = dim;
164 }
165 return Status::OK();
166 }
167
168 int64 counter = 2;
169 };
170
171 // Traditional Disjoint-Set datastructure with path compression.
172 // (https://en.wikipedia.org/wiki/Disjoint-set_data_structure)
173 template <typename Handle>
174 class DisjointSet {
175 public:
DisjointSet()176 DisjointSet() {}
~DisjointSet()177 ~DisjointSet() {
178 for (auto rep : nodes_) {
179 delete rep.second;
180 }
181 }
182
183 Status Merge(Handle x, Handle y);
184 const typename HandleToObject<Handle>::Object GetMergedValue(Handle value);
185
186 private:
187 // All the handles that belong to the same set are part of the same tree, and
188 // utimately represented by the root of that tree.
189 struct Rep {
190 // Parent in the tree used to encode the set.
191 Rep* parent;
192 // Rank in the tree, used to figure out how to compress the path to the root
193 // of the tree.
194 int rank;
195 // The handle.
196 typename HandleToObject<Handle>::Object value;
197 };
198
199 // Create a new set for the value if none exists, or return its representative
200 // node otherwise.
201 Rep* Find(Handle value);
202
203 private:
204 Processor<Handle> processor_;
205 absl::flat_hash_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>>
206 nodes_;
207 };
208
209 template <typename Handle>
210 const typename HandleToObject<Handle>::Object
GetMergedValue(Handle value)211 DisjointSet<Handle>::GetMergedValue(Handle value) {
212 Rep* rep = Find(value);
213 if (!rep) {
214 // We don't know anything about this handle.
215 return HandleToObject<Handle>::Unknown();
216 }
217 return rep->value;
218 }
219
220 template <typename Handle>
Merge(Handle x,Handle y)221 Status DisjointSet<Handle>::Merge(Handle x, Handle y) {
222 Rep* x_root = Find(x);
223 Rep* y_root = Find(y);
224
225 // x and y are already in the same set
226 if (x_root == y_root) {
227 return Status::OK();
228 }
229 // x and y are not in same set, so we merge them
230 // Use the occasion to strengthen what we know about the handle by merging the
231 // information about the 2 subsets.
232 if (x_root->rank < y_root->rank) {
233 TF_RETURN_IF_ERROR(processor_.Merge(y, x, &y_root->value));
234 x_root->parent = y_root;
235 } else if (x_root->rank > y_root->rank) {
236 TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
237 y_root->parent = x_root;
238 } else {
239 TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
240 // Arbitrarily make one root the new parent
241 y_root->parent = x_root;
242 x_root->rank = x_root->rank + 1;
243 }
244 return Status::OK();
245 }
246
247 template <typename Handle>
Find(Handle value)248 typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
249 auto it = nodes_.find(value);
250 if (it == nodes_.end()) {
251 // This is the first time we process this handle, create an entry for it.
252 Rep* node = new Rep;
253 node->parent = node;
254 node->rank = 0;
255 processor_.ExtractValue(value, &node->value);
256 nodes_[value] = node;
257 return node;
258 }
259 // Return the representative for the set, which is the root of the tree. Apply
260 // path compression to speedup future queries.
261 Rep* node = it->second;
262 Rep* root = node->parent;
263 while (root != root->parent) {
264 root = root->parent;
265 }
266 while (node->parent != root) {
267 Rep* next = node->parent;
268 node->parent = root;
269 node = next;
270 }
271 return root;
272 }
273
274 // TODO(dyoon): Move many helper functions in this file (including those within
275 // SymbolicShapeRefiner class) to shared utils.
IsEnqueue(const NodeDef & n)276 bool IsEnqueue(const NodeDef& n) {
277 return (n.op().find("Enqueue") != string::npos &&
278 n.op().find("EnqueueMany") == string::npos);
279 }
280
IsDequeue(const NodeDef & n)281 bool IsDequeue(const NodeDef& n) {
282 return (n.op().find("Dequeue") != string::npos &&
283 n.op().find("DequeueMany") == string::npos);
284 }
285
HasAnyUnknownDimensions(const TensorShapeProto & proto)286 bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
287 if (proto.unknown_rank()) {
288 return true;
289 }
290 for (const auto& dim : proto.dim()) {
291 if (dim.size() < 0) {
292 return true;
293 }
294 }
295 return false;
296 }
297
298 // This really should be done in an external debugging tool
VerboseLogUnknownDimensionSources(const GraphDef & graph,const absl::flat_hash_map<string,std::vector<OpInfo::TensorProperties>> & input_properties_map,const absl::flat_hash_map<string,std::vector<OpInfo::TensorProperties>> & output_properties_map)299 void VerboseLogUnknownDimensionSources(
300 const GraphDef& graph,
301 const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
302 input_properties_map,
303 const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
304 output_properties_map) {
305 if (!VLOG_IS_ON(2)) {
306 return;
307 }
308
309 VLOG(2) << "Nodes with known inputs, but with unknown output dimensions:";
310
311 // Find all nodes in the graph for which we
312 // do not have any unknown dimensions in their inputs, but
313 // we have some unknown dimensions in their outputs.
314 std::map<string, int> op_to_count;
315 for (const NodeDef& node : graph.node()) {
316 const auto& input_properties = input_properties_map.at(node.name());
317 const auto& output_properties = output_properties_map.at(node.name());
318
319 bool has_unknown_inputs = false;
320 for (const auto& input_prop : input_properties) {
321 if (HasAnyUnknownDimensions(input_prop.shape())) {
322 has_unknown_inputs = true;
323 break;
324 }
325 }
326
327 if (has_unknown_inputs) {
328 continue;
329 }
330
331 for (const auto& output_prop : output_properties) {
332 if (HasAnyUnknownDimensions(output_prop.shape())) {
333 string inputs = "input_shapes=[";
334 for (const auto& input_prop : input_properties) {
335 inputs += PartialTensorShape::DebugString(input_prop.shape());
336 }
337 inputs += "]";
338
339 string outputs = "output_shapes=[";
340 for (const auto& output_prop : output_properties) {
341 outputs += PartialTensorShape::DebugString(output_prop.shape());
342 }
343 outputs += "]";
344
345 VLOG(2) << "Node: " << node.name() << ", Op: " << node.op() << ", "
346 << inputs << ", " << outputs;
347
348 op_to_count[node.op()]++;
349
350 // don't log again for this node
351 break;
352 }
353 }
354 }
355 VLOG(2) << "Op types with known inputs, but with unknown output dimensions "
356 << "(format: <op_type> (<count>)):";
357 for (const auto& p : op_to_count) {
358 VLOG(2) << p.first << " (" << p.second << ")";
359 }
360 }
361
362 // Helper function to convert kUnknownDimFromConst into UnknownDim.
ReplaceUnknownDimFromConstWithUnknownDim(InferenceContext * ic,const std::vector<ShapeHandle> & shapes)363 std::vector<ShapeHandle> ReplaceUnknownDimFromConstWithUnknownDim(
364 InferenceContext* ic, const std::vector<ShapeHandle>& shapes) {
365 std::vector<ShapeHandle> converted_shapes(shapes.size());
366 for (int i = 0, shapes_size = shapes.size(); i < shapes_size; i++) {
367 const auto& shape = shapes[i];
368 if (!ic->RankKnown(shape)) {
369 converted_shapes[i] = shape;
370 continue;
371 }
372 bool just_copy = true;
373 std::vector<DimensionHandle> dims;
374 for (int32 i = 0; i < ic->Rank(shape); ++i) {
375 DimensionHandle dim = ic->Dim(shape, i);
376 if (ic->ValueKnown(dim) && ic->Value(dim) == kUnknownDimFromConst) {
377 just_copy = false;
378 dims.push_back(ic->UnknownDim());
379 } else {
380 dims.push_back(dim);
381 }
382 }
383 if (just_copy) {
384 converted_shapes[i] = shape;
385 continue;
386 }
387 converted_shapes[i] = ic->MakeShape(dims);
388 }
389 return converted_shapes;
390 }
391
392 // Returned tensor's shape is like `shape`, and its values and dtype are from
393 // `tensor_as_shape` and `dtype`.
MakeTensorProtoFromShape(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)394 TensorProto MakeTensorProtoFromShape(InferenceContext* ic,
395 const ShapeHandle& shape,
396 const ShapeHandle& tensor_as_shape,
397 const DataType& dtype) {
398 TensorProto tensor_proto;
399 tensor_proto.set_dtype(dtype);
400 auto* shape_proto = tensor_proto.mutable_tensor_shape();
401 if (ic->Rank(shape) == 1) {
402 shape_proto->add_dim()->set_size(ic->Rank(tensor_as_shape));
403 }
404 // For a scalar tensor, tensor_shape field will be left empty; no dim.
405 for (int i = 0; i < ic->Rank(tensor_as_shape); i++) {
406 int64 value = ic->Value(ic->Dim(tensor_as_shape, i));
407 if (dtype == DT_INT32) {
408 tensor_proto.add_int_val(value);
409 } else {
410 tensor_proto.add_int64_val(value);
411 }
412 }
413 return tensor_proto;
414 }
415
416 // Returns a Const NodeDef with tensor `tensor_proto` and dtype = `dtype`.
MakeConstNodeDefFromTensorProto(InferenceContext * ic,const TensorProto & tensor_proto,const DataType & dtype)417 NodeDef MakeConstNodeDefFromTensorProto(InferenceContext* ic,
418 const TensorProto& tensor_proto,
419 const DataType& dtype) {
420 NodeDef const_node;
421 const_node.set_name("const_from_shape");
422 const_node.set_op("Const");
423 auto* attr = const_node.mutable_attr();
424 (*attr)["dtype"].set_type(dtype);
425 auto* tensor = (*attr)["value"].mutable_tensor();
426 *tensor = tensor_proto;
427 return const_node;
428 }
429
430 // Returns a Const NodeDef with shape = `shape`, values = `tensor_as_shape`,
431 // and dtype = `dtype`.
MakeConstNodeDefFromShape(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)432 NodeDef MakeConstNodeDefFromShape(InferenceContext* ic,
433 const ShapeHandle& shape,
434 const ShapeHandle& tensor_as_shape,
435 const DataType& dtype) {
436 return MakeConstNodeDefFromTensorProto(
437 ic, MakeTensorProtoFromShape(ic, shape, tensor_as_shape, dtype), dtype);
438 }
439
440 } // namespace
441
442 // Note that tensor_as_shape input should not include kUnknownDimFromConst.
443 // This function check kUnknownDimFromConst, but will log WARNING.
444 // If checking input_tensors_as_shape_to_propgate or output_tensors_as_shape,
445 // which may include kUnknownDimFromConst, run
446 // convert it using ReplaceUnknownDimFromConstWithUnknownDim() before.
IsShapeFullyDefinedIntegerVectorOrScalar(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)447 bool IsShapeFullyDefinedIntegerVectorOrScalar(
448 InferenceContext* ic, const ShapeHandle& shape,
449 const ShapeHandle& tensor_as_shape, const DataType& dtype) {
450 if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 ||
451 !ic->FullyDefined(tensor_as_shape) ||
452 (dtype != DT_INT32 && dtype != DT_INT64)) {
453 return false;
454 }
455 // Also check whether any dim in tensor_as_shape is kUnknownDimFromConst.
456 for (int32 i = 0; i < ic->Rank(tensor_as_shape); ++i) {
457 DimensionHandle dim = ic->Dim(tensor_as_shape, i);
458 if (ic->Value(dim) == kUnknownDimFromConst) {
459 LOG(WARNING) << "IsShapeFullyDefinedIntegerVectorOrScalar(): "
460 << "tensor_as_shape input includes kUnknownDimFromConst -- "
461 << ic->DebugString(tensor_as_shape);
462 return false;
463 }
464 }
465 return true;
466 }
467
468 // Queue of nodes to process. Nodes can be enqueued in any order, but will be
469 // dequeued in (roughly) topological order. Propagating shapes following a
470 // topological ordering isn't required for correctness but helps speed things up
471 // since it avoids processing the same node multiple times as its inputs
472 // information is refined.
473 class TopoQueue {
474 public:
TopoQueue(const std::vector<const NodeDef * > & topo_order)475 explicit TopoQueue(const std::vector<const NodeDef*>& topo_order)
476 : topo_order_(TopoOrder(topo_order)) {}
477
push(const NodeDef * n)478 void push(const NodeDef* n) { queue_.emplace(n, topo_order_.at(n)); }
479
pop()480 const NodeDef* pop() {
481 CHECK(!empty());
482 auto it = queue_.begin();
483 const NodeDef* n = it->first;
484 queue_.erase(it);
485 return n;
486 }
487
empty() const488 bool empty() const { return queue_.empty(); }
size() const489 std::size_t size() const { return queue_.size(); }
490
491 private:
492 using NodeAndId = std::pair<const NodeDef*, int>;
493 // Graph nodes are created in (roughly) topological order. Therefore we can
494 // use their id to ensure they're sorted topologically.
495 struct OrderByIdAscending {
operator ()tensorflow::grappler::TopoQueue::OrderByIdAscending496 bool operator()(const NodeAndId& lhs, const NodeAndId& rhs) const {
497 return lhs.second < rhs.second;
498 }
499 };
500
TopoOrder(const std::vector<const NodeDef * > & topo_order) const501 const absl::flat_hash_map<const NodeDef*, int> TopoOrder(
502 const std::vector<const NodeDef*>& topo_order) const {
503 absl::flat_hash_map<const NodeDef*, int> map;
504 map.reserve(topo_order.size());
505 for (int i = 0, topo_order_size = topo_order.size(); i < topo_order_size;
506 ++i) {
507 map.emplace(topo_order[i], i);
508 }
509 return map;
510 }
511
512 const absl::flat_hash_map<const NodeDef*, int> topo_order_;
513 std::set<NodeAndId, OrderByIdAscending> queue_;
514 };
515
IsNumericType(const DataType dtype)516 bool IsNumericType(const DataType dtype) {
517 static const gtl::FlatSet<DataType>* const kRealNumberTypes =
518 CHECK_NOTNULL((new gtl::FlatSet<DataType>{
519 // Floating point.
520 DT_BFLOAT16,
521 DT_HALF,
522 DT_FLOAT,
523 DT_DOUBLE,
524 // Int / UInt.
525 DT_INT8,
526 DT_INT16,
527 DT_INT32,
528 DT_INT64,
529 DT_UINT8,
530 DT_UINT16,
531 DT_UINT32,
532 DT_UINT64,
533 // Quantized Int.
534 DT_QINT8,
535 DT_QUINT8,
536 DT_QINT16,
537 DT_QUINT16,
538 DT_QINT32,
539 // Bool.
540 DT_BOOL,
541 }));
542 return kRealNumberTypes->find(dtype) != kRealNumberTypes->end();
543 }
544
IsAllowListedOpTypeForEvaluateNode(const string & op_type)545 bool IsAllowListedOpTypeForEvaluateNode(const string& op_type) {
546 static const gtl::FlatSet<string>* const kOpTpeAllowlist =
547 CHECK_NOTNULL((new gtl::FlatSet<string>{
548 // Unary arithmetic ops
549 "Floor",
550 "Round",
551 "Sqrt",
552 "Square",
553 "Sign",
554 // Binary arithmetic ops
555 "Add",
556 "AddV2",
557 "Div",
558 "FloorDiv",
559 "FloorMod",
560 "Greater",
561 "GreaterEqual",
562 "Less",
563 "LessEqual",
564 "LogicalAnd",
565 "LogicalNot",
566 "LogicalOr",
567 "Maximum",
568 "Minimum",
569 "Mod",
570 "Mul",
571 "NotEqual",
572 "QuantizedAdd",
573 "QuantizedMul",
574 "SquareDifference",
575 "Sub",
576 "TruncateDiv",
577 "TruncateMod",
578 "RealDiv",
579 // N-ary arithmetic ops
580 "AddN",
581 // Others
582 "StridedSlice",
583 "OnesLike",
584 "ZerosLike",
585 "Concat",
586 "ConcatV2",
587 "Split",
588 "Range",
589 "Fill",
590 "Cast",
591 }));
592 return kOpTpeAllowlist->find(op_type) != kOpTpeAllowlist->end();
593 }
594
595 // Negative shape size of '-1' represents unknown, while negative shape sizes
596 // less than -1 represent unknown symbolic shapes (e.g. the shape of [-5, 5, -1,
597 // -5] really means [x, 5, ?, x]). Before we can output the tensors as shapes,
598 // we need to normalize them: mark all values <-1 as "unknown" (-1).
NormalizeShapeForOutput(TensorShapeProto * shape)599 static void NormalizeShapeForOutput(TensorShapeProto* shape) {
600 for (int i = 0; i < shape->dim_size(); i++) {
601 if (shape->dim(i).size() < -1) {
602 VLOG(2) << "Normalizing dimension: " << i << " from "
603 << shape->dim(i).size() << " to -1";
604 shape->mutable_dim(i)->set_size(-1);
605 }
606 }
607 }
608
609 // Processes symbolic shapes.
610 // Each symbolic shape or dimension is represented by a handle. Unlike the TF
611 // shape refiner which creates new handles every time it processes an unknown
612 // shape/dimension, the symbolic shape refiner assigns a specific handle to each
613 // unknown shape/dimension of a given node.
614 class SymbolicShapeRefiner {
615 public:
SymbolicShapeRefiner(const GraphView & graph,const absl::flat_hash_map<string,absl::flat_hash_set<int>> & fed_ports,const bool aggressive_shape_inference)616 explicit SymbolicShapeRefiner(
617 const GraphView& graph,
618 const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports,
619 const bool aggressive_shape_inference)
620 : graph_(graph),
621 function_library_(OpRegistry::Global(), graph.graph()->library()),
622 fed_ports_(fed_ports),
623 aggressive_shape_inference_(aggressive_shape_inference) {
624 graph_def_version_ = graph.graph()->versions().producer();
625 node_to_context_.reserve(graph.graph()->node_size());
626 }
627
graph() const628 const GraphView& graph() const { return graph_; }
629
630 struct NodeContext {
631 const OpRegistrationData* op_data;
632 DataTypeVector input_types;
633 DataTypeVector output_types;
634 std::unique_ptr<InferenceContext> inference_context;
635 // Additional info for propagating tensor values and tensor shapes.
636 std::vector<const TensorProto*> input_tensor_protos;
637 std::vector<const TensorProto*> output_tensor_protos;
638 // This is the same to inference_context->input_tensors_as_shapes, except
639 // that some UnknownDims (-1) can be kUnknownDimFromConst.
640 std::vector<ShapeHandle> input_tensors_as_shapes_to_propagate;
641 std::vector<ShapeHandle> output_tensors_as_shapes;
642
643 // Output shapes incompatible between annotation and shape inference.
644 bool shape_incompatible = false;
645
646 // Similar to DebugString() in InferenceContext, but prints out
647 // kUnknownDimFromConst properly.
StringifyShapeHandletensorflow::grappler::SymbolicShapeRefiner::NodeContext648 std::string StringifyShapeHandle(ShapeHandle s) {
649 auto* ic = inference_context.get();
650 if (ic->RankKnown(s)) {
651 std::vector<std::string> vals;
652 for (int i = 0; i < ic->Rank(s); i++) {
653 DimensionHandle d = ic->Dim(s, i);
654 if (ic->ValueKnown(d) && ic->Value(d) == kUnknownDimFromConst) {
655 vals.push_back("?(Const)");
656 } else {
657 vals.push_back(ic->DebugString(d));
658 }
659 }
660 return strings::StrCat("[", absl::StrJoin(vals, ","), "]");
661 } else {
662 return "?";
663 }
664 }
665
DebugStringtensorflow::grappler::SymbolicShapeRefiner::NodeContext666 std::string DebugString(const NodeDef& node) {
667 std::string output;
668 auto* ic = inference_context.get();
669 absl::StrAppend(
670 &output, node.name(), " [", node.op(), "] has ", ic->num_inputs(),
671 (ic->num_inputs() > 1 ? " inputs and " : " input and "),
672 ic->num_outputs(), (ic->num_outputs() > 1 ? " outputs" : " output"));
673 if (op_data->is_function_op) {
674 absl::StrAppend(&output, " (function op)");
675 }
676 absl::StrAppend(&output, ": \n");
677
678 for (int i = 0; i < ic->num_inputs(); i++) {
679 absl::StrAppend(&output, " input [", i, "] ", node.input(i),
680 " -- type: ", DataTypeString(input_types.at(i)),
681 ", shape: ", ic->DebugString(ic->input(i)),
682 ", tensor: ");
683 Tensor t1;
684 int input_tensor_protos_size = input_tensor_protos.size();
685 if (input_tensor_protos_size > i &&
686 input_tensor_protos.at(i) != nullptr &&
687 t1.FromProto(*input_tensor_protos.at(i))) {
688 absl::StrAppend(&output, t1.DebugString(), ", tensor_as_shape: ");
689 } else {
690 absl::StrAppend(&output, " null, tensor_as_shape: ");
691 }
692 int input_tensors_as_shapes_to_propagate_size =
693 input_tensors_as_shapes_to_propagate.size();
694 if (input_tensors_as_shapes_to_propagate_size > i) {
695 absl::StrAppend(
696 &output,
697 StringifyShapeHandle(input_tensors_as_shapes_to_propagate.at(i)),
698 "\n");
699 } else {
700 absl::StrAppend(&output, " null\n");
701 }
702 }
703 for (int i = 0; i < ic->num_outputs(); i++) {
704 absl::StrAppend(&output, " output [", i,
705 "] -- type: ", DataTypeString(output_types.at(i)),
706 ", shape: ", ic->DebugString(ic->output(i)),
707 ", tensor: ");
708 Tensor t2;
709 int output_tensor_protos_size = output_tensor_protos.size();
710 if (output_tensor_protos_size > i &&
711 output_tensor_protos.at(i) != nullptr &&
712 t2.FromProto(*output_tensor_protos.at(i))) {
713 absl::StrAppend(&output, t2.DebugString(), ", tensor_as_shape: ");
714 } else {
715 absl::StrAppend(&output, " null, tensor_as_shape: ");
716 }
717 int output_tensors_as_shapes_size = output_tensors_as_shapes.size();
718 if (output_tensors_as_shapes_size > i) {
719 absl::StrAppend(&output,
720 StringifyShapeHandle(output_tensors_as_shapes.at(i)),
721 "\n");
722 } else {
723 absl::StrAppend(&output, " null\n");
724 }
725 }
726 return output;
727 }
728 };
729
GetNodeContext(const NodeDef * node)730 NodeContext* GetNodeContext(const NodeDef* node) {
731 auto it = node_to_context_.find(node);
732 if (it == node_to_context_.end()) {
733 return nullptr;
734 }
735 return &it->second;
736 }
737
GetContext(const NodeDef * node)738 InferenceContext* GetContext(const NodeDef* node) {
739 auto it = node_to_context_.find(node);
740 if (it == node_to_context_.end()) {
741 return nullptr;
742 }
743 return it->second.inference_context.get();
744 }
745
746 // Forward the shapes from the function input nodes, PartitionedCalls or
747 // StatefulPartitionedCall to
748 // the argument nodes (which are Placeholder nodes), then
749 // perform shape inference on the function body.
750 //
751 // Propagate shape information of final function body node
752 // to function node `function_node`.
753 //
754 // In the event of an error, UpdateNode will simply set `function_node`'s
755 // output shape to be Unknown.
UpdateFunction(const NodeDef * function_node)756 Status UpdateFunction(const NodeDef* function_node) {
757 NameAttrList function;
758 TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*function_node, &function));
759 auto it = fun_to_grappler_function_item_.find(function.name());
760 if (it == fun_to_grappler_function_item_.end()) {
761 return errors::InvalidArgument(
762 function.name(),
763 " was not previously added to SymbolicShapeRefiner.");
764 }
765
766 const absl::optional<GrapplerFunctionItem>& maybe_grappler_function_item =
767 it->second;
768 if (!maybe_grappler_function_item.has_value()) {
769 VLOG(3) << "Skip failed to instantiate function call: function_name="
770 << function.name();
771
772 auto* ctx = GetNodeContext(function_node);
773 auto* ic = ctx->inference_context.get();
774 for (int i = 0; i < ic->num_outputs(); ++i) {
775 TF_RETURN_IF_ERROR(SetUnknownShape(function_node, i));
776 }
777
778 return Status::OK();
779 }
780
781 // Copy (not reference) so that changes we make here (e.g., replacing
782 // _Arg with Const and _Retval with Identity) don't affect one in
783 // fun_to_grappler_function_item_.
784 GrapplerFunctionItem grappler_function_item = *maybe_grappler_function_item;
785 MutableGraphView gv(&grappler_function_item.graph);
786
787 // Forward shapes from function input nodes to argument nodes.
788 for (int i = 0, end = grappler_function_item.inputs().size(); i < end;
789 ++i) {
790 auto& fun_input = grappler_function_item.input(i);
791 NodeDef* fun_node = gv.GetNode(fun_input.node_name);
792 const TensorId input_tensor = ParseTensorName(function_node->input(i));
793
794 if (IsControlInput(input_tensor)) {
795 return errors::FailedPrecondition(
796 "Function inputs should not contain control nodes.");
797 }
798
799 const NodeDef* input_node = graph_.GetNode(input_tensor.node());
800 if (input_node == nullptr) {
801 return errors::FailedPrecondition(input_tensor.node(),
802 " was not found in the graph.");
803 }
804
805 InferenceContext* input_ic = GetContext(input_node);
806 if (input_ic == nullptr) {
807 return errors::FailedPrecondition(
808 "Inference context has not been created for ", input_tensor.node());
809 }
810
811 int output_port_num = input_tensor.index();
812 AttrValue attr_output_shape;
813 TensorShapeProto proto;
814 const auto handle = input_ic->output(output_port_num);
815 input_ic->ShapeHandleToProto(handle, &proto);
816 // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
817 NormalizeShapeForOutput(&proto);
818 // _Arg op's output shape uses _output_shapes attr.
819 AttrValue output_attr;
820 output_attr.mutable_list()->add_shape()->Swap(&proto);
821 (*fun_node->mutable_attr())["_output_shapes"] = output_attr;
822
823 // If dtype is DT_RESOURCE, ops that read _Arg op use _handle_dtypes and
824 // _handle_shapes attr for its shapes and dtypes.
825 if (fun_input.data_type == DT_RESOURCE) {
826 auto* shapes_and_types =
827 input_ic->output_handle_shapes_and_types(output_port_num);
828 if (shapes_and_types != nullptr && !shapes_and_types->empty()) {
829 AttrValue dtype_attr;
830 AttrValue shape_attr;
831 for (const auto& shape_and_type : *shapes_and_types) {
832 const auto& dtype = shape_and_type.dtype;
833 const auto& shape_handle = shape_and_type.shape;
834 dtype_attr.mutable_list()->add_type(dtype);
835 input_ic->ShapeHandleToProto(
836 shape_handle, shape_attr.mutable_list()->add_shape());
837 }
838 (*fun_node->mutable_attr())["_handle_dtypes"] = dtype_attr;
839 (*fun_node->mutable_attr())["_handle_shapes"] = shape_attr;
840 } else {
841 // Note that we do not return error here, even if the input node does
842 // not have shapes_and_types. Within the function, we cannot infer the
843 // output shape of the DT_RESOURCE input; hence, potentially unknown
844 // shapes/dims in the function output shapes.
845 VLOG(2)
846 << "A function node (" << function_node->name()
847 << ") has input with DT_RESOURCE, but the input node does not "
848 << "have shapes_and_types information: \n"
849 << "function_node: " << function_node->ShortDebugString() << "\n"
850 << "function input: " << i
851 << ", input node's output: " << output_port_num << "\n"
852 << "input node: " << input_node->ShortDebugString();
853 }
854 }
855 }
856
857 // ReplaceInputWithConst() may break GraphView's internal node mapping
858 // structure; hence, we separately build node name to NodeDef* map, for the
859 // output nodes (before GraphView becomes invalid). Note that we use string,
860 // not string_view.
861 absl::flat_hash_map<std::string, NodeDef*> output_nodes;
862 for (const auto& output_arg : grappler_function_item.outputs()) {
863 output_nodes[output_arg.node_name] = gv.GetNode(output_arg.node_name);
864 }
865
866 // Replace input nodes with Consts, if values are known. Note that
867 // we don't check exceptions here as it's done in the above loop.
868 auto* ctx = GetNodeContext(function_node);
869 auto* ic = ctx->inference_context.get();
870 for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
871 const string& input = function_node->input(i);
872 const string node_name = NodeName(input);
873 const NodeDef* input_node = graph_.GetNode(node_name);
874 if (IsConstant(*input_node)) {
875 TF_CHECK_OK(
876 ReplaceInputWithConst(*input_node, i, &grappler_function_item));
877 } else if (static_cast<int>(ctx->input_tensor_protos.size()) > i &&
878 ctx->input_tensor_protos[i] != nullptr) {
879 NodeDef const_input_node = MakeConstNodeDefFromTensorProto(
880 ic, *ctx->input_tensor_protos[i], ctx->input_types[i]);
881 TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
882 &grappler_function_item));
883 } else if (static_cast<int>(ic->input_tensors_as_shapes().size()) > i &&
884 IsShapeFullyDefinedIntegerVectorOrScalar(
885 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
886 ctx->input_types[i])) {
887 // We have fully defined input_tensors_as_shapes for this input; use it
888 // as a const input to the function node.
889 NodeDef const_input_node = MakeConstNodeDefFromShape(
890 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
891 ctx->input_types[i]);
892 TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
893 &grappler_function_item));
894 }
895 }
896 // node_name to NodeDef* map in GraphView gv can be broken due to
897 // ReplaceInputWithConst(). gv should not be used after this.
898
899 // Replace output _Retval nodes with Identity nodes. _Retval is a system op
900 // without outputs and registered shape function.
901 for (const auto& output_arg : grappler_function_item.outputs()) {
902 NodeDef* output_node = output_nodes[output_arg.node_name];
903 DCHECK_EQ(output_node->op(), "_Retval");
904 output_node->set_op("Identity");
905 output_node->mutable_attr()->erase("index");
906 }
907
908 // Perform inference on function body.
909 GraphProperties gp(grappler_function_item);
910 TF_RETURN_IF_ERROR(gp.InferStatically(
911 /*assume_valid_feeds=*/true,
912 /*aggressive_shape_inference=*/aggressive_shape_inference_,
913 /*include_tensor_values=*/true));
914
915 // Add return nodes for output shapes.
916 int output = 0;
917 ctx->output_tensors_as_shapes.resize(grappler_function_item.output_size());
918 ctx->output_tensor_protos.resize(grappler_function_item.output_size(),
919 nullptr);
920 for (auto const& out_arg : grappler_function_item.outputs()) {
921 // It is guaranteed that output_tensors does not contain any control
922 // inputs, so port_id >= 0.
923 TensorId out_tensor = ParseTensorName(out_arg.node_name);
924
925 if (output_nodes.count(out_tensor.node()) <= 0) {
926 return errors::FailedPrecondition(
927 "Unable to find return function_node ", out_tensor.node(), " for ",
928 function_node->name());
929 }
930 const NodeDef* retnode = output_nodes[out_tensor.node()];
931
932 auto output_properties = gp.GetOutputProperties(retnode->name());
933 int output_properties_size = output_properties.size();
934 if (out_tensor.index() >= output_properties_size) {
935 return errors::InvalidArgument(
936 out_tensor.ToString(), " has invalid position ", out_tensor.index(),
937 " (output_properties.size() = ", output_properties.size(), ").");
938 }
939 auto& outprop = output_properties[out_tensor.index()];
940 TensorShapeProto shape = outprop.shape();
941 NormalizeShapeForOutput(&shape);
942 ShapeHandle out;
943 TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
944 ic->set_output(output, out);
945 if (outprop.has_value()) {
946 // Forward tensor value to output_tensors_as_shape.
947 MaybeTensorProtoToShape(ic, outprop.value(),
948 &ctx->output_tensors_as_shapes[output]);
949 const_tensors_to_propagate_.push_back(outprop.value());
950 ctx->output_tensor_protos[output] = &const_tensors_to_propagate_.back();
951 }
952 output++;
953 }
954
955 return Status::OK();
956 }
957
958 // Prepares input shapes/values/handles, then runs shape inference, and
959 // finally sets output shapes/values/handles.
UpdateNode(const NodeDef * node,bool * refined)960 Status UpdateNode(const NodeDef* node, bool* refined) {
961 NodeContext* ctx = GetNodeContext(node);
962 if (ctx == nullptr) {
963 TF_RETURN_IF_ERROR(AddNode(node));
964 ctx = CHECK_NOTNULL(GetNodeContext(node));
965 *refined = true;
966 }
967
968 // Check if the shapes of the nodes in the fan-in of this node have changed,
969 // and if they have, update the node input shapes.
970 InferenceContext* ic = ctx->inference_context.get();
971 ctx->input_tensors_as_shapes_to_propagate.resize(ic->num_inputs());
972 ctx->input_tensor_protos.resize(ic->num_inputs(), nullptr);
973
974 for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) {
975 const GraphView::InputPort port(node, dst_input);
976 const GraphView::OutputPort fanin = graph_.GetRegularFanin(port);
977 int src_output = fanin.port_id;
978 const NodeDef* src = fanin.node;
979 NodeContext* src_ctx = GetNodeContext(src);
980 if (src_ctx == nullptr) {
981 return errors::FailedPrecondition(
982 "Input ", dst_input, " for '", node->name(),
983 "' was not previously added to SymbolicShapeRefiner.");
984 }
985
986 InferenceContext* src_ic = src_ctx->inference_context.get();
987 if (src_output >= src_ic->num_outputs()) {
988 return errors::OutOfRange("src_output = ", src_output,
989 ", but num_outputs is only ",
990 src_ic->num_outputs());
991 }
992
993 // Propagate input node's NodeContext info to the current node's
994 // NodeContext:
995 // output_tensor_protos to input_tensor_protos and input_tensors, and
996 // output_tensors_as_shapes to input_tensors_as_shapes.
997 if (static_cast<int>(src_ctx->output_tensors_as_shapes.size()) >
998 src_output) {
999 ctx->input_tensors_as_shapes_to_propagate[dst_input] =
1000 src_ctx->output_tensors_as_shapes[src_output];
1001 }
1002
1003 if (static_cast<int>(src_ctx->output_tensor_protos.size()) > src_output) {
1004 const auto* tensor_proto = src_ctx->output_tensor_protos[src_output];
1005 if (tensor_proto != nullptr) {
1006 ctx->input_tensor_protos[dst_input] = tensor_proto;
1007 }
1008 }
1009
1010 // NOTE: we check only shape is refined; we do not (yet) check whether
1011 // tensor value is refined.
1012 if (!*refined &&
1013 !ic->input(dst_input).SameHandle(src_ic->output(src_output))) {
1014 *refined = true;
1015 }
1016 ic->SetInput(dst_input, src_ic->output(src_output));
1017
1018 if (!*refined && ic->requested_input_tensor_as_partial_shape(dst_input)) {
1019 // The input value may have changed. Since we have no way to know if
1020 // that's indeed the case, err on the safe side.
1021 *refined = true;
1022 }
1023
1024 // Also propagate handle shape and dtype of edges which are carrying
1025 // resource handles.
1026 if (ctx->input_types[dst_input] == DT_RESOURCE) {
1027 auto* outputs = src_ic->output_handle_shapes_and_types(src_output);
1028 if (!outputs) continue;
1029 auto* inputs = ic->input_handle_shapes_and_types(dst_input);
1030
1031 if (!inputs || !EquivalentShapesAndTypes(*outputs, *inputs))
1032 *refined = true;
1033 ic->set_input_handle_shapes_and_types(dst_input, *outputs);
1034 }
1035 }
1036
1037 // Make sure we schedule the fanout of resources (which have no input)
1038 // whenever the resources are updated.
1039 *refined |= ic->num_inputs() == 0;
1040
1041 if (!*refined) {
1042 // No input shape has changed, we're done.
1043 return Status::OK();
1044 }
1045
1046 // Convert all kUnknownDimFromConst to -1 for shape inference.
1047 ic->set_input_tensors_as_shapes(ReplaceUnknownDimFromConstWithUnknownDim(
1048 ic, ctx->input_tensors_as_shapes_to_propagate));
1049 // Note: UpdateFunction uses input_tensors_as_shapes and
1050 // input_tensor_protos (not the Tensor object) for input values.
1051 // so for function nodes, we don't need to convert TensorProtos
1052 // to Tensors here. If the current op is not a function op, we convert
1053 // TensorProtos to Tensors before calling InferShapes.
1054
1055 // Properly handle function nodes.
1056 if (ctx->op_data && ctx->op_data->is_function_op) {
1057 // TODO(jmdecker): Detect if the input shapes have changed for this
1058 // function. Note that when we hit a function call node, refined will be
1059 // true, as the updates to the call node will have changed, even if it's
1060 // the same function being called twice with the same input shapes.
1061 // Example: simple_function.pbtxt
1062 if (aggressive_shape_inference_) {
1063 // If output shapes are annotated, use it and skip UpdateFunction();
1064 // it can be very expensive when a function node has nested function
1065 // nodes internally. One downside with this approach is that we do not
1066 // get output values or output shapes as tensor from function node.
1067 auto s = UpdateOutputShapesUsingAnnotatedInformation(*node, ctx);
1068 if (s.ok() && AllOutputShapesKnown(ctx)) {
1069 return Status::OK();
1070 }
1071 // If shape annotation was not available, incomplete, or incompatible,
1072 // fall through to call UpdateFunction().
1073 }
1074 auto s = UpdateFunction(node);
1075 if (s.ok()) {
1076 return Status::OK();
1077 } else {
1078 VLOG(1) << "UpdateFunction failed for " << node->op()
1079 << ". Defaulting to ShapeUnknown.\n"
1080 << s.ToString();
1081 }
1082 }
1083
1084 // Construct Tensors for constant inputs used by shape functions.
1085 std::vector<Tensor> const_values(ic->num_inputs());
1086 std::vector<const Tensor*> input_tensors(ic->num_inputs(), nullptr);
1087 for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) {
1088 const TensorProto* tensor_proto = ctx->input_tensor_protos[dst_input];
1089 if (tensor_proto != nullptr &&
1090 const_values[dst_input].FromProto(*tensor_proto)) {
1091 input_tensors[dst_input] = &const_values[dst_input];
1092 }
1093 }
1094 ic->set_input_tensors(input_tensors);
1095
1096 // Update the shapes of the outputs.
1097 return InferShapes(*node, ctx);
1098 }
1099
SetUnknownShape(const NodeDef * node,int output_port)1100 Status SetUnknownShape(const NodeDef* node, int output_port) {
1101 shape_inference::ShapeHandle shape =
1102 GetUnknownOutputShape(node, output_port);
1103 InferenceContext* ctx = GetContext(node);
1104 if (ctx == nullptr) {
1105 return errors::InvalidArgument("Missing context");
1106 }
1107 ctx->set_output(output_port, shape);
1108 return Status::OK();
1109 }
1110
1111 struct ShapeId {
1112 const NodeDef* node;
1113 int port_id;
operator ==tensorflow::grappler::SymbolicShapeRefiner::ShapeId1114 bool operator==(const ShapeId& other) const {
1115 return node == other.node && port_id == other.port_id;
1116 }
1117 };
1118 struct HashShapeId {
operator ()tensorflow::grappler::SymbolicShapeRefiner::HashShapeId1119 std::size_t operator()(const ShapeId& shp) const {
1120 return std::hash<const NodeDef*>{}(shp.node) + shp.port_id;
1121 }
1122 };
1123
1124 struct DimId {
1125 const NodeDef* node;
1126 int port_id;
1127 int dim_index;
operator ==tensorflow::grappler::SymbolicShapeRefiner::DimId1128 bool operator==(const DimId& other) const {
1129 return node == other.node && port_id == other.port_id &&
1130 dim_index == other.dim_index;
1131 }
1132 };
1133
1134 struct HashDimId {
operator ()tensorflow::grappler::SymbolicShapeRefiner::HashDimId1135 std::size_t operator()(const DimId& dim) const {
1136 return std::hash<const NodeDef*>{}(dim.node) + dim.port_id +
1137 dim.dim_index;
1138 }
1139 };
1140
1141 // 'port_index' as the union of shape1 and shape2.
OutputAsUnion(const NodeDef * node,int port_index,ShapeHandle shape1,ShapeHandle shape2)1142 ShapeHandle OutputAsUnion(const NodeDef* node, int port_index,
1143 ShapeHandle shape1, ShapeHandle shape2) {
1144 if (shape1.SameHandle(shape2)) {
1145 return shape1;
1146 }
1147 InferenceContext* ctx = GetContext(node);
1148 ShapeHandle relaxed = shape1;
1149 const int rank = ctx->Rank(shape1);
1150 if (!ctx->RankKnown(shape2) || ctx->Rank(shape2) != rank) {
1151 relaxed = GetUnknownOutputShape(node, port_index);
1152 } else {
1153 for (int d = 0; d < rank; ++d) {
1154 if (!ctx->Dim(shape1, d).SameHandle(ctx->Dim(shape2, d))) {
1155 int64 val1 = ctx->Value(ctx->Dim(shape1, d));
1156 int64 val2 = ctx->Value(ctx->Dim(shape2, d));
1157 if (val1 != val2 || (val1 < 0 && val2 < 0)) {
1158 DimensionHandle new_dim = GetUnknownOutputDim(node, port_index, d);
1159 TF_CHECK_OK(ctx->ReplaceDim(relaxed, d, new_dim, &relaxed));
1160 }
1161 }
1162 }
1163 }
1164 return relaxed;
1165 }
1166
EquivalentShapes(ShapeHandle s1,ShapeHandle s2) const1167 bool EquivalentShapes(ShapeHandle s1, ShapeHandle s2) const {
1168 if (s1.SameHandle(s2)) {
1169 return true;
1170 }
1171 if (InferenceContext::Rank(s1) != InferenceContext::Rank(s2)) {
1172 return false;
1173 }
1174 if (!InferenceContext::RankKnown(s1) && !InferenceContext::RankKnown(s2)) {
1175 return true;
1176 }
1177 const int rank = InferenceContext::Rank(s1);
1178 for (int i = 0; i < rank; ++i) {
1179 if (!InferenceContext::DimKnownRank(s1, i).SameHandle(
1180 InferenceContext::DimKnownRank(s2, i))) {
1181 int64 val1 =
1182 InferenceContext::Value(InferenceContext::DimKnownRank(s1, i));
1183 int64 val2 =
1184 InferenceContext::Value(InferenceContext::DimKnownRank(s2, i));
1185 if (val1 >= 0 && val2 >= 0 && val1 == val2) {
1186 continue;
1187 }
1188 return false;
1189 }
1190 }
1191 return true;
1192 }
1193
1194 // Return true if the annotated shape is compatible with shape inference
1195 // result. Examples:
1196 // Inferred shape: ?, annotated shape: [10, 10] -> true;
1197 // Inferred shape: [-1, 10], annotated shape: [10, 10] -> true;
1198 // Inferred shape: [-1, 100], annotated shape: [10, 10] -> false;
1199 // Inferred shape: [-1, 10, 10], annotated shape: [10, 10] -> false.
CompatibleShapes(ShapeHandle inferred_shape,ShapeHandle annotated_shape) const1200 bool CompatibleShapes(ShapeHandle inferred_shape,
1201 ShapeHandle annotated_shape) const {
1202 if (inferred_shape.SameHandle(annotated_shape)) {
1203 return true;
1204 }
1205 if (!InferenceContext::RankKnown(inferred_shape)) {
1206 return true;
1207 }
1208 if (InferenceContext::Rank(inferred_shape) !=
1209 InferenceContext::Rank(annotated_shape)) {
1210 return false;
1211 }
1212 const int rank = InferenceContext::Rank(inferred_shape);
1213 for (int i = 0; i < rank; ++i) {
1214 if (!InferenceContext::DimKnownRank(inferred_shape, i)
1215 .SameHandle(
1216 InferenceContext::DimKnownRank(annotated_shape, i))) {
1217 int64 val1 = InferenceContext::Value(
1218 InferenceContext::DimKnownRank(inferred_shape, i));
1219 int64 val2 = InferenceContext::Value(
1220 InferenceContext::DimKnownRank(annotated_shape, i));
1221 if (val1 >= 0 && val1 != val2) {
1222 return false;
1223 }
1224 }
1225 }
1226 return true;
1227 }
1228
SameShapes(ShapeHandle inferred_shape,ShapeHandle annotated_shape) const1229 bool SameShapes(ShapeHandle inferred_shape,
1230 ShapeHandle annotated_shape) const {
1231 if (inferred_shape.SameHandle(annotated_shape)) {
1232 return true;
1233 }
1234 if (InferenceContext::Rank(inferred_shape) !=
1235 InferenceContext::Rank(annotated_shape)) {
1236 return false;
1237 }
1238 const int rank = InferenceContext::Rank(inferred_shape);
1239 for (int i = 0; i < rank; ++i) {
1240 int64 val1 = InferenceContext::Value(
1241 InferenceContext::DimKnownRank(inferred_shape, i));
1242 int64 val2 = InferenceContext::Value(
1243 InferenceContext::DimKnownRank(annotated_shape, i));
1244 if (val1 != val2) {
1245 return false;
1246 }
1247 }
1248 return true;
1249 }
1250
EquivalentShapesAndTypes(const std::vector<ShapeAndType> & st1,const std::vector<ShapeAndType> & st2) const1251 bool EquivalentShapesAndTypes(const std::vector<ShapeAndType>& st1,
1252 const std::vector<ShapeAndType>& st2) const {
1253 if (st1.size() != st2.size()) {
1254 return false;
1255 }
1256 for (int i = 0, st1_size = st1.size(); i < st1_size; ++i) {
1257 const ShapeAndType& s1 = st1[i];
1258 const ShapeAndType& s2 = st2[i];
1259 if (s1.dtype != s2.dtype) {
1260 return false;
1261 }
1262 if (!EquivalentShapes(s1.shape, s2.shape)) {
1263 return false;
1264 }
1265 }
1266 return true;
1267 }
1268
AddFunction(const NodeDef * function_node,NameAttrList function)1269 Status AddFunction(const NodeDef* function_node, NameAttrList function) {
1270 auto it = fun_to_grappler_function_item_.find(function.name());
1271 if (it != fun_to_grappler_function_item_.end()) {
1272 return Status::OK();
1273 }
1274
1275 const FunctionDef* function_def =
1276 CHECK_NOTNULL(function_library_.Find(function.name()));
1277 GrapplerFunctionItem grappler_function_item;
1278 Status function_instantiated =
1279 MakeGrapplerFunctionItem(*function_def, function_library_,
1280 graph_def_version_, &grappler_function_item);
1281
1282 // If function instantiation failed we will skip it during shape inference.
1283 if (!function_instantiated.ok()) {
1284 VLOG(3) << "Failed to instantiate a function. Error: "
1285 << function_instantiated.error_message();
1286 fun_to_grappler_function_item_[function_def->signature().name()] =
1287 absl::nullopt;
1288 return Status::OK();
1289 }
1290
1291 if (static_cast<int>(grappler_function_item.inputs().size()) >
1292 function_node->input_size()) {
1293 return errors::FailedPrecondition(
1294 "Function input size should be smaller than node input size.");
1295 }
1296
1297 for (int i = grappler_function_item.inputs().size(),
1298 end = function_node->input_size();
1299 i < end; ++i) {
1300 const string& input = function_node->input(i);
1301 if (!IsControlInput(input)) {
1302 return errors::FailedPrecondition(
1303 "Found regular input (", input,
1304 ") instead of control nodes for node ", function_node->name());
1305 }
1306 }
1307
1308 fun_to_grappler_function_item_[function_def->signature().name()] =
1309 grappler_function_item;
1310
1311 return Status::OK();
1312 }
1313
AddNode(const NodeDef * node)1314 Status AddNode(const NodeDef* node) {
1315 NodeContext& node_ctx = node_to_context_[node];
1316 NameAttrList function;
1317 TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*node, &function));
1318
1319 // For PartitionedCall, op_data represents the function info.
1320 TF_RETURN_IF_ERROR(
1321 function_library_.LookUp(function.name(), &node_ctx.op_data));
1322
1323 if (node_ctx.op_data->is_function_op) {
1324 TF_RETURN_IF_ERROR(AddFunction(node, function));
1325 }
1326
1327 TF_RETURN_IF_ERROR(InOutTypesForNode(*node, node_ctx.op_data->op_def,
1328 &node_ctx.input_types,
1329 &node_ctx.output_types));
1330
1331 // Create the inference context for this node.
1332 const int num_inputs = node_ctx.input_types.size();
1333 std::vector<ShapeHandle> input_shapes(num_inputs);
1334 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
1335 input_handle_shapes_and_types(num_inputs);
1336 std::vector<const Tensor*> input_tensors(num_inputs, nullptr);
1337 std::vector<ShapeHandle> input_tensors_as_shapes;
1338
1339 node_ctx.inference_context.reset(new InferenceContext(
1340 graph_def_version_, *node, node_ctx.op_data->op_def, input_shapes,
1341 input_tensors, input_tensors_as_shapes,
1342 std::move(input_handle_shapes_and_types)));
1343 const Status s = node_ctx.inference_context->construction_status();
1344 if (!s.ok()) {
1345 node_ctx.inference_context.reset(nullptr);
1346 }
1347 return s;
1348 }
1349
1350 private:
1351 // Return the one ShapeHandle used to denote a fully unknown shape for a node
1352 // output.
GetUnknownOutputShape(const NodeDef * node,int index)1353 ShapeHandle GetUnknownOutputShape(const NodeDef* node, int index) {
1354 ShapeId id{node, index};
1355 auto it = unknown_shapes_.find(id);
1356 if (it != unknown_shapes_.end()) {
1357 return it->second;
1358 }
1359 InferenceContext* c = GetContext(node);
1360 ShapeHandle shp = c->UnknownShape();
1361 unknown_shapes_[id] = shp;
1362 return shp;
1363 }
1364 // Return the one ShapeHandle used to denote a fully unknown dimension for a
1365 // node output.
GetUnknownOutputDim(const NodeDef * node,int index,int dim_id)1366 DimensionHandle GetUnknownOutputDim(const NodeDef* node, int index,
1367 int dim_id) {
1368 DimId id{node, index, dim_id};
1369 auto it = unknown_dims_.find(id);
1370 if (it != unknown_dims_.end()) {
1371 return it->second;
1372 }
1373 InferenceContext* c = GetContext(node);
1374 DimensionHandle dim = c->UnknownDim();
1375 unknown_dims_[id] = dim;
1376 return dim;
1377 }
1378
1379 // Returns true if all the output tensors have known values.
AllOutputValuesKnown(NodeContext * c)1380 bool AllOutputValuesKnown(NodeContext* c) {
1381 InferenceContext* ic = c->inference_context.get();
1382 int c_output_tensors_as_shapes_size = c->output_tensors_as_shapes.size();
1383 int c_output_tensor_protos_size = c->output_tensor_protos.size();
1384 if (c_output_tensors_as_shapes_size < ic->num_outputs() &&
1385 c_output_tensor_protos_size < ic->num_outputs()) {
1386 return false;
1387 } else {
1388 // Checks if we can get output value via either output_tensor_proto or
1389 // output_tensors_as_shapes.
1390 for (int i = 0; i < ic->num_outputs(); i++) {
1391 if (c_output_tensor_protos_size > i &&
1392 c->output_tensor_protos[i] != nullptr) {
1393 continue;
1394 }
1395 if (c_output_tensors_as_shapes_size > i &&
1396 ic->FullyDefined(c->output_tensors_as_shapes[i])) {
1397 bool no_unknown_dim_from_const = true;
1398 for (int32 j = 0; j < ic->Rank(c->output_tensors_as_shapes[i]); ++j) {
1399 const auto dim = ic->Dim(c->output_tensors_as_shapes[i], j);
1400 if (ic->ValueKnown(dim) && ic->Value(dim) == kUnknownDimFromConst) {
1401 no_unknown_dim_from_const = false;
1402 break;
1403 }
1404 }
1405 if (no_unknown_dim_from_const) {
1406 continue;
1407 }
1408 }
1409 return false;
1410 }
1411 }
1412 return true;
1413 }
1414
1415 // Returns true if all the output shapes are known.
AllOutputShapesKnown(NodeContext * c)1416 bool AllOutputShapesKnown(NodeContext* c) {
1417 InferenceContext* ic = c->inference_context.get();
1418 // Checks if all the output shapes are fully defined.
1419 for (int i = 0; i < ic->num_outputs(); i++) {
1420 if (!ic->FullyDefined(ic->output(i))) {
1421 return false;
1422 }
1423 }
1424 return true;
1425 }
1426
1427 // Returns true if we can infer output tensors' values -- we know values of
1428 // all the input tensors.
AllInputValuesKnown(NodeContext * c)1429 bool AllInputValuesKnown(NodeContext* c) {
1430 InferenceContext* ic = c->inference_context.get();
1431
1432 // Check inputs are fully defined and values are known.
1433 for (int i = 0; i < ic->num_inputs(); i++) {
1434 const Tensor* tensor = ic->input_tensor(i);
1435 // Note that we don't check c->input_tensor_protos[i], as UpdateNode()
1436 // already converted it to ic->input_tensor(i);
1437 const ShapeHandle& input_tensors_as_shape =
1438 ic->input_tensors_as_shapes()[i];
1439 // Either input_tensor is valid or input_tensors_as_shape, which has
1440 // value of input tensors as shape format, should be fully defined.
1441 if (tensor == nullptr && !ic->FullyDefined(input_tensors_as_shape)) {
1442 return false;
1443 }
1444 }
1445 return true;
1446 }
1447
1448 // Returns true if we want to update output shapes and values with running
1449 // EvaluateNode() for this op, based on op type, data type, and size.
ShouldUpdateOutputShapesAndValues(NodeContext * c,int64 max_size)1450 bool ShouldUpdateOutputShapesAndValues(NodeContext* c, int64 max_size) {
1451 InferenceContext* ic = c->inference_context.get();
1452
1453 // Due to the cost of running EvaluateNode(), we limit only to white listed
1454 // op types.
1455 if (!IsAllowListedOpTypeForEvaluateNode(c->op_data->op_def.name())) {
1456 return false;
1457 }
1458
1459 // Check input dtypes are number types.
1460 for (const auto& input_type : c->input_types) {
1461 if (!IsNumericType(input_type)) {
1462 return false;
1463 }
1464 }
1465
1466 // Check output dtypes are number types.
1467 for (const auto& output_type : c->output_types) {
1468 if (!IsNumericType(output_type)) {
1469 return false;
1470 }
1471 }
1472
1473 // Check if the number of elements of each of input tensor is no larger than
1474 // the given max size.
1475 for (int i = 0; i < ic->num_inputs(); i++) {
1476 const Tensor* tensor = ic->input_tensor(i);
1477 const ShapeHandle& input_shape_handle = ic->input(i);
1478 if (tensor != nullptr) {
1479 if (tensor->NumElements() > max_size) {
1480 return false;
1481 }
1482 } else if (ic->Value(ic->NumElements(input_shape_handle)) > max_size) {
1483 return false;
1484 }
1485 }
1486
1487 // Check if we know the shape of each output tensor, and the number of
1488 // elements is larger than the given max size.
1489 for (int i = 0; i < ic->num_outputs(); i++) {
1490 const ShapeHandle& shape_handle = ic->output(i);
1491 if (!ic->FullyDefined(shape_handle) ||
1492 ic->Value(ic->NumElements(shape_handle)) > max_size) {
1493 return false;
1494 }
1495 }
1496 return true;
1497 }
1498
1499 // Create input tensors from the NodeContext.
CreateInputTensors(NodeContext * c,std::vector<Tensor> * input_tensor_vector,TensorVector * inputs)1500 void CreateInputTensors(NodeContext* c,
1501 std::vector<Tensor>* input_tensor_vector,
1502 TensorVector* inputs) {
1503 InferenceContext* ic = c->inference_context.get();
1504 for (int i = 0; i < ic->num_inputs(); i++) {
1505 if (ic->input_tensor(i)) {
1506 input_tensor_vector->at(i) = *ic->input_tensor(i);
1507 inputs->emplace_back(&input_tensor_vector->at(i));
1508 // Note that we don't check c->input_tensor_protos[i], as UpdateNode()
1509 // already converted it to ic->input_tensor(i);
1510 } else {
1511 // Create Tensor from input_tensors_as_shapes, and then emplace it
1512 // back to inputs.
1513 // Note that input_tensors_as_shapes is scalar or vector.
1514 const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
1515 const DataType& data_type = c->input_types[i];
1516 int32 rank = ic->Rank(shape_handle);
1517 if (rank < 1) {
1518 input_tensor_vector->at(i) = Tensor(data_type, {});
1519 } else {
1520 input_tensor_vector->at(i) = Tensor(data_type, {rank});
1521 }
1522 auto* tensor = &input_tensor_vector->at(i);
1523 if (data_type == DT_INT32) {
1524 auto flat = tensor->flat<int32>();
1525 for (int j = 0; j < rank; j++) {
1526 int32 dim = ic->Value(ic->Dim(shape_handle, j));
1527 flat(j) = dim;
1528 }
1529 } else {
1530 auto flat = tensor->flat<int64>();
1531 for (int j = 0; j < rank; j++) {
1532 int64 dim = ic->Value(ic->Dim(shape_handle, j));
1533 flat(j) = dim;
1534 }
1535 }
1536 inputs->emplace_back(tensor);
1537 }
1538 }
1539 }
1540
1541 // Run a node to infer output shapes and values, and add it to the
1542 // NodeContext.
UpdateOutputShapesAndValues(const NodeDef & node,NodeContext * c)1543 Status UpdateOutputShapesAndValues(const NodeDef& node, NodeContext* c) {
1544 InferenceContext* ic = c->inference_context.get();
1545
1546 // Input to EvaluateNode()
1547 TensorVector inputs;
1548 // Container for temporarily created tensor object.
1549 std::vector<Tensor> input_tensor_vector(ic->num_inputs());
1550 CreateInputTensors(c, &input_tensor_vector, &inputs);
1551
1552 // Output for EvaluateNode() and output tensor clean up object.
1553 TensorVector outputs;
1554 auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
1555 for (const auto& output : outputs) {
1556 if (output.tensor) {
1557 delete output.tensor;
1558 }
1559 }
1560 });
1561
1562 TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, /*cpu_device=*/nullptr,
1563 &resource_mgr_, &outputs));
1564 c->output_tensors_as_shapes.resize(outputs.size());
1565 c->output_tensor_protos.resize(outputs.size(), nullptr);
1566 for (int k = 0, outputs_size = outputs.size(); k < outputs_size; k++) {
1567 const auto& t = outputs[k];
1568 // Override output shape.
1569 ShapeHandle output_shape;
1570 TF_RETURN_IF_ERROR(
1571 ic->MakeShapeFromTensorShape(t->shape(), &output_shape));
1572 if (ic->FullyDefined(ic->output(k)) &&
1573 !EquivalentShapes(ic->output(k), output_shape)) {
1574 LOG(WARNING) << "UpdateOutputShapesAndValues() -- node: " << node.name()
1575 << ", inferred output shape "
1576 << "doesn't match for k=" << k << ": "
1577 << "ic->output(k): " << ic->DebugString(ic->output(k))
1578 << ", output_shape: " << ic->DebugString(output_shape)
1579 << " -- " << node.DebugString();
1580 }
1581 ic->set_output(k, output_shape);
1582 // Set output_tensors_as_shape.
1583 MaybeTensorValueToShape(ic, *t.tensor, &c->output_tensors_as_shapes[k]);
1584
1585 // Set output_tensor_protos.
1586 TensorProto tensor_proto;
1587 t->AsProtoTensorContent(&tensor_proto);
1588 const_tensors_to_propagate_.push_back(tensor_proto);
1589 c->output_tensor_protos[k] = &const_tensors_to_propagate_.back();
1590 }
1591 return Status::OK();
1592 }
1593
1594 // Update output shapes with annotated information.
1595 // Currently only handle nodes with static shapes, i.e. shapes do not change
1596 // during execution.
1597 // TODO(andiryxu): Use annotated shapes in Enter/Merge etc as well.
UpdateOutputShapesUsingAnnotatedInformation(const NodeDef & node,NodeContext * c) const1598 Status UpdateOutputShapesUsingAnnotatedInformation(const NodeDef& node,
1599 NodeContext* c) const {
1600 const auto& attr = node.attr();
1601 if (attr.count(kOutputSame) == 0 || !attr.at(kOutputSame).b() ||
1602 attr.count(kOutputShapes) == 0)
1603 return Status::OK();
1604
1605 InferenceContext* ic = c->inference_context.get();
1606 int output_size = attr.at(kOutputShapes).list().shape_size();
1607
1608 for (int i = 0; i < ic->num_outputs(); i++) {
1609 // Annotated Switch node has only one output. Propagate the shape to all
1610 // the outputs.
1611 int shape_index = IsSwitch(node) ? 0 : i;
1612 if (shape_index >= output_size) {
1613 LOG(WARNING)
1614 << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1615 << node.name() << ", inferred output shape size "
1616 << ic->num_outputs() << ", annotated output shape size "
1617 << output_size;
1618 break;
1619 }
1620
1621 const TensorShapeProto& shape =
1622 attr.at(kOutputShapes).list().shape(shape_index);
1623 if (shape.dim().empty()) continue;
1624
1625 ShapeHandle output_shape;
1626 TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &output_shape));
1627
1628 // Check if annotated shapes are incompatible with inferred shapes.
1629 if ((ic->FullyDefined(ic->output(i)) &&
1630 !SameShapes(ic->output(i), output_shape)) ||
1631 (!ic->FullyDefined(ic->output(i)) &&
1632 !CompatibleShapes(ic->output(i), output_shape))) {
1633 LOG(WARNING)
1634 << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1635 << node.name() << ", inferred output shape "
1636 << "doesn't match for i=" << i << ": "
1637 << "ic->output(k): " << ic->DebugString(ic->output(i))
1638 << ", annotated output shape: " << ic->DebugString(output_shape)
1639 << " -- " << node.DebugString();
1640 c->shape_incompatible = true;
1641 }
1642
1643 // Only use annotated shapes if the inference shape is unknown and
1644 // compatible with annotated shapes.
1645 if (!ic->FullyDefined(ic->output(i)) &&
1646 CompatibleShapes(ic->output(i), output_shape)) {
1647 VLOG(3) << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1648 << node.name() << ", inferred output shape " << i << ": "
1649 << "ic->output(i): " << ic->DebugString(ic->output(i))
1650 << ", annotated output shape: " << ic->DebugString(output_shape)
1651 << " -- " << node.ShortDebugString();
1652 ic->set_output(i, output_shape);
1653 }
1654 }
1655
1656 return Status::OK();
1657 }
1658
MaybeUpdateNodeContextOutput(const NodeDef & node,const bool is_fed,NodeContext * c)1659 Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed,
1660 NodeContext* c) {
1661 // Propagate tensors and shape tensors unless the node is fed.
1662 // TODO(bsteiner) We should still propagate the shapes to the ports that
1663 // aren't fed in the case of a ShapeN node.
1664
1665 // Note that when propagating tensors_as_shapes, we use
1666 // c->input_tensors_as_shapes_to_progate instead of
1667 // ic->input_tensors_as_shapes. The former uses kUnknownDimFromConst if
1668 // UnknownDim is from Const tensor, and it is propagated through shape
1669 // inference. Before calling shape functions, we convert it to UnknownDim,
1670 // but instantiate a new UnknownDim to prevent incorrect symbolic shape
1671 // inference through UnknownDim from Const.
1672 InferenceContext* ic = c->inference_context.get();
1673 if (!is_fed) {
1674 if (IsConstant(node)) {
1675 c->output_tensor_protos.resize(1);
1676 const TensorProto& tensor_proto = node.attr().at("value").tensor();
1677 c->output_tensor_protos[0] = &tensor_proto;
1678 c->output_tensors_as_shapes.resize(1);
1679 MaybeTensorProtoToShape(ic, tensor_proto,
1680 &c->output_tensors_as_shapes[0]);
1681 } else if (IsRank(node)) {
1682 if (ic->RankKnown(ic->input(0))) {
1683 // Propagate rank value.
1684 int32 rank = ic->Rank(ic->input(0));
1685 const_tensors_to_propagate_.push_back(
1686 MakeIntegerScalarTensorProto(DT_INT32, rank));
1687 c->output_tensor_protos.resize(1);
1688 c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
1689 }
1690 } else if (IsSize(node)) {
1691 DimensionHandle size = ic->NumElements(ic->input(0));
1692 if (ic->ValueKnown(size)) {
1693 // Propagate size value.
1694 int64 sz = ic->Value(size);
1695 bool valid = false;
1696 if (node.attr().at("out_type").type() == DT_INT32) {
1697 if (sz < std::numeric_limits<int32>::max()) {
1698 const_tensors_to_propagate_.push_back(
1699 MakeIntegerScalarTensorProto(DT_INT32, sz));
1700 valid = true;
1701 }
1702 } else {
1703 const_tensors_to_propagate_.push_back(
1704 MakeIntegerScalarTensorProto(DT_INT64, sz));
1705 valid = true;
1706 }
1707 if (valid) {
1708 c->output_tensor_protos.resize(1);
1709 c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
1710 }
1711 }
1712 } else if (IsShape(node)) {
1713 c->output_tensors_as_shapes.resize(1);
1714 c->output_tensors_as_shapes[0] = c->inference_context->input(0);
1715 } else if (IsShapeN(node)) {
1716 c->output_tensors_as_shapes.resize(c->inference_context->num_inputs());
1717 for (int i = 0; i < c->inference_context->num_inputs(); ++i) {
1718 c->output_tensors_as_shapes[i] = c->inference_context->input(i);
1719 }
1720 } else if (node.op() == "ConcatV2") {
1721 bool valid = true;
1722 ShapeHandle result;
1723 for (int i = 0; i < ic->num_inputs() - 1; ++i) {
1724 ShapeHandle input = c->input_tensors_as_shapes_to_propagate[i];
1725 if (!ic->RankKnown(input)) {
1726 valid = false;
1727 break;
1728 } else if (i == 0) {
1729 result = input;
1730 } else {
1731 TF_RETURN_IF_ERROR(ic->Concatenate(result, input, &result));
1732 }
1733 }
1734 if (valid) {
1735 c->output_tensors_as_shapes.resize(1);
1736 c->output_tensors_as_shapes[0] = result;
1737 }
1738 } else if (IsPack(node)) {
1739 // A Pack node concatenating scalars is often used to generate a shape.
1740 std::vector<DimensionHandle> dims;
1741 bool valid = true;
1742 for (int i = 0; i < ic->num_inputs(); ++i) {
1743 const Tensor* t = ic->input_tensor(i);
1744 if (t) {
1745 if (t->dims() != 0 ||
1746 (t->dtype() != DT_INT32 && t->dtype() != DT_INT64)) {
1747 valid = false;
1748 break;
1749 }
1750 int64 size = t->dtype() == DT_INT32 ? t->scalar<int32>()()
1751 : t->scalar<int64>()();
1752 dims.push_back(size < 0 ? ic->MakeDim(kUnknownDimFromConst)
1753 : ic->MakeDim(size));
1754 } else {
1755 // Don't have tensor value, but use input_tensors_as_shapes, if
1756 // possible.
1757 const ShapeHandle& shape_handle =
1758 c->input_tensors_as_shapes_to_propagate[i];
1759 if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 &&
1760 ic->ValueKnown(ic->Dim(shape_handle, 0))) {
1761 dims.push_back(ic->Dim(shape_handle, 0));
1762 } else {
1763 // This is not from Const, but as it shouldn'be used as symbolic
1764 // unknown dim for different ops, we use kUnknownDimFromConst.
1765 dims.push_back(ic->MakeDim(kUnknownDimFromConst));
1766 }
1767 }
1768 }
1769 if (valid) {
1770 c->output_tensors_as_shapes.resize(1);
1771 c->output_tensors_as_shapes[0] = ic->MakeShape(dims);
1772 }
1773 } else if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
1774 c->output_tensors_as_shapes.resize(1);
1775 c->output_tensors_as_shapes[0] =
1776 c->input_tensors_as_shapes_to_propagate[0];
1777 if (c->input_tensor_protos[0] != nullptr) {
1778 c->output_tensor_protos.resize(1);
1779 c->output_tensor_protos[0] = c->input_tensor_protos[0];
1780 }
1781 } else if (IsSlice(node)) {
1782 ShapeHandle input = c->input_tensors_as_shapes_to_propagate[0];
1783 bool valid = ic->RankKnown(input);
1784 const Tensor* slice_offset = ic->input_tensor(1);
1785 valid &= slice_offset != nullptr && slice_offset->NumElements() == 1;
1786 const Tensor* slice_size = ic->input_tensor(2);
1787 valid &= slice_size != nullptr && slice_size->NumElements() == 1;
1788 if (valid) {
1789 int64 start = slice_offset->dtype() == DT_INT32
1790 ? slice_offset->flat<int32>()(0)
1791 : slice_offset->flat<int64>()(0);
1792 int64 size =
1793 (slice_size->dtype() == DT_INT32 ? slice_size->flat<int32>()(0)
1794 : slice_size->flat<int64>()(0));
1795 ShapeHandle result;
1796 if (size == -1) {
1797 TF_RETURN_IF_ERROR(ic->Subshape(input, start, &result));
1798 } else {
1799 int64 end = start + size;
1800 TF_RETURN_IF_ERROR(ic->Subshape(input, start, end, &result));
1801 }
1802 c->output_tensors_as_shapes.resize(1);
1803 c->output_tensors_as_shapes[0] = result;
1804 }
1805 } else if (IsStridedSlice(node)) {
1806 ShapeHandle input = c->input_tensors_as_shapes_to_propagate[0];
1807 bool valid = ic->RankKnown(input);
1808 const Tensor* slice_begin = ic->input_tensor(1);
1809 valid &= slice_begin != nullptr && slice_begin->NumElements() == 1;
1810 const Tensor* slice_end = ic->input_tensor(2);
1811 valid &= slice_end != nullptr && slice_end->NumElements() == 1;
1812 const Tensor* slice_stride = ic->input_tensor(3);
1813 valid &= slice_stride != nullptr && slice_stride->NumElements() == 1;
1814
1815 if (node.attr().count("ellipsis_mask") > 0 &&
1816 node.attr().at("ellipsis_mask").i() != 0) {
1817 valid = false;
1818 }
1819 if (node.attr().count("new_axis_mask") > 0 &&
1820 node.attr().at("new_axis_mask").i() != 0) {
1821 valid = false;
1822 }
1823 if (node.attr().count("shrink_axis_mask") > 0 &&
1824 node.attr().at("shrink_axis_mask").i() != 0) {
1825 valid = false;
1826 }
1827 int begin_mask = 0;
1828 if (node.attr().count("begin_mask") > 0) {
1829 begin_mask = node.attr().at("begin_mask").i();
1830 }
1831 int end_mask = 0;
1832 if (node.attr().count("end_mask") > 0) {
1833 end_mask = node.attr().at("end_mask").i();
1834 }
1835 if (begin_mask < 0 || begin_mask > 1 || end_mask < 0 || end_mask > 1) {
1836 valid = false;
1837 }
1838 if (valid) {
1839 int64 begin = 0;
1840 if (begin_mask == 0) {
1841 begin = slice_begin->dtype() == DT_INT32
1842 ? slice_begin->flat<int32>()(0)
1843 : slice_begin->flat<int64>()(0);
1844 }
1845 int64 end = std::numeric_limits<int64>::max();
1846 if (end_mask == 0) {
1847 end =
1848 (slice_end->dtype() == DT_INT32 ? slice_end->flat<int32>()(0)
1849 : slice_end->flat<int64>()(0));
1850 }
1851 int64 stride = slice_stride->dtype() == DT_INT32
1852 ? slice_stride->flat<int32>()(0)
1853 : slice_stride->flat<int64>()(0);
1854 ShapeHandle result;
1855 TF_RETURN_IF_ERROR(ic->Subshape(input, begin, end, stride, &result));
1856 c->output_tensors_as_shapes.resize(1);
1857 c->output_tensors_as_shapes[0] = result;
1858 }
1859 }
1860 }
1861
1862 if (aggressive_shape_inference_) {
1863 // Update output shapes with annotated information. This is optional.
1864 UpdateOutputShapesUsingAnnotatedInformation(node, c).IgnoreError();
1865
1866 // Update output tensor values using EvaluateNode() if we can.
1867 // Due to the cost of EvaluateNode(), we run it only for certain op types
1868 // (white listed) and small integer tensors.
1869
1870 const int max_element_size = 17; // Max up to 4x4 matrix or similar.
1871 if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) ||
1872 !ShouldUpdateOutputShapesAndValues(c, max_element_size)) {
1873 return Status::OK();
1874 }
1875 UpdateOutputShapesAndValues(node, c).IgnoreError(); // This is optional.
1876 }
1877 return Status::OK();
1878 }
1879
InferShapes(const NodeDef & node,NodeContext * c)1880 Status InferShapes(const NodeDef& node, NodeContext* c) {
1881 // Infer the shapes of output tensors.
1882 if (!c->op_data || c->op_data->shape_inference_fn == nullptr ||
1883 !c->inference_context->Run(c->op_data->shape_inference_fn).ok()) {
1884 // Annotate outputs with unknown shapes. Update output shapes with
1885 // annotated information later on if available.
1886 // Note that shape inference function may return an error, but we ignore
1887 // it, and use UnknownShape in that case.
1888 TF_RETURN_IF_ERROR(
1889 c->inference_context->Run(shape_inference::UnknownShape));
1890 }
1891 Status status = Status::OK();
1892 auto it = fed_ports_.find(node.name());
1893 const bool is_fed = it != fed_ports_.end();
1894 if (is_fed) {
1895 // It is possible to feed node output ports with tensors of any shape: as
1896 // a result, the shape of a fed port is completely unknown.
1897 for (const int output_port : it->second) {
1898 status.Update(SetUnknownShape(&node, output_port));
1899 }
1900 }
1901
1902 // Update NodeContext output fields after shape inference function runs.
1903 status.Update(MaybeUpdateNodeContextOutput(node, is_fed, c));
1904
1905 return status;
1906 }
1907
1908 private:
IsIntegerVector(const Tensor & tensor)1909 bool IsIntegerVector(const Tensor& tensor) {
1910 if (tensor.dims() == 1 &&
1911 (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64)) {
1912 return true;
1913 }
1914 return false;
1915 }
1916
IsIntegerScalar(const Tensor & tensor)1917 bool IsIntegerScalar(const Tensor& tensor) {
1918 if (tensor.dims() == 0 &&
1919 (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64) &&
1920 tensor.NumElements() == 1) {
1921 return true;
1922 }
1923 return false;
1924 }
1925
MakeIntegerScalarTensorProto(const DataType dtype,const int64 val)1926 TensorProto MakeIntegerScalarTensorProto(const DataType dtype,
1927 const int64 val) {
1928 TensorProto tensor_proto;
1929 tensor_proto.set_dtype(dtype);
1930 // Scalar TensorProto has an empty tensor_shape; no dim, no dim.size.
1931 tensor_proto.mutable_tensor_shape();
1932 if (dtype == DT_INT32) {
1933 tensor_proto.add_int_val(val);
1934 } else if (dtype == DT_INT64) {
1935 tensor_proto.add_int64_val(val);
1936 }
1937 return tensor_proto;
1938 }
1939
MaybeTensorProtoToShape(InferenceContext * ic,const TensorProto & tensor_proto,ShapeHandle * tensors_as_shapes)1940 bool MaybeTensorProtoToShape(InferenceContext* ic,
1941 const TensorProto& tensor_proto,
1942 ShapeHandle* tensors_as_shapes) {
1943 // Skip if dtype is not integer.
1944 if (tensor_proto.dtype() != DT_INT32 && tensor_proto.dtype() != DT_INT64) {
1945 return false;
1946 }
1947 // Skip if shape is neither scalar nor vector.
1948 if (tensor_proto.tensor_shape().unknown_rank() ||
1949 tensor_proto.tensor_shape().dim_size() > 1) {
1950 return false;
1951 }
1952 Tensor tensor;
1953 if (!tensor.FromProto(tensor_proto)) {
1954 return false;
1955 }
1956 return MaybeTensorValueToShape(ic, tensor, tensors_as_shapes);
1957 }
1958
MaybeTensorValueToShape(InferenceContext * ic,const Tensor & tensor,ShapeHandle * tensors_as_shapes)1959 bool MaybeTensorValueToShape(InferenceContext* ic, const Tensor& tensor,
1960 ShapeHandle* tensors_as_shapes) {
1961 // Integer tensors of rank one can also be interpreted as a shape
1962 // provided all their values are >= -1.
1963
1964 if (IsIntegerVector(tensor)) {
1965 bool has_values_smaller_than_minus_1 = false;
1966 std::vector<DimensionHandle> dims;
1967 for (int i = 0; i < tensor.NumElements(); i++) {
1968 int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(i)
1969 : tensor.flat<int64>()(i);
1970 has_values_smaller_than_minus_1 |= (value < -1);
1971 // Mark this as UnknownDim from Const.
1972 dims.push_back(value < 0 ? ic->MakeDim(kUnknownDimFromConst)
1973 : ic->MakeDim(value));
1974 }
1975
1976 if (!has_values_smaller_than_minus_1) {
1977 *tensors_as_shapes = ic->MakeShape(dims);
1978 return true;
1979 }
1980 } else if (IsIntegerScalar(tensor)) {
1981 // Scalar constant.
1982 int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0)
1983 : tensor.flat<int64>()(0);
1984 if (value == -1) {
1985 // Scalar value -1 represents an unknown shape. If we would try to
1986 // MakeShape(MakeDim) with it, we would get vector of unknown size.
1987 *tensors_as_shapes = ic->UnknownShape();
1988 return true;
1989 } else if (value >= 0) {
1990 // Ideally, values can be < -1, but MakeDim() fails with a value < -1.
1991 // It's a limitation as we use ShapeHandle as a means to pass values.
1992 *tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)});
1993 return true;
1994 }
1995 }
1996 return false;
1997 }
1998
1999 const GraphView& graph_;
2000 int graph_def_version_;
2001 absl::flat_hash_map<const NodeDef*, NodeContext> node_to_context_;
2002 absl::flat_hash_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
2003 absl::flat_hash_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
2004 // Store function instantiations only for valid function. If function
2005 // instantiation failed it will have an `absl::nullopt`.
2006 absl::flat_hash_map<string, absl::optional<GrapplerFunctionItem>>
2007 fun_to_grappler_function_item_;
2008 FunctionLibraryDefinition function_library_;
2009 const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports_;
2010 // Store TensorProtos for tensor value propagation. Note that we use deque,
2011 // not vector, as we use pointers to the TensorProtos in this container.
2012 // Vector may resize and copy the objects into a new buffer, then the existing
2013 // pointers become dangling pointers.
2014 std::deque<TensorProto> const_tensors_to_propagate_;
2015
2016 // For more aggressive shape and value inference.
2017 bool aggressive_shape_inference_;
2018 ResourceMgr resource_mgr_;
2019 };
2020
2021 // Keep track of shapes and dimensions in a graph.
2022 // In particular, use disjoint sets to track equivalence between shapes and
2023 // dims, and consolidate the information globally.
2024 class SymbolicShapeManager {
2025 public:
SymbolicShapeManager()2026 SymbolicShapeManager() {}
2027
Merge(ShapeHandle s1,ShapeHandle s2)2028 Status Merge(ShapeHandle s1, ShapeHandle s2) {
2029 if (!s1.IsSet() || !s2.IsSet()) {
2030 return Status::OK();
2031 }
2032 TF_RETURN_IF_ERROR(shapes_.Merge(s1, s2));
2033 if (InferenceContext::Rank(s1) > 0 && InferenceContext::Rank(s2) > 0) {
2034 CHECK_EQ(InferenceContext::Rank(s1), InferenceContext::Rank(s2));
2035 for (int i = 0; i < InferenceContext::Rank(s1); ++i) {
2036 TF_RETURN_IF_ERROR(dims_.Merge(InferenceContext::DimKnownRank(s1, i),
2037 InferenceContext::DimKnownRank(s2, i)));
2038 }
2039 }
2040 return Status::OK();
2041 }
Merge(DimensionHandle d1,DimensionHandle d2)2042 Status Merge(DimensionHandle d1, DimensionHandle d2) {
2043 if (!d1.IsSet() || !d2.IsSet()) {
2044 return Status::OK();
2045 }
2046 return dims_.Merge(d1, d2);
2047 }
2048
AsTensorProperties(const ShapeHandle & shape,const DataType & type,OpInfo::TensorProperties * properties)2049 void AsTensorProperties(const ShapeHandle& shape, const DataType& type,
2050 OpInfo::TensorProperties* properties) {
2051 properties->set_dtype(type);
2052 ShapeHandle actual_shape = shapes_.GetMergedValue(shape);
2053 if (!InferenceContext::RankKnown(actual_shape)) {
2054 properties->mutable_shape()->set_unknown_rank(true);
2055 } else {
2056 for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
2057 shape_inference::DimensionHandle dim =
2058 InferenceContext::DimKnownRank(actual_shape, j);
2059 int64 d = dims_.GetMergedValue(dim);
2060 properties->mutable_shape()->add_dim()->set_size(d);
2061 }
2062 }
2063 }
2064
2065 // Returns merged shape with merged dimensions.
GetMergedShape(InferenceContext * ic,ShapeHandle s)2066 ShapeHandle GetMergedShape(InferenceContext* ic, ShapeHandle s) {
2067 const auto& actual_shape = shapes_.GetMergedValue(s);
2068 if (!InferenceContext::RankKnown(actual_shape)) {
2069 return ic->UnknownShape();
2070 } else {
2071 std::vector<DimensionHandle> dims;
2072 for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
2073 shape_inference::DimensionHandle dim =
2074 InferenceContext::DimKnownRank(actual_shape, j);
2075 int64 d = dims_.GetMergedValue(dim);
2076 // Symbolic shape manager may made some dims < -1, which causes errors
2077 // in creating Dimension.
2078 if (d < -1) {
2079 d = -1;
2080 }
2081 dims.push_back(ic->MakeDim(d));
2082 }
2083 return ic->MakeShape(dims);
2084 }
2085 }
2086
2087 private:
2088 DisjointSet<shape_inference::ShapeHandle> shapes_;
2089 DisjointSet<shape_inference::DimensionHandle> dims_;
2090 };
2091
2092 // Checks whether there is any conflict in merged shapes and dims in
2093 // SymbolicShapeManager.
ValidateSymbolicShapeManager(const GraphDef & graph_def,SymbolicShapeRefiner * refiner,SymbolicShapeManager * shape_manager)2094 Status ValidateSymbolicShapeManager(const GraphDef& graph_def,
2095 SymbolicShapeRefiner* refiner,
2096 SymbolicShapeManager* shape_manager) {
2097 if (!VLOG_IS_ON(1)) {
2098 return Status::OK();
2099 }
2100
2101 VLOG(1) << "Checking any conflics in shapes and dimensions ...";
2102 int64 num_incompatible_shapes = 0;
2103 for (const NodeDef& node : graph_def.node()) {
2104 auto ctx = refiner->GetNodeContext(&node);
2105 if (!ctx) {
2106 continue;
2107 }
2108 auto* ic = ctx->inference_context.get();
2109 for (int i = 0; i < ic->num_inputs(); ++i) {
2110 const auto& shape = ic->input(i);
2111 const auto& merged_shape = shape_manager->GetMergedShape(ic, shape);
2112 if (!refiner->CompatibleShapes(shape, merged_shape)) {
2113 num_incompatible_shapes++;
2114 VLOG(1) << "**** Incompatible shape from SymbolicShapeManager "
2115 << "for node " << node.name() << " input (" << i << ") "
2116 << ic->DebugString(shape)
2117 << " vs. merged: " << ic->DebugString(merged_shape);
2118 }
2119 }
2120 for (int i = 0; i < ic->num_outputs(); ++i) {
2121 const auto& shape = ic->output(i);
2122 const auto& merged_shape = shape_manager->GetMergedShape(ic, shape);
2123 if (!refiner->CompatibleShapes(shape, merged_shape)) {
2124 num_incompatible_shapes++;
2125 VLOG(1) << "**** Incompatible shape from SymbolicShapeManager "
2126 << "for node " << node.name() << " output (" << i << ") "
2127 << ic->DebugString(shape)
2128 << " vs. merged: " << ic->DebugString(merged_shape);
2129 }
2130 }
2131 }
2132 if (num_incompatible_shapes > 0) {
2133 VLOG(1) << "**** WARNING: " << num_incompatible_shapes
2134 << " incompatible shapes from SymbolicShapeManager.";
2135 } else {
2136 VLOG(1) << "**** No incompatible shape found from SymbolicShapeManager.";
2137 }
2138
2139 return Status::OK();
2140 }
2141
2142 // Log shape inference and its merged shapes.
VerboseShapeInferenceLogging(const GraphDef & graph_def,SymbolicShapeRefiner * refiner,SymbolicShapeManager * shape_manager)2143 Status VerboseShapeInferenceLogging(const GraphDef& graph_def,
2144 SymbolicShapeRefiner* refiner,
2145 SymbolicShapeManager* shape_manager) {
2146 // As logging all the nodes would generate too many lines, we by default
2147 // skip this detailed logging. Users may add nodes of interest to
2148 // node_names_for_logging to enable detailed logging.
2149 absl::flat_hash_set<std::string> node_names_for_logging = {};
2150 if (!VLOG_IS_ON(3) || node_names_for_logging.empty()) {
2151 return Status::OK();
2152 }
2153
2154 auto should_log = [&node_names_for_logging](std::string node_name) {
2155 return node_names_for_logging.find(node_name) !=
2156 node_names_for_logging.end();
2157 };
2158
2159 for (const NodeDef& node : graph_def.node()) {
2160 if (!should_log(node.name())) {
2161 continue;
2162 }
2163 auto ctx = refiner->GetNodeContext(&node);
2164 if (!ctx) {
2165 continue;
2166 }
2167 auto* ic = ctx->inference_context.get();
2168 VLOG(3) << "Shape inference for node : " << node.name();
2169 VLOG(3) << ctx->DebugString(node);
2170 std::string merged_shapes = "Merged shapes from SymbolicShapManager:\n";
2171 for (int i = 0; i < ic->num_inputs(); ++i) {
2172 absl::StrAppend(
2173 &merged_shapes, " input[", i, "] -- ",
2174 ic->DebugString(shape_manager->GetMergedShape(ic, ic->input(i))),
2175 "\n");
2176 }
2177 for (int i = 0; i < ic->num_outputs(); ++i) {
2178 absl::StrAppend(
2179 &merged_shapes, " output[", i, "] -- ",
2180 ic->DebugString(shape_manager->GetMergedShape(ic, ic->output(i))),
2181 "\n");
2182 }
2183 VLOG(3) << merged_shapes;
2184 VLOG(3) << "--------------------------------";
2185 VLOG(3) << "";
2186 }
2187
2188 return Status::OK();
2189 }
2190
RelaxEnqueueShapesAndMergeTypes(SymbolicShapeRefiner * shape_refiner,const NodeDef * qnode,const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * queue_shapes_and_types)2191 Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
2192 SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
2193 const std::vector<ShapeAndType>& shapes_and_types,
2194 std::vector<ShapeAndType>* queue_shapes_and_types) {
2195 if (shapes_and_types.size() != queue_shapes_and_types->size()) {
2196 return errors::InvalidArgument(
2197 "Enqueue nodes mixed number of tensors: ", shapes_and_types.size(),
2198 " vs ", queue_shapes_and_types->size());
2199 }
2200 for (size_t i = 0; i < shapes_and_types.size(); ++i) {
2201 const ShapeAndType& a = shapes_and_types[i];
2202 ShapeAndType& b = (*queue_shapes_and_types)[i];
2203 if (a.dtype != b.dtype) {
2204 return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ",
2205 i, ": ", DataTypeString(a.dtype), " vs ",
2206 DataTypeString(b.dtype));
2207 }
2208
2209 b.shape = shape_refiner->OutputAsUnion(qnode, i, a.shape, b.shape);
2210 }
2211 return Status::OK();
2212 }
2213
2214 // Compute the output shape of the merge node as the union of the available
2215 // input shapes.
UpdateMerge(SymbolicShapeRefiner * shape_refiner,const NodeDef * node,bool * new_shapes) const2216 Status GraphProperties::UpdateMerge(SymbolicShapeRefiner* shape_refiner,
2217 const NodeDef* node,
2218 bool* new_shapes) const {
2219 InferenceContext* ic = shape_refiner->GetContext(node);
2220 if (!ic) {
2221 // Now we can run shape inference
2222 TF_RETURN_IF_ERROR(shape_refiner->AddNode(node));
2223 ic = CHECK_NOTNULL(shape_refiner->GetContext(node));
2224 *new_shapes = true;
2225
2226 // Infer the shape of the second output once and for all since it never
2227 // changes.
2228 ShapeHandle out1 = ic->Scalar();
2229 if (ic->num_outputs() >= 2) ic->set_output(1, out1);
2230 }
2231
2232 ShapeHandle out;
2233 const std::vector<ShapeAndType>* out_handle = nullptr;
2234 bool out_initialized = false;
2235 for (const GraphView::Edge fanin : shape_refiner->graph().GetFaninEdges(
2236 *node, /*include_controlling_edges=*/false)) {
2237 InferenceContext* src_ic = shape_refiner->GetContext(fanin.src.node);
2238 if (!src_ic) {
2239 // Handling a loop for the first time, the back edge won't have any shape
2240 // info.
2241 continue;
2242 }
2243 ShapeHandle input = src_ic->output(fanin.src.port_id);
2244 ic->SetInput(fanin.dst.port_id, input);
2245 auto* input_handle =
2246 src_ic->output_handle_shapes_and_types(fanin.src.port_id);
2247 if (input_handle)
2248 ic->set_input_handle_shapes_and_types(fanin.dst.port_id, *input_handle);
2249 if (!out_initialized) {
2250 out_initialized = true;
2251 out = input;
2252 out_handle = input_handle;
2253 } else {
2254 // Note here only out, not out_handle, is modified.
2255 out = shape_refiner->OutputAsUnion(node, 0, input, out);
2256 }
2257 }
2258
2259 if (*new_shapes || !shape_refiner->EquivalentShapes(out, ic->output(0))) {
2260 ic->set_output(0, out);
2261 if (out_handle) ic->set_output_handle_shapes_and_types(0, *out_handle);
2262 *new_shapes = true;
2263 }
2264
2265 return Status::OK();
2266 }
2267
2268 // Manually propagate the input shape for Enter nodes.
UpdateEnter(SymbolicShapeRefiner * shape_refiner,const NodeDef * node,bool * new_shapes)2269 Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
2270 const NodeDef* node, bool* new_shapes) {
2271 InferenceContext* ic = shape_refiner->GetContext(node);
2272 if (!ic) {
2273 TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(node, new_shapes));
2274 ic = shape_refiner->GetContext(node);
2275 }
2276
2277 GraphView::InputPort port(node, 0);
2278 GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(port);
2279
2280 InferenceContext* src_ic = shape_refiner->GetContext(fanin.node);
2281 ShapeHandle input = src_ic->output(fanin.port_id);
2282 if (!ic->output(0).SameHandle(input)) {
2283 ic->SetInput(0, input);
2284 ic->set_output(0, input);
2285 *new_shapes = true;
2286 }
2287 auto* outputs = src_ic->output_handle_shapes_and_types(fanin.port_id);
2288 if (outputs) {
2289 ic->set_input_handle_shapes_and_types(0, *outputs);
2290 ic->set_output_handle_shapes_and_types(0, *outputs);
2291 *new_shapes = true;
2292 }
2293 return Status::OK();
2294 }
2295
UpdateShapes(SymbolicShapeRefiner * shape_refiner,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,const NodeDef * n,bool * new_shapes) const2296 Status GraphProperties::UpdateShapes(
2297 SymbolicShapeRefiner* shape_refiner,
2298 const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2299 const NodeDef* n, bool* new_shapes) const {
2300 if (IsEnter(*n)) {
2301 // The Enter shape function always forwards an UnknownShape, so do the right
2302 // thing here.
2303 TF_RETURN_IF_ERROR(UpdateEnter(shape_refiner, n, new_shapes));
2304 } else if (IsMerge(*n)) {
2305 // Properly handle merge nodes.
2306 TF_RETURN_IF_ERROR(UpdateMerge(shape_refiner, n, new_shapes));
2307 } else if (IsEnqueue(*n)) {
2308 // Make sure the shapes of enqueued tensors are propagated to the queue
2309 // itself.
2310 TF_RETURN_IF_ERROR(
2311 UpdateEnqueue(n, resource_handles, shape_refiner, new_shapes));
2312 } else if (IsQueue(*n)) {
2313 // Set shapes and types of Queue ops, if needed.
2314 TF_RETURN_IF_ERROR(UpdateQueue(n, shape_refiner, new_shapes));
2315 } else {
2316 // Rely on regular TF shape refinement for all the other nodes.
2317 // UpdateNode calls UpdateFunction if a function node is detected.
2318 TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
2319 }
2320
2321 return Status::OK();
2322 }
2323
2324 // Propagates the shapes in the transitive fan-out of <new_shapes>.
PropagateShapes(SymbolicShapeRefiner * shape_refiner,TopoQueue * new_shapes,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,int num_loops) const2325 Status GraphProperties::PropagateShapes(
2326 SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
2327 const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2328 int num_loops) const {
2329 // Limit the number of iterations to prevent infinite loops in the presence of
2330 // incorrect shape functions. The algorithm should converge in at most
2331 // num_nested_loops^2 * max_rank. We approximate max_rank with the constant 4.
2332 // The same applies to resources.
2333 VLOG(1) << "Propagating " << new_shapes->size() << " new shapes through "
2334 << num_loops << " loops and " << resource_handles.size()
2335 << " resources" << std::endl;
2336
2337 const int64 max_loop_length = item_.graph.node_size();
2338 const int64 max_rank = 4;
2339 const int64 max_loop_iterations =
2340 max_rank * max_loop_length * std::max<int64>(1, num_loops * num_loops);
2341 const int64 num_queues = resource_handles.size();
2342 const int64 max_resource_iterations = num_queues * num_queues * max_rank;
2343
2344 int64 num_resource_iterations = 0;
2345 do {
2346 int64 num_loop_iterations = 0;
2347 while (!new_shapes->empty() &&
2348 num_loop_iterations++ < max_loop_iterations) {
2349 const NodeDef* n = new_shapes->pop();
2350 bool updated = false;
2351 TF_RETURN_IF_ERROR(
2352 UpdateShapes(shape_refiner, resource_handles, n, &updated));
2353 if (updated) {
2354 for (const auto& fanout : shape_refiner->graph().GetFanouts(
2355 *n, /*include_controlled_nodes=*/false)) {
2356 new_shapes->push(fanout.node);
2357 }
2358 // Make sure the corresponding queue nodes are (re)processed.
2359 if (IsEnqueue(*n)) {
2360 auto it = resource_handles.find(n);
2361 if (it != resource_handles.end()) {
2362 new_shapes->push(it->second);
2363 }
2364 }
2365 }
2366 }
2367 } while (!new_shapes->empty() &&
2368 num_resource_iterations++ < max_resource_iterations);
2369
2370 if (!new_shapes->empty()) {
2371 return errors::Internal("Shape inference failed to converge");
2372 }
2373
2374 return Status::OK();
2375 }
2376
UpdateQueue(const NodeDef * queue_node,SymbolicShapeRefiner * shape_refiner,bool * new_shapes)2377 Status GraphProperties::UpdateQueue(const NodeDef* queue_node,
2378 SymbolicShapeRefiner* shape_refiner,
2379 bool* new_shapes) {
2380 auto* ctx = shape_refiner->GetNodeContext(queue_node);
2381 if (!ctx) {
2382 TF_RETURN_IF_ERROR(shape_refiner->AddNode(queue_node));
2383 ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(queue_node));
2384 }
2385 auto* ic = ctx->inference_context.get();
2386
2387 auto* outputs = ic->output_handle_shapes_and_types(0);
2388 if (outputs) {
2389 // Shapes and types are already set, presumably by Enqueue ops.
2390 return shape_refiner->UpdateNode(queue_node, new_shapes);
2391 }
2392
2393 if (queue_node->attr().count("shapes") <= 0 ||
2394 queue_node->attr().count("component_types") <= 0 ||
2395 queue_node->attr().at("shapes").list().shape_size() !=
2396 queue_node->attr().at("component_types").list().type_size()) {
2397 // Errors in shapes and component_types attr.
2398 return shape_refiner->UpdateNode(queue_node, new_shapes);
2399 }
2400
2401 // Extract types and shapes from Queue attr.
2402 const auto& shapes = queue_node->attr().at("shapes").list().shape();
2403 const auto& types = queue_node->attr().at("component_types").list().type();
2404 std::vector<ShapeAndType> shapes_and_types;
2405 for (int i = 0; i < types.size(); i++) {
2406 const auto& shape = shapes[i];
2407 ShapeHandle shape_handle;
2408 TF_RETURN_IF_ERROR(
2409 ic->MakeShapeFromPartialTensorShape(shape, &shape_handle));
2410 DataType data_type =
2411 queue_node->attr().at("component_types").list().type(i);
2412 ShapeAndType shape_and_type(shape_handle, data_type);
2413 shapes_and_types.push_back(shape_and_type);
2414 }
2415 ic->set_output_handle_shapes_and_types(0, shapes_and_types);
2416
2417 // Queue node is updated with output_handle_shapes_and_types, so set
2418 // new_shapes and ignore it from UpdateNoe().
2419 *new_shapes = true;
2420 bool dummy_new_shapes = false;
2421 return shape_refiner->UpdateNode(queue_node, &dummy_new_shapes);
2422 }
2423
UpdateEnqueue(const NodeDef * enqueue_node,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,SymbolicShapeRefiner * shape_refiner,bool * new_shapes)2424 Status GraphProperties::UpdateEnqueue(
2425 const NodeDef* enqueue_node,
2426 const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2427 SymbolicShapeRefiner* shape_refiner, bool* new_shapes) {
2428 auto ctx = shape_refiner->GetNodeContext(enqueue_node);
2429 if (!ctx) {
2430 TF_RETURN_IF_ERROR(shape_refiner->AddNode(enqueue_node));
2431 ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(enqueue_node));
2432 }
2433
2434 auto it = resource_handles.find(enqueue_node);
2435 if (it == resource_handles.end()) {
2436 // The corresponding queue was not found, there isn't much we can do.
2437 return Status::OK();
2438 }
2439 const NodeDef* qnode = it->second;
2440 auto qctx = shape_refiner->GetContext(qnode);
2441 if (!qctx) {
2442 return Status::OK();
2443 }
2444 auto* queue_handle_data = qctx->output_handle_shapes_and_types(0);
2445
2446 // TODO(bsteiner): handle EnqueueMany as well.
2447 std::vector<ShapeAndType> shapes_and_types;
2448 for (int i = 1, end = ctx->input_types.size(); i < end; ++i) {
2449 GraphView::InputPort inp(enqueue_node, i);
2450 GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(inp);
2451 InferenceContext* in = shape_refiner->GetContext(fanin.node);
2452 ShapeHandle input = in->output(fanin.port_id);
2453 ctx->inference_context->SetInput(i, input);
2454 shapes_and_types.push_back({input, ctx->input_types[i]});
2455 }
2456
2457 if (queue_handle_data == nullptr) {
2458 qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
2459 *new_shapes = true;
2460 } else {
2461 TF_RETURN_IF_ERROR(RelaxEnqueueShapesAndMergeTypes(
2462 shape_refiner, qnode, *queue_handle_data, &shapes_and_types));
2463 *new_shapes |= !shape_refiner->EquivalentShapesAndTypes(*queue_handle_data,
2464 shapes_and_types);
2465 qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
2466 }
2467
2468 return Status::OK();
2469 }
2470
InferStatically(bool assume_valid_feeds,bool aggressive_shape_inference,bool include_input_tensor_values,bool include_output_tensor_values)2471 Status GraphProperties::InferStatically(bool assume_valid_feeds,
2472 bool aggressive_shape_inference,
2473 bool include_input_tensor_values,
2474 bool include_output_tensor_values) {
2475 FunctionLibraryDefinition function_library(OpRegistry::Global(),
2476 item_.graph.library());
2477 absl::flat_hash_map<string, absl::flat_hash_set<int>> fed_ports;
2478 if (!assume_valid_feeds) {
2479 for (const auto& feed : item_.feed) {
2480 SafeTensorId tensor_id = ParseTensorName(feed.first);
2481 fed_ports[tensor_id.node()].insert(tensor_id.index());
2482 }
2483 }
2484
2485 GraphView graph_view(&item_.graph);
2486
2487 // List the resources and the nodes using them. Also collect the Merge nodes,
2488 // fed nodes, and primary inputs.
2489 absl::flat_hash_map<const NodeDef*,
2490 std::pair<absl::flat_hash_set<const NodeDef*>,
2491 absl::flat_hash_set<const NodeDef*>>>
2492 resources;
2493 absl::flat_hash_set<const NodeDef*> merge_nodes;
2494 absl::flat_hash_set<const NodeDef*> fed_nodes;
2495 absl::flat_hash_set<const NodeDef*> primary_inputs;
2496 int num_loops = 0;
2497 for (const NodeDef& node : item_.graph.node()) {
2498 if (IsQueue(node)) {
2499 for (const GraphView::InputPort& fanout :
2500 graph_view.GetFanouts(node, false)) {
2501 if (IsEnter(*fanout.node)) {
2502 const NodeDef& enter = *fanout.node;
2503 for (const GraphView::InputPort& fanout :
2504 graph_view.GetFanouts(enter, false)) {
2505 if (IsEnqueue(*fanout.node)) {
2506 resources[&node].first.insert(fanout.node);
2507 } else if (IsDequeue(*fanout.node)) {
2508 resources[&node].second.insert(fanout.node);
2509 }
2510 }
2511 } else {
2512 if (IsEnqueue(*fanout.node)) {
2513 resources[&node].first.insert(fanout.node);
2514 } else if (IsDequeue(*fanout.node)) {
2515 resources[&node].second.insert(fanout.node);
2516 }
2517 }
2518 }
2519 }
2520 if (!HasRegularInputs(node)) {
2521 primary_inputs.insert(&node);
2522 } else if (IsMerge(node)) {
2523 merge_nodes.insert(&node);
2524 } else if (IsNextIteration(node)) {
2525 ++num_loops;
2526 }
2527 if (fed_ports.find(node.name()) != fed_ports.end()) {
2528 fed_nodes.insert(&node);
2529 }
2530 }
2531
2532 absl::flat_hash_map<const NodeDef*, const NodeDef*> resource_handles;
2533 std::vector<TopologicalDependency> extra_deps;
2534 for (const auto& resource : resources) {
2535 for (const NodeDef* src : resource.second.first) {
2536 resource_handles[src] = resource.first;
2537 for (const NodeDef* dst : resource.second.second) {
2538 // Add control edges from enqueue to dequeue nodes to ensure they are
2539 // processed in their logical order.
2540 extra_deps.emplace_back(src, dst);
2541 }
2542 }
2543 }
2544
2545 std::vector<const NodeDef*> topo_order;
2546 Status s = ComputeTopologicalOrder(item_.graph, extra_deps, &topo_order);
2547 if (!s.ok()) {
2548 if (extra_deps.empty()) {
2549 return s;
2550 } else {
2551 // There is a loop between queues: we'll just use the graph topological
2552 // order. This will make the shape inference less precise but since this
2553 // isn't common it's not worth to figure out where to break the loop and
2554 // do a proper relaxation.
2555 TF_RETURN_IF_ERROR(ComputeTopologicalOrder(item_.graph, &topo_order));
2556 }
2557 }
2558
2559 // Heap-allocate SymbolicShapeRefiner in order to not consume a large amount
2560 // of stack space.
2561 auto refiner = absl::make_unique<SymbolicShapeRefiner>(
2562 graph_view, fed_ports, aggressive_shape_inference);
2563
2564 TopoQueue new_shapes(topo_order);
2565 // Also seed the propagation of shapes in the fanout of primary inputs.
2566 for (const NodeDef* node : primary_inputs) {
2567 new_shapes.push(node);
2568 }
2569 // Also seed the propagation of shapes in the fanout of fed nodes.
2570 for (const NodeDef* node : fed_nodes) {
2571 new_shapes.push(node);
2572 }
2573 // Propagate shapes normally.
2574 TF_RETURN_IF_ERROR(
2575 PropagateShapes(refiner.get(), &new_shapes, resource_handles, num_loops));
2576
2577 // Track shapes globally across the graph.
2578 std::unique_ptr<SymbolicShapeManager> shape_manager =
2579 absl::make_unique<SymbolicShapeManager>();
2580 bool found_error = false;
2581 for (const NodeDef& node : item_.graph.node()) {
2582 auto node_ctx = refiner->GetContext(&node);
2583 if (!node_ctx) {
2584 continue;
2585 }
2586 // Skip any information that comes from fed nodes.
2587 if (fed_ports.find(node.name()) != fed_ports.end()) {
2588 VLOG(2) << "Skipping feed node shape: " << node.name();
2589 continue;
2590 }
2591 for (const auto& merged_shapes : node_ctx->MergedShapes()) {
2592 if (!shape_manager->Merge(merged_shapes.first, merged_shapes.second)
2593 .ok()) {
2594 found_error = true;
2595 break;
2596 }
2597 }
2598 for (const auto& merged_dims : node_ctx->MergedDims()) {
2599 if (!shape_manager->Merge(merged_dims.first, merged_dims.second).ok()) {
2600 found_error = true;
2601 break;
2602 }
2603 }
2604 if (found_error) {
2605 // The shapes aren't consistent, we can't infer safely: discard all the
2606 // information discovered so far.
2607 shape_manager = absl::make_unique<SymbolicShapeManager>();
2608 break;
2609 }
2610 }
2611
2612 TF_RETURN_IF_ERROR(ValidateSymbolicShapeManager(item_.graph, refiner.get(),
2613 shape_manager.get()));
2614
2615 for (const NodeDef& node : item_.graph.node()) {
2616 VLOG(4) << "Filling in graph properties for node: " << node.name();
2617 auto ctx = refiner->GetNodeContext(&node);
2618 if (!ctx) {
2619 continue;
2620 }
2621
2622 auto* ic = ctx->inference_context.get();
2623
2624 // Fill input properties.
2625 {
2626 auto& input_properties = input_properties_[node.name()];
2627
2628 // Should always be empty, node names in graph are supposed to be unique.
2629 CHECK_EQ(input_properties.size(), 0);
2630
2631 input_properties.resize(ic->num_inputs());
2632 GraphView::InputPort input(&node, -1);
2633 for (int i = 0; i < ic->num_inputs(); ++i) {
2634 shape_manager->AsTensorProperties(ic->input(i), ctx->input_types[i],
2635 &input_properties[i]);
2636 input.port_id = i;
2637 GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
2638 if (include_input_tensor_values) {
2639 // Export tensor value to input_properties.value.
2640 if (IsConstant(*fanin.node)) {
2641 const TensorProto& raw_val =
2642 fanin.node->attr().at("value").tensor();
2643 *input_properties[i].mutable_value() = raw_val;
2644 } else if (static_cast<int>(ctx->input_tensor_protos.size()) > i &&
2645 ctx->input_tensor_protos[i] != nullptr) {
2646 *input_properties[i].mutable_value() = *ctx->input_tensor_protos[i];
2647 } else if (static_cast<int>(ic->input_tensors_as_shapes().size()) >
2648 i &&
2649 IsShapeFullyDefinedIntegerVectorOrScalar(
2650 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
2651 ctx->input_types[i])) {
2652 *input_properties[i].mutable_value() = MakeTensorProtoFromShape(
2653 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
2654 ctx->input_types[i]);
2655 }
2656 }
2657 }
2658 }
2659
2660 // Fill output properties.
2661 {
2662 auto& output_properties = output_properties_[node.name()];
2663
2664 // Should always be empty, node names in graph are supposed to be unique.
2665 CHECK_EQ(output_properties.size(), 0);
2666
2667 output_properties.resize(ic->num_outputs());
2668 for (int i = 0; i < ic->num_outputs(); ++i) {
2669 shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i],
2670 &output_properties[i]);
2671 auto converted_output_tensors_as_shapes =
2672 ReplaceUnknownDimFromConstWithUnknownDim(
2673 ic, ctx->output_tensors_as_shapes);
2674 if (include_output_tensor_values) {
2675 // Export tensor value to output_properties.value.
2676 if (IsConstant(node)) {
2677 // TODO(rmlarsen): Eliminate this copy.
2678 const TensorProto& raw_val = node.attr().at("value").tensor();
2679 *output_properties[i].mutable_value() = raw_val;
2680 } else if (static_cast<int>(ctx->output_tensor_protos.size()) > i &&
2681 ctx->output_tensor_protos[i] != nullptr) {
2682 *output_properties[i].mutable_value() =
2683 *ctx->output_tensor_protos[i];
2684 } else if (static_cast<int>(
2685 converted_output_tensors_as_shapes.size()) > i &&
2686 IsShapeFullyDefinedIntegerVectorOrScalar(
2687 ic, ic->output(i),
2688 converted_output_tensors_as_shapes[i],
2689 ctx->output_types[i])) {
2690 *output_properties[i].mutable_value() = MakeTensorProtoFromShape(
2691 ic, ic->output(i), converted_output_tensors_as_shapes[i],
2692 ctx->output_types[i]);
2693 }
2694 }
2695 }
2696 }
2697
2698 if (aggressive_shape_inference && ctx->shape_incompatible)
2699 incompatible_shape_nodes_.insert(node.name());
2700 }
2701
2702 if (aggressive_shape_inference && !incompatible_shape_nodes_.empty())
2703 LOG(WARNING) << incompatible_shape_nodes_.size()
2704 << " nodes have incompatible output shapes.";
2705
2706 // Help trace the unknown dimensions to their origins.
2707 VerboseLogUnknownDimensionSources(item_.graph, input_properties_,
2708 output_properties_);
2709
2710 TF_RETURN_IF_ERROR(VerboseShapeInferenceLogging(item_.graph, refiner.get(),
2711 shape_manager.get()));
2712
2713 return Status::OK();
2714 }
2715
InferDynamically(Cluster * cluster)2716 Status GraphProperties::InferDynamically(Cluster* cluster) {
2717 TF_RETURN_IF_ERROR(cluster->Initialize(item_));
2718
2719 // Runs the model once to collect the shapes in the cost model.
2720 RunMetadata metadata;
2721 TF_RETURN_IF_ERROR(
2722 cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata));
2723
2724 return InferFromCostGraph(metadata.cost_graph());
2725 }
2726
AnnotateOutputShapes(GraphDef * output_graph_def,bool allow_symbolic_shapes) const2727 Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def,
2728 bool allow_symbolic_shapes) const {
2729 *output_graph_def = item_.graph;
2730 for (int i = 0; i < output_graph_def->node_size(); i++) {
2731 auto node = output_graph_def->mutable_node(i);
2732 AttrValue attr_output_shape;
2733 auto tensor_properties = GetOutputProperties(node->name());
2734 for (const auto& tensor_property : tensor_properties) {
2735 TensorShapeProto* proto = attr_output_shape.mutable_list()->add_shape();
2736 *proto = tensor_property.shape();
2737 if (!allow_symbolic_shapes) {
2738 NormalizeShapeForOutput(proto);
2739 }
2740 }
2741 (*node->mutable_attr())["_output_shapes"] = attr_output_shape;
2742 }
2743 return Status::OK();
2744 }
2745
InferFromCostGraph(const CostGraphDef & cost_graph)2746 Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) {
2747 if (cost_graph.node_size() == 0) {
2748 LOG(WARNING) << "cost_graph is empty: nothing can be inferred!";
2749 }
2750 std::unordered_map<string, const CostGraphDef::Node*> name_to_cost;
2751 std::unordered_map<string, const NodeDef*> name_to_node; // Empty
2752 for (auto& node : cost_graph.node()) {
2753 name_to_cost[node.name()] = &node;
2754
2755 std::vector<OpInfo::TensorProperties> output_properties;
2756 for (const auto& out : node.output_info()) {
2757 OpInfo::TensorProperties properties;
2758 properties.set_dtype(out.dtype());
2759 *properties.mutable_shape() = out.shape();
2760 output_properties.push_back(properties);
2761 }
2762 output_properties_[node.name()] = output_properties;
2763 }
2764
2765 for (const auto& node : item_.graph.node()) {
2766 // Skip the nodes that are not in the cost graph: these are nodes that
2767 // aren't run, because they aren't in the intersection of transitive fan-in
2768 // of a fetch node and the transitive fan-out of an input, or nodes that
2769 // were optimized away by the optimizer.
2770 auto it = name_to_cost.find(node.name());
2771 if (it == name_to_cost.end()) {
2772 continue;
2773 }
2774 std::vector<OpInfo::TensorProperties> inputs =
2775 FindInputFeatures(node, name_to_cost, name_to_node);
2776
2777 input_properties_[node.name()] = inputs;
2778 }
2779 return Status::OK();
2780 }
2781
HasInputProperties(const string & node_name) const2782 bool GraphProperties::HasInputProperties(const string& node_name) const {
2783 return input_properties_.find(node_name) != input_properties_.end();
2784 }
2785
HasOutputProperties(const string & node_name) const2786 bool GraphProperties::HasOutputProperties(const string& node_name) const {
2787 return output_properties_.find(node_name) != output_properties_.end();
2788 }
2789
2790 const std::vector<OpInfo::TensorProperties>&
GetInputProperties(const string & node_name) const2791 GraphProperties::GetInputProperties(const string& node_name) const {
2792 auto it = input_properties_.find(node_name);
2793 if (it != input_properties_.end()) {
2794 return it->second;
2795 }
2796 return missing_properties_;
2797 }
2798
2799 const std::vector<OpInfo::TensorProperties>&
GetOutputProperties(const string & node_name) const2800 GraphProperties::GetOutputProperties(const string& node_name) const {
2801 auto it = output_properties_.find(node_name);
2802 if (it != output_properties_.end()) {
2803 return it->second;
2804 }
2805 return missing_properties_;
2806 }
2807
ClearInputProperties(const string & node_name)2808 void GraphProperties::ClearInputProperties(const string& node_name) {
2809 input_properties_.erase(node_name);
2810 }
ClearOutputProperties(const string & node_name)2811 void GraphProperties::ClearOutputProperties(const string& node_name) {
2812 output_properties_.erase(node_name);
2813 }
2814
2815 } // end namespace grappler
2816 } // end namespace tensorflow
2817