1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/graph_properties.h"
17
18 #include "absl/types/optional.h"
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/framework/common_shape_fns.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/framework/versions.pb.h"
29 #include "tensorflow/core/graph/tensor_id.h"
30 #include "tensorflow/core/grappler/costs/utils.h"
31 #include "tensorflow/core/grappler/mutable_graph_view.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
34 #include "tensorflow/core/grappler/utils.h"
35 #include "tensorflow/core/grappler/utils/functions.h"
36 #include "tensorflow/core/grappler/utils/topological_sort.h"
37 #include "tensorflow/core/lib/gtl/cleanup.h"
38 #include "tensorflow/core/lib/gtl/flatset.h"
39 #include "tensorflow/core/lib/strings/str_util.h"
40
41 namespace tensorflow {
42 namespace grappler {
43
44 namespace {
45
46 using shape_inference::DimensionHandle;
47 using shape_inference::InferenceContext;
48 using shape_inference::ShapeAndType;
49 using shape_inference::ShapeHandle;
50 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
51
52 // A large value for UnknownDim from Const used as a dim value in shape.
53 // Some ops treat "-1" specially, different from UnknownDim:
54 // e.g., shape input to Reshape op.
55 const int64_t kUnknownDimFromConst = INT64_MAX;
56
57 // Skip const value instantiation if the number of elements in a const tensor
58 // is greater than this threshold.
59 const int kThresholdToSkipConstTensorInstantiation = 128;
60
61 template <typename Handle>
62 struct HashHandle {
operator ()tensorflow::grappler::__anon19a2e1120111::HashHandle63 std::size_t operator()(const Handle& h) const { return h.Handle(); }
64 };
65 template <typename Handle>
66 struct CompareHandle {
operator ()tensorflow::grappler::__anon19a2e1120111::CompareHandle67 bool operator()(const Handle& h1, const Handle& h2) const {
68 return h1.SameHandle(h2);
69 }
70 };
71
72 template <typename Handle>
73 struct HandleToObject {};
74 template <>
75 struct HandleToObject<ShapeHandle> {
76 typedef ShapeHandle Object;
77
Unknowntensorflow::grappler::__anon19a2e1120111::HandleToObject78 static ShapeHandle Unknown() { return ShapeHandle(); }
79 };
80
81 template <>
82 struct HandleToObject<DimensionHandle> {
83 typedef int64 Object;
84
Unknowntensorflow::grappler::__anon19a2e1120111::HandleToObject85 static int64 Unknown() { return -1; }
86 };
87
88 template <typename Handle>
89 struct Processor {};
90
91 template <>
92 struct Processor<ShapeHandle> {
93 // Extract the shape or dim denoted by the handle.
ExtractValuetensorflow::grappler::__anon19a2e1120111::Processor94 void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; }
95 // Merge the shapes or dims.
Mergetensorflow::grappler::__anon19a2e1120111::Processor96 Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) {
97 if (InferenceContext::RankKnown(*result)) {
98 // The result was initialized in a previous merge to a shape of known
99 // rank, make sure we preserve that information.
100 return Status::OK();
101 }
102 if (InferenceContext::RankKnown(h1)) {
103 *result = h1;
104 } else {
105 *result = h2;
106 }
107 return Status::OK();
108 }
109 };
110
111 template <>
112 struct Processor<DimensionHandle> {
113 // Assign a negative id to unknown dimensions, starting at -2 (the -1 id
114 // reserved by TensorFlow).
ExtractValuetensorflow::grappler::__anon19a2e1120111::Processor115 void ExtractValue(DimensionHandle d, int64* result) {
116 if (!InferenceContext::ValueKnown(d)) {
117 *result = -counter;
118 counter++;
119 } else {
120 int64_t val = InferenceContext::Value(d);
121 if (val >= 0) {
122 *result = val;
123 } else {
124 // A shape inference function generated an invalid dimension handle.
125 // Use a symbolic dimension to encode this.
126 *result = -counter;
127 counter++;
128 }
129 }
130 }
131
132 // Merge the dimensions d1 and d2. Return the known shape if there is one,
133 // otherwise look for a symbolic shape. If there is no symbolic shape and no
134 // known shape, the shape if fully unknown so return -1.
Mergetensorflow::grappler::__anon19a2e1120111::Processor135 Status Merge(DimensionHandle d1, DimensionHandle d2, int64* result) {
136 const int64_t dim1 = InferenceContext::Value(d1);
137 const int64_t dim2 = InferenceContext::Value(d2);
138
139 if (dim1 >= 0 && dim2 >= 0) {
140 CHECK_EQ(dim1, dim2);
141 return RefineDim(dim1, result);
142 } else if (dim1 >= 0 && dim2 < 0) {
143 return RefineDim(dim1, result);
144 } else if (dim1 < 0 && dim2 >= 0) {
145 return RefineDim(dim2, result);
146 } else if (dim1 < -1) {
147 return RefineDim(dim1, result);
148 } else if (dim2 < -1) {
149 return RefineDim(dim2, result);
150 } else {
151 CHECK_EQ(dim1, dim2);
152 CHECK_EQ(-1, dim1);
153 return RefineDim(-1, result);
154 }
155 return Status::OK();
156 }
157
158 private:
RefineDimtensorflow::grappler::__anon19a2e1120111::Processor159 Status RefineDim(int64_t dim, int64* result) {
160 if (*result >= 0) {
161 if (!(*result == dim || dim < 0)) {
162 return errors::InvalidArgument("Inconsistent dimensions detected");
163 }
164 } else if (dim >= 0) {
165 *result = dim;
166 } else if (dim < *result) {
167 *result = dim;
168 }
169 return Status::OK();
170 }
171
172 int64 counter = 2;
173 };
174
175 // Traditional Disjoint-Set datastructure with path compression.
176 // (https://en.wikipedia.org/wiki/Disjoint-set_data_structure)
177 template <typename Handle>
178 class DisjointSet {
179 public:
DisjointSet()180 DisjointSet() {}
~DisjointSet()181 ~DisjointSet() {
182 for (auto rep : nodes_) {
183 delete rep.second;
184 }
185 }
186
187 Status Merge(Handle x, Handle y);
188 const typename HandleToObject<Handle>::Object GetMergedValue(Handle value);
189
190 private:
191 // All the handles that belong to the same set are part of the same tree, and
192 // utimately represented by the root of that tree.
193 struct Rep {
194 // Parent in the tree used to encode the set.
195 Rep* parent;
196 // Rank in the tree, used to figure out how to compress the path to the root
197 // of the tree.
198 int rank;
199 // The handle.
200 typename HandleToObject<Handle>::Object value;
201 };
202
203 // Create a new set for the value if none exists, or return its representative
204 // node otherwise.
205 Rep* Find(Handle value);
206
207 private:
208 Processor<Handle> processor_;
209 absl::flat_hash_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>>
210 nodes_;
211 };
212
213 template <typename Handle>
214 const typename HandleToObject<Handle>::Object
GetMergedValue(Handle value)215 DisjointSet<Handle>::GetMergedValue(Handle value) {
216 Rep* rep = Find(value);
217 if (!rep) {
218 // We don't know anything about this handle.
219 return HandleToObject<Handle>::Unknown();
220 }
221 return rep->value;
222 }
223
224 template <typename Handle>
Merge(Handle x,Handle y)225 Status DisjointSet<Handle>::Merge(Handle x, Handle y) {
226 Rep* x_root = Find(x);
227 Rep* y_root = Find(y);
228
229 // x and y are already in the same set
230 if (x_root == y_root) {
231 return Status::OK();
232 }
233 // x and y are not in same set, so we merge them
234 // Use the occasion to strengthen what we know about the handle by merging the
235 // information about the 2 subsets.
236 if (x_root->rank < y_root->rank) {
237 TF_RETURN_IF_ERROR(processor_.Merge(y, x, &y_root->value));
238 x_root->parent = y_root;
239 } else if (x_root->rank > y_root->rank) {
240 TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
241 y_root->parent = x_root;
242 } else {
243 TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
244 // Arbitrarily make one root the new parent
245 y_root->parent = x_root;
246 x_root->rank = x_root->rank + 1;
247 }
248 return Status::OK();
249 }
250
251 template <typename Handle>
Find(Handle value)252 typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
253 auto it = nodes_.find(value);
254 if (it == nodes_.end()) {
255 // This is the first time we process this handle, create an entry for it.
256 Rep* node = new Rep;
257 node->parent = node;
258 node->rank = 0;
259 processor_.ExtractValue(value, &node->value);
260 nodes_[value] = node;
261 return node;
262 }
263 // Return the representative for the set, which is the root of the tree. Apply
264 // path compression to speedup future queries.
265 Rep* node = it->second;
266 Rep* root = node->parent;
267 while (root != root->parent) {
268 root = root->parent;
269 }
270 while (node->parent != root) {
271 Rep* next = node->parent;
272 node->parent = root;
273 node = next;
274 }
275 return root;
276 }
277
278 // TODO(dyoon): Move many helper functions in this file (including those within
279 // SymbolicShapeRefiner class) to shared utils.
IsEnqueue(const NodeDef & n)280 bool IsEnqueue(const NodeDef& n) {
281 return (n.op().find("Enqueue") != string::npos &&
282 n.op().find("EnqueueMany") == string::npos);
283 }
284
IsDequeue(const NodeDef & n)285 bool IsDequeue(const NodeDef& n) {
286 return (n.op().find("Dequeue") != string::npos &&
287 n.op().find("DequeueMany") == string::npos);
288 }
289
HasAnyUnknownDimensions(const TensorShapeProto & proto)290 bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
291 if (proto.unknown_rank()) {
292 return true;
293 }
294 for (const auto& dim : proto.dim()) {
295 if (dim.size() < 0) {
296 return true;
297 }
298 }
299 return false;
300 }
301
302 // This really should be done in an external debugging tool
VerboseLogUnknownDimensionSources(const GraphDef & graph,const absl::flat_hash_map<string,std::vector<OpInfo::TensorProperties>> & input_properties_map,const absl::flat_hash_map<string,std::vector<OpInfo::TensorProperties>> & output_properties_map)303 void VerboseLogUnknownDimensionSources(
304 const GraphDef& graph,
305 const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
306 input_properties_map,
307 const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
308 output_properties_map) {
309 if (!VLOG_IS_ON(2)) {
310 return;
311 }
312
313 VLOG(2) << "Nodes with known inputs, but with unknown output dimensions:";
314
315 // Find all nodes in the graph for which we
316 // do not have any unknown dimensions in their inputs, but
317 // we have some unknown dimensions in their outputs.
318 std::map<string, int> op_to_count;
319 for (const NodeDef& node : graph.node()) {
320 const auto& input_properties = input_properties_map.at(node.name());
321 const auto& output_properties = output_properties_map.at(node.name());
322
323 bool has_unknown_inputs = false;
324 for (const auto& input_prop : input_properties) {
325 if (HasAnyUnknownDimensions(input_prop.shape())) {
326 has_unknown_inputs = true;
327 break;
328 }
329 }
330
331 if (has_unknown_inputs) {
332 continue;
333 }
334
335 for (const auto& output_prop : output_properties) {
336 if (HasAnyUnknownDimensions(output_prop.shape())) {
337 string inputs = "input_shapes=[";
338 for (const auto& input_prop : input_properties) {
339 inputs += PartialTensorShape::DebugString(input_prop.shape());
340 }
341 inputs += "]";
342
343 string outputs = "output_shapes=[";
344 for (const auto& output_prop : output_properties) {
345 outputs += PartialTensorShape::DebugString(output_prop.shape());
346 }
347 outputs += "]";
348
349 VLOG(2) << "Node: " << node.name() << ", Op: " << node.op() << ", "
350 << inputs << ", " << outputs;
351
352 op_to_count[node.op()]++;
353
354 // don't log again for this node
355 break;
356 }
357 }
358 }
359 VLOG(2) << "Op types with known inputs, but with unknown output dimensions "
360 << "(format: <op_type> (<count>)):";
361 for (const auto& p : op_to_count) {
362 VLOG(2) << p.first << " (" << p.second << ")";
363 }
364 }
365
366 // Helper function to convert kUnknownDimFromConst into UnknownDim.
ReplaceUnknownDimFromConstWithUnknownDim(InferenceContext * ic,const std::vector<ShapeHandle> & shapes)367 std::vector<ShapeHandle> ReplaceUnknownDimFromConstWithUnknownDim(
368 InferenceContext* ic, const std::vector<ShapeHandle>& shapes) {
369 std::vector<ShapeHandle> converted_shapes(shapes.size());
370 for (int i = 0, shapes_size = shapes.size(); i < shapes_size; i++) {
371 const auto& shape = shapes[i];
372 if (!ic->RankKnown(shape)) {
373 converted_shapes[i] = shape;
374 continue;
375 }
376 bool just_copy = true;
377 std::vector<DimensionHandle> dims;
378 for (int32_t i = 0; i < ic->Rank(shape); ++i) {
379 DimensionHandle dim = ic->Dim(shape, i);
380 if (ic->ValueKnown(dim) && ic->Value(dim) == kUnknownDimFromConst) {
381 just_copy = false;
382 dims.push_back(ic->UnknownDim());
383 } else {
384 dims.push_back(dim);
385 }
386 }
387 if (just_copy) {
388 converted_shapes[i] = shape;
389 continue;
390 }
391 converted_shapes[i] = ic->MakeShape(dims);
392 }
393 return converted_shapes;
394 }
395
396 // Returned tensor's shape is like `shape`, and its values and dtype are from
397 // `tensor_as_shape` and `dtype`.
MakeTensorProtoFromShape(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)398 TensorProto MakeTensorProtoFromShape(InferenceContext* ic,
399 const ShapeHandle& shape,
400 const ShapeHandle& tensor_as_shape,
401 const DataType& dtype) {
402 TensorProto tensor_proto;
403 tensor_proto.set_dtype(dtype);
404 auto* shape_proto = tensor_proto.mutable_tensor_shape();
405 if (ic->Rank(shape) == 1) {
406 shape_proto->add_dim()->set_size(ic->Rank(tensor_as_shape));
407 }
408 // For a scalar tensor, tensor_shape field will be left empty; no dim.
409 for (int i = 0; i < ic->Rank(tensor_as_shape); i++) {
410 int64_t value = ic->Value(ic->Dim(tensor_as_shape, i));
411 if (dtype == DT_INT32) {
412 tensor_proto.add_int_val(value);
413 } else {
414 tensor_proto.add_int64_val(value);
415 }
416 }
417 return tensor_proto;
418 }
419
420 // Returns a Const NodeDef with tensor `tensor_proto` and dtype = `dtype`.
MakeConstNodeDefFromTensorProto(InferenceContext * ic,const TensorProto & tensor_proto,const DataType & dtype)421 NodeDef MakeConstNodeDefFromTensorProto(InferenceContext* ic,
422 const TensorProto& tensor_proto,
423 const DataType& dtype) {
424 NodeDef const_node;
425 const_node.set_name("const_from_shape");
426 const_node.set_op("Const");
427 auto* attr = const_node.mutable_attr();
428 (*attr)["dtype"].set_type(dtype);
429 auto* tensor = (*attr)["value"].mutable_tensor();
430 *tensor = tensor_proto;
431 return const_node;
432 }
433
434 // Returns a Const NodeDef with shape = `shape`, values = `tensor_as_shape`,
435 // and dtype = `dtype`.
MakeConstNodeDefFromShape(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)436 NodeDef MakeConstNodeDefFromShape(InferenceContext* ic,
437 const ShapeHandle& shape,
438 const ShapeHandle& tensor_as_shape,
439 const DataType& dtype) {
440 return MakeConstNodeDefFromTensorProto(
441 ic, MakeTensorProtoFromShape(ic, shape, tensor_as_shape, dtype), dtype);
442 }
443
IsNumericType(const DataType dtype)444 bool IsNumericType(const DataType dtype) {
445 static const gtl::FlatSet<DataType>* const kRealNumberTypes =
446 CHECK_NOTNULL((new gtl::FlatSet<DataType>{
447 // Floating point.
448 DT_BFLOAT16,
449 DT_HALF,
450 DT_FLOAT,
451 DT_DOUBLE,
452 // Int / UInt.
453 DT_INT8,
454 DT_INT16,
455 DT_INT32,
456 DT_INT64,
457 DT_UINT8,
458 DT_UINT16,
459 DT_UINT32,
460 DT_UINT64,
461 // Quantized Int.
462 DT_QINT8,
463 DT_QUINT8,
464 DT_QINT16,
465 DT_QUINT16,
466 DT_QINT32,
467 // Bool.
468 DT_BOOL,
469 }));
470 return kRealNumberTypes->find(dtype) != kRealNumberTypes->end();
471 }
472
473 // Returns the number of elements in the input (const) tensor.
474 // -1 if the tensor has no shape or unknown rank.
NumElementsFromTensorProto(const TensorProto & tensor_proto)475 uint64 NumElementsFromTensorProto(const TensorProto& tensor_proto) {
476 if (!tensor_proto.has_tensor_shape()) {
477 return -1;
478 }
479 const auto& tensor_shape_proto = tensor_proto.tensor_shape();
480 if (tensor_shape_proto.unknown_rank()) {
481 return -1;
482 }
483 int64_t num_elements = 1;
484 for (const auto& dim : tensor_shape_proto.dim()) {
485 // Note that in some cases, dim.size() can be zero (e.g., empty vector).
486 num_elements *= dim.size();
487 }
488 return num_elements;
489 }
490
491 } // namespace
492
493 // Note that tensor_as_shape input should not include kUnknownDimFromConst.
494 // This function check kUnknownDimFromConst, but will log WARNING.
495 // If checking input_tensors_as_shape_to_propgate or output_tensors_as_shape,
496 // which may include kUnknownDimFromConst, run
497 // convert it using ReplaceUnknownDimFromConstWithUnknownDim() before.
IsShapeFullyDefinedIntegerVectorOrScalar(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)498 bool IsShapeFullyDefinedIntegerVectorOrScalar(
499 InferenceContext* ic, const ShapeHandle& shape,
500 const ShapeHandle& tensor_as_shape, const DataType& dtype) {
501 if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 ||
502 !ic->FullyDefined(tensor_as_shape) ||
503 (dtype != DT_INT32 && dtype != DT_INT64)) {
504 return false;
505 }
506 // Also check whether any dim in tensor_as_shape is kUnknownDimFromConst.
507 for (int32_t i = 0; i < ic->Rank(tensor_as_shape); ++i) {
508 DimensionHandle dim = ic->Dim(tensor_as_shape, i);
509 if (ic->Value(dim) == kUnknownDimFromConst) {
510 LOG(WARNING) << "IsShapeFullyDefinedIntegerVectorOrScalar(): "
511 << "tensor_as_shape input includes kUnknownDimFromConst -- "
512 << ic->DebugString(tensor_as_shape);
513 return false;
514 }
515 }
516 return true;
517 }
518
519 // Queue of nodes to process. Nodes can be enqueued in any order, but will be
520 // dequeued in (roughly) topological order. Propagating shapes following a
521 // topological ordering isn't required for correctness but helps speed things up
522 // since it avoids processing the same node multiple times as its inputs
523 // information is refined.
524 class TopoQueue {
525 public:
TopoQueue(const std::vector<const NodeDef * > & topo_order)526 explicit TopoQueue(const std::vector<const NodeDef*>& topo_order)
527 : topo_order_(TopoOrder(topo_order)) {}
528
push(const NodeDef * n)529 void push(const NodeDef* n) { queue_.emplace(n, topo_order_.at(n)); }
530
pop()531 const NodeDef* pop() {
532 CHECK(!empty());
533 auto it = queue_.begin();
534 const NodeDef* n = it->first;
535 queue_.erase(it);
536 return n;
537 }
538
empty() const539 bool empty() const { return queue_.empty(); }
size() const540 std::size_t size() const { return queue_.size(); }
541
542 private:
543 using NodeAndId = std::pair<const NodeDef*, int>;
544 // Graph nodes are created in (roughly) topological order. Therefore we can
545 // use their id to ensure they're sorted topologically.
546 struct OrderByIdAscending {
operator ()tensorflow::grappler::TopoQueue::OrderByIdAscending547 bool operator()(const NodeAndId& lhs, const NodeAndId& rhs) const {
548 return lhs.second < rhs.second;
549 }
550 };
551
TopoOrder(const std::vector<const NodeDef * > & topo_order) const552 const absl::flat_hash_map<const NodeDef*, int> TopoOrder(
553 const std::vector<const NodeDef*>& topo_order) const {
554 absl::flat_hash_map<const NodeDef*, int> map;
555 map.reserve(topo_order.size());
556 for (int i = 0, topo_order_size = topo_order.size(); i < topo_order_size;
557 ++i) {
558 map.emplace(topo_order[i], i);
559 }
560 return map;
561 }
562
563 const absl::flat_hash_map<const NodeDef*, int> topo_order_;
564 std::set<NodeAndId, OrderByIdAscending> queue_;
565 };
566
567
IsAllowListedOpTypeForEvaluateNode(const string & op_type)568 bool IsAllowListedOpTypeForEvaluateNode(const string& op_type) {
569 static const gtl::FlatSet<string>* const kOpTpeAllowlist =
570 CHECK_NOTNULL((new gtl::FlatSet<string>{
571 // Unary arithmetic ops
572 "Floor",
573 "Round",
574 "Sqrt",
575 "Square",
576 "Sign",
577 // Binary arithmetic ops
578 "Add",
579 "AddV2",
580 "Div",
581 "FloorDiv",
582 "FloorMod",
583 "Greater",
584 "GreaterEqual",
585 "Less",
586 "LessEqual",
587 "LogicalAnd",
588 "LogicalNot",
589 "LogicalOr",
590 "Maximum",
591 "Minimum",
592 "Mod",
593 "Mul",
594 "NotEqual",
595 "QuantizedAdd",
596 "QuantizedMul",
597 "SquareDifference",
598 "Sub",
599 "TruncateDiv",
600 "TruncateMod",
601 "RealDiv",
602 // N-ary arithmetic ops
603 "AddN",
604 // Others
605 "StridedSlice",
606 "OnesLike",
607 "ZerosLike",
608 "Concat",
609 "ConcatV2",
610 "Split",
611 "Range",
612 "Fill",
613 "Cast",
614 }));
615 return kOpTpeAllowlist->find(op_type) != kOpTpeAllowlist->end();
616 }
617
618 // Negative shape size of '-1' represents unknown, while negative shape sizes
619 // less than -1 represent unknown symbolic shapes (e.g. the shape of [-5, 5, -1,
620 // -5] really means [x, 5, ?, x]). Before we can output the tensors as shapes,
621 // we need to normalize them: mark all values <-1 as "unknown" (-1).
NormalizeShapeForOutput(TensorShapeProto * shape)622 static void NormalizeShapeForOutput(TensorShapeProto* shape) {
623 for (int i = 0; i < shape->dim_size(); i++) {
624 if (shape->dim(i).size() < -1) {
625 VLOG(2) << "Normalizing dimension: " << i << " from "
626 << shape->dim(i).size() << " to -1";
627 shape->mutable_dim(i)->set_size(-1);
628 }
629 }
630 }
631
632 // Processes symbolic shapes.
633 // Each symbolic shape or dimension is represented by a handle. Unlike the TF
634 // shape refiner which creates new handles every time it processes an unknown
635 // shape/dimension, the symbolic shape refiner assigns a specific handle to each
636 // unknown shape/dimension of a given node.
637 class SymbolicShapeRefiner {
638 public:
SymbolicShapeRefiner(const GraphView & graph,const absl::flat_hash_map<string,absl::flat_hash_set<int>> & fed_ports,const bool aggressive_shape_inference)639 explicit SymbolicShapeRefiner(
640 const GraphView& graph,
641 const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports,
642 const bool aggressive_shape_inference)
643 : graph_(graph),
644 function_library_(OpRegistry::Global(), graph.graph()->library()),
645 fed_ports_(fed_ports),
646 aggressive_shape_inference_(aggressive_shape_inference) {
647 graph_def_version_ = graph.graph()->versions().producer();
648 node_to_context_.reserve(graph.graph()->node_size());
649 }
650
graph() const651 const GraphView& graph() const { return graph_; }
652
653 struct NodeContext {
654 const OpRegistrationData* op_data;
655 DataTypeVector input_types;
656 DataTypeVector output_types;
657 std::unique_ptr<InferenceContext> inference_context;
658 // Additional info for propagating tensor values and tensor shapes.
659 std::vector<const TensorProto*> input_tensor_protos;
660 std::vector<const TensorProto*> output_tensor_protos;
661 // This is the same to inference_context->input_tensors_as_shapes, except
662 // that some UnknownDims (-1) can be kUnknownDimFromConst.
663 std::vector<ShapeHandle> input_tensors_as_shapes_to_propagate;
664 std::vector<ShapeHandle> output_tensors_as_shapes;
665
666 // Output shapes incompatible between annotation and shape inference.
667 bool shape_incompatible = false;
668
669 // Similar to DebugString() in InferenceContext, but prints out
670 // kUnknownDimFromConst properly.
StringifyShapeHandletensorflow::grappler::SymbolicShapeRefiner::NodeContext671 std::string StringifyShapeHandle(ShapeHandle s) {
672 auto* ic = inference_context.get();
673 if (ic->RankKnown(s)) {
674 std::vector<std::string> vals;
675 for (int i = 0; i < ic->Rank(s); i++) {
676 DimensionHandle d = ic->Dim(s, i);
677 if (ic->ValueKnown(d) && ic->Value(d) == kUnknownDimFromConst) {
678 vals.push_back("?(Const)");
679 } else {
680 vals.push_back(ic->DebugString(d));
681 }
682 }
683 return strings::StrCat("[", absl::StrJoin(vals, ","), "]");
684 } else {
685 return "?";
686 }
687 }
688
DebugStringtensorflow::grappler::SymbolicShapeRefiner::NodeContext689 std::string DebugString(const NodeDef& node) {
690 std::string output;
691 auto* ic = inference_context.get();
692 absl::StrAppend(
693 &output, node.name(), " [", node.op(), "] has ", ic->num_inputs(),
694 (ic->num_inputs() > 1 ? " inputs and " : " input and "),
695 ic->num_outputs(), (ic->num_outputs() > 1 ? " outputs" : " output"));
696 if (op_data->is_function_op) {
697 absl::StrAppend(&output, " (function op)");
698 }
699 absl::StrAppend(&output, ": \n");
700
701 for (int i = 0; i < ic->num_inputs(); i++) {
702 absl::StrAppend(&output, " input [", i, "] ", node.input(i),
703 " -- type: ", DataTypeString(input_types.at(i)),
704 ", shape: ", ic->DebugString(ic->input(i)),
705 ", tensor: ");
706 Tensor t1;
707 int input_tensor_protos_size = input_tensor_protos.size();
708 if (input_tensor_protos_size > i &&
709 input_tensor_protos.at(i) != nullptr &&
710 t1.FromProto(*input_tensor_protos.at(i))) {
711 absl::StrAppend(&output, t1.DebugString(), ", tensor_as_shape: ");
712 } else {
713 absl::StrAppend(&output, " null, tensor_as_shape: ");
714 }
715 int input_tensors_as_shapes_to_propagate_size =
716 input_tensors_as_shapes_to_propagate.size();
717 if (input_tensors_as_shapes_to_propagate_size > i) {
718 absl::StrAppend(
719 &output,
720 StringifyShapeHandle(input_tensors_as_shapes_to_propagate.at(i)),
721 "\n");
722 } else {
723 absl::StrAppend(&output, " null\n");
724 }
725 }
726 for (int i = 0; i < ic->num_outputs(); i++) {
727 absl::StrAppend(&output, " output [", i,
728 "] -- type: ", DataTypeString(output_types.at(i)),
729 ", shape: ", ic->DebugString(ic->output(i)),
730 ", tensor: ");
731 Tensor t2;
732 int output_tensor_protos_size = output_tensor_protos.size();
733 if (output_tensor_protos_size > i &&
734 output_tensor_protos.at(i) != nullptr &&
735 t2.FromProto(*output_tensor_protos.at(i))) {
736 absl::StrAppend(&output, t2.DebugString(), ", tensor_as_shape: ");
737 } else {
738 absl::StrAppend(&output, " null, tensor_as_shape: ");
739 }
740 int output_tensors_as_shapes_size = output_tensors_as_shapes.size();
741 if (output_tensors_as_shapes_size > i) {
742 absl::StrAppend(&output,
743 StringifyShapeHandle(output_tensors_as_shapes.at(i)),
744 "\n");
745 } else {
746 absl::StrAppend(&output, " null\n");
747 }
748 }
749 return output;
750 }
751 };
752
GetNodeContext(const NodeDef * node)753 NodeContext* GetNodeContext(const NodeDef* node) {
754 auto it = node_to_context_.find(node);
755 if (it == node_to_context_.end()) {
756 return nullptr;
757 }
758 return &it->second;
759 }
760
GetContext(const NodeDef * node)761 InferenceContext* GetContext(const NodeDef* node) {
762 auto it = node_to_context_.find(node);
763 if (it == node_to_context_.end()) {
764 return nullptr;
765 }
766 return it->second.inference_context.get();
767 }
768
769 // Forward the shapes from the function input nodes, PartitionedCalls or
770 // StatefulPartitionedCall to
771 // the argument nodes (which are Placeholder nodes), then
772 // perform shape inference on the function body.
773 //
774 // Propagate shape information of final function body node
775 // to function node `function_node`.
776 //
777 // In the event of an error, UpdateNode will simply set `function_node`'s
778 // output shape to be Unknown.
UpdateFunction(const NodeDef * function_node)779 Status UpdateFunction(const NodeDef* function_node) {
780 NameAttrList function;
781 TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*function_node, &function));
782 auto it = fun_to_grappler_function_item_.find(function.name());
783 if (it == fun_to_grappler_function_item_.end()) {
784 return errors::InvalidArgument(
785 function.name(),
786 " was not previously added to SymbolicShapeRefiner.");
787 }
788
789 const absl::optional<GrapplerFunctionItem>& maybe_grappler_function_item =
790 it->second;
791 if (!maybe_grappler_function_item.has_value()) {
792 VLOG(3) << "Skip failed to instantiate function call: function_name="
793 << function.name();
794
795 auto* ctx = GetNodeContext(function_node);
796 auto* ic = ctx->inference_context.get();
797 for (int i = 0; i < ic->num_outputs(); ++i) {
798 TF_RETURN_IF_ERROR(SetUnknownShape(function_node, i));
799 }
800
801 return Status::OK();
802 }
803
804 // Copy (not reference) so that changes we make here (e.g., replacing
805 // _Arg with Const and _Retval with Identity) don't affect one in
806 // fun_to_grappler_function_item_.
807 GrapplerFunctionItem grappler_function_item = *maybe_grappler_function_item;
808 MutableGraphView gv(&grappler_function_item.graph);
809
810 // Forward shapes from function input nodes to argument nodes.
811 for (int i = 0, end = grappler_function_item.inputs().size(); i < end;
812 ++i) {
813 auto& fun_input = grappler_function_item.input(i);
814 NodeDef* fun_node = gv.GetNode(fun_input.node_name);
815 const TensorId input_tensor = ParseTensorName(function_node->input(i));
816
817 if (IsControlInput(input_tensor)) {
818 return errors::FailedPrecondition(
819 "Function inputs should not contain control nodes.");
820 }
821
822 const NodeDef* input_node = graph_.GetNode(input_tensor.node());
823 if (input_node == nullptr) {
824 return errors::FailedPrecondition(input_tensor.node(),
825 " was not found in the graph.");
826 }
827
828 InferenceContext* input_ic = GetContext(input_node);
829 if (input_ic == nullptr) {
830 return errors::FailedPrecondition(
831 "Inference context has not been created for ", input_tensor.node());
832 }
833
834 int output_port_num = input_tensor.index();
835 AttrValue attr_output_shape;
836 TensorShapeProto proto;
837 const auto handle = input_ic->output(output_port_num);
838 input_ic->ShapeHandleToProto(handle, &proto);
839 // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
840 NormalizeShapeForOutput(&proto);
841 // _Arg op's output shape uses _output_shapes attr.
842 AttrValue output_attr;
843 output_attr.mutable_list()->add_shape()->Swap(&proto);
844 (*fun_node->mutable_attr())["_output_shapes"] = output_attr;
845
846 // If dtype is DT_RESOURCE, ops that read _Arg op use _handle_dtypes and
847 // _handle_shapes attr for its shapes and dtypes.
848 if (fun_input.data_type == DT_RESOURCE) {
849 auto* shapes_and_types =
850 input_ic->output_handle_shapes_and_types(output_port_num);
851 if (shapes_and_types != nullptr && !shapes_and_types->empty()) {
852 AttrValue dtype_attr;
853 AttrValue shape_attr;
854 for (const auto& shape_and_type : *shapes_and_types) {
855 const auto& dtype = shape_and_type.dtype;
856 const auto& shape_handle = shape_and_type.shape;
857 dtype_attr.mutable_list()->add_type(dtype);
858 input_ic->ShapeHandleToProto(
859 shape_handle, shape_attr.mutable_list()->add_shape());
860 }
861 (*fun_node->mutable_attr())["_handle_dtypes"] = dtype_attr;
862 (*fun_node->mutable_attr())["_handle_shapes"] = shape_attr;
863 } else {
864 // Note that we do not return error here, even if the input node does
865 // not have shapes_and_types. Within the function, we cannot infer the
866 // output shape of the DT_RESOURCE input; hence, potentially unknown
867 // shapes/dims in the function output shapes.
868 VLOG(2)
869 << "A function node (" << function_node->name()
870 << ") has input with DT_RESOURCE, but the input node does not "
871 << "have shapes_and_types information: \n"
872 << "function_node: " << function_node->ShortDebugString() << "\n"
873 << "function input: " << i
874 << ", input node's output: " << output_port_num << "\n"
875 << "input node: " << input_node->ShortDebugString();
876 }
877 }
878 }
879
880 // ReplaceInputWithConst() may break GraphView's internal node mapping
881 // structure; hence, we separately build node name to NodeDef* map, for the
882 // output nodes (before GraphView becomes invalid). Note that we use string,
883 // not string_view.
884 absl::flat_hash_map<std::string, NodeDef*> output_nodes;
885 for (const auto& output_arg : grappler_function_item.outputs()) {
886 output_nodes[output_arg.node_name] = gv.GetNode(output_arg.node_name);
887 }
888
889 // Replace input nodes with Consts, if values are known. Note that
890 // we don't check exceptions here as it's done in the above loop.
891 auto* ctx = GetNodeContext(function_node);
892 auto* ic = ctx->inference_context.get();
893 for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
894 const string& input = function_node->input(i);
895 const string node_name = NodeName(input);
896 const NodeDef* input_node = graph_.GetNode(node_name);
897 if (IsConstant(*input_node)) {
898 TF_CHECK_OK(
899 ReplaceInputWithConst(*input_node, i, &grappler_function_item));
900 } else if (static_cast<int>(ctx->input_tensor_protos.size()) > i &&
901 ctx->input_tensor_protos[i] != nullptr) {
902 NodeDef const_input_node = MakeConstNodeDefFromTensorProto(
903 ic, *ctx->input_tensor_protos[i], ctx->input_types[i]);
904 TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
905 &grappler_function_item));
906 } else if (static_cast<int>(ic->input_tensors_as_shapes().size()) > i &&
907 IsShapeFullyDefinedIntegerVectorOrScalar(
908 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
909 ctx->input_types[i])) {
910 // We have fully defined input_tensors_as_shapes for this input; use it
911 // as a const input to the function node.
912 NodeDef const_input_node = MakeConstNodeDefFromShape(
913 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
914 ctx->input_types[i]);
915 TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
916 &grappler_function_item));
917 }
918 }
919 // node_name to NodeDef* map in GraphView gv can be broken due to
920 // ReplaceInputWithConst(). gv should not be used after this.
921
922 // Replace output _Retval nodes with Identity nodes. _Retval is a system op
923 // without outputs and registered shape function.
924 for (const auto& output_arg : grappler_function_item.outputs()) {
925 NodeDef* output_node = output_nodes[output_arg.node_name];
926 DCHECK_EQ(output_node->op(), "_Retval");
927 output_node->set_op("Identity");
928 output_node->mutable_attr()->erase("index");
929 }
930
931 // Perform inference on function body.
932 GraphProperties gp(grappler_function_item);
933 TF_RETURN_IF_ERROR(gp.InferStatically(
934 /*assume_valid_feeds=*/true,
935 /*aggressive_shape_inference=*/aggressive_shape_inference_,
936 /*include_tensor_values=*/true));
937
938 // Add return nodes for output shapes.
939 int output = 0;
940 ctx->output_tensors_as_shapes.resize(grappler_function_item.output_size());
941 ctx->output_tensor_protos.resize(grappler_function_item.output_size(),
942 nullptr);
943 for (auto const& out_arg : grappler_function_item.outputs()) {
944 // It is guaranteed that output_tensors does not contain any control
945 // inputs, so port_id >= 0.
946 TensorId out_tensor = ParseTensorName(out_arg.node_name);
947
948 if (output_nodes.count(out_tensor.node()) <= 0) {
949 return errors::FailedPrecondition(
950 "Unable to find return function_node ", out_tensor.node(), " for ",
951 function_node->name());
952 }
953 const NodeDef* retnode = output_nodes[out_tensor.node()];
954
955 auto output_properties = gp.GetOutputProperties(retnode->name());
956 int output_properties_size = output_properties.size();
957 if (out_tensor.index() >= output_properties_size) {
958 return errors::InvalidArgument(
959 out_tensor.ToString(), " has invalid position ", out_tensor.index(),
960 " (output_properties.size() = ", output_properties.size(), ").");
961 }
962 auto& outprop = output_properties[out_tensor.index()];
963 TensorShapeProto shape = outprop.shape();
964 NormalizeShapeForOutput(&shape);
965 ShapeHandle out;
966 TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
967 ic->set_output(output, out);
968 if (outprop.has_value()) {
969 // Forward tensor value to output_tensors_as_shape.
970 MaybeTensorProtoToShape(ic, outprop.value(),
971 &ctx->output_tensors_as_shapes[output]);
972 const_tensors_to_propagate_.push_back(outprop.value());
973 ctx->output_tensor_protos[output] = &const_tensors_to_propagate_.back();
974 }
975 output++;
976 }
977
978 return Status::OK();
979 }
980
981 // Prepares input shapes/values/handles, then runs shape inference, and
982 // finally sets output shapes/values/handles.
UpdateNode(const NodeDef * node,bool * refined)983 Status UpdateNode(const NodeDef* node, bool* refined) {
984 NodeContext* ctx = GetNodeContext(node);
985 if (ctx == nullptr) {
986 TF_RETURN_IF_ERROR(AddNode(node));
987 ctx = CHECK_NOTNULL(GetNodeContext(node));
988 *refined = true;
989 }
990
991 // Check if the shapes of the nodes in the fan-in of this node have changed,
992 // and if they have, update the node input shapes.
993 InferenceContext* ic = ctx->inference_context.get();
994 ctx->input_tensors_as_shapes_to_propagate.resize(ic->num_inputs());
995 ctx->input_tensor_protos.resize(ic->num_inputs(), nullptr);
996
997 for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) {
998 const GraphView::InputPort port(node, dst_input);
999 const GraphView::OutputPort fanin = graph_.GetRegularFanin(port);
1000 int src_output = fanin.port_id;
1001 const NodeDef* src = fanin.node;
1002 NodeContext* src_ctx = GetNodeContext(src);
1003 if (src_ctx == nullptr) {
1004 return errors::FailedPrecondition(
1005 "Input ", dst_input, " for '", node->name(),
1006 "' was not previously added to SymbolicShapeRefiner.");
1007 }
1008
1009 InferenceContext* src_ic = src_ctx->inference_context.get();
1010 if (src_output >= src_ic->num_outputs()) {
1011 return errors::OutOfRange("src_output = ", src_output,
1012 ", but num_outputs is only ",
1013 src_ic->num_outputs());
1014 }
1015
1016 // Propagate input node's NodeContext info to the current node's
1017 // NodeContext:
1018 // output_tensor_protos to input_tensor_protos and input_tensors, and
1019 // output_tensors_as_shapes to input_tensors_as_shapes.
1020 if (static_cast<int>(src_ctx->output_tensors_as_shapes.size()) >
1021 src_output) {
1022 ctx->input_tensors_as_shapes_to_propagate[dst_input] =
1023 src_ctx->output_tensors_as_shapes[src_output];
1024 }
1025
1026 if (static_cast<int>(src_ctx->output_tensor_protos.size()) > src_output) {
1027 const auto* tensor_proto = src_ctx->output_tensor_protos[src_output];
1028 if (tensor_proto != nullptr) {
1029 ctx->input_tensor_protos[dst_input] = tensor_proto;
1030 }
1031 }
1032
1033 // NOTE: we check only shape is refined; we do not (yet) check whether
1034 // tensor value is refined.
1035 if (!*refined &&
1036 !ic->input(dst_input).SameHandle(src_ic->output(src_output))) {
1037 *refined = true;
1038 }
1039 ic->SetInput(dst_input, src_ic->output(src_output));
1040
1041 if (!*refined && ic->requested_input_tensor_as_partial_shape(dst_input)) {
1042 // The input value may have changed. Since we have no way to know if
1043 // that's indeed the case, err on the safe side.
1044 *refined = true;
1045 }
1046
1047 // Also propagate handle shape and dtype of edges which are carrying
1048 // resource handles.
1049 if (ctx->input_types[dst_input] == DT_RESOURCE) {
1050 auto* outputs = src_ic->output_handle_shapes_and_types(src_output);
1051 if (!outputs) continue;
1052 auto* inputs = ic->input_handle_shapes_and_types(dst_input);
1053
1054 if (!inputs || !EquivalentShapesAndTypes(*outputs, *inputs))
1055 *refined = true;
1056 ic->set_input_handle_shapes_and_types(dst_input, *outputs);
1057 }
1058 }
1059
1060 // Make sure we schedule the fanout of resources (which have no input)
1061 // whenever the resources are updated.
1062 *refined |= ic->num_inputs() == 0;
1063
1064 if (!*refined) {
1065 // No input shape has changed, we're done.
1066 return Status::OK();
1067 }
1068
1069 // Convert all kUnknownDimFromConst to -1 for shape inference.
1070 ic->set_input_tensors_as_shapes(ReplaceUnknownDimFromConstWithUnknownDim(
1071 ic, ctx->input_tensors_as_shapes_to_propagate));
1072 // Note: UpdateFunction uses input_tensors_as_shapes and
1073 // input_tensor_protos (not the Tensor object) for input values.
1074 // so for function nodes, we don't need to convert TensorProtos
1075 // to Tensors here. If the current op is not a function op, we convert
1076 // TensorProtos to Tensors before calling InferShapes.
1077
1078 // Properly handle function nodes.
1079 if (ctx->op_data && ctx->op_data->is_function_op) {
1080 // TODO(jmdecker): Detect if the input shapes have changed for this
1081 // function. Note that when we hit a function call node, refined will be
1082 // true, as the updates to the call node will have changed, even if it's
1083 // the same function being called twice with the same input shapes.
1084 // Example: simple_function.pbtxt
1085 if (aggressive_shape_inference_) {
1086 // If output shapes are annotated, use it and skip UpdateFunction();
1087 // it can be very expensive when a function node has nested function
1088 // nodes internally. One downside with this approach is that we do not
1089 // get output values or output shapes as tensor from function node.
1090 auto s = UpdateOutputShapesUsingAnnotatedInformation(*node, ctx);
1091 if (s.ok() && AllOutputShapesKnown(ctx)) {
1092 return Status::OK();
1093 }
1094 // If shape annotation was not available, incomplete, or incompatible,
1095 // fall through to call UpdateFunction().
1096 }
1097 auto s = UpdateFunction(node);
1098 if (s.ok()) {
1099 return Status::OK();
1100 } else {
1101 VLOG(1) << "UpdateFunction failed for " << node->op()
1102 << ". Defaulting to ShapeUnknown.\n"
1103 << s.ToString();
1104 }
1105 }
1106
1107 // Construct Tensors for constant inputs used by shape functions.
1108 std::vector<Tensor> const_values(ic->num_inputs());
1109 std::vector<const Tensor*> input_tensors(ic->num_inputs(), nullptr);
1110 for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) {
1111 const TensorProto* tensor_proto = ctx->input_tensor_protos[dst_input];
1112 if (tensor_proto != nullptr &&
1113 // Skip if the const tensor is too large.
1114 NumElementsFromTensorProto(*tensor_proto) <=
1115 kThresholdToSkipConstTensorInstantiation &&
1116 const_values[dst_input].FromProto(*tensor_proto)) {
1117 input_tensors[dst_input] = &const_values[dst_input];
1118 }
1119 }
1120 ic->set_input_tensors(input_tensors);
1121
1122 // Update the shapes of the outputs.
1123 return InferShapes(*node, ctx);
1124 }
1125
SetUnknownShape(const NodeDef * node,int output_port)1126 Status SetUnknownShape(const NodeDef* node, int output_port) {
1127 shape_inference::ShapeHandle shape =
1128 GetUnknownOutputShape(node, output_port);
1129 InferenceContext* ctx = GetContext(node);
1130 if (ctx == nullptr) {
1131 return errors::InvalidArgument("Missing context");
1132 }
1133 ctx->set_output(output_port, shape);
1134 return Status::OK();
1135 }
1136
1137 struct ShapeId {
1138 const NodeDef* node;
1139 int port_id;
operator ==tensorflow::grappler::SymbolicShapeRefiner::ShapeId1140 bool operator==(const ShapeId& other) const {
1141 return node == other.node && port_id == other.port_id;
1142 }
1143 };
1144 struct HashShapeId {
operator ()tensorflow::grappler::SymbolicShapeRefiner::HashShapeId1145 std::size_t operator()(const ShapeId& shp) const {
1146 return std::hash<const NodeDef*>{}(shp.node) + shp.port_id;
1147 }
1148 };
1149
1150 struct DimId {
1151 const NodeDef* node;
1152 int port_id;
1153 int dim_index;
operator ==tensorflow::grappler::SymbolicShapeRefiner::DimId1154 bool operator==(const DimId& other) const {
1155 return node == other.node && port_id == other.port_id &&
1156 dim_index == other.dim_index;
1157 }
1158 };
1159
1160 struct HashDimId {
operator ()tensorflow::grappler::SymbolicShapeRefiner::HashDimId1161 std::size_t operator()(const DimId& dim) const {
1162 return std::hash<const NodeDef*>{}(dim.node) + dim.port_id +
1163 dim.dim_index;
1164 }
1165 };
1166
1167 // 'port_index' as the union of shape1 and shape2.
OutputAsUnion(const NodeDef * node,int port_index,ShapeHandle shape1,ShapeHandle shape2)1168 ShapeHandle OutputAsUnion(const NodeDef* node, int port_index,
1169 ShapeHandle shape1, ShapeHandle shape2) {
1170 if (shape1.SameHandle(shape2)) {
1171 return shape1;
1172 }
1173 InferenceContext* ctx = GetContext(node);
1174 ShapeHandle relaxed = shape1;
1175 const int rank = ctx->Rank(shape1);
1176 if (!ctx->RankKnown(shape2) || ctx->Rank(shape2) != rank) {
1177 relaxed = GetUnknownOutputShape(node, port_index);
1178 } else {
1179 for (int d = 0; d < rank; ++d) {
1180 if (!ctx->Dim(shape1, d).SameHandle(ctx->Dim(shape2, d))) {
1181 int64_t val1 = ctx->Value(ctx->Dim(shape1, d));
1182 int64_t val2 = ctx->Value(ctx->Dim(shape2, d));
1183 if (val1 != val2 || (val1 < 0 && val2 < 0)) {
1184 DimensionHandle new_dim = GetUnknownOutputDim(node, port_index, d);
1185 TF_CHECK_OK(ctx->ReplaceDim(relaxed, d, new_dim, &relaxed));
1186 }
1187 }
1188 }
1189 }
1190 return relaxed;
1191 }
1192
EquivalentShapes(ShapeHandle s1,ShapeHandle s2) const1193 bool EquivalentShapes(ShapeHandle s1, ShapeHandle s2) const {
1194 if (s1.SameHandle(s2)) {
1195 return true;
1196 }
1197 if (InferenceContext::Rank(s1) != InferenceContext::Rank(s2)) {
1198 return false;
1199 }
1200 if (!InferenceContext::RankKnown(s1) && !InferenceContext::RankKnown(s2)) {
1201 return true;
1202 }
1203 const int rank = InferenceContext::Rank(s1);
1204 for (int i = 0; i < rank; ++i) {
1205 if (!InferenceContext::DimKnownRank(s1, i).SameHandle(
1206 InferenceContext::DimKnownRank(s2, i))) {
1207 int64_t val1 =
1208 InferenceContext::Value(InferenceContext::DimKnownRank(s1, i));
1209 int64_t val2 =
1210 InferenceContext::Value(InferenceContext::DimKnownRank(s2, i));
1211 if (val1 >= 0 && val2 >= 0 && val1 == val2) {
1212 continue;
1213 }
1214 return false;
1215 }
1216 }
1217 return true;
1218 }
1219
1220 // Return true if the annotated shape is compatible with shape inference
1221 // result. Examples:
1222 // Inferred shape: ?, annotated shape: [10, 10] -> true;
1223 // Inferred shape: [-1, 10], annotated shape: [10, 10] -> true;
1224 // Inferred shape: [-1, 100], annotated shape: [10, 10] -> false;
1225 // Inferred shape: [-1, 10, 10], annotated shape: [10, 10] -> false.
CompatibleShapes(ShapeHandle inferred_shape,ShapeHandle annotated_shape) const1226 bool CompatibleShapes(ShapeHandle inferred_shape,
1227 ShapeHandle annotated_shape) const {
1228 if (inferred_shape.SameHandle(annotated_shape)) {
1229 return true;
1230 }
1231 if (!InferenceContext::RankKnown(inferred_shape)) {
1232 return true;
1233 }
1234 if (InferenceContext::Rank(inferred_shape) !=
1235 InferenceContext::Rank(annotated_shape)) {
1236 return false;
1237 }
1238 const int rank = InferenceContext::Rank(inferred_shape);
1239 for (int i = 0; i < rank; ++i) {
1240 if (!InferenceContext::DimKnownRank(inferred_shape, i)
1241 .SameHandle(
1242 InferenceContext::DimKnownRank(annotated_shape, i))) {
1243 int64_t val1 = InferenceContext::Value(
1244 InferenceContext::DimKnownRank(inferred_shape, i));
1245 int64_t val2 = InferenceContext::Value(
1246 InferenceContext::DimKnownRank(annotated_shape, i));
1247 if (val1 >= 0 && val1 != val2) {
1248 return false;
1249 }
1250 }
1251 }
1252 return true;
1253 }
1254
SameShapes(ShapeHandle inferred_shape,ShapeHandle annotated_shape) const1255 bool SameShapes(ShapeHandle inferred_shape,
1256 ShapeHandle annotated_shape) const {
1257 if (inferred_shape.SameHandle(annotated_shape)) {
1258 return true;
1259 }
1260 if (InferenceContext::Rank(inferred_shape) !=
1261 InferenceContext::Rank(annotated_shape)) {
1262 return false;
1263 }
1264 const int rank = InferenceContext::Rank(inferred_shape);
1265 for (int i = 0; i < rank; ++i) {
1266 int64_t val1 = InferenceContext::Value(
1267 InferenceContext::DimKnownRank(inferred_shape, i));
1268 int64_t val2 = InferenceContext::Value(
1269 InferenceContext::DimKnownRank(annotated_shape, i));
1270 if (val1 != val2) {
1271 return false;
1272 }
1273 }
1274 return true;
1275 }
1276
EquivalentShapesAndTypes(const std::vector<ShapeAndType> & st1,const std::vector<ShapeAndType> & st2) const1277 bool EquivalentShapesAndTypes(const std::vector<ShapeAndType>& st1,
1278 const std::vector<ShapeAndType>& st2) const {
1279 if (st1.size() != st2.size()) {
1280 return false;
1281 }
1282 for (int i = 0, st1_size = st1.size(); i < st1_size; ++i) {
1283 const ShapeAndType& s1 = st1[i];
1284 const ShapeAndType& s2 = st2[i];
1285 if (s1.dtype != s2.dtype) {
1286 return false;
1287 }
1288 if (!EquivalentShapes(s1.shape, s2.shape)) {
1289 return false;
1290 }
1291 }
1292 return true;
1293 }
1294
AddFunction(const NodeDef * function_node,NameAttrList function)1295 Status AddFunction(const NodeDef* function_node, NameAttrList function) {
1296 auto it = fun_to_grappler_function_item_.find(function.name());
1297 if (it != fun_to_grappler_function_item_.end()) {
1298 return Status::OK();
1299 }
1300
1301 const FunctionDef* function_def =
1302 CHECK_NOTNULL(function_library_.Find(function.name()));
1303 GrapplerFunctionItem grappler_function_item;
1304 Status function_instantiated =
1305 MakeGrapplerFunctionItem(*function_def, function_library_,
1306 graph_def_version_, &grappler_function_item);
1307
1308 // If function instantiation failed we will skip it during shape inference.
1309 if (!function_instantiated.ok()) {
1310 VLOG(3) << "Failed to instantiate a function. Error: "
1311 << function_instantiated.error_message();
1312 fun_to_grappler_function_item_[function_def->signature().name()] =
1313 absl::nullopt;
1314 return Status::OK();
1315 }
1316
1317 if (static_cast<int>(grappler_function_item.inputs().size()) >
1318 function_node->input_size()) {
1319 return errors::FailedPrecondition(
1320 "Function input size should be smaller than node input size.");
1321 }
1322
1323 for (int i = grappler_function_item.inputs().size(),
1324 end = function_node->input_size();
1325 i < end; ++i) {
1326 const string& input = function_node->input(i);
1327 if (!IsControlInput(input)) {
1328 return errors::FailedPrecondition(
1329 "Found regular input (", input,
1330 ") instead of control nodes for node ", function_node->name());
1331 }
1332 }
1333
1334 fun_to_grappler_function_item_[function_def->signature().name()] =
1335 grappler_function_item;
1336
1337 return Status::OK();
1338 }
1339
AddNode(const NodeDef * node)1340 Status AddNode(const NodeDef* node) {
1341 NodeContext& node_ctx = node_to_context_[node];
1342 NameAttrList function;
1343 TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*node, &function));
1344
1345 // For PartitionedCall, op_data represents the function info.
1346 TF_RETURN_IF_ERROR(
1347 function_library_.LookUp(function.name(), &node_ctx.op_data));
1348
1349 if (node_ctx.op_data->is_function_op) {
1350 TF_RETURN_IF_ERROR(AddFunction(node, function));
1351 }
1352
1353 TF_RETURN_IF_ERROR(InOutTypesForNode(*node, node_ctx.op_data->op_def,
1354 &node_ctx.input_types,
1355 &node_ctx.output_types));
1356
1357 // Create the inference context for this node.
1358 const int num_inputs = node_ctx.input_types.size();
1359 std::vector<ShapeHandle> input_shapes(num_inputs);
1360 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
1361 input_handle_shapes_and_types(num_inputs);
1362 std::vector<const Tensor*> input_tensors(num_inputs, nullptr);
1363 std::vector<ShapeHandle> input_tensors_as_shapes;
1364
1365 node_ctx.inference_context.reset(new InferenceContext(
1366 graph_def_version_, *node, node_ctx.op_data->op_def, input_shapes,
1367 input_tensors, input_tensors_as_shapes,
1368 std::move(input_handle_shapes_and_types)));
1369 const Status s = node_ctx.inference_context->construction_status();
1370 if (!s.ok()) {
1371 node_ctx.inference_context.reset(nullptr);
1372 }
1373 return s;
1374 }
1375
1376 private:
1377 // Return the one ShapeHandle used to denote a fully unknown shape for a node
1378 // output.
GetUnknownOutputShape(const NodeDef * node,int index)1379 ShapeHandle GetUnknownOutputShape(const NodeDef* node, int index) {
1380 ShapeId id{node, index};
1381 auto it = unknown_shapes_.find(id);
1382 if (it != unknown_shapes_.end()) {
1383 return it->second;
1384 }
1385 InferenceContext* c = GetContext(node);
1386 ShapeHandle shp = c->UnknownShape();
1387 unknown_shapes_[id] = shp;
1388 return shp;
1389 }
1390 // Return the one ShapeHandle used to denote a fully unknown dimension for a
1391 // node output.
GetUnknownOutputDim(const NodeDef * node,int index,int dim_id)1392 DimensionHandle GetUnknownOutputDim(const NodeDef* node, int index,
1393 int dim_id) {
1394 DimId id{node, index, dim_id};
1395 auto it = unknown_dims_.find(id);
1396 if (it != unknown_dims_.end()) {
1397 return it->second;
1398 }
1399 InferenceContext* c = GetContext(node);
1400 DimensionHandle dim = c->UnknownDim();
1401 unknown_dims_[id] = dim;
1402 return dim;
1403 }
1404
1405 // Returns true if all the output tensors have known values.
AllOutputValuesKnown(NodeContext * c)1406 bool AllOutputValuesKnown(NodeContext* c) {
1407 InferenceContext* ic = c->inference_context.get();
1408 int c_output_tensors_as_shapes_size = c->output_tensors_as_shapes.size();
1409 int c_output_tensor_protos_size = c->output_tensor_protos.size();
1410 if (c_output_tensors_as_shapes_size < ic->num_outputs() &&
1411 c_output_tensor_protos_size < ic->num_outputs()) {
1412 return false;
1413 } else {
1414 // Checks if we can get output value via either output_tensor_proto or
1415 // output_tensors_as_shapes.
1416 for (int i = 0; i < ic->num_outputs(); i++) {
1417 if (c_output_tensor_protos_size > i &&
1418 c->output_tensor_protos[i] != nullptr) {
1419 continue;
1420 }
1421 if (c_output_tensors_as_shapes_size > i &&
1422 ic->FullyDefined(c->output_tensors_as_shapes[i])) {
1423 bool no_unknown_dim_from_const = true;
1424 for (int32_t j = 0; j < ic->Rank(c->output_tensors_as_shapes[i]);
1425 ++j) {
1426 const auto dim = ic->Dim(c->output_tensors_as_shapes[i], j);
1427 if (ic->ValueKnown(dim) && ic->Value(dim) == kUnknownDimFromConst) {
1428 no_unknown_dim_from_const = false;
1429 break;
1430 }
1431 }
1432 if (no_unknown_dim_from_const) {
1433 continue;
1434 }
1435 }
1436 return false;
1437 }
1438 }
1439 return true;
1440 }
1441
1442 // Returns true if all the output shapes are known.
AllOutputShapesKnown(NodeContext * c)1443 bool AllOutputShapesKnown(NodeContext* c) {
1444 InferenceContext* ic = c->inference_context.get();
1445 // Checks if all the output shapes are fully defined.
1446 for (int i = 0; i < ic->num_outputs(); i++) {
1447 if (!ic->FullyDefined(ic->output(i))) {
1448 return false;
1449 }
1450 }
1451 return true;
1452 }
1453
1454 // Returns true if we can infer output tensors' values -- we know values of
1455 // all the input tensors.
AllInputValuesKnown(NodeContext * c)1456 bool AllInputValuesKnown(NodeContext* c) {
1457 InferenceContext* ic = c->inference_context.get();
1458
1459 // Check inputs are fully defined and values are known.
1460 for (int i = 0; i < ic->num_inputs(); i++) {
1461 const Tensor* tensor = ic->input_tensor(i);
1462 // Note that we don't check c->input_tensor_protos[i], as UpdateNode()
1463 // already converted it to ic->input_tensor(i);
1464 const ShapeHandle& input_tensors_as_shape =
1465 ic->input_tensors_as_shapes()[i];
1466 // Either input_tensor is valid or input_tensors_as_shape, which has
1467 // value of input tensors as shape format, should be fully defined.
1468 if (tensor == nullptr && !ic->FullyDefined(input_tensors_as_shape)) {
1469 return false;
1470 }
1471 }
1472 return true;
1473 }
1474
1475 // Returns true if we want to update output shapes and values with running
1476 // EvaluateNode() for this op, based on op type, data type, and size.
ShouldUpdateOutputShapesAndValues(NodeContext * c,int64_t max_size)1477 bool ShouldUpdateOutputShapesAndValues(NodeContext* c, int64_t max_size) {
1478 InferenceContext* ic = c->inference_context.get();
1479
1480 // Due to the cost of running EvaluateNode(), we limit only to white listed
1481 // op types.
1482 if (!IsAllowListedOpTypeForEvaluateNode(c->op_data->op_def.name())) {
1483 return false;
1484 }
1485
1486 // Check input dtypes are number types.
1487 for (const auto& input_type : c->input_types) {
1488 if (!IsNumericType(input_type)) {
1489 return false;
1490 }
1491 }
1492
1493 // Check output dtypes are number types.
1494 for (const auto& output_type : c->output_types) {
1495 if (!IsNumericType(output_type)) {
1496 return false;
1497 }
1498 }
1499
1500 // Check if the number of elements of each of input tensor is no larger than
1501 // the given max size.
1502 for (int i = 0; i < ic->num_inputs(); i++) {
1503 const Tensor* tensor = ic->input_tensor(i);
1504 const ShapeHandle& input_shape_handle = ic->input(i);
1505 if (tensor != nullptr) {
1506 if (tensor->NumElements() > max_size) {
1507 return false;
1508 }
1509 } else if (ic->Value(ic->NumElements(input_shape_handle)) > max_size) {
1510 return false;
1511 }
1512 }
1513
1514 // Check if we know the shape of each output tensor, and the number of
1515 // elements is larger than the given max size.
1516 for (int i = 0; i < ic->num_outputs(); i++) {
1517 const ShapeHandle& shape_handle = ic->output(i);
1518 if (!ic->FullyDefined(shape_handle) ||
1519 ic->Value(ic->NumElements(shape_handle)) > max_size) {
1520 return false;
1521 }
1522 }
1523 return true;
1524 }
1525
1526 // Create input tensors from the NodeContext.
CreateInputTensors(NodeContext * c,std::vector<Tensor> * input_tensor_vector,TensorVector * inputs)1527 void CreateInputTensors(NodeContext* c,
1528 std::vector<Tensor>* input_tensor_vector,
1529 TensorVector* inputs) {
1530 InferenceContext* ic = c->inference_context.get();
1531 for (int i = 0; i < ic->num_inputs(); i++) {
1532 if (ic->input_tensor(i)) {
1533 input_tensor_vector->at(i) = *ic->input_tensor(i);
1534 inputs->emplace_back(&input_tensor_vector->at(i));
1535 // Note that we don't check c->input_tensor_protos[i], as UpdateNode()
1536 // already converted it to ic->input_tensor(i);
1537 } else {
1538 // Create Tensor from input_tensors_as_shapes, and then emplace it
1539 // back to inputs.
1540 // Note that input_tensors_as_shapes is scalar or vector.
1541 const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
1542 const DataType& data_type = c->input_types[i];
1543 int32_t rank = ic->Rank(shape_handle);
1544 if (rank < 1) {
1545 input_tensor_vector->at(i) = Tensor(data_type, {});
1546 } else {
1547 input_tensor_vector->at(i) = Tensor(data_type, {rank});
1548 }
1549 auto* tensor = &input_tensor_vector->at(i);
1550 if (data_type == DT_INT32) {
1551 auto flat = tensor->flat<int32>();
1552 for (int j = 0; j < rank; j++) {
1553 int32_t dim = ic->Value(ic->Dim(shape_handle, j));
1554 flat(j) = dim;
1555 }
1556 } else {
1557 auto flat = tensor->flat<int64>();
1558 for (int j = 0; j < rank; j++) {
1559 int64_t dim = ic->Value(ic->Dim(shape_handle, j));
1560 flat(j) = dim;
1561 }
1562 }
1563 inputs->emplace_back(tensor);
1564 }
1565 }
1566 }
1567
1568 // Run a node to infer output shapes and values, and add it to the
1569 // NodeContext.
UpdateOutputShapesAndValues(const NodeDef & node,NodeContext * c)1570 Status UpdateOutputShapesAndValues(const NodeDef& node, NodeContext* c) {
1571 InferenceContext* ic = c->inference_context.get();
1572
1573 // Input to EvaluateNode()
1574 TensorVector inputs;
1575 // Container for temporarily created tensor object.
1576 std::vector<Tensor> input_tensor_vector(ic->num_inputs());
1577 CreateInputTensors(c, &input_tensor_vector, &inputs);
1578
1579 // Output for EvaluateNode() and output tensor clean up object.
1580 TensorVector outputs;
1581 auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
1582 for (const auto& output : outputs) {
1583 if (output.tensor) {
1584 delete output.tensor;
1585 }
1586 }
1587 });
1588
1589 TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, /*cpu_device=*/nullptr,
1590 &resource_mgr_, &outputs));
1591 c->output_tensors_as_shapes.resize(outputs.size());
1592 c->output_tensor_protos.resize(outputs.size(), nullptr);
1593 for (int k = 0, outputs_size = outputs.size(); k < outputs_size; k++) {
1594 const auto& t = outputs[k];
1595 // Override output shape.
1596 ShapeHandle output_shape;
1597 TF_RETURN_IF_ERROR(
1598 ic->MakeShapeFromTensorShape(t->shape(), &output_shape));
1599 if (ic->FullyDefined(ic->output(k)) &&
1600 !EquivalentShapes(ic->output(k), output_shape)) {
1601 LOG(WARNING) << "UpdateOutputShapesAndValues() -- node: " << node.name()
1602 << ", inferred output shape "
1603 << "doesn't match for k=" << k << ": "
1604 << "ic->output(k): " << ic->DebugString(ic->output(k))
1605 << ", output_shape: " << ic->DebugString(output_shape)
1606 << " -- " << node.DebugString();
1607 }
1608 ic->set_output(k, output_shape);
1609 // Set output_tensors_as_shape.
1610 MaybeTensorValueToShape(ic, *t.tensor, &c->output_tensors_as_shapes[k]);
1611
1612 // Set output_tensor_protos.
1613 TensorProto tensor_proto;
1614 t->AsProtoTensorContent(&tensor_proto);
1615 const_tensors_to_propagate_.push_back(tensor_proto);
1616 c->output_tensor_protos[k] = &const_tensors_to_propagate_.back();
1617 }
1618 return Status::OK();
1619 }
1620
1621 // Update output shapes with annotated information.
1622 // Currently only handle nodes with static shapes, i.e. shapes do not change
1623 // during execution.
1624 // TODO(andiryxu): Use annotated shapes in Enter/Merge etc as well.
UpdateOutputShapesUsingAnnotatedInformation(const NodeDef & node,NodeContext * c) const1625 Status UpdateOutputShapesUsingAnnotatedInformation(const NodeDef& node,
1626 NodeContext* c) const {
1627 const auto& attr = node.attr();
1628 if (attr.count(kOutputSame) == 0 || !attr.at(kOutputSame).b() ||
1629 attr.count(kOutputShapes) == 0)
1630 return Status::OK();
1631
1632 InferenceContext* ic = c->inference_context.get();
1633 int output_size = attr.at(kOutputShapes).list().shape_size();
1634
1635 for (int i = 0; i < ic->num_outputs(); i++) {
1636 // Annotated Switch node has only one output. Propagate the shape to all
1637 // the outputs.
1638 int shape_index = IsSwitch(node) ? 0 : i;
1639 if (shape_index >= output_size) {
1640 LOG(WARNING)
1641 << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1642 << node.name() << ", inferred output shape size "
1643 << ic->num_outputs() << ", annotated output shape size "
1644 << output_size;
1645 break;
1646 }
1647
1648 const TensorShapeProto& shape =
1649 attr.at(kOutputShapes).list().shape(shape_index);
1650 if (shape.dim().empty()) continue;
1651
1652 ShapeHandle output_shape;
1653 TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &output_shape));
1654
1655 // Check if annotated shapes are incompatible with inferred shapes.
1656 if ((ic->FullyDefined(ic->output(i)) &&
1657 !SameShapes(ic->output(i), output_shape)) ||
1658 (!ic->FullyDefined(ic->output(i)) &&
1659 !CompatibleShapes(ic->output(i), output_shape))) {
1660 LOG(WARNING)
1661 << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1662 << node.name() << ", inferred output shape "
1663 << "doesn't match for i=" << i << ": "
1664 << "ic->output(k): " << ic->DebugString(ic->output(i))
1665 << ", annotated output shape: " << ic->DebugString(output_shape)
1666 << " -- " << node.DebugString();
1667 c->shape_incompatible = true;
1668 }
1669
1670 // Only use annotated shapes if the inference shape is unknown and
1671 // compatible with annotated shapes.
1672 if (!ic->FullyDefined(ic->output(i)) &&
1673 CompatibleShapes(ic->output(i), output_shape)) {
1674 VLOG(3) << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1675 << node.name() << ", inferred output shape " << i << ": "
1676 << "ic->output(i): " << ic->DebugString(ic->output(i))
1677 << ", annotated output shape: " << ic->DebugString(output_shape)
1678 << " -- " << node.ShortDebugString();
1679 ic->set_output(i, output_shape);
1680 }
1681 }
1682
1683 return Status::OK();
1684 }
1685
MaybeUpdateNodeContextOutput(const NodeDef & node,const bool is_fed,NodeContext * c)1686 Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed,
1687 NodeContext* c) {
1688 // Propagate tensors and shape tensors unless the node is fed.
1689 // TODO(bsteiner) We should still propagate the shapes to the ports that
1690 // aren't fed in the case of a ShapeN node.
1691
1692 // Note that when propagating tensors_as_shapes, we use
1693 // c->input_tensors_as_shapes_to_progate instead of
1694 // ic->input_tensors_as_shapes. The former uses kUnknownDimFromConst if
1695 // UnknownDim is from Const tensor, and it is propagated through shape
1696 // inference. Before calling shape functions, we convert it to UnknownDim,
1697 // but instantiate a new UnknownDim to prevent incorrect symbolic shape
1698 // inference through UnknownDim from Const.
1699 InferenceContext* ic = c->inference_context.get();
1700 if (!is_fed) {
1701 if (IsConstant(node)) {
1702 const TensorProto& tensor_proto = node.attr().at("value").tensor();
1703 c->output_tensor_protos.resize(1);
1704 c->output_tensor_protos[0] = &tensor_proto;
1705 c->output_tensors_as_shapes.resize(1);
1706 MaybeTensorProtoToShape(ic, tensor_proto,
1707 &c->output_tensors_as_shapes[0]);
1708 } else if (IsRank(node)) {
1709 if (ic->RankKnown(ic->input(0))) {
1710 // Propagate rank value.
1711 int32_t rank = ic->Rank(ic->input(0));
1712 const_tensors_to_propagate_.push_back(
1713 MakeIntegerScalarTensorProto(DT_INT32, rank));
1714 c->output_tensor_protos.resize(1);
1715 c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
1716 }
1717 } else if (IsSize(node)) {
1718 DimensionHandle size = ic->NumElements(ic->input(0));
1719 if (ic->ValueKnown(size)) {
1720 // Propagate size value.
1721 int64_t sz = ic->Value(size);
1722 bool valid = false;
1723 if (node.attr().at("out_type").type() == DT_INT32) {
1724 if (sz < std::numeric_limits<int32>::max()) {
1725 const_tensors_to_propagate_.push_back(
1726 MakeIntegerScalarTensorProto(DT_INT32, sz));
1727 valid = true;
1728 }
1729 } else {
1730 const_tensors_to_propagate_.push_back(
1731 MakeIntegerScalarTensorProto(DT_INT64, sz));
1732 valid = true;
1733 }
1734 if (valid) {
1735 c->output_tensor_protos.resize(1);
1736 c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
1737 }
1738 }
1739 } else if (IsShape(node)) {
1740 c->output_tensors_as_shapes.resize(1);
1741 c->output_tensors_as_shapes[0] = c->inference_context->input(0);
1742 } else if (IsShapeN(node)) {
1743 c->output_tensors_as_shapes.resize(c->inference_context->num_inputs());
1744 for (int i = 0; i < c->inference_context->num_inputs(); ++i) {
1745 c->output_tensors_as_shapes[i] = c->inference_context->input(i);
1746 }
1747 } else if (node.op() == "ConcatV2") {
1748 bool valid = true;
1749 ShapeHandle result;
1750 for (int i = 0; i < ic->num_inputs() - 1; ++i) {
1751 ShapeHandle input = c->input_tensors_as_shapes_to_propagate[i];
1752 if (!ic->RankKnown(input)) {
1753 valid = false;
1754 break;
1755 } else if (i == 0) {
1756 result = input;
1757 } else {
1758 TF_RETURN_IF_ERROR(ic->Concatenate(result, input, &result));
1759 }
1760 }
1761 if (valid) {
1762 c->output_tensors_as_shapes.resize(1);
1763 c->output_tensors_as_shapes[0] = result;
1764 }
1765 } else if (IsPack(node)) {
1766 // A Pack node concatenating scalars is often used to generate a shape.
1767 std::vector<DimensionHandle> dims;
1768 bool valid = true;
1769 for (int i = 0; i < ic->num_inputs(); ++i) {
1770 const Tensor* t = ic->input_tensor(i);
1771 if (t) {
1772 if (t->dims() != 0 ||
1773 (t->dtype() != DT_INT32 && t->dtype() != DT_INT64)) {
1774 valid = false;
1775 break;
1776 }
1777 int64_t size = t->dtype() == DT_INT32 ? t->scalar<int32>()()
1778 : t->scalar<int64>()();
1779 dims.push_back(size < 0 ? ic->MakeDim(kUnknownDimFromConst)
1780 : ic->MakeDim(size));
1781 } else {
1782 // Don't have tensor value, but use input_tensors_as_shapes, if
1783 // possible.
1784 const ShapeHandle& shape_handle =
1785 c->input_tensors_as_shapes_to_propagate[i];
1786 if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 &&
1787 ic->ValueKnown(ic->Dim(shape_handle, 0))) {
1788 dims.push_back(ic->Dim(shape_handle, 0));
1789 } else {
1790 // This is not from Const, but as it shouldn'be used as symbolic
1791 // unknown dim for different ops, we use kUnknownDimFromConst.
1792 dims.push_back(ic->MakeDim(kUnknownDimFromConst));
1793 }
1794 }
1795 }
1796 if (valid) {
1797 c->output_tensors_as_shapes.resize(1);
1798 c->output_tensors_as_shapes[0] = ic->MakeShape(dims);
1799 }
1800 } else if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
1801 c->output_tensors_as_shapes.resize(1);
1802 c->output_tensors_as_shapes[0] =
1803 c->input_tensors_as_shapes_to_propagate[0];
1804 if (c->input_tensor_protos[0] != nullptr) {
1805 c->output_tensor_protos.resize(1);
1806 c->output_tensor_protos[0] = c->input_tensor_protos[0];
1807 }
1808 } else if (IsSlice(node)) {
1809 ShapeHandle input = c->input_tensors_as_shapes_to_propagate[0];
1810 bool valid = ic->RankKnown(input);
1811 const Tensor* slice_offset = ic->input_tensor(1);
1812 valid &= slice_offset != nullptr && slice_offset->NumElements() == 1;
1813 const Tensor* slice_size = ic->input_tensor(2);
1814 valid &= slice_size != nullptr && slice_size->NumElements() == 1;
1815 if (valid) {
1816 int64_t start = slice_offset->dtype() == DT_INT32
1817 ? slice_offset->flat<int32>()(0)
1818 : slice_offset->flat<int64>()(0);
1819 int64_t size =
1820 (slice_size->dtype() == DT_INT32 ? slice_size->flat<int32>()(0)
1821 : slice_size->flat<int64>()(0));
1822 ShapeHandle result;
1823 if (size == -1) {
1824 TF_RETURN_IF_ERROR(ic->Subshape(input, start, &result));
1825 } else {
1826 int64_t end = start + size;
1827 TF_RETURN_IF_ERROR(ic->Subshape(input, start, end, &result));
1828 }
1829 c->output_tensors_as_shapes.resize(1);
1830 c->output_tensors_as_shapes[0] = result;
1831 }
1832 } else if (IsStridedSlice(node)) {
1833 ShapeHandle input = c->input_tensors_as_shapes_to_propagate[0];
1834 bool valid = ic->RankKnown(input);
1835 const Tensor* slice_begin = ic->input_tensor(1);
1836 valid &= slice_begin != nullptr && slice_begin->NumElements() == 1;
1837 const Tensor* slice_end = ic->input_tensor(2);
1838 valid &= slice_end != nullptr && slice_end->NumElements() == 1;
1839 const Tensor* slice_stride = ic->input_tensor(3);
1840 valid &= slice_stride != nullptr && slice_stride->NumElements() == 1;
1841
1842 if (node.attr().count("ellipsis_mask") > 0 &&
1843 node.attr().at("ellipsis_mask").i() != 0) {
1844 valid = false;
1845 }
1846 if (node.attr().count("new_axis_mask") > 0 &&
1847 node.attr().at("new_axis_mask").i() != 0) {
1848 valid = false;
1849 }
1850 if (node.attr().count("shrink_axis_mask") > 0 &&
1851 node.attr().at("shrink_axis_mask").i() != 0) {
1852 valid = false;
1853 }
1854 int begin_mask = 0;
1855 if (node.attr().count("begin_mask") > 0) {
1856 begin_mask = node.attr().at("begin_mask").i();
1857 }
1858 int end_mask = 0;
1859 if (node.attr().count("end_mask") > 0) {
1860 end_mask = node.attr().at("end_mask").i();
1861 }
1862 if (begin_mask < 0 || begin_mask > 1 || end_mask < 0 || end_mask > 1) {
1863 valid = false;
1864 }
1865 if (valid) {
1866 int64_t begin = 0;
1867 if (begin_mask == 0) {
1868 begin = slice_begin->dtype() == DT_INT32
1869 ? slice_begin->flat<int32>()(0)
1870 : slice_begin->flat<int64>()(0);
1871 }
1872 int64_t end = std::numeric_limits<int64>::max();
1873 if (end_mask == 0) {
1874 end =
1875 (slice_end->dtype() == DT_INT32 ? slice_end->flat<int32>()(0)
1876 : slice_end->flat<int64>()(0));
1877 }
1878 int64_t stride = slice_stride->dtype() == DT_INT32
1879 ? slice_stride->flat<int32>()(0)
1880 : slice_stride->flat<int64>()(0);
1881 ShapeHandle result;
1882 TF_RETURN_IF_ERROR(ic->Subshape(input, begin, end, stride, &result));
1883 c->output_tensors_as_shapes.resize(1);
1884 c->output_tensors_as_shapes[0] = result;
1885 }
1886 }
1887 }
1888
1889 if (aggressive_shape_inference_) {
1890 // Update output shapes with annotated information. This is optional.
1891 UpdateOutputShapesUsingAnnotatedInformation(node, c).IgnoreError();
1892
1893 // Update output tensor values using EvaluateNode() if we can.
1894 // Due to the cost of EvaluateNode(), we run it only for certain op types
1895 // (white listed) and small integer tensors.
1896
1897 const int max_element_size = 17; // Max up to 4x4 matrix or similar.
1898 if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) ||
1899 !ShouldUpdateOutputShapesAndValues(c, max_element_size)) {
1900 return Status::OK();
1901 }
1902 UpdateOutputShapesAndValues(node, c).IgnoreError(); // This is optional.
1903 }
1904 return Status::OK();
1905 }
1906
InferShapes(const NodeDef & node,NodeContext * c)1907 Status InferShapes(const NodeDef& node, NodeContext* c) {
1908 // Infer the shapes of output tensors.
1909 if (!c->op_data || c->op_data->shape_inference_fn == nullptr ||
1910 !c->inference_context->Run(c->op_data->shape_inference_fn).ok()) {
1911 // Annotate outputs with unknown shapes. Update output shapes with
1912 // annotated information later on if available.
1913 // Note that shape inference function may return an error, but we ignore
1914 // it, and use UnknownShape in that case.
1915 TF_RETURN_IF_ERROR(
1916 c->inference_context->Run(shape_inference::UnknownShape));
1917 }
1918 Status status = Status::OK();
1919 auto it = fed_ports_.find(node.name());
1920 const bool is_fed = it != fed_ports_.end();
1921 if (is_fed) {
1922 // It is possible to feed node output ports with tensors of any shape: as
1923 // a result, the shape of a fed port is completely unknown.
1924 for (const int output_port : it->second) {
1925 status.Update(SetUnknownShape(&node, output_port));
1926 }
1927 }
1928
1929 // Update NodeContext output fields after shape inference function runs.
1930 status.Update(MaybeUpdateNodeContextOutput(node, is_fed, c));
1931
1932 return status;
1933 }
1934
1935 private:
IsIntegerVector(const Tensor & tensor)1936 bool IsIntegerVector(const Tensor& tensor) {
1937 if (tensor.dims() == 1 &&
1938 (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64)) {
1939 return true;
1940 }
1941 return false;
1942 }
1943
IsIntegerScalar(const Tensor & tensor)1944 bool IsIntegerScalar(const Tensor& tensor) {
1945 if (tensor.dims() == 0 &&
1946 (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64) &&
1947 tensor.NumElements() == 1) {
1948 return true;
1949 }
1950 return false;
1951 }
1952
MakeIntegerScalarTensorProto(const DataType dtype,const int64_t val)1953 TensorProto MakeIntegerScalarTensorProto(const DataType dtype,
1954 const int64_t val) {
1955 TensorProto tensor_proto;
1956 tensor_proto.set_dtype(dtype);
1957 // Scalar TensorProto has an empty tensor_shape; no dim, no dim.size.
1958 tensor_proto.mutable_tensor_shape();
1959 if (dtype == DT_INT32) {
1960 tensor_proto.add_int_val(val);
1961 } else if (dtype == DT_INT64) {
1962 tensor_proto.add_int64_val(val);
1963 }
1964 return tensor_proto;
1965 }
1966
MaybeTensorProtoToShape(InferenceContext * ic,const TensorProto & tensor_proto,ShapeHandle * tensors_as_shapes)1967 bool MaybeTensorProtoToShape(InferenceContext* ic,
1968 const TensorProto& tensor_proto,
1969 ShapeHandle* tensors_as_shapes) {
1970 // Skip if dtype is not integer.
1971 if (tensor_proto.dtype() != DT_INT32 && tensor_proto.dtype() != DT_INT64) {
1972 return false;
1973 }
1974 // Skip if the const tensor is too large.
1975 if (NumElementsFromTensorProto(tensor_proto) >
1976 kThresholdToSkipConstTensorInstantiation) {
1977 return false;
1978 }
1979 // Skip if shape is neither scalar nor vector.
1980 if (tensor_proto.tensor_shape().unknown_rank() ||
1981 tensor_proto.tensor_shape().dim_size() > 1) {
1982 return false;
1983 }
1984 Tensor tensor;
1985 if (!tensor.FromProto(tensor_proto)) {
1986 return false;
1987 }
1988 return MaybeTensorValueToShape(ic, tensor, tensors_as_shapes);
1989 }
1990
MaybeTensorValueToShape(InferenceContext * ic,const Tensor & tensor,ShapeHandle * tensors_as_shapes)1991 bool MaybeTensorValueToShape(InferenceContext* ic, const Tensor& tensor,
1992 ShapeHandle* tensors_as_shapes) {
1993 // Integer tensors of rank one can also be interpreted as a shape
1994 // provided all their values are >= -1.
1995
1996 if (IsIntegerVector(tensor)) {
1997 bool has_values_smaller_than_minus_1 = false;
1998 std::vector<DimensionHandle> dims;
1999 for (int i = 0; i < tensor.NumElements(); i++) {
2000 int64_t value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(i)
2001 : tensor.flat<int64>()(i);
2002 has_values_smaller_than_minus_1 |= (value < -1);
2003 // Mark this as UnknownDim from Const.
2004 dims.push_back(value < 0 ? ic->MakeDim(kUnknownDimFromConst)
2005 : ic->MakeDim(value));
2006 }
2007
2008 if (!has_values_smaller_than_minus_1) {
2009 *tensors_as_shapes = ic->MakeShape(dims);
2010 return true;
2011 }
2012 } else if (IsIntegerScalar(tensor)) {
2013 // Scalar constant.
2014 int64_t value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0)
2015 : tensor.flat<int64>()(0);
2016 if (value == -1) {
2017 // Scalar value -1 represents an unknown shape. If we would try to
2018 // MakeShape(MakeDim) with it, we would get vector of unknown size.
2019 *tensors_as_shapes = ic->UnknownShape();
2020 return true;
2021 } else if (value >= 0) {
2022 // Ideally, values can be < -1, but MakeDim() fails with a value < -1.
2023 // It's a limitation as we use ShapeHandle as a means to pass values.
2024 *tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)});
2025 return true;
2026 }
2027 }
2028 return false;
2029 }
2030
2031 const GraphView& graph_;
2032 int graph_def_version_;
2033 absl::flat_hash_map<const NodeDef*, NodeContext> node_to_context_;
2034 absl::flat_hash_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
2035 absl::flat_hash_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
2036 // Store function instantiations only for valid function. If function
2037 // instantiation failed it will have an `absl::nullopt`.
2038 absl::flat_hash_map<string, absl::optional<GrapplerFunctionItem>>
2039 fun_to_grappler_function_item_;
2040 FunctionLibraryDefinition function_library_;
2041 const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports_;
2042 // Store TensorProtos for tensor value propagation. Note that we use deque,
2043 // not vector, as we use pointers to the TensorProtos in this container.
2044 // Vector may resize and copy the objects into a new buffer, then the existing
2045 // pointers become dangling pointers.
2046 std::deque<TensorProto> const_tensors_to_propagate_;
2047
2048 // For more aggressive shape and value inference.
2049 bool aggressive_shape_inference_;
2050 ResourceMgr resource_mgr_;
2051 };
2052
2053 // Keep track of shapes and dimensions in a graph.
2054 // In particular, use disjoint sets to track equivalence between shapes and
2055 // dims, and consolidate the information globally.
2056 class SymbolicShapeManager {
2057 public:
SymbolicShapeManager()2058 SymbolicShapeManager() {}
2059
Merge(ShapeHandle s1,ShapeHandle s2)2060 Status Merge(ShapeHandle s1, ShapeHandle s2) {
2061 if (!s1.IsSet() || !s2.IsSet()) {
2062 return Status::OK();
2063 }
2064 TF_RETURN_IF_ERROR(shapes_.Merge(s1, s2));
2065 if (InferenceContext::Rank(s1) > 0 && InferenceContext::Rank(s2) > 0) {
2066 CHECK_EQ(InferenceContext::Rank(s1), InferenceContext::Rank(s2));
2067 for (int i = 0; i < InferenceContext::Rank(s1); ++i) {
2068 TF_RETURN_IF_ERROR(dims_.Merge(InferenceContext::DimKnownRank(s1, i),
2069 InferenceContext::DimKnownRank(s2, i)));
2070 }
2071 }
2072 return Status::OK();
2073 }
Merge(DimensionHandle d1,DimensionHandle d2)2074 Status Merge(DimensionHandle d1, DimensionHandle d2) {
2075 if (!d1.IsSet() || !d2.IsSet()) {
2076 return Status::OK();
2077 }
2078 return dims_.Merge(d1, d2);
2079 }
2080
AsTensorProperties(const ShapeHandle & shape,const DataType & type,OpInfo::TensorProperties * properties)2081 void AsTensorProperties(const ShapeHandle& shape, const DataType& type,
2082 OpInfo::TensorProperties* properties) {
2083 properties->set_dtype(type);
2084 ShapeHandle actual_shape = shapes_.GetMergedValue(shape);
2085 if (!InferenceContext::RankKnown(actual_shape)) {
2086 properties->mutable_shape()->set_unknown_rank(true);
2087 } else {
2088 for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
2089 shape_inference::DimensionHandle dim =
2090 InferenceContext::DimKnownRank(actual_shape, j);
2091 int64_t d = dims_.GetMergedValue(dim);
2092 properties->mutable_shape()->add_dim()->set_size(d);
2093 }
2094 }
2095 }
2096
2097 // Returns merged shape with merged dimensions.
GetMergedShape(InferenceContext * ic,ShapeHandle s)2098 ShapeHandle GetMergedShape(InferenceContext* ic, ShapeHandle s) {
2099 const auto& actual_shape = shapes_.GetMergedValue(s);
2100 if (!InferenceContext::RankKnown(actual_shape)) {
2101 return ic->UnknownShape();
2102 } else {
2103 std::vector<DimensionHandle> dims;
2104 for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
2105 shape_inference::DimensionHandle dim =
2106 InferenceContext::DimKnownRank(actual_shape, j);
2107 int64_t d = dims_.GetMergedValue(dim);
2108 // Symbolic shape manager may made some dims < -1, which causes errors
2109 // in creating Dimension.
2110 if (d < -1) {
2111 d = -1;
2112 }
2113 dims.push_back(ic->MakeDim(d));
2114 }
2115 return ic->MakeShape(dims);
2116 }
2117 }
2118
2119 private:
2120 DisjointSet<shape_inference::ShapeHandle> shapes_;
2121 DisjointSet<shape_inference::DimensionHandle> dims_;
2122 };
2123
2124 // Checks whether there is any conflict in merged shapes and dims in
2125 // SymbolicShapeManager.
ValidateSymbolicShapeManager(const GraphDef & graph_def,SymbolicShapeRefiner * refiner,SymbolicShapeManager * shape_manager)2126 Status ValidateSymbolicShapeManager(const GraphDef& graph_def,
2127 SymbolicShapeRefiner* refiner,
2128 SymbolicShapeManager* shape_manager) {
2129 if (!VLOG_IS_ON(1)) {
2130 return Status::OK();
2131 }
2132
2133 VLOG(1) << "Checking any conflics in shapes and dimensions ...";
2134 int64_t num_incompatible_shapes = 0;
2135 for (const NodeDef& node : graph_def.node()) {
2136 auto ctx = refiner->GetNodeContext(&node);
2137 if (!ctx) {
2138 continue;
2139 }
2140 auto* ic = ctx->inference_context.get();
2141 for (int i = 0; i < ic->num_inputs(); ++i) {
2142 const auto& shape = ic->input(i);
2143 const auto& merged_shape = shape_manager->GetMergedShape(ic, shape);
2144 if (!refiner->CompatibleShapes(shape, merged_shape)) {
2145 num_incompatible_shapes++;
2146 VLOG(1) << "**** Incompatible shape from SymbolicShapeManager "
2147 << "for node " << node.name() << " input (" << i << ") "
2148 << ic->DebugString(shape)
2149 << " vs. merged: " << ic->DebugString(merged_shape);
2150 }
2151 }
2152 for (int i = 0; i < ic->num_outputs(); ++i) {
2153 const auto& shape = ic->output(i);
2154 const auto& merged_shape = shape_manager->GetMergedShape(ic, shape);
2155 if (!refiner->CompatibleShapes(shape, merged_shape)) {
2156 num_incompatible_shapes++;
2157 VLOG(1) << "**** Incompatible shape from SymbolicShapeManager "
2158 << "for node " << node.name() << " output (" << i << ") "
2159 << ic->DebugString(shape)
2160 << " vs. merged: " << ic->DebugString(merged_shape);
2161 }
2162 }
2163 }
2164 if (num_incompatible_shapes > 0) {
2165 VLOG(1) << "**** WARNING: " << num_incompatible_shapes
2166 << " incompatible shapes from SymbolicShapeManager.";
2167 } else {
2168 VLOG(1) << "**** No incompatible shape found from SymbolicShapeManager.";
2169 }
2170
2171 return Status::OK();
2172 }
2173
2174 // Log shape inference and its merged shapes.
VerboseShapeInferenceLogging(const GraphDef & graph_def,SymbolicShapeRefiner * refiner,SymbolicShapeManager * shape_manager)2175 Status VerboseShapeInferenceLogging(const GraphDef& graph_def,
2176 SymbolicShapeRefiner* refiner,
2177 SymbolicShapeManager* shape_manager) {
2178 // As logging all the nodes would generate too many lines, we by default
2179 // skip this detailed logging. Users may add nodes of interest to
2180 // node_names_for_logging to enable detailed logging.
2181 absl::flat_hash_set<std::string> node_names_for_logging = {};
2182 if (!VLOG_IS_ON(3) || node_names_for_logging.empty()) {
2183 return Status::OK();
2184 }
2185
2186 auto should_log = [&node_names_for_logging](std::string node_name) {
2187 return node_names_for_logging.find(node_name) !=
2188 node_names_for_logging.end();
2189 };
2190
2191 for (const NodeDef& node : graph_def.node()) {
2192 if (!should_log(node.name())) {
2193 continue;
2194 }
2195 auto ctx = refiner->GetNodeContext(&node);
2196 if (!ctx) {
2197 continue;
2198 }
2199 auto* ic = ctx->inference_context.get();
2200 VLOG(3) << "Shape inference for node : " << node.name();
2201 VLOG(3) << ctx->DebugString(node);
2202 std::string merged_shapes = "Merged shapes from SymbolicShapManager:\n";
2203 for (int i = 0; i < ic->num_inputs(); ++i) {
2204 absl::StrAppend(
2205 &merged_shapes, " input[", i, "] -- ",
2206 ic->DebugString(shape_manager->GetMergedShape(ic, ic->input(i))),
2207 "\n");
2208 }
2209 for (int i = 0; i < ic->num_outputs(); ++i) {
2210 absl::StrAppend(
2211 &merged_shapes, " output[", i, "] -- ",
2212 ic->DebugString(shape_manager->GetMergedShape(ic, ic->output(i))),
2213 "\n");
2214 }
2215 VLOG(3) << merged_shapes;
2216 VLOG(3) << "--------------------------------";
2217 VLOG(3) << "";
2218 }
2219
2220 return Status::OK();
2221 }
2222
RelaxEnqueueShapesAndMergeTypes(SymbolicShapeRefiner * shape_refiner,const NodeDef * qnode,const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * queue_shapes_and_types)2223 Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
2224 SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
2225 const std::vector<ShapeAndType>& shapes_and_types,
2226 std::vector<ShapeAndType>* queue_shapes_and_types) {
2227 if (shapes_and_types.size() != queue_shapes_and_types->size()) {
2228 return errors::InvalidArgument(
2229 "Enqueue nodes mixed number of tensors: ", shapes_and_types.size(),
2230 " vs ", queue_shapes_and_types->size());
2231 }
2232 for (size_t i = 0; i < shapes_and_types.size(); ++i) {
2233 const ShapeAndType& a = shapes_and_types[i];
2234 ShapeAndType& b = (*queue_shapes_and_types)[i];
2235 if (a.dtype != b.dtype) {
2236 return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ",
2237 i, ": ", DataTypeString(a.dtype), " vs ",
2238 DataTypeString(b.dtype));
2239 }
2240
2241 b.shape = shape_refiner->OutputAsUnion(qnode, i, a.shape, b.shape);
2242 }
2243 return Status::OK();
2244 }
2245
2246 // Compute the output shape of the merge node as the union of the available
2247 // input shapes.
UpdateMerge(SymbolicShapeRefiner * shape_refiner,const NodeDef * node,bool * new_shapes) const2248 Status GraphProperties::UpdateMerge(SymbolicShapeRefiner* shape_refiner,
2249 const NodeDef* node,
2250 bool* new_shapes) const {
2251 InferenceContext* ic = shape_refiner->GetContext(node);
2252 if (!ic) {
2253 // Now we can run shape inference
2254 TF_RETURN_IF_ERROR(shape_refiner->AddNode(node));
2255 ic = CHECK_NOTNULL(shape_refiner->GetContext(node));
2256 *new_shapes = true;
2257
2258 // Infer the shape of the second output once and for all since it never
2259 // changes.
2260 ShapeHandle out1 = ic->Scalar();
2261 if (ic->num_outputs() >= 2) ic->set_output(1, out1);
2262 }
2263
2264 ShapeHandle out;
2265 const std::vector<ShapeAndType>* out_handle = nullptr;
2266 bool out_initialized = false;
2267 for (const GraphView::Edge fanin : shape_refiner->graph().GetFaninEdges(
2268 *node, /*include_controlling_edges=*/false)) {
2269 InferenceContext* src_ic = shape_refiner->GetContext(fanin.src.node);
2270 if (!src_ic) {
2271 // Handling a loop for the first time, the back edge won't have any shape
2272 // info.
2273 continue;
2274 }
2275 ShapeHandle input = src_ic->output(fanin.src.port_id);
2276 ic->SetInput(fanin.dst.port_id, input);
2277 auto* input_handle =
2278 src_ic->output_handle_shapes_and_types(fanin.src.port_id);
2279 if (input_handle)
2280 ic->set_input_handle_shapes_and_types(fanin.dst.port_id, *input_handle);
2281 if (!out_initialized) {
2282 out_initialized = true;
2283 out = input;
2284 out_handle = input_handle;
2285 } else {
2286 // Note here only out, not out_handle, is modified.
2287 out = shape_refiner->OutputAsUnion(node, 0, input, out);
2288 }
2289 }
2290
2291 if (*new_shapes || !shape_refiner->EquivalentShapes(out, ic->output(0))) {
2292 ic->set_output(0, out);
2293 if (out_handle) ic->set_output_handle_shapes_and_types(0, *out_handle);
2294 *new_shapes = true;
2295 }
2296
2297 return Status::OK();
2298 }
2299
2300 // Manually propagate the input shape for Enter nodes.
UpdateEnter(SymbolicShapeRefiner * shape_refiner,const NodeDef * node,bool * new_shapes)2301 Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
2302 const NodeDef* node, bool* new_shapes) {
2303 InferenceContext* ic = shape_refiner->GetContext(node);
2304 if (!ic) {
2305 TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(node, new_shapes));
2306 ic = shape_refiner->GetContext(node);
2307 }
2308
2309 GraphView::InputPort port(node, 0);
2310 GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(port);
2311
2312 InferenceContext* src_ic = shape_refiner->GetContext(fanin.node);
2313 ShapeHandle input = src_ic->output(fanin.port_id);
2314 if (!ic->output(0).SameHandle(input)) {
2315 ic->SetInput(0, input);
2316 ic->set_output(0, input);
2317 *new_shapes = true;
2318 }
2319 auto* outputs = src_ic->output_handle_shapes_and_types(fanin.port_id);
2320 if (outputs) {
2321 ic->set_input_handle_shapes_and_types(0, *outputs);
2322 ic->set_output_handle_shapes_and_types(0, *outputs);
2323 *new_shapes = true;
2324 }
2325 return Status::OK();
2326 }
2327
UpdateShapes(SymbolicShapeRefiner * shape_refiner,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,const NodeDef * n,bool * new_shapes) const2328 Status GraphProperties::UpdateShapes(
2329 SymbolicShapeRefiner* shape_refiner,
2330 const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2331 const NodeDef* n, bool* new_shapes) const {
2332 if (IsEnter(*n)) {
2333 // The Enter shape function always forwards an UnknownShape, so do the right
2334 // thing here.
2335 TF_RETURN_IF_ERROR(UpdateEnter(shape_refiner, n, new_shapes));
2336 } else if (IsMerge(*n)) {
2337 // Properly handle merge nodes.
2338 TF_RETURN_IF_ERROR(UpdateMerge(shape_refiner, n, new_shapes));
2339 } else if (IsEnqueue(*n)) {
2340 // Make sure the shapes of enqueued tensors are propagated to the queue
2341 // itself.
2342 TF_RETURN_IF_ERROR(
2343 UpdateEnqueue(n, resource_handles, shape_refiner, new_shapes));
2344 } else if (IsQueue(*n)) {
2345 // Set shapes and types of Queue ops, if needed.
2346 TF_RETURN_IF_ERROR(UpdateQueue(n, shape_refiner, new_shapes));
2347 } else {
2348 // Rely on regular TF shape refinement for all the other nodes.
2349 // UpdateNode calls UpdateFunction if a function node is detected.
2350 TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
2351 }
2352
2353 return Status::OK();
2354 }
2355
2356 // Propagates the shapes in the transitive fan-out of <new_shapes>.
PropagateShapes(SymbolicShapeRefiner * shape_refiner,TopoQueue * new_shapes,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,int num_loops) const2357 Status GraphProperties::PropagateShapes(
2358 SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
2359 const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2360 int num_loops) const {
2361 // Limit the number of iterations to prevent infinite loops in the presence of
2362 // incorrect shape functions. The algorithm should converge in at most
2363 // num_nested_loops^2 * max_rank. We approximate max_rank with the constant 4.
2364 // The same applies to resources.
2365 VLOG(1) << "Propagating " << new_shapes->size() << " new shapes through "
2366 << num_loops << " loops and " << resource_handles.size()
2367 << " resources" << std::endl;
2368
2369 const int64_t max_loop_length = item_.graph.node_size();
2370 const int64_t max_rank = 4;
2371 const int64_t max_loop_iterations =
2372 max_rank * max_loop_length * std::max<int64>(1, num_loops * num_loops);
2373 const int64_t num_queues = resource_handles.size();
2374 const int64_t max_resource_iterations = num_queues * num_queues * max_rank;
2375
2376 int64_t num_resource_iterations = 0;
2377 do {
2378 int64_t num_loop_iterations = 0;
2379 while (!new_shapes->empty() &&
2380 num_loop_iterations++ < max_loop_iterations) {
2381 const NodeDef* n = new_shapes->pop();
2382 bool updated = false;
2383 TF_RETURN_IF_ERROR(
2384 UpdateShapes(shape_refiner, resource_handles, n, &updated));
2385 if (updated) {
2386 for (const auto& fanout : shape_refiner->graph().GetFanouts(
2387 *n, /*include_controlled_nodes=*/false)) {
2388 new_shapes->push(fanout.node);
2389 }
2390 // Make sure the corresponding queue nodes are (re)processed.
2391 if (IsEnqueue(*n)) {
2392 auto it = resource_handles.find(n);
2393 if (it != resource_handles.end()) {
2394 new_shapes->push(it->second);
2395 }
2396 }
2397 }
2398 }
2399 } while (!new_shapes->empty() &&
2400 num_resource_iterations++ < max_resource_iterations);
2401
2402 if (!new_shapes->empty()) {
2403 return errors::Internal("Shape inference failed to converge");
2404 }
2405
2406 return Status::OK();
2407 }
2408
UpdateQueue(const NodeDef * queue_node,SymbolicShapeRefiner * shape_refiner,bool * new_shapes)2409 Status GraphProperties::UpdateQueue(const NodeDef* queue_node,
2410 SymbolicShapeRefiner* shape_refiner,
2411 bool* new_shapes) {
2412 auto* ctx = shape_refiner->GetNodeContext(queue_node);
2413 if (!ctx) {
2414 TF_RETURN_IF_ERROR(shape_refiner->AddNode(queue_node));
2415 ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(queue_node));
2416 }
2417 auto* ic = ctx->inference_context.get();
2418
2419 auto* outputs = ic->output_handle_shapes_and_types(0);
2420 if (outputs) {
2421 // Shapes and types are already set, presumably by Enqueue ops.
2422 return shape_refiner->UpdateNode(queue_node, new_shapes);
2423 }
2424
2425 if (queue_node->attr().count("shapes") <= 0 ||
2426 queue_node->attr().count("component_types") <= 0 ||
2427 queue_node->attr().at("shapes").list().shape_size() !=
2428 queue_node->attr().at("component_types").list().type_size()) {
2429 // Errors in shapes and component_types attr.
2430 return shape_refiner->UpdateNode(queue_node, new_shapes);
2431 }
2432
2433 // Extract types and shapes from Queue attr.
2434 const auto& shapes = queue_node->attr().at("shapes").list().shape();
2435 const auto& types = queue_node->attr().at("component_types").list().type();
2436 std::vector<ShapeAndType> shapes_and_types;
2437 for (int i = 0; i < types.size(); i++) {
2438 const auto& shape = shapes[i];
2439 ShapeHandle shape_handle;
2440 TF_RETURN_IF_ERROR(
2441 ic->MakeShapeFromPartialTensorShape(shape, &shape_handle));
2442 DataType data_type =
2443 queue_node->attr().at("component_types").list().type(i);
2444 ShapeAndType shape_and_type(shape_handle, data_type);
2445 shapes_and_types.push_back(shape_and_type);
2446 }
2447 ic->set_output_handle_shapes_and_types(0, shapes_and_types);
2448
2449 // Queue node is updated with output_handle_shapes_and_types, so set
2450 // new_shapes and ignore it from UpdateNoe().
2451 *new_shapes = true;
2452 bool dummy_new_shapes = false;
2453 return shape_refiner->UpdateNode(queue_node, &dummy_new_shapes);
2454 }
2455
UpdateEnqueue(const NodeDef * enqueue_node,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,SymbolicShapeRefiner * shape_refiner,bool * new_shapes)2456 Status GraphProperties::UpdateEnqueue(
2457 const NodeDef* enqueue_node,
2458 const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2459 SymbolicShapeRefiner* shape_refiner, bool* new_shapes) {
2460 auto ctx = shape_refiner->GetNodeContext(enqueue_node);
2461 if (!ctx) {
2462 TF_RETURN_IF_ERROR(shape_refiner->AddNode(enqueue_node));
2463 ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(enqueue_node));
2464 }
2465
2466 auto it = resource_handles.find(enqueue_node);
2467 if (it == resource_handles.end()) {
2468 // The corresponding queue was not found, there isn't much we can do.
2469 return Status::OK();
2470 }
2471 const NodeDef* qnode = it->second;
2472 auto qctx = shape_refiner->GetContext(qnode);
2473 if (!qctx) {
2474 return Status::OK();
2475 }
2476 auto* queue_handle_data = qctx->output_handle_shapes_and_types(0);
2477
2478 // TODO(bsteiner): handle EnqueueMany as well.
2479 std::vector<ShapeAndType> shapes_and_types;
2480 for (int i = 1, end = ctx->input_types.size(); i < end; ++i) {
2481 GraphView::InputPort inp(enqueue_node, i);
2482 GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(inp);
2483 InferenceContext* in = shape_refiner->GetContext(fanin.node);
2484 ShapeHandle input = in->output(fanin.port_id);
2485 ctx->inference_context->SetInput(i, input);
2486 shapes_and_types.push_back({input, ctx->input_types[i]});
2487 }
2488
2489 if (queue_handle_data == nullptr) {
2490 qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
2491 *new_shapes = true;
2492 } else {
2493 TF_RETURN_IF_ERROR(RelaxEnqueueShapesAndMergeTypes(
2494 shape_refiner, qnode, *queue_handle_data, &shapes_and_types));
2495 *new_shapes |= !shape_refiner->EquivalentShapesAndTypes(*queue_handle_data,
2496 shapes_and_types);
2497 qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
2498 }
2499
2500 return Status::OK();
2501 }
2502
InferStatically(bool assume_valid_feeds,bool aggressive_shape_inference,bool include_input_tensor_values,bool include_output_tensor_values)2503 Status GraphProperties::InferStatically(bool assume_valid_feeds,
2504 bool aggressive_shape_inference,
2505 bool include_input_tensor_values,
2506 bool include_output_tensor_values) {
2507 FunctionLibraryDefinition function_library(OpRegistry::Global(),
2508 item_.graph.library());
2509 absl::flat_hash_map<string, absl::flat_hash_set<int>> fed_ports;
2510 if (!assume_valid_feeds) {
2511 for (const auto& feed : item_.feed) {
2512 SafeTensorId tensor_id = ParseTensorName(feed.first);
2513 fed_ports[tensor_id.node()].insert(tensor_id.index());
2514 }
2515 }
2516
2517 GraphView graph_view(&item_.graph);
2518
2519 // List the resources and the nodes using them. Also collect the Merge nodes,
2520 // fed nodes, and primary inputs.
2521 absl::flat_hash_map<const NodeDef*,
2522 std::pair<absl::flat_hash_set<const NodeDef*>,
2523 absl::flat_hash_set<const NodeDef*>>>
2524 resources;
2525 absl::flat_hash_set<const NodeDef*> merge_nodes;
2526 absl::flat_hash_set<const NodeDef*> fed_nodes;
2527 absl::flat_hash_set<const NodeDef*> primary_inputs;
2528 int num_loops = 0;
2529 for (const NodeDef& node : item_.graph.node()) {
2530 if (IsQueue(node)) {
2531 for (const GraphView::InputPort& fanout :
2532 graph_view.GetFanouts(node, false)) {
2533 if (IsEnter(*fanout.node)) {
2534 const NodeDef& enter = *fanout.node;
2535 for (const GraphView::InputPort& fanout :
2536 graph_view.GetFanouts(enter, false)) {
2537 if (IsEnqueue(*fanout.node)) {
2538 resources[&node].first.insert(fanout.node);
2539 } else if (IsDequeue(*fanout.node)) {
2540 resources[&node].second.insert(fanout.node);
2541 }
2542 }
2543 } else {
2544 if (IsEnqueue(*fanout.node)) {
2545 resources[&node].first.insert(fanout.node);
2546 } else if (IsDequeue(*fanout.node)) {
2547 resources[&node].second.insert(fanout.node);
2548 }
2549 }
2550 }
2551 }
2552 if (!HasRegularInputs(node)) {
2553 primary_inputs.insert(&node);
2554 } else if (IsMerge(node)) {
2555 merge_nodes.insert(&node);
2556 } else if (IsNextIteration(node)) {
2557 ++num_loops;
2558 }
2559 if (fed_ports.find(node.name()) != fed_ports.end()) {
2560 fed_nodes.insert(&node);
2561 }
2562 }
2563
2564 absl::flat_hash_map<const NodeDef*, const NodeDef*> resource_handles;
2565 std::vector<TopologicalDependency> extra_deps;
2566 for (const auto& resource : resources) {
2567 for (const NodeDef* src : resource.second.first) {
2568 resource_handles[src] = resource.first;
2569 for (const NodeDef* dst : resource.second.second) {
2570 // Add control edges from enqueue to dequeue nodes to ensure they are
2571 // processed in their logical order.
2572 extra_deps.emplace_back(src, dst);
2573 }
2574 }
2575 }
2576
2577 std::vector<const NodeDef*> topo_order;
2578 Status s = ComputeTopologicalOrder(item_.graph, extra_deps, &topo_order);
2579 if (!s.ok()) {
2580 if (extra_deps.empty()) {
2581 return s;
2582 } else {
2583 // There is a loop between queues: we'll just use the graph topological
2584 // order. This will make the shape inference less precise but since this
2585 // isn't common it's not worth to figure out where to break the loop and
2586 // do a proper relaxation.
2587 TF_RETURN_IF_ERROR(ComputeTopologicalOrder(item_.graph, &topo_order));
2588 }
2589 }
2590
2591 // Heap-allocate SymbolicShapeRefiner in order to not consume a large amount
2592 // of stack space.
2593 auto refiner = absl::make_unique<SymbolicShapeRefiner>(
2594 graph_view, fed_ports, aggressive_shape_inference);
2595
2596 TopoQueue new_shapes(topo_order);
2597 // Also seed the propagation of shapes in the fanout of primary inputs.
2598 for (const NodeDef* node : primary_inputs) {
2599 new_shapes.push(node);
2600 }
2601 // Also seed the propagation of shapes in the fanout of fed nodes.
2602 for (const NodeDef* node : fed_nodes) {
2603 new_shapes.push(node);
2604 }
2605 // Propagate shapes normally.
2606 TF_RETURN_IF_ERROR(
2607 PropagateShapes(refiner.get(), &new_shapes, resource_handles, num_loops));
2608
2609 // Track shapes globally across the graph.
2610 std::unique_ptr<SymbolicShapeManager> shape_manager =
2611 absl::make_unique<SymbolicShapeManager>();
2612 bool found_error = false;
2613 for (const NodeDef& node : item_.graph.node()) {
2614 auto node_ctx = refiner->GetContext(&node);
2615 if (!node_ctx) {
2616 continue;
2617 }
2618 // Skip any information that comes from fed nodes.
2619 if (fed_ports.find(node.name()) != fed_ports.end()) {
2620 VLOG(2) << "Skipping feed node shape: " << node.name();
2621 continue;
2622 }
2623 for (const auto& merged_shapes : node_ctx->MergedShapes()) {
2624 if (!shape_manager->Merge(merged_shapes.first, merged_shapes.second)
2625 .ok()) {
2626 found_error = true;
2627 break;
2628 }
2629 }
2630 for (const auto& merged_dims : node_ctx->MergedDims()) {
2631 if (!shape_manager->Merge(merged_dims.first, merged_dims.second).ok()) {
2632 found_error = true;
2633 break;
2634 }
2635 }
2636 if (found_error) {
2637 // The shapes aren't consistent, we can't infer safely: discard all the
2638 // information discovered so far.
2639 shape_manager = absl::make_unique<SymbolicShapeManager>();
2640 break;
2641 }
2642 }
2643
2644 TF_RETURN_IF_ERROR(ValidateSymbolicShapeManager(item_.graph, refiner.get(),
2645 shape_manager.get()));
2646
2647 for (const NodeDef& node : item_.graph.node()) {
2648 VLOG(4) << "Filling in graph properties for node: " << node.name();
2649 auto ctx = refiner->GetNodeContext(&node);
2650 if (!ctx) {
2651 continue;
2652 }
2653
2654 auto* ic = ctx->inference_context.get();
2655
2656 // Fill input properties.
2657 {
2658 auto& input_properties = input_properties_[node.name()];
2659
2660 // Should always be empty, node names in graph are supposed to be unique.
2661 CHECK_EQ(input_properties.size(), 0);
2662
2663 input_properties.resize(ic->num_inputs());
2664 GraphView::InputPort input(&node, -1);
2665 for (int i = 0; i < ic->num_inputs(); ++i) {
2666 shape_manager->AsTensorProperties(ic->input(i), ctx->input_types[i],
2667 &input_properties[i]);
2668 input.port_id = i;
2669 GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
2670 if (include_input_tensor_values) {
2671 // Export tensor value to input_properties.value.
2672 if (IsConstant(*fanin.node)) {
2673 const TensorProto& raw_val =
2674 fanin.node->attr().at("value").tensor();
2675 *input_properties[i].mutable_value() = raw_val;
2676 } else if (static_cast<int>(ctx->input_tensor_protos.size()) > i &&
2677 ctx->input_tensor_protos[i] != nullptr) {
2678 *input_properties[i].mutable_value() = *ctx->input_tensor_protos[i];
2679 } else if (static_cast<int>(ic->input_tensors_as_shapes().size()) >
2680 i &&
2681 IsShapeFullyDefinedIntegerVectorOrScalar(
2682 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
2683 ctx->input_types[i])) {
2684 *input_properties[i].mutable_value() = MakeTensorProtoFromShape(
2685 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
2686 ctx->input_types[i]);
2687 }
2688 }
2689 }
2690 }
2691
2692 // Fill output properties.
2693 {
2694 auto& output_properties = output_properties_[node.name()];
2695
2696 // Should always be empty, node names in graph are supposed to be unique.
2697 CHECK_EQ(output_properties.size(), 0);
2698
2699 output_properties.resize(ic->num_outputs());
2700 for (int i = 0; i < ic->num_outputs(); ++i) {
2701 shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i],
2702 &output_properties[i]);
2703 auto converted_output_tensors_as_shapes =
2704 ReplaceUnknownDimFromConstWithUnknownDim(
2705 ic, ctx->output_tensors_as_shapes);
2706 if (include_output_tensor_values) {
2707 // Export tensor value to output_properties.value.
2708 if (IsConstant(node)) {
2709 // TODO(rmlarsen): Eliminate this copy.
2710 const TensorProto& raw_val = node.attr().at("value").tensor();
2711 *output_properties[i].mutable_value() = raw_val;
2712 } else if (static_cast<int>(ctx->output_tensor_protos.size()) > i &&
2713 ctx->output_tensor_protos[i] != nullptr) {
2714 *output_properties[i].mutable_value() =
2715 *ctx->output_tensor_protos[i];
2716 } else if (static_cast<int>(
2717 converted_output_tensors_as_shapes.size()) > i &&
2718 IsShapeFullyDefinedIntegerVectorOrScalar(
2719 ic, ic->output(i),
2720 converted_output_tensors_as_shapes[i],
2721 ctx->output_types[i])) {
2722 *output_properties[i].mutable_value() = MakeTensorProtoFromShape(
2723 ic, ic->output(i), converted_output_tensors_as_shapes[i],
2724 ctx->output_types[i]);
2725 }
2726 }
2727 }
2728 }
2729
2730 if (aggressive_shape_inference && ctx->shape_incompatible)
2731 incompatible_shape_nodes_.insert(node.name());
2732 }
2733
2734 if (aggressive_shape_inference && !incompatible_shape_nodes_.empty())
2735 LOG(WARNING) << incompatible_shape_nodes_.size()
2736 << " nodes have incompatible output shapes.";
2737
2738 // Help trace the unknown dimensions to their origins.
2739 VerboseLogUnknownDimensionSources(item_.graph, input_properties_,
2740 output_properties_);
2741
2742 TF_RETURN_IF_ERROR(VerboseShapeInferenceLogging(item_.graph, refiner.get(),
2743 shape_manager.get()));
2744
2745 return Status::OK();
2746 }
2747
InferDynamically(Cluster * cluster)2748 Status GraphProperties::InferDynamically(Cluster* cluster) {
2749 TF_RETURN_IF_ERROR(cluster->Initialize(item_));
2750
2751 // Runs the model once to collect the shapes in the cost model.
2752 RunMetadata metadata;
2753 TF_RETURN_IF_ERROR(
2754 cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata));
2755
2756 return InferFromCostGraph(metadata.cost_graph());
2757 }
2758
AnnotateOutputShapes(GraphDef * output_graph_def,bool allow_symbolic_shapes) const2759 Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def,
2760 bool allow_symbolic_shapes) const {
2761 *output_graph_def = item_.graph;
2762 for (int i = 0; i < output_graph_def->node_size(); i++) {
2763 auto node = output_graph_def->mutable_node(i);
2764 AttrValue attr_output_shape;
2765 auto tensor_properties = GetOutputProperties(node->name());
2766 for (const auto& tensor_property : tensor_properties) {
2767 TensorShapeProto* proto = attr_output_shape.mutable_list()->add_shape();
2768 *proto = tensor_property.shape();
2769 if (!allow_symbolic_shapes) {
2770 NormalizeShapeForOutput(proto);
2771 }
2772 }
2773 (*node->mutable_attr())["_output_shapes"] = attr_output_shape;
2774 }
2775 return Status::OK();
2776 }
2777
InferFromCostGraph(const CostGraphDef & cost_graph)2778 Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) {
2779 if (cost_graph.node_size() == 0) {
2780 LOG(WARNING) << "cost_graph is empty: nothing can be inferred!";
2781 }
2782 std::unordered_map<string, const CostGraphDef::Node*> name_to_cost;
2783 std::unordered_map<string, const NodeDef*> name_to_node; // Empty
2784 for (auto& node : cost_graph.node()) {
2785 name_to_cost[node.name()] = &node;
2786
2787 std::vector<OpInfo::TensorProperties> output_properties;
2788 for (const auto& out : node.output_info()) {
2789 OpInfo::TensorProperties properties;
2790 properties.set_dtype(out.dtype());
2791 *properties.mutable_shape() = out.shape();
2792 output_properties.push_back(properties);
2793 }
2794 output_properties_[node.name()] = output_properties;
2795 }
2796
2797 for (const auto& node : item_.graph.node()) {
2798 // Skip the nodes that are not in the cost graph: these are nodes that
2799 // aren't run, because they aren't in the intersection of transitive fan-in
2800 // of a fetch node and the transitive fan-out of an input, or nodes that
2801 // were optimized away by the optimizer.
2802 auto it = name_to_cost.find(node.name());
2803 if (it == name_to_cost.end()) {
2804 continue;
2805 }
2806 std::vector<OpInfo::TensorProperties> inputs =
2807 FindInputFeatures(node, name_to_cost, name_to_node);
2808
2809 input_properties_[node.name()] = inputs;
2810 }
2811 return Status::OK();
2812 }
2813
HasInputProperties(const string & node_name) const2814 bool GraphProperties::HasInputProperties(const string& node_name) const {
2815 return input_properties_.find(node_name) != input_properties_.end();
2816 }
2817
HasOutputProperties(const string & node_name) const2818 bool GraphProperties::HasOutputProperties(const string& node_name) const {
2819 return output_properties_.find(node_name) != output_properties_.end();
2820 }
2821
2822 const std::vector<OpInfo::TensorProperties>&
GetInputProperties(const string & node_name) const2823 GraphProperties::GetInputProperties(const string& node_name) const {
2824 auto it = input_properties_.find(node_name);
2825 if (it != input_properties_.end()) {
2826 return it->second;
2827 }
2828 return missing_properties_;
2829 }
2830
2831 const std::vector<OpInfo::TensorProperties>&
GetOutputProperties(const string & node_name) const2832 GraphProperties::GetOutputProperties(const string& node_name) const {
2833 auto it = output_properties_.find(node_name);
2834 if (it != output_properties_.end()) {
2835 return it->second;
2836 }
2837 return missing_properties_;
2838 }
2839
ClearInputProperties(const string & node_name)2840 void GraphProperties::ClearInputProperties(const string& node_name) {
2841 input_properties_.erase(node_name);
2842 }
ClearOutputProperties(const string & node_name)2843 void GraphProperties::ClearOutputProperties(const string& node_name) {
2844 output_properties_.erase(node_name);
2845 }
2846
2847 } // end namespace grappler
2848 } // end namespace tensorflow
2849