• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/graph_properties.h"
17 
18 #include "absl/types/optional.h"
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/framework/common_shape_fns.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/framework/versions.pb.h"
29 #include "tensorflow/core/graph/tensor_id.h"
30 #include "tensorflow/core/grappler/costs/utils.h"
31 #include "tensorflow/core/grappler/mutable_graph_view.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
34 #include "tensorflow/core/grappler/utils.h"
35 #include "tensorflow/core/grappler/utils/functions.h"
36 #include "tensorflow/core/grappler/utils/topological_sort.h"
37 #include "tensorflow/core/lib/gtl/cleanup.h"
38 #include "tensorflow/core/lib/gtl/flatset.h"
39 #include "tensorflow/core/lib/strings/str_util.h"
40 
41 namespace tensorflow {
42 namespace grappler {
43 
44 namespace {
45 
46 using shape_inference::DimensionHandle;
47 using shape_inference::InferenceContext;
48 using shape_inference::ShapeAndType;
49 using shape_inference::ShapeHandle;
50 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
51 
52 // A large value for UnknownDim from Const used as a dim value in shape.
53 // Some ops treat "-1" specially, different from UnknownDim:
54 // e.g., shape input to Reshape op.
55 const int64_t kUnknownDimFromConst = INT64_MAX;
56 
57 // Skip const value instantiation if the number of elements in a const tensor
58 // is greater than this threshold.
59 const int kThresholdToSkipConstTensorInstantiation = 128;
60 
61 template <typename Handle>
62 struct HashHandle {
operator ()tensorflow::grappler::__anon19a2e1120111::HashHandle63   std::size_t operator()(const Handle& h) const { return h.Handle(); }
64 };
65 template <typename Handle>
66 struct CompareHandle {
operator ()tensorflow::grappler::__anon19a2e1120111::CompareHandle67   bool operator()(const Handle& h1, const Handle& h2) const {
68     return h1.SameHandle(h2);
69   }
70 };
71 
72 template <typename Handle>
73 struct HandleToObject {};
74 template <>
75 struct HandleToObject<ShapeHandle> {
76   typedef ShapeHandle Object;
77 
Unknowntensorflow::grappler::__anon19a2e1120111::HandleToObject78   static ShapeHandle Unknown() { return ShapeHandle(); }
79 };
80 
81 template <>
82 struct HandleToObject<DimensionHandle> {
83   typedef int64 Object;
84 
Unknowntensorflow::grappler::__anon19a2e1120111::HandleToObject85   static int64 Unknown() { return -1; }
86 };
87 
88 template <typename Handle>
89 struct Processor {};
90 
91 template <>
92 struct Processor<ShapeHandle> {
93   // Extract the shape or dim denoted by the handle.
ExtractValuetensorflow::grappler::__anon19a2e1120111::Processor94   void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; }
95   // Merge the shapes or dims.
Mergetensorflow::grappler::__anon19a2e1120111::Processor96   Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) {
97     if (InferenceContext::RankKnown(*result)) {
98       // The result was initialized in a previous merge to a shape of known
99       // rank, make sure we preserve that information.
100       return Status::OK();
101     }
102     if (InferenceContext::RankKnown(h1)) {
103       *result = h1;
104     } else {
105       *result = h2;
106     }
107     return Status::OK();
108   }
109 };
110 
111 template <>
112 struct Processor<DimensionHandle> {
113   // Assign a negative id to unknown dimensions, starting at -2 (the -1 id
114   // reserved by TensorFlow).
ExtractValuetensorflow::grappler::__anon19a2e1120111::Processor115   void ExtractValue(DimensionHandle d, int64* result) {
116     if (!InferenceContext::ValueKnown(d)) {
117       *result = -counter;
118       counter++;
119     } else {
120       int64_t val = InferenceContext::Value(d);
121       if (val >= 0) {
122         *result = val;
123       } else {
124         // A shape inference function generated an invalid dimension handle.
125         // Use a symbolic dimension to encode this.
126         *result = -counter;
127         counter++;
128       }
129     }
130   }
131 
132   // Merge the dimensions d1 and d2. Return the known shape if there is one,
133   // otherwise look for a symbolic shape. If there is no symbolic shape and no
134   // known shape, the shape if fully unknown so return -1.
Mergetensorflow::grappler::__anon19a2e1120111::Processor135   Status Merge(DimensionHandle d1, DimensionHandle d2, int64* result) {
136     const int64_t dim1 = InferenceContext::Value(d1);
137     const int64_t dim2 = InferenceContext::Value(d2);
138 
139     if (dim1 >= 0 && dim2 >= 0) {
140       CHECK_EQ(dim1, dim2);
141       return RefineDim(dim1, result);
142     } else if (dim1 >= 0 && dim2 < 0) {
143       return RefineDim(dim1, result);
144     } else if (dim1 < 0 && dim2 >= 0) {
145       return RefineDim(dim2, result);
146     } else if (dim1 < -1) {
147       return RefineDim(dim1, result);
148     } else if (dim2 < -1) {
149       return RefineDim(dim2, result);
150     } else {
151       CHECK_EQ(dim1, dim2);
152       CHECK_EQ(-1, dim1);
153       return RefineDim(-1, result);
154     }
155     return Status::OK();
156   }
157 
158  private:
RefineDimtensorflow::grappler::__anon19a2e1120111::Processor159   Status RefineDim(int64_t dim, int64* result) {
160     if (*result >= 0) {
161       if (!(*result == dim || dim < 0)) {
162         return errors::InvalidArgument("Inconsistent dimensions detected");
163       }
164     } else if (dim >= 0) {
165       *result = dim;
166     } else if (dim < *result) {
167       *result = dim;
168     }
169     return Status::OK();
170   }
171 
172   int64 counter = 2;
173 };
174 
175 // Traditional Disjoint-Set datastructure with path compression.
176 // (https://en.wikipedia.org/wiki/Disjoint-set_data_structure)
177 template <typename Handle>
178 class DisjointSet {
179  public:
DisjointSet()180   DisjointSet() {}
~DisjointSet()181   ~DisjointSet() {
182     for (auto rep : nodes_) {
183       delete rep.second;
184     }
185   }
186 
187   Status Merge(Handle x, Handle y);
188   const typename HandleToObject<Handle>::Object GetMergedValue(Handle value);
189 
190  private:
191   // All the handles that belong to the same set are part of the same tree, and
192   // utimately represented by the root of that tree.
193   struct Rep {
194     // Parent in the tree used to encode the set.
195     Rep* parent;
196     // Rank in the tree, used to figure out how to compress the path to the root
197     // of the tree.
198     int rank;
199     // The handle.
200     typename HandleToObject<Handle>::Object value;
201   };
202 
203   // Create a new set for the value if none exists, or return its representative
204   // node otherwise.
205   Rep* Find(Handle value);
206 
207  private:
208   Processor<Handle> processor_;
209   absl::flat_hash_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>>
210       nodes_;
211 };
212 
213 template <typename Handle>
214 const typename HandleToObject<Handle>::Object
GetMergedValue(Handle value)215 DisjointSet<Handle>::GetMergedValue(Handle value) {
216   Rep* rep = Find(value);
217   if (!rep) {
218     // We don't know anything about this handle.
219     return HandleToObject<Handle>::Unknown();
220   }
221   return rep->value;
222 }
223 
224 template <typename Handle>
Merge(Handle x,Handle y)225 Status DisjointSet<Handle>::Merge(Handle x, Handle y) {
226   Rep* x_root = Find(x);
227   Rep* y_root = Find(y);
228 
229   // x and y are already in the same set
230   if (x_root == y_root) {
231     return Status::OK();
232   }
233   // x and y are not in same set, so we merge them
234   // Use the occasion to strengthen what we know about the handle by merging the
235   // information about the 2 subsets.
236   if (x_root->rank < y_root->rank) {
237     TF_RETURN_IF_ERROR(processor_.Merge(y, x, &y_root->value));
238     x_root->parent = y_root;
239   } else if (x_root->rank > y_root->rank) {
240     TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
241     y_root->parent = x_root;
242   } else {
243     TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
244     // Arbitrarily make one root the new parent
245     y_root->parent = x_root;
246     x_root->rank = x_root->rank + 1;
247   }
248   return Status::OK();
249 }
250 
251 template <typename Handle>
Find(Handle value)252 typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
253   auto it = nodes_.find(value);
254   if (it == nodes_.end()) {
255     // This is the first time we process this handle, create an entry for it.
256     Rep* node = new Rep;
257     node->parent = node;
258     node->rank = 0;
259     processor_.ExtractValue(value, &node->value);
260     nodes_[value] = node;
261     return node;
262   }
263   // Return the representative for the set, which is the root of the tree. Apply
264   // path compression to speedup future queries.
265   Rep* node = it->second;
266   Rep* root = node->parent;
267   while (root != root->parent) {
268     root = root->parent;
269   }
270   while (node->parent != root) {
271     Rep* next = node->parent;
272     node->parent = root;
273     node = next;
274   }
275   return root;
276 }
277 
278 // TODO(dyoon): Move many helper functions in this file (including those within
279 // SymbolicShapeRefiner class) to shared utils.
IsEnqueue(const NodeDef & n)280 bool IsEnqueue(const NodeDef& n) {
281   return (n.op().find("Enqueue") != string::npos &&
282           n.op().find("EnqueueMany") == string::npos);
283 }
284 
IsDequeue(const NodeDef & n)285 bool IsDequeue(const NodeDef& n) {
286   return (n.op().find("Dequeue") != string::npos &&
287           n.op().find("DequeueMany") == string::npos);
288 }
289 
HasAnyUnknownDimensions(const TensorShapeProto & proto)290 bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
291   if (proto.unknown_rank()) {
292     return true;
293   }
294   for (const auto& dim : proto.dim()) {
295     if (dim.size() < 0) {
296       return true;
297     }
298   }
299   return false;
300 }
301 
302 // This really should be done in an external debugging tool
VerboseLogUnknownDimensionSources(const GraphDef & graph,const absl::flat_hash_map<string,std::vector<OpInfo::TensorProperties>> & input_properties_map,const absl::flat_hash_map<string,std::vector<OpInfo::TensorProperties>> & output_properties_map)303 void VerboseLogUnknownDimensionSources(
304     const GraphDef& graph,
305     const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
306         input_properties_map,
307     const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
308         output_properties_map) {
309   if (!VLOG_IS_ON(2)) {
310     return;
311   }
312 
313   VLOG(2) << "Nodes with known inputs, but with unknown output dimensions:";
314 
315   // Find all nodes in the graph for which we
316   // do not have any unknown dimensions in their inputs, but
317   // we have some unknown dimensions in their outputs.
318   std::map<string, int> op_to_count;
319   for (const NodeDef& node : graph.node()) {
320     const auto& input_properties = input_properties_map.at(node.name());
321     const auto& output_properties = output_properties_map.at(node.name());
322 
323     bool has_unknown_inputs = false;
324     for (const auto& input_prop : input_properties) {
325       if (HasAnyUnknownDimensions(input_prop.shape())) {
326         has_unknown_inputs = true;
327         break;
328       }
329     }
330 
331     if (has_unknown_inputs) {
332       continue;
333     }
334 
335     for (const auto& output_prop : output_properties) {
336       if (HasAnyUnknownDimensions(output_prop.shape())) {
337         string inputs = "input_shapes=[";
338         for (const auto& input_prop : input_properties) {
339           inputs += PartialTensorShape::DebugString(input_prop.shape());
340         }
341         inputs += "]";
342 
343         string outputs = "output_shapes=[";
344         for (const auto& output_prop : output_properties) {
345           outputs += PartialTensorShape::DebugString(output_prop.shape());
346         }
347         outputs += "]";
348 
349         VLOG(2) << "Node: " << node.name() << ", Op: " << node.op() << ", "
350                 << inputs << ", " << outputs;
351 
352         op_to_count[node.op()]++;
353 
354         // don't log again for this node
355         break;
356       }
357     }
358   }
359   VLOG(2) << "Op types with known inputs, but with unknown output dimensions "
360           << "(format: <op_type> (<count>)):";
361   for (const auto& p : op_to_count) {
362     VLOG(2) << p.first << " (" << p.second << ")";
363   }
364 }
365 
366 // Helper function to convert kUnknownDimFromConst into UnknownDim.
ReplaceUnknownDimFromConstWithUnknownDim(InferenceContext * ic,const std::vector<ShapeHandle> & shapes)367 std::vector<ShapeHandle> ReplaceUnknownDimFromConstWithUnknownDim(
368     InferenceContext* ic, const std::vector<ShapeHandle>& shapes) {
369   std::vector<ShapeHandle> converted_shapes(shapes.size());
370   for (int i = 0, shapes_size = shapes.size(); i < shapes_size; i++) {
371     const auto& shape = shapes[i];
372     if (!ic->RankKnown(shape)) {
373       converted_shapes[i] = shape;
374       continue;
375     }
376     bool just_copy = true;
377     std::vector<DimensionHandle> dims;
378     for (int32_t i = 0; i < ic->Rank(shape); ++i) {
379       DimensionHandle dim = ic->Dim(shape, i);
380       if (ic->ValueKnown(dim) && ic->Value(dim) == kUnknownDimFromConst) {
381         just_copy = false;
382         dims.push_back(ic->UnknownDim());
383       } else {
384         dims.push_back(dim);
385       }
386     }
387     if (just_copy) {
388       converted_shapes[i] = shape;
389       continue;
390     }
391     converted_shapes[i] = ic->MakeShape(dims);
392   }
393   return converted_shapes;
394 }
395 
396 // Returned tensor's shape is like `shape`, and its values and dtype are from
397 // `tensor_as_shape` and `dtype`.
MakeTensorProtoFromShape(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)398 TensorProto MakeTensorProtoFromShape(InferenceContext* ic,
399                                      const ShapeHandle& shape,
400                                      const ShapeHandle& tensor_as_shape,
401                                      const DataType& dtype) {
402   TensorProto tensor_proto;
403   tensor_proto.set_dtype(dtype);
404   auto* shape_proto = tensor_proto.mutable_tensor_shape();
405   if (ic->Rank(shape) == 1) {
406     shape_proto->add_dim()->set_size(ic->Rank(tensor_as_shape));
407   }
408   // For a scalar tensor, tensor_shape field will be left empty; no dim.
409   for (int i = 0; i < ic->Rank(tensor_as_shape); i++) {
410     int64_t value = ic->Value(ic->Dim(tensor_as_shape, i));
411     if (dtype == DT_INT32) {
412       tensor_proto.add_int_val(value);
413     } else {
414       tensor_proto.add_int64_val(value);
415     }
416   }
417   return tensor_proto;
418 }
419 
420 // Returns a Const NodeDef with tensor `tensor_proto` and dtype = `dtype`.
MakeConstNodeDefFromTensorProto(InferenceContext * ic,const TensorProto & tensor_proto,const DataType & dtype)421 NodeDef MakeConstNodeDefFromTensorProto(InferenceContext* ic,
422                                         const TensorProto& tensor_proto,
423                                         const DataType& dtype) {
424   NodeDef const_node;
425   const_node.set_name("const_from_shape");
426   const_node.set_op("Const");
427   auto* attr = const_node.mutable_attr();
428   (*attr)["dtype"].set_type(dtype);
429   auto* tensor = (*attr)["value"].mutable_tensor();
430   *tensor = tensor_proto;
431   return const_node;
432 }
433 
434 // Returns a Const NodeDef with shape = `shape`, values = `tensor_as_shape`,
435 // and dtype = `dtype`.
MakeConstNodeDefFromShape(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)436 NodeDef MakeConstNodeDefFromShape(InferenceContext* ic,
437                                   const ShapeHandle& shape,
438                                   const ShapeHandle& tensor_as_shape,
439                                   const DataType& dtype) {
440   return MakeConstNodeDefFromTensorProto(
441       ic, MakeTensorProtoFromShape(ic, shape, tensor_as_shape, dtype), dtype);
442 }
443 
IsNumericType(const DataType dtype)444 bool IsNumericType(const DataType dtype) {
445   static const gtl::FlatSet<DataType>* const kRealNumberTypes =
446       CHECK_NOTNULL((new gtl::FlatSet<DataType>{
447           // Floating point.
448           DT_BFLOAT16,
449           DT_HALF,
450           DT_FLOAT,
451           DT_DOUBLE,
452           // Int / UInt.
453           DT_INT8,
454           DT_INT16,
455           DT_INT32,
456           DT_INT64,
457           DT_UINT8,
458           DT_UINT16,
459           DT_UINT32,
460           DT_UINT64,
461           // Quantized Int.
462           DT_QINT8,
463           DT_QUINT8,
464           DT_QINT16,
465           DT_QUINT16,
466           DT_QINT32,
467           // Bool.
468           DT_BOOL,
469       }));
470   return kRealNumberTypes->find(dtype) != kRealNumberTypes->end();
471 }
472 
473 // Returns the number of elements in the input (const) tensor.
474 // -1 if the tensor has no shape or unknown rank.
NumElementsFromTensorProto(const TensorProto & tensor_proto)475 uint64 NumElementsFromTensorProto(const TensorProto& tensor_proto) {
476   if (!tensor_proto.has_tensor_shape()) {
477     return -1;
478   }
479   const auto& tensor_shape_proto = tensor_proto.tensor_shape();
480   if (tensor_shape_proto.unknown_rank()) {
481     return -1;
482   }
483   int64_t num_elements = 1;
484   for (const auto& dim : tensor_shape_proto.dim()) {
485     // Note that in some cases, dim.size() can be zero (e.g., empty vector).
486     num_elements *= dim.size();
487   }
488   return num_elements;
489 }
490 
491 }  // namespace
492 
493 // Note that tensor_as_shape input should not include kUnknownDimFromConst.
494 // This function check kUnknownDimFromConst, but will log WARNING.
495 // If checking input_tensors_as_shape_to_propgate or output_tensors_as_shape,
496 // which may include kUnknownDimFromConst, run
497 // convert it using ReplaceUnknownDimFromConstWithUnknownDim() before.
IsShapeFullyDefinedIntegerVectorOrScalar(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)498 bool IsShapeFullyDefinedIntegerVectorOrScalar(
499     InferenceContext* ic, const ShapeHandle& shape,
500     const ShapeHandle& tensor_as_shape, const DataType& dtype) {
501   if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 ||
502       !ic->FullyDefined(tensor_as_shape) ||
503       (dtype != DT_INT32 && dtype != DT_INT64)) {
504     return false;
505   }
506   // Also check whether any dim in tensor_as_shape is kUnknownDimFromConst.
507   for (int32_t i = 0; i < ic->Rank(tensor_as_shape); ++i) {
508     DimensionHandle dim = ic->Dim(tensor_as_shape, i);
509     if (ic->Value(dim) == kUnknownDimFromConst) {
510       LOG(WARNING) << "IsShapeFullyDefinedIntegerVectorOrScalar(): "
511                    << "tensor_as_shape input includes kUnknownDimFromConst -- "
512                    << ic->DebugString(tensor_as_shape);
513       return false;
514     }
515   }
516   return true;
517 }
518 
519 // Queue of nodes to process. Nodes can be enqueued in any order, but will be
520 // dequeued in (roughly) topological order. Propagating shapes following a
521 // topological ordering isn't required for correctness but helps speed things up
522 // since it avoids processing the same node multiple times as its inputs
523 // information is refined.
524 class TopoQueue {
525  public:
TopoQueue(const std::vector<const NodeDef * > & topo_order)526   explicit TopoQueue(const std::vector<const NodeDef*>& topo_order)
527       : topo_order_(TopoOrder(topo_order)) {}
528 
push(const NodeDef * n)529   void push(const NodeDef* n) { queue_.emplace(n, topo_order_.at(n)); }
530 
pop()531   const NodeDef* pop() {
532     CHECK(!empty());
533     auto it = queue_.begin();
534     const NodeDef* n = it->first;
535     queue_.erase(it);
536     return n;
537   }
538 
empty() const539   bool empty() const { return queue_.empty(); }
size() const540   std::size_t size() const { return queue_.size(); }
541 
542  private:
543   using NodeAndId = std::pair<const NodeDef*, int>;
544   // Graph nodes are created in (roughly) topological order. Therefore we can
545   // use their id to ensure they're sorted topologically.
546   struct OrderByIdAscending {
operator ()tensorflow::grappler::TopoQueue::OrderByIdAscending547     bool operator()(const NodeAndId& lhs, const NodeAndId& rhs) const {
548       return lhs.second < rhs.second;
549     }
550   };
551 
TopoOrder(const std::vector<const NodeDef * > & topo_order) const552   const absl::flat_hash_map<const NodeDef*, int> TopoOrder(
553       const std::vector<const NodeDef*>& topo_order) const {
554     absl::flat_hash_map<const NodeDef*, int> map;
555     map.reserve(topo_order.size());
556     for (int i = 0, topo_order_size = topo_order.size(); i < topo_order_size;
557          ++i) {
558       map.emplace(topo_order[i], i);
559     }
560     return map;
561   }
562 
563   const absl::flat_hash_map<const NodeDef*, int> topo_order_;
564   std::set<NodeAndId, OrderByIdAscending> queue_;
565 };
566 
567 
IsAllowListedOpTypeForEvaluateNode(const string & op_type)568 bool IsAllowListedOpTypeForEvaluateNode(const string& op_type) {
569   static const gtl::FlatSet<string>* const kOpTpeAllowlist =
570       CHECK_NOTNULL((new gtl::FlatSet<string>{
571           // Unary arithmetic ops
572           "Floor",
573           "Round",
574           "Sqrt",
575           "Square",
576           "Sign",
577           // Binary arithmetic ops
578           "Add",
579           "AddV2",
580           "Div",
581           "FloorDiv",
582           "FloorMod",
583           "Greater",
584           "GreaterEqual",
585           "Less",
586           "LessEqual",
587           "LogicalAnd",
588           "LogicalNot",
589           "LogicalOr",
590           "Maximum",
591           "Minimum",
592           "Mod",
593           "Mul",
594           "NotEqual",
595           "QuantizedAdd",
596           "QuantizedMul",
597           "SquareDifference",
598           "Sub",
599           "TruncateDiv",
600           "TruncateMod",
601           "RealDiv",
602           // N-ary arithmetic ops
603           "AddN",
604           // Others
605           "StridedSlice",
606           "OnesLike",
607           "ZerosLike",
608           "Concat",
609           "ConcatV2",
610           "Split",
611           "Range",
612           "Fill",
613           "Cast",
614       }));
615   return kOpTpeAllowlist->find(op_type) != kOpTpeAllowlist->end();
616 }
617 
618 // Negative shape size of '-1' represents unknown, while negative shape sizes
619 // less than -1 represent unknown symbolic shapes (e.g. the shape of [-5, 5, -1,
620 // -5] really means [x, 5, ?, x]). Before we can output the tensors as shapes,
621 // we need to normalize them: mark all values <-1 as "unknown" (-1).
NormalizeShapeForOutput(TensorShapeProto * shape)622 static void NormalizeShapeForOutput(TensorShapeProto* shape) {
623   for (int i = 0; i < shape->dim_size(); i++) {
624     if (shape->dim(i).size() < -1) {
625       VLOG(2) << "Normalizing dimension: " << i << " from "
626               << shape->dim(i).size() << " to -1";
627       shape->mutable_dim(i)->set_size(-1);
628     }
629   }
630 }
631 
632 // Processes symbolic shapes.
633 // Each symbolic shape or dimension is represented by a handle. Unlike the TF
634 // shape refiner which creates new handles every time it processes an unknown
635 // shape/dimension, the symbolic shape refiner assigns a specific handle to each
636 // unknown shape/dimension of a given node.
637 class SymbolicShapeRefiner {
638  public:
SymbolicShapeRefiner(const GraphView & graph,const absl::flat_hash_map<string,absl::flat_hash_set<int>> & fed_ports,const bool aggressive_shape_inference)639   explicit SymbolicShapeRefiner(
640       const GraphView& graph,
641       const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports,
642       const bool aggressive_shape_inference)
643       : graph_(graph),
644         function_library_(OpRegistry::Global(), graph.graph()->library()),
645         fed_ports_(fed_ports),
646         aggressive_shape_inference_(aggressive_shape_inference) {
647     graph_def_version_ = graph.graph()->versions().producer();
648     node_to_context_.reserve(graph.graph()->node_size());
649   }
650 
graph() const651   const GraphView& graph() const { return graph_; }
652 
653   struct NodeContext {
654     const OpRegistrationData* op_data;
655     DataTypeVector input_types;
656     DataTypeVector output_types;
657     std::unique_ptr<InferenceContext> inference_context;
658     // Additional info for propagating tensor values and tensor shapes.
659     std::vector<const TensorProto*> input_tensor_protos;
660     std::vector<const TensorProto*> output_tensor_protos;
661     // This is the same to inference_context->input_tensors_as_shapes, except
662     // that some UnknownDims (-1) can be kUnknownDimFromConst.
663     std::vector<ShapeHandle> input_tensors_as_shapes_to_propagate;
664     std::vector<ShapeHandle> output_tensors_as_shapes;
665 
666     // Output shapes incompatible between annotation and shape inference.
667     bool shape_incompatible = false;
668 
669     // Similar to DebugString() in InferenceContext, but prints out
670     // kUnknownDimFromConst properly.
StringifyShapeHandletensorflow::grappler::SymbolicShapeRefiner::NodeContext671     std::string StringifyShapeHandle(ShapeHandle s) {
672       auto* ic = inference_context.get();
673       if (ic->RankKnown(s)) {
674         std::vector<std::string> vals;
675         for (int i = 0; i < ic->Rank(s); i++) {
676           DimensionHandle d = ic->Dim(s, i);
677           if (ic->ValueKnown(d) && ic->Value(d) == kUnknownDimFromConst) {
678             vals.push_back("?(Const)");
679           } else {
680             vals.push_back(ic->DebugString(d));
681           }
682         }
683         return strings::StrCat("[", absl::StrJoin(vals, ","), "]");
684       } else {
685         return "?";
686       }
687     }
688 
DebugStringtensorflow::grappler::SymbolicShapeRefiner::NodeContext689     std::string DebugString(const NodeDef& node) {
690       std::string output;
691       auto* ic = inference_context.get();
692       absl::StrAppend(
693           &output, node.name(), " [", node.op(), "] has ", ic->num_inputs(),
694           (ic->num_inputs() > 1 ? " inputs and " : " input and "),
695           ic->num_outputs(), (ic->num_outputs() > 1 ? " outputs" : " output"));
696       if (op_data->is_function_op) {
697         absl::StrAppend(&output, " (function op)");
698       }
699       absl::StrAppend(&output, ": \n");
700 
701       for (int i = 0; i < ic->num_inputs(); i++) {
702         absl::StrAppend(&output, " input [", i, "] ", node.input(i),
703                         " -- type: ", DataTypeString(input_types.at(i)),
704                         ", shape: ", ic->DebugString(ic->input(i)),
705                         ", tensor: ");
706         Tensor t1;
707         int input_tensor_protos_size = input_tensor_protos.size();
708         if (input_tensor_protos_size > i &&
709             input_tensor_protos.at(i) != nullptr &&
710             t1.FromProto(*input_tensor_protos.at(i))) {
711           absl::StrAppend(&output, t1.DebugString(), ", tensor_as_shape: ");
712         } else {
713           absl::StrAppend(&output, " null, tensor_as_shape: ");
714         }
715         int input_tensors_as_shapes_to_propagate_size =
716             input_tensors_as_shapes_to_propagate.size();
717         if (input_tensors_as_shapes_to_propagate_size > i) {
718           absl::StrAppend(
719               &output,
720               StringifyShapeHandle(input_tensors_as_shapes_to_propagate.at(i)),
721               "\n");
722         } else {
723           absl::StrAppend(&output, " null\n");
724         }
725       }
726       for (int i = 0; i < ic->num_outputs(); i++) {
727         absl::StrAppend(&output, " output [", i,
728                         "] -- type: ", DataTypeString(output_types.at(i)),
729                         ", shape: ", ic->DebugString(ic->output(i)),
730                         ", tensor: ");
731         Tensor t2;
732         int output_tensor_protos_size = output_tensor_protos.size();
733         if (output_tensor_protos_size > i &&
734             output_tensor_protos.at(i) != nullptr &&
735             t2.FromProto(*output_tensor_protos.at(i))) {
736           absl::StrAppend(&output, t2.DebugString(), ", tensor_as_shape: ");
737         } else {
738           absl::StrAppend(&output, " null, tensor_as_shape: ");
739         }
740         int output_tensors_as_shapes_size = output_tensors_as_shapes.size();
741         if (output_tensors_as_shapes_size > i) {
742           absl::StrAppend(&output,
743                           StringifyShapeHandle(output_tensors_as_shapes.at(i)),
744                           "\n");
745         } else {
746           absl::StrAppend(&output, " null\n");
747         }
748       }
749       return output;
750     }
751   };
752 
GetNodeContext(const NodeDef * node)753   NodeContext* GetNodeContext(const NodeDef* node) {
754     auto it = node_to_context_.find(node);
755     if (it == node_to_context_.end()) {
756       return nullptr;
757     }
758     return &it->second;
759   }
760 
GetContext(const NodeDef * node)761   InferenceContext* GetContext(const NodeDef* node) {
762     auto it = node_to_context_.find(node);
763     if (it == node_to_context_.end()) {
764       return nullptr;
765     }
766     return it->second.inference_context.get();
767   }
768 
769   // Forward the shapes from the function input nodes, PartitionedCalls or
770   // StatefulPartitionedCall to
771   // the argument nodes (which are Placeholder nodes), then
772   // perform shape inference on the function body.
773   //
774   // Propagate shape information of final function body node
775   // to function node `function_node`.
776   //
777   // In the event of an error, UpdateNode will simply set `function_node`'s
778   // output shape to be Unknown.
UpdateFunction(const NodeDef * function_node)779   Status UpdateFunction(const NodeDef* function_node) {
780     NameAttrList function;
781     TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*function_node, &function));
782     auto it = fun_to_grappler_function_item_.find(function.name());
783     if (it == fun_to_grappler_function_item_.end()) {
784       return errors::InvalidArgument(
785           function.name(),
786           " was not previously added to SymbolicShapeRefiner.");
787     }
788 
789     const absl::optional<GrapplerFunctionItem>& maybe_grappler_function_item =
790         it->second;
791     if (!maybe_grappler_function_item.has_value()) {
792       VLOG(3) << "Skip failed to instantiate function call: function_name="
793               << function.name();
794 
795       auto* ctx = GetNodeContext(function_node);
796       auto* ic = ctx->inference_context.get();
797       for (int i = 0; i < ic->num_outputs(); ++i) {
798         TF_RETURN_IF_ERROR(SetUnknownShape(function_node, i));
799       }
800 
801       return Status::OK();
802     }
803 
804     // Copy (not reference) so that changes we make here (e.g., replacing
805     // _Arg with Const and _Retval with Identity) don't affect one in
806     // fun_to_grappler_function_item_.
807     GrapplerFunctionItem grappler_function_item = *maybe_grappler_function_item;
808     MutableGraphView gv(&grappler_function_item.graph);
809 
810     // Forward shapes from function input nodes to argument nodes.
811     for (int i = 0, end = grappler_function_item.inputs().size(); i < end;
812          ++i) {
813       auto& fun_input = grappler_function_item.input(i);
814       NodeDef* fun_node = gv.GetNode(fun_input.node_name);
815       const TensorId input_tensor = ParseTensorName(function_node->input(i));
816 
817       if (IsControlInput(input_tensor)) {
818         return errors::FailedPrecondition(
819             "Function inputs should not contain control nodes.");
820       }
821 
822       const NodeDef* input_node = graph_.GetNode(input_tensor.node());
823       if (input_node == nullptr) {
824         return errors::FailedPrecondition(input_tensor.node(),
825                                           " was not found in the graph.");
826       }
827 
828       InferenceContext* input_ic = GetContext(input_node);
829       if (input_ic == nullptr) {
830         return errors::FailedPrecondition(
831             "Inference context has not been created for ", input_tensor.node());
832       }
833 
834       int output_port_num = input_tensor.index();
835       AttrValue attr_output_shape;
836       TensorShapeProto proto;
837       const auto handle = input_ic->output(output_port_num);
838       input_ic->ShapeHandleToProto(handle, &proto);
839       // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
840       NormalizeShapeForOutput(&proto);
841       // _Arg op's output shape uses _output_shapes attr.
842       AttrValue output_attr;
843       output_attr.mutable_list()->add_shape()->Swap(&proto);
844       (*fun_node->mutable_attr())["_output_shapes"] = output_attr;
845 
846       // If dtype is DT_RESOURCE, ops that read _Arg op use _handle_dtypes and
847       // _handle_shapes attr for its shapes and dtypes.
848       if (fun_input.data_type == DT_RESOURCE) {
849         auto* shapes_and_types =
850             input_ic->output_handle_shapes_and_types(output_port_num);
851         if (shapes_and_types != nullptr && !shapes_and_types->empty()) {
852           AttrValue dtype_attr;
853           AttrValue shape_attr;
854           for (const auto& shape_and_type : *shapes_and_types) {
855             const auto& dtype = shape_and_type.dtype;
856             const auto& shape_handle = shape_and_type.shape;
857             dtype_attr.mutable_list()->add_type(dtype);
858             input_ic->ShapeHandleToProto(
859                 shape_handle, shape_attr.mutable_list()->add_shape());
860           }
861           (*fun_node->mutable_attr())["_handle_dtypes"] = dtype_attr;
862           (*fun_node->mutable_attr())["_handle_shapes"] = shape_attr;
863         } else {
864           // Note that we do not return error here, even if the input node does
865           // not have shapes_and_types. Within the function, we cannot infer the
866           // output shape of the DT_RESOURCE input; hence, potentially unknown
867           // shapes/dims in the function output shapes.
868           VLOG(2)
869               << "A function node (" << function_node->name()
870               << ") has input with DT_RESOURCE, but the input node does not "
871               << "have shapes_and_types information: \n"
872               << "function_node: " << function_node->ShortDebugString() << "\n"
873               << "function input: " << i
874               << ", input node's output: " << output_port_num << "\n"
875               << "input node: " << input_node->ShortDebugString();
876         }
877       }
878     }
879 
880     // ReplaceInputWithConst() may break GraphView's internal node mapping
881     // structure; hence, we separately build node name to NodeDef* map, for the
882     // output nodes (before GraphView becomes invalid). Note that we use string,
883     // not string_view.
884     absl::flat_hash_map<std::string, NodeDef*> output_nodes;
885     for (const auto& output_arg : grappler_function_item.outputs()) {
886       output_nodes[output_arg.node_name] = gv.GetNode(output_arg.node_name);
887     }
888 
889     // Replace input nodes with Consts, if values are known. Note that
890     // we don't check exceptions here as it's done in the above loop.
891     auto* ctx = GetNodeContext(function_node);
892     auto* ic = ctx->inference_context.get();
893     for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
894       const string& input = function_node->input(i);
895       const string node_name = NodeName(input);
896       const NodeDef* input_node = graph_.GetNode(node_name);
897       if (IsConstant(*input_node)) {
898         TF_CHECK_OK(
899             ReplaceInputWithConst(*input_node, i, &grappler_function_item));
900       } else if (static_cast<int>(ctx->input_tensor_protos.size()) > i &&
901                  ctx->input_tensor_protos[i] != nullptr) {
902         NodeDef const_input_node = MakeConstNodeDefFromTensorProto(
903             ic, *ctx->input_tensor_protos[i], ctx->input_types[i]);
904         TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
905                                           &grappler_function_item));
906       } else if (static_cast<int>(ic->input_tensors_as_shapes().size()) > i &&
907                  IsShapeFullyDefinedIntegerVectorOrScalar(
908                      ic, ic->input(i), ic->input_tensors_as_shapes()[i],
909                      ctx->input_types[i])) {
910         // We have fully defined input_tensors_as_shapes for this input; use it
911         // as a const input to the function node.
912         NodeDef const_input_node = MakeConstNodeDefFromShape(
913             ic, ic->input(i), ic->input_tensors_as_shapes()[i],
914             ctx->input_types[i]);
915         TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
916                                           &grappler_function_item));
917       }
918     }
919     // node_name to NodeDef* map in GraphView gv can be broken due to
920     // ReplaceInputWithConst(). gv should not be used after this.
921 
922     // Replace output _Retval nodes with Identity nodes. _Retval is a system op
923     // without outputs and registered shape function.
924     for (const auto& output_arg : grappler_function_item.outputs()) {
925       NodeDef* output_node = output_nodes[output_arg.node_name];
926       DCHECK_EQ(output_node->op(), "_Retval");
927       output_node->set_op("Identity");
928       output_node->mutable_attr()->erase("index");
929     }
930 
931     // Perform inference on function body.
932     GraphProperties gp(grappler_function_item);
933     TF_RETURN_IF_ERROR(gp.InferStatically(
934         /*assume_valid_feeds=*/true,
935         /*aggressive_shape_inference=*/aggressive_shape_inference_,
936         /*include_tensor_values=*/true));
937 
938     // Add return nodes for output shapes.
939     int output = 0;
940     ctx->output_tensors_as_shapes.resize(grappler_function_item.output_size());
941     ctx->output_tensor_protos.resize(grappler_function_item.output_size(),
942                                      nullptr);
943     for (auto const& out_arg : grappler_function_item.outputs()) {
944       // It is guaranteed that output_tensors does not contain any control
945       // inputs, so port_id >= 0.
946       TensorId out_tensor = ParseTensorName(out_arg.node_name);
947 
948       if (output_nodes.count(out_tensor.node()) <= 0) {
949         return errors::FailedPrecondition(
950             "Unable to find return function_node ", out_tensor.node(), " for ",
951             function_node->name());
952       }
953       const NodeDef* retnode = output_nodes[out_tensor.node()];
954 
955       auto output_properties = gp.GetOutputProperties(retnode->name());
956       int output_properties_size = output_properties.size();
957       if (out_tensor.index() >= output_properties_size) {
958         return errors::InvalidArgument(
959             out_tensor.ToString(), " has invalid position ", out_tensor.index(),
960             " (output_properties.size() = ", output_properties.size(), ").");
961       }
962       auto& outprop = output_properties[out_tensor.index()];
963       TensorShapeProto shape = outprop.shape();
964       NormalizeShapeForOutput(&shape);
965       ShapeHandle out;
966       TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
967       ic->set_output(output, out);
968       if (outprop.has_value()) {
969         // Forward tensor value to output_tensors_as_shape.
970         MaybeTensorProtoToShape(ic, outprop.value(),
971                                 &ctx->output_tensors_as_shapes[output]);
972         const_tensors_to_propagate_.push_back(outprop.value());
973         ctx->output_tensor_protos[output] = &const_tensors_to_propagate_.back();
974       }
975       output++;
976     }
977 
978     return Status::OK();
979   }
980 
981   // Prepares input shapes/values/handles, then runs shape inference, and
982   // finally sets output shapes/values/handles.
UpdateNode(const NodeDef * node,bool * refined)983   Status UpdateNode(const NodeDef* node, bool* refined) {
984     NodeContext* ctx = GetNodeContext(node);
985     if (ctx == nullptr) {
986       TF_RETURN_IF_ERROR(AddNode(node));
987       ctx = CHECK_NOTNULL(GetNodeContext(node));
988       *refined = true;
989     }
990 
991     // Check if the shapes of the nodes in the fan-in of this node have changed,
992     // and if they have, update the node input shapes.
993     InferenceContext* ic = ctx->inference_context.get();
994     ctx->input_tensors_as_shapes_to_propagate.resize(ic->num_inputs());
995     ctx->input_tensor_protos.resize(ic->num_inputs(), nullptr);
996 
997     for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) {
998       const GraphView::InputPort port(node, dst_input);
999       const GraphView::OutputPort fanin = graph_.GetRegularFanin(port);
1000       int src_output = fanin.port_id;
1001       const NodeDef* src = fanin.node;
1002       NodeContext* src_ctx = GetNodeContext(src);
1003       if (src_ctx == nullptr) {
1004         return errors::FailedPrecondition(
1005             "Input ", dst_input, " for '", node->name(),
1006             "' was not previously added to SymbolicShapeRefiner.");
1007       }
1008 
1009       InferenceContext* src_ic = src_ctx->inference_context.get();
1010       if (src_output >= src_ic->num_outputs()) {
1011         return errors::OutOfRange("src_output = ", src_output,
1012                                   ", but num_outputs is only ",
1013                                   src_ic->num_outputs());
1014       }
1015 
1016       // Propagate input node's NodeContext info to the current node's
1017       // NodeContext:
1018       // output_tensor_protos to input_tensor_protos and input_tensors, and
1019       // output_tensors_as_shapes to input_tensors_as_shapes.
1020       if (static_cast<int>(src_ctx->output_tensors_as_shapes.size()) >
1021           src_output) {
1022         ctx->input_tensors_as_shapes_to_propagate[dst_input] =
1023             src_ctx->output_tensors_as_shapes[src_output];
1024       }
1025 
1026       if (static_cast<int>(src_ctx->output_tensor_protos.size()) > src_output) {
1027         const auto* tensor_proto = src_ctx->output_tensor_protos[src_output];
1028         if (tensor_proto != nullptr) {
1029           ctx->input_tensor_protos[dst_input] = tensor_proto;
1030         }
1031       }
1032 
1033       // NOTE: we check only shape is refined; we do not (yet) check whether
1034       // tensor value is refined.
1035       if (!*refined &&
1036           !ic->input(dst_input).SameHandle(src_ic->output(src_output))) {
1037         *refined = true;
1038       }
1039       ic->SetInput(dst_input, src_ic->output(src_output));
1040 
1041       if (!*refined && ic->requested_input_tensor_as_partial_shape(dst_input)) {
1042         // The input value may have changed. Since we have no way to know if
1043         // that's indeed the case, err on the safe side.
1044         *refined = true;
1045       }
1046 
1047       // Also propagate handle shape and dtype of edges which are carrying
1048       // resource handles.
1049       if (ctx->input_types[dst_input] == DT_RESOURCE) {
1050         auto* outputs = src_ic->output_handle_shapes_and_types(src_output);
1051         if (!outputs) continue;
1052         auto* inputs = ic->input_handle_shapes_and_types(dst_input);
1053 
1054         if (!inputs || !EquivalentShapesAndTypes(*outputs, *inputs))
1055           *refined = true;
1056         ic->set_input_handle_shapes_and_types(dst_input, *outputs);
1057       }
1058     }
1059 
1060     // Make sure we schedule the fanout of resources (which have no input)
1061     // whenever the resources are updated.
1062     *refined |= ic->num_inputs() == 0;
1063 
1064     if (!*refined) {
1065       // No input shape has changed, we're done.
1066       return Status::OK();
1067     }
1068 
1069     // Convert all kUnknownDimFromConst to -1 for shape inference.
1070     ic->set_input_tensors_as_shapes(ReplaceUnknownDimFromConstWithUnknownDim(
1071         ic, ctx->input_tensors_as_shapes_to_propagate));
1072     // Note: UpdateFunction uses input_tensors_as_shapes and
1073     // input_tensor_protos (not the Tensor object) for input values.
1074     // so for function nodes, we don't need to convert TensorProtos
1075     // to Tensors here. If the current op is not a function op, we convert
1076     // TensorProtos to Tensors before calling InferShapes.
1077 
1078     // Properly handle function nodes.
1079     if (ctx->op_data && ctx->op_data->is_function_op) {
1080       // TODO(jmdecker): Detect if the input shapes have changed for this
1081       // function. Note that when we hit a function call node, refined will be
1082       // true, as the updates to the call node will have changed, even if it's
1083       // the same function being called twice with the same input shapes.
1084       // Example: simple_function.pbtxt
1085       if (aggressive_shape_inference_) {
1086         // If output shapes are annotated, use it and skip UpdateFunction();
1087         // it can be very expensive when a function node has nested function
1088         // nodes internally. One downside with this approach is that we do not
1089         // get output values or output shapes as tensor from function node.
1090         auto s = UpdateOutputShapesUsingAnnotatedInformation(*node, ctx);
1091         if (s.ok() && AllOutputShapesKnown(ctx)) {
1092           return Status::OK();
1093         }
1094         // If shape annotation was not available, incomplete, or incompatible,
1095         // fall through to call UpdateFunction().
1096       }
1097       auto s = UpdateFunction(node);
1098       if (s.ok()) {
1099         return Status::OK();
1100       } else {
1101         VLOG(1) << "UpdateFunction failed for " << node->op()
1102                 << ". Defaulting to ShapeUnknown.\n"
1103                 << s.ToString();
1104       }
1105     }
1106 
1107     //  Construct Tensors for constant inputs used by shape functions.
1108     std::vector<Tensor> const_values(ic->num_inputs());
1109     std::vector<const Tensor*> input_tensors(ic->num_inputs(), nullptr);
1110     for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) {
1111       const TensorProto* tensor_proto = ctx->input_tensor_protos[dst_input];
1112       if (tensor_proto != nullptr &&
1113           // Skip if the const tensor is too large.
1114           NumElementsFromTensorProto(*tensor_proto) <=
1115               kThresholdToSkipConstTensorInstantiation &&
1116           const_values[dst_input].FromProto(*tensor_proto)) {
1117         input_tensors[dst_input] = &const_values[dst_input];
1118       }
1119     }
1120     ic->set_input_tensors(input_tensors);
1121 
1122     // Update the shapes of the outputs.
1123     return InferShapes(*node, ctx);
1124   }
1125 
SetUnknownShape(const NodeDef * node,int output_port)1126   Status SetUnknownShape(const NodeDef* node, int output_port) {
1127     shape_inference::ShapeHandle shape =
1128         GetUnknownOutputShape(node, output_port);
1129     InferenceContext* ctx = GetContext(node);
1130     if (ctx == nullptr) {
1131       return errors::InvalidArgument("Missing context");
1132     }
1133     ctx->set_output(output_port, shape);
1134     return Status::OK();
1135   }
1136 
1137   struct ShapeId {
1138     const NodeDef* node;
1139     int port_id;
operator ==tensorflow::grappler::SymbolicShapeRefiner::ShapeId1140     bool operator==(const ShapeId& other) const {
1141       return node == other.node && port_id == other.port_id;
1142     }
1143   };
1144   struct HashShapeId {
operator ()tensorflow::grappler::SymbolicShapeRefiner::HashShapeId1145     std::size_t operator()(const ShapeId& shp) const {
1146       return std::hash<const NodeDef*>{}(shp.node) + shp.port_id;
1147     }
1148   };
1149 
1150   struct DimId {
1151     const NodeDef* node;
1152     int port_id;
1153     int dim_index;
operator ==tensorflow::grappler::SymbolicShapeRefiner::DimId1154     bool operator==(const DimId& other) const {
1155       return node == other.node && port_id == other.port_id &&
1156              dim_index == other.dim_index;
1157     }
1158   };
1159 
1160   struct HashDimId {
operator ()tensorflow::grappler::SymbolicShapeRefiner::HashDimId1161     std::size_t operator()(const DimId& dim) const {
1162       return std::hash<const NodeDef*>{}(dim.node) + dim.port_id +
1163              dim.dim_index;
1164     }
1165   };
1166 
1167   // 'port_index' as the union of shape1 and shape2.
OutputAsUnion(const NodeDef * node,int port_index,ShapeHandle shape1,ShapeHandle shape2)1168   ShapeHandle OutputAsUnion(const NodeDef* node, int port_index,
1169                             ShapeHandle shape1, ShapeHandle shape2) {
1170     if (shape1.SameHandle(shape2)) {
1171       return shape1;
1172     }
1173     InferenceContext* ctx = GetContext(node);
1174     ShapeHandle relaxed = shape1;
1175     const int rank = ctx->Rank(shape1);
1176     if (!ctx->RankKnown(shape2) || ctx->Rank(shape2) != rank) {
1177       relaxed = GetUnknownOutputShape(node, port_index);
1178     } else {
1179       for (int d = 0; d < rank; ++d) {
1180         if (!ctx->Dim(shape1, d).SameHandle(ctx->Dim(shape2, d))) {
1181           int64_t val1 = ctx->Value(ctx->Dim(shape1, d));
1182           int64_t val2 = ctx->Value(ctx->Dim(shape2, d));
1183           if (val1 != val2 || (val1 < 0 && val2 < 0)) {
1184             DimensionHandle new_dim = GetUnknownOutputDim(node, port_index, d);
1185             TF_CHECK_OK(ctx->ReplaceDim(relaxed, d, new_dim, &relaxed));
1186           }
1187         }
1188       }
1189     }
1190     return relaxed;
1191   }
1192 
EquivalentShapes(ShapeHandle s1,ShapeHandle s2) const1193   bool EquivalentShapes(ShapeHandle s1, ShapeHandle s2) const {
1194     if (s1.SameHandle(s2)) {
1195       return true;
1196     }
1197     if (InferenceContext::Rank(s1) != InferenceContext::Rank(s2)) {
1198       return false;
1199     }
1200     if (!InferenceContext::RankKnown(s1) && !InferenceContext::RankKnown(s2)) {
1201       return true;
1202     }
1203     const int rank = InferenceContext::Rank(s1);
1204     for (int i = 0; i < rank; ++i) {
1205       if (!InferenceContext::DimKnownRank(s1, i).SameHandle(
1206               InferenceContext::DimKnownRank(s2, i))) {
1207         int64_t val1 =
1208             InferenceContext::Value(InferenceContext::DimKnownRank(s1, i));
1209         int64_t val2 =
1210             InferenceContext::Value(InferenceContext::DimKnownRank(s2, i));
1211         if (val1 >= 0 && val2 >= 0 && val1 == val2) {
1212           continue;
1213         }
1214         return false;
1215       }
1216     }
1217     return true;
1218   }
1219 
1220   // Return true if the annotated shape is compatible with shape inference
1221   // result. Examples:
1222   // Inferred shape: ?, annotated shape: [10, 10] -> true;
1223   // Inferred shape: [-1, 10], annotated shape: [10, 10] -> true;
1224   // Inferred shape: [-1, 100], annotated shape: [10, 10] -> false;
1225   // Inferred shape: [-1, 10, 10], annotated shape: [10, 10] -> false.
CompatibleShapes(ShapeHandle inferred_shape,ShapeHandle annotated_shape) const1226   bool CompatibleShapes(ShapeHandle inferred_shape,
1227                         ShapeHandle annotated_shape) const {
1228     if (inferred_shape.SameHandle(annotated_shape)) {
1229       return true;
1230     }
1231     if (!InferenceContext::RankKnown(inferred_shape)) {
1232       return true;
1233     }
1234     if (InferenceContext::Rank(inferred_shape) !=
1235         InferenceContext::Rank(annotated_shape)) {
1236       return false;
1237     }
1238     const int rank = InferenceContext::Rank(inferred_shape);
1239     for (int i = 0; i < rank; ++i) {
1240       if (!InferenceContext::DimKnownRank(inferred_shape, i)
1241                .SameHandle(
1242                    InferenceContext::DimKnownRank(annotated_shape, i))) {
1243         int64_t val1 = InferenceContext::Value(
1244             InferenceContext::DimKnownRank(inferred_shape, i));
1245         int64_t val2 = InferenceContext::Value(
1246             InferenceContext::DimKnownRank(annotated_shape, i));
1247         if (val1 >= 0 && val1 != val2) {
1248           return false;
1249         }
1250       }
1251     }
1252     return true;
1253   }
1254 
SameShapes(ShapeHandle inferred_shape,ShapeHandle annotated_shape) const1255   bool SameShapes(ShapeHandle inferred_shape,
1256                   ShapeHandle annotated_shape) const {
1257     if (inferred_shape.SameHandle(annotated_shape)) {
1258       return true;
1259     }
1260     if (InferenceContext::Rank(inferred_shape) !=
1261         InferenceContext::Rank(annotated_shape)) {
1262       return false;
1263     }
1264     const int rank = InferenceContext::Rank(inferred_shape);
1265     for (int i = 0; i < rank; ++i) {
1266       int64_t val1 = InferenceContext::Value(
1267           InferenceContext::DimKnownRank(inferred_shape, i));
1268       int64_t val2 = InferenceContext::Value(
1269           InferenceContext::DimKnownRank(annotated_shape, i));
1270       if (val1 != val2) {
1271         return false;
1272       }
1273     }
1274     return true;
1275   }
1276 
EquivalentShapesAndTypes(const std::vector<ShapeAndType> & st1,const std::vector<ShapeAndType> & st2) const1277   bool EquivalentShapesAndTypes(const std::vector<ShapeAndType>& st1,
1278                                 const std::vector<ShapeAndType>& st2) const {
1279     if (st1.size() != st2.size()) {
1280       return false;
1281     }
1282     for (int i = 0, st1_size = st1.size(); i < st1_size; ++i) {
1283       const ShapeAndType& s1 = st1[i];
1284       const ShapeAndType& s2 = st2[i];
1285       if (s1.dtype != s2.dtype) {
1286         return false;
1287       }
1288       if (!EquivalentShapes(s1.shape, s2.shape)) {
1289         return false;
1290       }
1291     }
1292     return true;
1293   }
1294 
AddFunction(const NodeDef * function_node,NameAttrList function)1295   Status AddFunction(const NodeDef* function_node, NameAttrList function) {
1296     auto it = fun_to_grappler_function_item_.find(function.name());
1297     if (it != fun_to_grappler_function_item_.end()) {
1298       return Status::OK();
1299     }
1300 
1301     const FunctionDef* function_def =
1302         CHECK_NOTNULL(function_library_.Find(function.name()));
1303     GrapplerFunctionItem grappler_function_item;
1304     Status function_instantiated =
1305         MakeGrapplerFunctionItem(*function_def, function_library_,
1306                                  graph_def_version_, &grappler_function_item);
1307 
1308     // If function instantiation failed we will skip it during shape inference.
1309     if (!function_instantiated.ok()) {
1310       VLOG(3) << "Failed to instantiate a function. Error: "
1311               << function_instantiated.error_message();
1312       fun_to_grappler_function_item_[function_def->signature().name()] =
1313           absl::nullopt;
1314       return Status::OK();
1315     }
1316 
1317     if (static_cast<int>(grappler_function_item.inputs().size()) >
1318         function_node->input_size()) {
1319       return errors::FailedPrecondition(
1320           "Function input size should be smaller than node input size.");
1321     }
1322 
1323     for (int i = grappler_function_item.inputs().size(),
1324              end = function_node->input_size();
1325          i < end; ++i) {
1326       const string& input = function_node->input(i);
1327       if (!IsControlInput(input)) {
1328         return errors::FailedPrecondition(
1329             "Found regular input (", input,
1330             ") instead of control nodes for node ", function_node->name());
1331       }
1332     }
1333 
1334     fun_to_grappler_function_item_[function_def->signature().name()] =
1335         grappler_function_item;
1336 
1337     return Status::OK();
1338   }
1339 
AddNode(const NodeDef * node)1340   Status AddNode(const NodeDef* node) {
1341     NodeContext& node_ctx = node_to_context_[node];
1342     NameAttrList function;
1343     TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*node, &function));
1344 
1345     // For PartitionedCall, op_data represents the function info.
1346     TF_RETURN_IF_ERROR(
1347         function_library_.LookUp(function.name(), &node_ctx.op_data));
1348 
1349     if (node_ctx.op_data->is_function_op) {
1350       TF_RETURN_IF_ERROR(AddFunction(node, function));
1351     }
1352 
1353     TF_RETURN_IF_ERROR(InOutTypesForNode(*node, node_ctx.op_data->op_def,
1354                                          &node_ctx.input_types,
1355                                          &node_ctx.output_types));
1356 
1357     // Create the inference context for this node.
1358     const int num_inputs = node_ctx.input_types.size();
1359     std::vector<ShapeHandle> input_shapes(num_inputs);
1360     std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
1361         input_handle_shapes_and_types(num_inputs);
1362     std::vector<const Tensor*> input_tensors(num_inputs, nullptr);
1363     std::vector<ShapeHandle> input_tensors_as_shapes;
1364 
1365     node_ctx.inference_context.reset(new InferenceContext(
1366         graph_def_version_, *node, node_ctx.op_data->op_def, input_shapes,
1367         input_tensors, input_tensors_as_shapes,
1368         std::move(input_handle_shapes_and_types)));
1369     const Status s = node_ctx.inference_context->construction_status();
1370     if (!s.ok()) {
1371       node_ctx.inference_context.reset(nullptr);
1372     }
1373     return s;
1374   }
1375 
1376  private:
1377   // Return the one ShapeHandle used to denote a fully unknown shape for a node
1378   // output.
GetUnknownOutputShape(const NodeDef * node,int index)1379   ShapeHandle GetUnknownOutputShape(const NodeDef* node, int index) {
1380     ShapeId id{node, index};
1381     auto it = unknown_shapes_.find(id);
1382     if (it != unknown_shapes_.end()) {
1383       return it->second;
1384     }
1385     InferenceContext* c = GetContext(node);
1386     ShapeHandle shp = c->UnknownShape();
1387     unknown_shapes_[id] = shp;
1388     return shp;
1389   }
1390   // Return the one ShapeHandle used to denote a fully unknown dimension for a
1391   // node output.
GetUnknownOutputDim(const NodeDef * node,int index,int dim_id)1392   DimensionHandle GetUnknownOutputDim(const NodeDef* node, int index,
1393                                       int dim_id) {
1394     DimId id{node, index, dim_id};
1395     auto it = unknown_dims_.find(id);
1396     if (it != unknown_dims_.end()) {
1397       return it->second;
1398     }
1399     InferenceContext* c = GetContext(node);
1400     DimensionHandle dim = c->UnknownDim();
1401     unknown_dims_[id] = dim;
1402     return dim;
1403   }
1404 
1405   // Returns true if all the output tensors have known values.
AllOutputValuesKnown(NodeContext * c)1406   bool AllOutputValuesKnown(NodeContext* c) {
1407     InferenceContext* ic = c->inference_context.get();
1408     int c_output_tensors_as_shapes_size = c->output_tensors_as_shapes.size();
1409     int c_output_tensor_protos_size = c->output_tensor_protos.size();
1410     if (c_output_tensors_as_shapes_size < ic->num_outputs() &&
1411         c_output_tensor_protos_size < ic->num_outputs()) {
1412       return false;
1413     } else {
1414       // Checks if we can get output value via either output_tensor_proto or
1415       // output_tensors_as_shapes.
1416       for (int i = 0; i < ic->num_outputs(); i++) {
1417         if (c_output_tensor_protos_size > i &&
1418             c->output_tensor_protos[i] != nullptr) {
1419           continue;
1420         }
1421         if (c_output_tensors_as_shapes_size > i &&
1422             ic->FullyDefined(c->output_tensors_as_shapes[i])) {
1423           bool no_unknown_dim_from_const = true;
1424           for (int32_t j = 0; j < ic->Rank(c->output_tensors_as_shapes[i]);
1425                ++j) {
1426             const auto dim = ic->Dim(c->output_tensors_as_shapes[i], j);
1427             if (ic->ValueKnown(dim) && ic->Value(dim) == kUnknownDimFromConst) {
1428               no_unknown_dim_from_const = false;
1429               break;
1430             }
1431           }
1432           if (no_unknown_dim_from_const) {
1433             continue;
1434           }
1435         }
1436         return false;
1437       }
1438     }
1439     return true;
1440   }
1441 
1442   // Returns true if all the output shapes are known.
AllOutputShapesKnown(NodeContext * c)1443   bool AllOutputShapesKnown(NodeContext* c) {
1444     InferenceContext* ic = c->inference_context.get();
1445     // Checks if all the output shapes are fully defined.
1446     for (int i = 0; i < ic->num_outputs(); i++) {
1447       if (!ic->FullyDefined(ic->output(i))) {
1448         return false;
1449       }
1450     }
1451     return true;
1452   }
1453 
1454   // Returns true if we can infer output tensors' values -- we know values of
1455   // all the input tensors.
AllInputValuesKnown(NodeContext * c)1456   bool AllInputValuesKnown(NodeContext* c) {
1457     InferenceContext* ic = c->inference_context.get();
1458 
1459     // Check inputs are fully defined and values are known.
1460     for (int i = 0; i < ic->num_inputs(); i++) {
1461       const Tensor* tensor = ic->input_tensor(i);
1462       // Note that we don't check c->input_tensor_protos[i], as UpdateNode()
1463       // already converted it to ic->input_tensor(i);
1464       const ShapeHandle& input_tensors_as_shape =
1465           ic->input_tensors_as_shapes()[i];
1466       // Either input_tensor is valid or input_tensors_as_shape, which has
1467       // value of input tensors as shape format, should be fully defined.
1468       if (tensor == nullptr && !ic->FullyDefined(input_tensors_as_shape)) {
1469         return false;
1470       }
1471     }
1472     return true;
1473   }
1474 
1475   // Returns true if we want to update output shapes and values with running
1476   // EvaluateNode() for this op, based on op type, data type, and size.
ShouldUpdateOutputShapesAndValues(NodeContext * c,int64_t max_size)1477   bool ShouldUpdateOutputShapesAndValues(NodeContext* c, int64_t max_size) {
1478     InferenceContext* ic = c->inference_context.get();
1479 
1480     // Due to the cost of running EvaluateNode(), we limit only to white listed
1481     // op types.
1482     if (!IsAllowListedOpTypeForEvaluateNode(c->op_data->op_def.name())) {
1483       return false;
1484     }
1485 
1486     // Check input dtypes are number types.
1487     for (const auto& input_type : c->input_types) {
1488       if (!IsNumericType(input_type)) {
1489         return false;
1490       }
1491     }
1492 
1493     // Check output dtypes are number types.
1494     for (const auto& output_type : c->output_types) {
1495       if (!IsNumericType(output_type)) {
1496         return false;
1497       }
1498     }
1499 
1500     // Check if the number of elements of each of input tensor is no larger than
1501     // the given max size.
1502     for (int i = 0; i < ic->num_inputs(); i++) {
1503       const Tensor* tensor = ic->input_tensor(i);
1504       const ShapeHandle& input_shape_handle = ic->input(i);
1505       if (tensor != nullptr) {
1506         if (tensor->NumElements() > max_size) {
1507           return false;
1508         }
1509       } else if (ic->Value(ic->NumElements(input_shape_handle)) > max_size) {
1510         return false;
1511       }
1512     }
1513 
1514     // Check if we know the shape of each output tensor, and the number of
1515     // elements is larger than the given max size.
1516     for (int i = 0; i < ic->num_outputs(); i++) {
1517       const ShapeHandle& shape_handle = ic->output(i);
1518       if (!ic->FullyDefined(shape_handle) ||
1519           ic->Value(ic->NumElements(shape_handle)) > max_size) {
1520         return false;
1521       }
1522     }
1523     return true;
1524   }
1525 
1526   // Create input tensors from the NodeContext.
CreateInputTensors(NodeContext * c,std::vector<Tensor> * input_tensor_vector,TensorVector * inputs)1527   void CreateInputTensors(NodeContext* c,
1528                           std::vector<Tensor>* input_tensor_vector,
1529                           TensorVector* inputs) {
1530     InferenceContext* ic = c->inference_context.get();
1531     for (int i = 0; i < ic->num_inputs(); i++) {
1532       if (ic->input_tensor(i)) {
1533         input_tensor_vector->at(i) = *ic->input_tensor(i);
1534         inputs->emplace_back(&input_tensor_vector->at(i));
1535         // Note that we don't check c->input_tensor_protos[i], as UpdateNode()
1536         // already converted it to ic->input_tensor(i);
1537       } else {
1538         // Create Tensor from input_tensors_as_shapes, and then emplace it
1539         // back to inputs.
1540         // Note that input_tensors_as_shapes is scalar or vector.
1541         const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
1542         const DataType& data_type = c->input_types[i];
1543         int32_t rank = ic->Rank(shape_handle);
1544         if (rank < 1) {
1545           input_tensor_vector->at(i) = Tensor(data_type, {});
1546         } else {
1547           input_tensor_vector->at(i) = Tensor(data_type, {rank});
1548         }
1549         auto* tensor = &input_tensor_vector->at(i);
1550         if (data_type == DT_INT32) {
1551           auto flat = tensor->flat<int32>();
1552           for (int j = 0; j < rank; j++) {
1553             int32_t dim = ic->Value(ic->Dim(shape_handle, j));
1554             flat(j) = dim;
1555           }
1556         } else {
1557           auto flat = tensor->flat<int64>();
1558           for (int j = 0; j < rank; j++) {
1559             int64_t dim = ic->Value(ic->Dim(shape_handle, j));
1560             flat(j) = dim;
1561           }
1562         }
1563         inputs->emplace_back(tensor);
1564       }
1565     }
1566   }
1567 
1568   // Run a node to infer output shapes and values, and add it to the
1569   // NodeContext.
UpdateOutputShapesAndValues(const NodeDef & node,NodeContext * c)1570   Status UpdateOutputShapesAndValues(const NodeDef& node, NodeContext* c) {
1571     InferenceContext* ic = c->inference_context.get();
1572 
1573     // Input to EvaluateNode()
1574     TensorVector inputs;
1575     // Container for temporarily created tensor object.
1576     std::vector<Tensor> input_tensor_vector(ic->num_inputs());
1577     CreateInputTensors(c, &input_tensor_vector, &inputs);
1578 
1579     // Output for EvaluateNode() and output tensor clean up object.
1580     TensorVector outputs;
1581     auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
1582       for (const auto& output : outputs) {
1583         if (output.tensor) {
1584           delete output.tensor;
1585         }
1586       }
1587     });
1588 
1589     TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, /*cpu_device=*/nullptr,
1590                                     &resource_mgr_, &outputs));
1591     c->output_tensors_as_shapes.resize(outputs.size());
1592     c->output_tensor_protos.resize(outputs.size(), nullptr);
1593     for (int k = 0, outputs_size = outputs.size(); k < outputs_size; k++) {
1594       const auto& t = outputs[k];
1595       // Override output shape.
1596       ShapeHandle output_shape;
1597       TF_RETURN_IF_ERROR(
1598           ic->MakeShapeFromTensorShape(t->shape(), &output_shape));
1599       if (ic->FullyDefined(ic->output(k)) &&
1600           !EquivalentShapes(ic->output(k), output_shape)) {
1601         LOG(WARNING) << "UpdateOutputShapesAndValues() -- node: " << node.name()
1602                      << ", inferred output shape "
1603                      << "doesn't match for k=" << k << ": "
1604                      << "ic->output(k): " << ic->DebugString(ic->output(k))
1605                      << ", output_shape: " << ic->DebugString(output_shape)
1606                      << " -- " << node.DebugString();
1607       }
1608       ic->set_output(k, output_shape);
1609       // Set output_tensors_as_shape.
1610       MaybeTensorValueToShape(ic, *t.tensor, &c->output_tensors_as_shapes[k]);
1611 
1612       // Set output_tensor_protos.
1613       TensorProto tensor_proto;
1614       t->AsProtoTensorContent(&tensor_proto);
1615       const_tensors_to_propagate_.push_back(tensor_proto);
1616       c->output_tensor_protos[k] = &const_tensors_to_propagate_.back();
1617     }
1618     return Status::OK();
1619   }
1620 
1621   // Update output shapes with annotated information.
1622   // Currently only handle nodes with static shapes, i.e. shapes do not change
1623   // during execution.
1624   // TODO(andiryxu): Use annotated shapes in Enter/Merge etc as well.
UpdateOutputShapesUsingAnnotatedInformation(const NodeDef & node,NodeContext * c) const1625   Status UpdateOutputShapesUsingAnnotatedInformation(const NodeDef& node,
1626                                                      NodeContext* c) const {
1627     const auto& attr = node.attr();
1628     if (attr.count(kOutputSame) == 0 || !attr.at(kOutputSame).b() ||
1629         attr.count(kOutputShapes) == 0)
1630       return Status::OK();
1631 
1632     InferenceContext* ic = c->inference_context.get();
1633     int output_size = attr.at(kOutputShapes).list().shape_size();
1634 
1635     for (int i = 0; i < ic->num_outputs(); i++) {
1636       // Annotated Switch node has only one output. Propagate the shape to all
1637       // the outputs.
1638       int shape_index = IsSwitch(node) ? 0 : i;
1639       if (shape_index >= output_size) {
1640         LOG(WARNING)
1641             << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1642             << node.name() << ", inferred output shape size "
1643             << ic->num_outputs() << ", annotated output shape size "
1644             << output_size;
1645         break;
1646       }
1647 
1648       const TensorShapeProto& shape =
1649           attr.at(kOutputShapes).list().shape(shape_index);
1650       if (shape.dim().empty()) continue;
1651 
1652       ShapeHandle output_shape;
1653       TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &output_shape));
1654 
1655       // Check if annotated shapes are incompatible with inferred shapes.
1656       if ((ic->FullyDefined(ic->output(i)) &&
1657            !SameShapes(ic->output(i), output_shape)) ||
1658           (!ic->FullyDefined(ic->output(i)) &&
1659            !CompatibleShapes(ic->output(i), output_shape))) {
1660         LOG(WARNING)
1661             << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1662             << node.name() << ", inferred output shape "
1663             << "doesn't match for i=" << i << ": "
1664             << "ic->output(k): " << ic->DebugString(ic->output(i))
1665             << ", annotated output shape: " << ic->DebugString(output_shape)
1666             << " -- " << node.DebugString();
1667         c->shape_incompatible = true;
1668       }
1669 
1670       // Only use annotated shapes if the inference shape is unknown and
1671       // compatible with annotated shapes.
1672       if (!ic->FullyDefined(ic->output(i)) &&
1673           CompatibleShapes(ic->output(i), output_shape)) {
1674         VLOG(3) << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1675                 << node.name() << ", inferred output shape " << i << ": "
1676                 << "ic->output(i): " << ic->DebugString(ic->output(i))
1677                 << ", annotated output shape: " << ic->DebugString(output_shape)
1678                 << " -- " << node.ShortDebugString();
1679         ic->set_output(i, output_shape);
1680       }
1681     }
1682 
1683     return Status::OK();
1684   }
1685 
MaybeUpdateNodeContextOutput(const NodeDef & node,const bool is_fed,NodeContext * c)1686   Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed,
1687                                       NodeContext* c) {
1688     // Propagate tensors and shape tensors unless the node is fed.
1689     // TODO(bsteiner) We should still propagate the shapes to the ports that
1690     // aren't fed in the case of a ShapeN node.
1691 
1692     // Note that when propagating tensors_as_shapes, we use
1693     // c->input_tensors_as_shapes_to_progate instead of
1694     // ic->input_tensors_as_shapes. The former uses kUnknownDimFromConst if
1695     // UnknownDim is from Const tensor, and it is propagated through shape
1696     // inference. Before calling shape functions, we convert it to UnknownDim,
1697     // but instantiate a new UnknownDim to prevent incorrect symbolic shape
1698     // inference through UnknownDim from Const.
1699     InferenceContext* ic = c->inference_context.get();
1700     if (!is_fed) {
1701       if (IsConstant(node)) {
1702         const TensorProto& tensor_proto = node.attr().at("value").tensor();
1703         c->output_tensor_protos.resize(1);
1704         c->output_tensor_protos[0] = &tensor_proto;
1705         c->output_tensors_as_shapes.resize(1);
1706         MaybeTensorProtoToShape(ic, tensor_proto,
1707                                 &c->output_tensors_as_shapes[0]);
1708       } else if (IsRank(node)) {
1709         if (ic->RankKnown(ic->input(0))) {
1710           // Propagate rank value.
1711           int32_t rank = ic->Rank(ic->input(0));
1712           const_tensors_to_propagate_.push_back(
1713               MakeIntegerScalarTensorProto(DT_INT32, rank));
1714           c->output_tensor_protos.resize(1);
1715           c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
1716         }
1717       } else if (IsSize(node)) {
1718         DimensionHandle size = ic->NumElements(ic->input(0));
1719         if (ic->ValueKnown(size)) {
1720           // Propagate size value.
1721           int64_t sz = ic->Value(size);
1722           bool valid = false;
1723           if (node.attr().at("out_type").type() == DT_INT32) {
1724             if (sz < std::numeric_limits<int32>::max()) {
1725               const_tensors_to_propagate_.push_back(
1726                   MakeIntegerScalarTensorProto(DT_INT32, sz));
1727               valid = true;
1728             }
1729           } else {
1730             const_tensors_to_propagate_.push_back(
1731                 MakeIntegerScalarTensorProto(DT_INT64, sz));
1732             valid = true;
1733           }
1734           if (valid) {
1735             c->output_tensor_protos.resize(1);
1736             c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
1737           }
1738         }
1739       } else if (IsShape(node)) {
1740         c->output_tensors_as_shapes.resize(1);
1741         c->output_tensors_as_shapes[0] = c->inference_context->input(0);
1742       } else if (IsShapeN(node)) {
1743         c->output_tensors_as_shapes.resize(c->inference_context->num_inputs());
1744         for (int i = 0; i < c->inference_context->num_inputs(); ++i) {
1745           c->output_tensors_as_shapes[i] = c->inference_context->input(i);
1746         }
1747       } else if (node.op() == "ConcatV2") {
1748         bool valid = true;
1749         ShapeHandle result;
1750         for (int i = 0; i < ic->num_inputs() - 1; ++i) {
1751           ShapeHandle input = c->input_tensors_as_shapes_to_propagate[i];
1752           if (!ic->RankKnown(input)) {
1753             valid = false;
1754             break;
1755           } else if (i == 0) {
1756             result = input;
1757           } else {
1758             TF_RETURN_IF_ERROR(ic->Concatenate(result, input, &result));
1759           }
1760         }
1761         if (valid) {
1762           c->output_tensors_as_shapes.resize(1);
1763           c->output_tensors_as_shapes[0] = result;
1764         }
1765       } else if (IsPack(node)) {
1766         // A Pack node concatenating scalars is often used to generate a shape.
1767         std::vector<DimensionHandle> dims;
1768         bool valid = true;
1769         for (int i = 0; i < ic->num_inputs(); ++i) {
1770           const Tensor* t = ic->input_tensor(i);
1771           if (t) {
1772             if (t->dims() != 0 ||
1773                 (t->dtype() != DT_INT32 && t->dtype() != DT_INT64)) {
1774               valid = false;
1775               break;
1776             }
1777             int64_t size = t->dtype() == DT_INT32 ? t->scalar<int32>()()
1778                                                   : t->scalar<int64>()();
1779             dims.push_back(size < 0 ? ic->MakeDim(kUnknownDimFromConst)
1780                                     : ic->MakeDim(size));
1781           } else {
1782             // Don't have tensor value, but use input_tensors_as_shapes, if
1783             // possible.
1784             const ShapeHandle& shape_handle =
1785                 c->input_tensors_as_shapes_to_propagate[i];
1786             if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 &&
1787                 ic->ValueKnown(ic->Dim(shape_handle, 0))) {
1788               dims.push_back(ic->Dim(shape_handle, 0));
1789             } else {
1790               // This is not from Const, but as it shouldn'be used as symbolic
1791               // unknown dim for different ops, we use kUnknownDimFromConst.
1792               dims.push_back(ic->MakeDim(kUnknownDimFromConst));
1793             }
1794           }
1795         }
1796         if (valid) {
1797           c->output_tensors_as_shapes.resize(1);
1798           c->output_tensors_as_shapes[0] = ic->MakeShape(dims);
1799         }
1800       } else if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
1801         c->output_tensors_as_shapes.resize(1);
1802         c->output_tensors_as_shapes[0] =
1803             c->input_tensors_as_shapes_to_propagate[0];
1804         if (c->input_tensor_protos[0] != nullptr) {
1805           c->output_tensor_protos.resize(1);
1806           c->output_tensor_protos[0] = c->input_tensor_protos[0];
1807         }
1808       } else if (IsSlice(node)) {
1809         ShapeHandle input = c->input_tensors_as_shapes_to_propagate[0];
1810         bool valid = ic->RankKnown(input);
1811         const Tensor* slice_offset = ic->input_tensor(1);
1812         valid &= slice_offset != nullptr && slice_offset->NumElements() == 1;
1813         const Tensor* slice_size = ic->input_tensor(2);
1814         valid &= slice_size != nullptr && slice_size->NumElements() == 1;
1815         if (valid) {
1816           int64_t start = slice_offset->dtype() == DT_INT32
1817                               ? slice_offset->flat<int32>()(0)
1818                               : slice_offset->flat<int64>()(0);
1819           int64_t size =
1820               (slice_size->dtype() == DT_INT32 ? slice_size->flat<int32>()(0)
1821                                                : slice_size->flat<int64>()(0));
1822           ShapeHandle result;
1823           if (size == -1) {
1824             TF_RETURN_IF_ERROR(ic->Subshape(input, start, &result));
1825           } else {
1826             int64_t end = start + size;
1827             TF_RETURN_IF_ERROR(ic->Subshape(input, start, end, &result));
1828           }
1829           c->output_tensors_as_shapes.resize(1);
1830           c->output_tensors_as_shapes[0] = result;
1831         }
1832       } else if (IsStridedSlice(node)) {
1833         ShapeHandle input = c->input_tensors_as_shapes_to_propagate[0];
1834         bool valid = ic->RankKnown(input);
1835         const Tensor* slice_begin = ic->input_tensor(1);
1836         valid &= slice_begin != nullptr && slice_begin->NumElements() == 1;
1837         const Tensor* slice_end = ic->input_tensor(2);
1838         valid &= slice_end != nullptr && slice_end->NumElements() == 1;
1839         const Tensor* slice_stride = ic->input_tensor(3);
1840         valid &= slice_stride != nullptr && slice_stride->NumElements() == 1;
1841 
1842         if (node.attr().count("ellipsis_mask") > 0 &&
1843             node.attr().at("ellipsis_mask").i() != 0) {
1844           valid = false;
1845         }
1846         if (node.attr().count("new_axis_mask") > 0 &&
1847             node.attr().at("new_axis_mask").i() != 0) {
1848           valid = false;
1849         }
1850         if (node.attr().count("shrink_axis_mask") > 0 &&
1851             node.attr().at("shrink_axis_mask").i() != 0) {
1852           valid = false;
1853         }
1854         int begin_mask = 0;
1855         if (node.attr().count("begin_mask") > 0) {
1856           begin_mask = node.attr().at("begin_mask").i();
1857         }
1858         int end_mask = 0;
1859         if (node.attr().count("end_mask") > 0) {
1860           end_mask = node.attr().at("end_mask").i();
1861         }
1862         if (begin_mask < 0 || begin_mask > 1 || end_mask < 0 || end_mask > 1) {
1863           valid = false;
1864         }
1865         if (valid) {
1866           int64_t begin = 0;
1867           if (begin_mask == 0) {
1868             begin = slice_begin->dtype() == DT_INT32
1869                         ? slice_begin->flat<int32>()(0)
1870                         : slice_begin->flat<int64>()(0);
1871           }
1872           int64_t end = std::numeric_limits<int64>::max();
1873           if (end_mask == 0) {
1874             end =
1875                 (slice_end->dtype() == DT_INT32 ? slice_end->flat<int32>()(0)
1876                                                 : slice_end->flat<int64>()(0));
1877           }
1878           int64_t stride = slice_stride->dtype() == DT_INT32
1879                                ? slice_stride->flat<int32>()(0)
1880                                : slice_stride->flat<int64>()(0);
1881           ShapeHandle result;
1882           TF_RETURN_IF_ERROR(ic->Subshape(input, begin, end, stride, &result));
1883           c->output_tensors_as_shapes.resize(1);
1884           c->output_tensors_as_shapes[0] = result;
1885         }
1886       }
1887     }
1888 
1889     if (aggressive_shape_inference_) {
1890       // Update output shapes with annotated information. This is optional.
1891       UpdateOutputShapesUsingAnnotatedInformation(node, c).IgnoreError();
1892 
1893       // Update output tensor values using EvaluateNode() if we can.
1894       // Due to the cost of EvaluateNode(), we run it only for certain op types
1895       // (white listed) and small integer tensors.
1896 
1897       const int max_element_size = 17;  // Max up to 4x4 matrix or similar.
1898       if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) ||
1899           !ShouldUpdateOutputShapesAndValues(c, max_element_size)) {
1900         return Status::OK();
1901       }
1902       UpdateOutputShapesAndValues(node, c).IgnoreError();  // This is optional.
1903     }
1904     return Status::OK();
1905   }
1906 
InferShapes(const NodeDef & node,NodeContext * c)1907   Status InferShapes(const NodeDef& node, NodeContext* c) {
1908     // Infer the shapes of output tensors.
1909     if (!c->op_data || c->op_data->shape_inference_fn == nullptr ||
1910         !c->inference_context->Run(c->op_data->shape_inference_fn).ok()) {
1911       // Annotate outputs with unknown shapes. Update output shapes with
1912       // annotated information later on if available.
1913       // Note that shape inference function may return an error, but we ignore
1914       // it, and use UnknownShape in that case.
1915       TF_RETURN_IF_ERROR(
1916           c->inference_context->Run(shape_inference::UnknownShape));
1917     }
1918     Status status = Status::OK();
1919     auto it = fed_ports_.find(node.name());
1920     const bool is_fed = it != fed_ports_.end();
1921     if (is_fed) {
1922       // It is possible to feed node output ports with tensors of any shape: as
1923       // a result, the shape of a fed port is completely unknown.
1924       for (const int output_port : it->second) {
1925         status.Update(SetUnknownShape(&node, output_port));
1926       }
1927     }
1928 
1929     // Update NodeContext output fields after shape inference function runs.
1930     status.Update(MaybeUpdateNodeContextOutput(node, is_fed, c));
1931 
1932     return status;
1933   }
1934 
1935  private:
IsIntegerVector(const Tensor & tensor)1936   bool IsIntegerVector(const Tensor& tensor) {
1937     if (tensor.dims() == 1 &&
1938         (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64)) {
1939       return true;
1940     }
1941     return false;
1942   }
1943 
IsIntegerScalar(const Tensor & tensor)1944   bool IsIntegerScalar(const Tensor& tensor) {
1945     if (tensor.dims() == 0 &&
1946         (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64) &&
1947         tensor.NumElements() == 1) {
1948       return true;
1949     }
1950     return false;
1951   }
1952 
MakeIntegerScalarTensorProto(const DataType dtype,const int64_t val)1953   TensorProto MakeIntegerScalarTensorProto(const DataType dtype,
1954                                            const int64_t val) {
1955     TensorProto tensor_proto;
1956     tensor_proto.set_dtype(dtype);
1957     // Scalar TensorProto has an empty tensor_shape; no dim, no dim.size.
1958     tensor_proto.mutable_tensor_shape();
1959     if (dtype == DT_INT32) {
1960       tensor_proto.add_int_val(val);
1961     } else if (dtype == DT_INT64) {
1962       tensor_proto.add_int64_val(val);
1963     }
1964     return tensor_proto;
1965   }
1966 
MaybeTensorProtoToShape(InferenceContext * ic,const TensorProto & tensor_proto,ShapeHandle * tensors_as_shapes)1967   bool MaybeTensorProtoToShape(InferenceContext* ic,
1968                                const TensorProto& tensor_proto,
1969                                ShapeHandle* tensors_as_shapes) {
1970     // Skip if dtype is not integer.
1971     if (tensor_proto.dtype() != DT_INT32 && tensor_proto.dtype() != DT_INT64) {
1972       return false;
1973     }
1974     // Skip if the const tensor is too large.
1975     if (NumElementsFromTensorProto(tensor_proto) >
1976         kThresholdToSkipConstTensorInstantiation) {
1977       return false;
1978     }
1979     // Skip if shape is neither scalar nor vector.
1980     if (tensor_proto.tensor_shape().unknown_rank() ||
1981         tensor_proto.tensor_shape().dim_size() > 1) {
1982       return false;
1983     }
1984     Tensor tensor;
1985     if (!tensor.FromProto(tensor_proto)) {
1986       return false;
1987     }
1988     return MaybeTensorValueToShape(ic, tensor, tensors_as_shapes);
1989   }
1990 
MaybeTensorValueToShape(InferenceContext * ic,const Tensor & tensor,ShapeHandle * tensors_as_shapes)1991   bool MaybeTensorValueToShape(InferenceContext* ic, const Tensor& tensor,
1992                                ShapeHandle* tensors_as_shapes) {
1993     // Integer tensors of rank one can also be interpreted as a shape
1994     // provided all their values are >= -1.
1995 
1996     if (IsIntegerVector(tensor)) {
1997       bool has_values_smaller_than_minus_1 = false;
1998       std::vector<DimensionHandle> dims;
1999       for (int i = 0; i < tensor.NumElements(); i++) {
2000         int64_t value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(i)
2001                                                    : tensor.flat<int64>()(i);
2002         has_values_smaller_than_minus_1 |= (value < -1);
2003         // Mark this as UnknownDim from Const.
2004         dims.push_back(value < 0 ? ic->MakeDim(kUnknownDimFromConst)
2005                                  : ic->MakeDim(value));
2006       }
2007 
2008       if (!has_values_smaller_than_minus_1) {
2009         *tensors_as_shapes = ic->MakeShape(dims);
2010         return true;
2011       }
2012     } else if (IsIntegerScalar(tensor)) {
2013       // Scalar constant.
2014       int64_t value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0)
2015                                                  : tensor.flat<int64>()(0);
2016       if (value == -1) {
2017         // Scalar value -1 represents an unknown shape. If we would try to
2018         // MakeShape(MakeDim) with it, we would get vector of unknown size.
2019         *tensors_as_shapes = ic->UnknownShape();
2020         return true;
2021       } else if (value >= 0) {
2022         // Ideally, values can be < -1, but MakeDim() fails with a value < -1.
2023         // It's a limitation as we use ShapeHandle as a means to pass values.
2024         *tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)});
2025         return true;
2026       }
2027     }
2028     return false;
2029   }
2030 
2031   const GraphView& graph_;
2032   int graph_def_version_;
2033   absl::flat_hash_map<const NodeDef*, NodeContext> node_to_context_;
2034   absl::flat_hash_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
2035   absl::flat_hash_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
2036   // Store function instantiations only for valid function. If function
2037   // instantiation failed it will have an `absl::nullopt`.
2038   absl::flat_hash_map<string, absl::optional<GrapplerFunctionItem>>
2039       fun_to_grappler_function_item_;
2040   FunctionLibraryDefinition function_library_;
2041   const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports_;
2042   // Store TensorProtos for tensor value propagation. Note that we use deque,
2043   // not vector, as we use pointers to the TensorProtos in this container.
2044   // Vector may resize and copy the objects into a new buffer, then the existing
2045   // pointers become dangling pointers.
2046   std::deque<TensorProto> const_tensors_to_propagate_;
2047 
2048   // For more aggressive shape and value inference.
2049   bool aggressive_shape_inference_;
2050   ResourceMgr resource_mgr_;
2051 };
2052 
2053 // Keep track of shapes and dimensions in a graph.
2054 // In particular, use disjoint sets to track equivalence between shapes and
2055 // dims, and consolidate the information globally.
2056 class SymbolicShapeManager {
2057  public:
SymbolicShapeManager()2058   SymbolicShapeManager() {}
2059 
Merge(ShapeHandle s1,ShapeHandle s2)2060   Status Merge(ShapeHandle s1, ShapeHandle s2) {
2061     if (!s1.IsSet() || !s2.IsSet()) {
2062       return Status::OK();
2063     }
2064     TF_RETURN_IF_ERROR(shapes_.Merge(s1, s2));
2065     if (InferenceContext::Rank(s1) > 0 && InferenceContext::Rank(s2) > 0) {
2066       CHECK_EQ(InferenceContext::Rank(s1), InferenceContext::Rank(s2));
2067       for (int i = 0; i < InferenceContext::Rank(s1); ++i) {
2068         TF_RETURN_IF_ERROR(dims_.Merge(InferenceContext::DimKnownRank(s1, i),
2069                                        InferenceContext::DimKnownRank(s2, i)));
2070       }
2071     }
2072     return Status::OK();
2073   }
Merge(DimensionHandle d1,DimensionHandle d2)2074   Status Merge(DimensionHandle d1, DimensionHandle d2) {
2075     if (!d1.IsSet() || !d2.IsSet()) {
2076       return Status::OK();
2077     }
2078     return dims_.Merge(d1, d2);
2079   }
2080 
AsTensorProperties(const ShapeHandle & shape,const DataType & type,OpInfo::TensorProperties * properties)2081   void AsTensorProperties(const ShapeHandle& shape, const DataType& type,
2082                           OpInfo::TensorProperties* properties) {
2083     properties->set_dtype(type);
2084     ShapeHandle actual_shape = shapes_.GetMergedValue(shape);
2085     if (!InferenceContext::RankKnown(actual_shape)) {
2086       properties->mutable_shape()->set_unknown_rank(true);
2087     } else {
2088       for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
2089         shape_inference::DimensionHandle dim =
2090             InferenceContext::DimKnownRank(actual_shape, j);
2091         int64_t d = dims_.GetMergedValue(dim);
2092         properties->mutable_shape()->add_dim()->set_size(d);
2093       }
2094     }
2095   }
2096 
2097   // Returns merged shape with merged dimensions.
GetMergedShape(InferenceContext * ic,ShapeHandle s)2098   ShapeHandle GetMergedShape(InferenceContext* ic, ShapeHandle s) {
2099     const auto& actual_shape = shapes_.GetMergedValue(s);
2100     if (!InferenceContext::RankKnown(actual_shape)) {
2101       return ic->UnknownShape();
2102     } else {
2103       std::vector<DimensionHandle> dims;
2104       for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
2105         shape_inference::DimensionHandle dim =
2106             InferenceContext::DimKnownRank(actual_shape, j);
2107         int64_t d = dims_.GetMergedValue(dim);
2108         // Symbolic shape manager may made some dims < -1, which causes errors
2109         // in creating Dimension.
2110         if (d < -1) {
2111           d = -1;
2112         }
2113         dims.push_back(ic->MakeDim(d));
2114       }
2115       return ic->MakeShape(dims);
2116     }
2117   }
2118 
2119  private:
2120   DisjointSet<shape_inference::ShapeHandle> shapes_;
2121   DisjointSet<shape_inference::DimensionHandle> dims_;
2122 };
2123 
2124 // Checks whether there is any conflict in merged shapes and dims in
2125 // SymbolicShapeManager.
ValidateSymbolicShapeManager(const GraphDef & graph_def,SymbolicShapeRefiner * refiner,SymbolicShapeManager * shape_manager)2126 Status ValidateSymbolicShapeManager(const GraphDef& graph_def,
2127                                     SymbolicShapeRefiner* refiner,
2128                                     SymbolicShapeManager* shape_manager) {
2129   if (!VLOG_IS_ON(1)) {
2130     return Status::OK();
2131   }
2132 
2133   VLOG(1) << "Checking any conflics in shapes and dimensions ...";
2134   int64_t num_incompatible_shapes = 0;
2135   for (const NodeDef& node : graph_def.node()) {
2136     auto ctx = refiner->GetNodeContext(&node);
2137     if (!ctx) {
2138       continue;
2139     }
2140     auto* ic = ctx->inference_context.get();
2141     for (int i = 0; i < ic->num_inputs(); ++i) {
2142       const auto& shape = ic->input(i);
2143       const auto& merged_shape = shape_manager->GetMergedShape(ic, shape);
2144       if (!refiner->CompatibleShapes(shape, merged_shape)) {
2145         num_incompatible_shapes++;
2146         VLOG(1) << "**** Incompatible shape from SymbolicShapeManager "
2147                 << "for node " << node.name() << " input (" << i << ") "
2148                 << ic->DebugString(shape)
2149                 << " vs. merged: " << ic->DebugString(merged_shape);
2150       }
2151     }
2152     for (int i = 0; i < ic->num_outputs(); ++i) {
2153       const auto& shape = ic->output(i);
2154       const auto& merged_shape = shape_manager->GetMergedShape(ic, shape);
2155       if (!refiner->CompatibleShapes(shape, merged_shape)) {
2156         num_incompatible_shapes++;
2157         VLOG(1) << "**** Incompatible shape from SymbolicShapeManager "
2158                 << "for node " << node.name() << " output (" << i << ") "
2159                 << ic->DebugString(shape)
2160                 << " vs. merged: " << ic->DebugString(merged_shape);
2161       }
2162     }
2163   }
2164   if (num_incompatible_shapes > 0) {
2165     VLOG(1) << "**** WARNING: " << num_incompatible_shapes
2166             << " incompatible shapes from SymbolicShapeManager.";
2167   } else {
2168     VLOG(1) << "**** No incompatible shape found from SymbolicShapeManager.";
2169   }
2170 
2171   return Status::OK();
2172 }
2173 
2174 // Log shape inference and its merged shapes.
VerboseShapeInferenceLogging(const GraphDef & graph_def,SymbolicShapeRefiner * refiner,SymbolicShapeManager * shape_manager)2175 Status VerboseShapeInferenceLogging(const GraphDef& graph_def,
2176                                     SymbolicShapeRefiner* refiner,
2177                                     SymbolicShapeManager* shape_manager) {
2178   // As logging all the nodes would generate too many lines, we by default
2179   // skip this detailed logging. Users may add nodes of interest to
2180   // node_names_for_logging to enable detailed logging.
2181   absl::flat_hash_set<std::string> node_names_for_logging = {};
2182   if (!VLOG_IS_ON(3) || node_names_for_logging.empty()) {
2183     return Status::OK();
2184   }
2185 
2186   auto should_log = [&node_names_for_logging](std::string node_name) {
2187     return node_names_for_logging.find(node_name) !=
2188            node_names_for_logging.end();
2189   };
2190 
2191   for (const NodeDef& node : graph_def.node()) {
2192     if (!should_log(node.name())) {
2193       continue;
2194     }
2195     auto ctx = refiner->GetNodeContext(&node);
2196     if (!ctx) {
2197       continue;
2198     }
2199     auto* ic = ctx->inference_context.get();
2200     VLOG(3) << "Shape inference for node : " << node.name();
2201     VLOG(3) << ctx->DebugString(node);
2202     std::string merged_shapes = "Merged shapes from SymbolicShapManager:\n";
2203     for (int i = 0; i < ic->num_inputs(); ++i) {
2204       absl::StrAppend(
2205           &merged_shapes, " input[", i, "] -- ",
2206           ic->DebugString(shape_manager->GetMergedShape(ic, ic->input(i))),
2207           "\n");
2208     }
2209     for (int i = 0; i < ic->num_outputs(); ++i) {
2210       absl::StrAppend(
2211           &merged_shapes, " output[", i, "] -- ",
2212           ic->DebugString(shape_manager->GetMergedShape(ic, ic->output(i))),
2213           "\n");
2214     }
2215     VLOG(3) << merged_shapes;
2216     VLOG(3) << "--------------------------------";
2217     VLOG(3) << "";
2218   }
2219 
2220   return Status::OK();
2221 }
2222 
RelaxEnqueueShapesAndMergeTypes(SymbolicShapeRefiner * shape_refiner,const NodeDef * qnode,const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * queue_shapes_and_types)2223 Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
2224     SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
2225     const std::vector<ShapeAndType>& shapes_and_types,
2226     std::vector<ShapeAndType>* queue_shapes_and_types) {
2227   if (shapes_and_types.size() != queue_shapes_and_types->size()) {
2228     return errors::InvalidArgument(
2229         "Enqueue nodes mixed number of tensors: ", shapes_and_types.size(),
2230         "  vs ", queue_shapes_and_types->size());
2231   }
2232   for (size_t i = 0; i < shapes_and_types.size(); ++i) {
2233     const ShapeAndType& a = shapes_and_types[i];
2234     ShapeAndType& b = (*queue_shapes_and_types)[i];
2235     if (a.dtype != b.dtype) {
2236       return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ",
2237                                      i, ": ", DataTypeString(a.dtype), " vs ",
2238                                      DataTypeString(b.dtype));
2239     }
2240 
2241     b.shape = shape_refiner->OutputAsUnion(qnode, i, a.shape, b.shape);
2242   }
2243   return Status::OK();
2244 }
2245 
2246 // Compute the output shape of the merge node as the union of the available
2247 // input shapes.
UpdateMerge(SymbolicShapeRefiner * shape_refiner,const NodeDef * node,bool * new_shapes) const2248 Status GraphProperties::UpdateMerge(SymbolicShapeRefiner* shape_refiner,
2249                                     const NodeDef* node,
2250                                     bool* new_shapes) const {
2251   InferenceContext* ic = shape_refiner->GetContext(node);
2252   if (!ic) {
2253     // Now we can run shape inference
2254     TF_RETURN_IF_ERROR(shape_refiner->AddNode(node));
2255     ic = CHECK_NOTNULL(shape_refiner->GetContext(node));
2256     *new_shapes = true;
2257 
2258     // Infer the shape of the second output once and for all since it never
2259     // changes.
2260     ShapeHandle out1 = ic->Scalar();
2261     if (ic->num_outputs() >= 2) ic->set_output(1, out1);
2262   }
2263 
2264   ShapeHandle out;
2265   const std::vector<ShapeAndType>* out_handle = nullptr;
2266   bool out_initialized = false;
2267   for (const GraphView::Edge fanin : shape_refiner->graph().GetFaninEdges(
2268            *node, /*include_controlling_edges=*/false)) {
2269     InferenceContext* src_ic = shape_refiner->GetContext(fanin.src.node);
2270     if (!src_ic) {
2271       // Handling a loop for the first time, the back edge won't have any shape
2272       // info.
2273       continue;
2274     }
2275     ShapeHandle input = src_ic->output(fanin.src.port_id);
2276     ic->SetInput(fanin.dst.port_id, input);
2277     auto* input_handle =
2278         src_ic->output_handle_shapes_and_types(fanin.src.port_id);
2279     if (input_handle)
2280       ic->set_input_handle_shapes_and_types(fanin.dst.port_id, *input_handle);
2281     if (!out_initialized) {
2282       out_initialized = true;
2283       out = input;
2284       out_handle = input_handle;
2285     } else {
2286       // Note here only out, not out_handle, is modified.
2287       out = shape_refiner->OutputAsUnion(node, 0, input, out);
2288     }
2289   }
2290 
2291   if (*new_shapes || !shape_refiner->EquivalentShapes(out, ic->output(0))) {
2292     ic->set_output(0, out);
2293     if (out_handle) ic->set_output_handle_shapes_and_types(0, *out_handle);
2294     *new_shapes = true;
2295   }
2296 
2297   return Status::OK();
2298 }
2299 
2300 // Manually propagate the input shape for Enter nodes.
UpdateEnter(SymbolicShapeRefiner * shape_refiner,const NodeDef * node,bool * new_shapes)2301 Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
2302                                     const NodeDef* node, bool* new_shapes) {
2303   InferenceContext* ic = shape_refiner->GetContext(node);
2304   if (!ic) {
2305     TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(node, new_shapes));
2306     ic = shape_refiner->GetContext(node);
2307   }
2308 
2309   GraphView::InputPort port(node, 0);
2310   GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(port);
2311 
2312   InferenceContext* src_ic = shape_refiner->GetContext(fanin.node);
2313   ShapeHandle input = src_ic->output(fanin.port_id);
2314   if (!ic->output(0).SameHandle(input)) {
2315     ic->SetInput(0, input);
2316     ic->set_output(0, input);
2317     *new_shapes = true;
2318   }
2319   auto* outputs = src_ic->output_handle_shapes_and_types(fanin.port_id);
2320   if (outputs) {
2321     ic->set_input_handle_shapes_and_types(0, *outputs);
2322     ic->set_output_handle_shapes_and_types(0, *outputs);
2323     *new_shapes = true;
2324   }
2325   return Status::OK();
2326 }
2327 
UpdateShapes(SymbolicShapeRefiner * shape_refiner,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,const NodeDef * n,bool * new_shapes) const2328 Status GraphProperties::UpdateShapes(
2329     SymbolicShapeRefiner* shape_refiner,
2330     const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2331     const NodeDef* n, bool* new_shapes) const {
2332   if (IsEnter(*n)) {
2333     // The Enter shape function always forwards an UnknownShape, so do the right
2334     // thing here.
2335     TF_RETURN_IF_ERROR(UpdateEnter(shape_refiner, n, new_shapes));
2336   } else if (IsMerge(*n)) {
2337     // Properly handle merge nodes.
2338     TF_RETURN_IF_ERROR(UpdateMerge(shape_refiner, n, new_shapes));
2339   } else if (IsEnqueue(*n)) {
2340     // Make sure the shapes of enqueued tensors are propagated to the queue
2341     // itself.
2342     TF_RETURN_IF_ERROR(
2343         UpdateEnqueue(n, resource_handles, shape_refiner, new_shapes));
2344   } else if (IsQueue(*n)) {
2345     // Set shapes and types of Queue ops, if needed.
2346     TF_RETURN_IF_ERROR(UpdateQueue(n, shape_refiner, new_shapes));
2347   } else {
2348     // Rely on regular TF shape refinement for all the other nodes.
2349     // UpdateNode calls UpdateFunction if a function node is detected.
2350     TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
2351   }
2352 
2353   return Status::OK();
2354 }
2355 
2356 // Propagates the shapes in the transitive fan-out of <new_shapes>.
PropagateShapes(SymbolicShapeRefiner * shape_refiner,TopoQueue * new_shapes,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,int num_loops) const2357 Status GraphProperties::PropagateShapes(
2358     SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
2359     const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2360     int num_loops) const {
2361   // Limit the number of iterations to prevent infinite loops in the presence of
2362   // incorrect shape functions. The algorithm should converge in at most
2363   // num_nested_loops^2 * max_rank. We approximate max_rank with the constant 4.
2364   // The same applies to resources.
2365   VLOG(1) << "Propagating " << new_shapes->size() << " new shapes through "
2366           << num_loops << " loops and " << resource_handles.size()
2367           << " resources" << std::endl;
2368 
2369   const int64_t max_loop_length = item_.graph.node_size();
2370   const int64_t max_rank = 4;
2371   const int64_t max_loop_iterations =
2372       max_rank * max_loop_length * std::max<int64>(1, num_loops * num_loops);
2373   const int64_t num_queues = resource_handles.size();
2374   const int64_t max_resource_iterations = num_queues * num_queues * max_rank;
2375 
2376   int64_t num_resource_iterations = 0;
2377   do {
2378     int64_t num_loop_iterations = 0;
2379     while (!new_shapes->empty() &&
2380            num_loop_iterations++ < max_loop_iterations) {
2381       const NodeDef* n = new_shapes->pop();
2382       bool updated = false;
2383       TF_RETURN_IF_ERROR(
2384           UpdateShapes(shape_refiner, resource_handles, n, &updated));
2385       if (updated) {
2386         for (const auto& fanout : shape_refiner->graph().GetFanouts(
2387                  *n, /*include_controlled_nodes=*/false)) {
2388           new_shapes->push(fanout.node);
2389         }
2390         // Make sure the corresponding queue nodes are (re)processed.
2391         if (IsEnqueue(*n)) {
2392           auto it = resource_handles.find(n);
2393           if (it != resource_handles.end()) {
2394             new_shapes->push(it->second);
2395           }
2396         }
2397       }
2398     }
2399   } while (!new_shapes->empty() &&
2400            num_resource_iterations++ < max_resource_iterations);
2401 
2402   if (!new_shapes->empty()) {
2403     return errors::Internal("Shape inference failed to converge");
2404   }
2405 
2406   return Status::OK();
2407 }
2408 
UpdateQueue(const NodeDef * queue_node,SymbolicShapeRefiner * shape_refiner,bool * new_shapes)2409 Status GraphProperties::UpdateQueue(const NodeDef* queue_node,
2410                                     SymbolicShapeRefiner* shape_refiner,
2411                                     bool* new_shapes) {
2412   auto* ctx = shape_refiner->GetNodeContext(queue_node);
2413   if (!ctx) {
2414     TF_RETURN_IF_ERROR(shape_refiner->AddNode(queue_node));
2415     ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(queue_node));
2416   }
2417   auto* ic = ctx->inference_context.get();
2418 
2419   auto* outputs = ic->output_handle_shapes_and_types(0);
2420   if (outputs) {
2421     // Shapes and types are already set, presumably by Enqueue ops.
2422     return shape_refiner->UpdateNode(queue_node, new_shapes);
2423   }
2424 
2425   if (queue_node->attr().count("shapes") <= 0 ||
2426       queue_node->attr().count("component_types") <= 0 ||
2427       queue_node->attr().at("shapes").list().shape_size() !=
2428           queue_node->attr().at("component_types").list().type_size()) {
2429     // Errors in shapes and component_types attr.
2430     return shape_refiner->UpdateNode(queue_node, new_shapes);
2431   }
2432 
2433   // Extract types and shapes from Queue attr.
2434   const auto& shapes = queue_node->attr().at("shapes").list().shape();
2435   const auto& types = queue_node->attr().at("component_types").list().type();
2436   std::vector<ShapeAndType> shapes_and_types;
2437   for (int i = 0; i < types.size(); i++) {
2438     const auto& shape = shapes[i];
2439     ShapeHandle shape_handle;
2440     TF_RETURN_IF_ERROR(
2441         ic->MakeShapeFromPartialTensorShape(shape, &shape_handle));
2442     DataType data_type =
2443         queue_node->attr().at("component_types").list().type(i);
2444     ShapeAndType shape_and_type(shape_handle, data_type);
2445     shapes_and_types.push_back(shape_and_type);
2446   }
2447   ic->set_output_handle_shapes_and_types(0, shapes_and_types);
2448 
2449   // Queue node is updated with output_handle_shapes_and_types, so set
2450   // new_shapes and ignore it from UpdateNoe().
2451   *new_shapes = true;
2452   bool dummy_new_shapes = false;
2453   return shape_refiner->UpdateNode(queue_node, &dummy_new_shapes);
2454 }
2455 
UpdateEnqueue(const NodeDef * enqueue_node,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,SymbolicShapeRefiner * shape_refiner,bool * new_shapes)2456 Status GraphProperties::UpdateEnqueue(
2457     const NodeDef* enqueue_node,
2458     const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2459     SymbolicShapeRefiner* shape_refiner, bool* new_shapes) {
2460   auto ctx = shape_refiner->GetNodeContext(enqueue_node);
2461   if (!ctx) {
2462     TF_RETURN_IF_ERROR(shape_refiner->AddNode(enqueue_node));
2463     ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(enqueue_node));
2464   }
2465 
2466   auto it = resource_handles.find(enqueue_node);
2467   if (it == resource_handles.end()) {
2468     // The corresponding queue was not found, there isn't much we can do.
2469     return Status::OK();
2470   }
2471   const NodeDef* qnode = it->second;
2472   auto qctx = shape_refiner->GetContext(qnode);
2473   if (!qctx) {
2474     return Status::OK();
2475   }
2476   auto* queue_handle_data = qctx->output_handle_shapes_and_types(0);
2477 
2478   // TODO(bsteiner): handle EnqueueMany as well.
2479   std::vector<ShapeAndType> shapes_and_types;
2480   for (int i = 1, end = ctx->input_types.size(); i < end; ++i) {
2481     GraphView::InputPort inp(enqueue_node, i);
2482     GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(inp);
2483     InferenceContext* in = shape_refiner->GetContext(fanin.node);
2484     ShapeHandle input = in->output(fanin.port_id);
2485     ctx->inference_context->SetInput(i, input);
2486     shapes_and_types.push_back({input, ctx->input_types[i]});
2487   }
2488 
2489   if (queue_handle_data == nullptr) {
2490     qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
2491     *new_shapes = true;
2492   } else {
2493     TF_RETURN_IF_ERROR(RelaxEnqueueShapesAndMergeTypes(
2494         shape_refiner, qnode, *queue_handle_data, &shapes_and_types));
2495     *new_shapes |= !shape_refiner->EquivalentShapesAndTypes(*queue_handle_data,
2496                                                             shapes_and_types);
2497     qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
2498   }
2499 
2500   return Status::OK();
2501 }
2502 
InferStatically(bool assume_valid_feeds,bool aggressive_shape_inference,bool include_input_tensor_values,bool include_output_tensor_values)2503 Status GraphProperties::InferStatically(bool assume_valid_feeds,
2504                                         bool aggressive_shape_inference,
2505                                         bool include_input_tensor_values,
2506                                         bool include_output_tensor_values) {
2507   FunctionLibraryDefinition function_library(OpRegistry::Global(),
2508                                              item_.graph.library());
2509   absl::flat_hash_map<string, absl::flat_hash_set<int>> fed_ports;
2510   if (!assume_valid_feeds) {
2511     for (const auto& feed : item_.feed) {
2512       SafeTensorId tensor_id = ParseTensorName(feed.first);
2513       fed_ports[tensor_id.node()].insert(tensor_id.index());
2514     }
2515   }
2516 
2517   GraphView graph_view(&item_.graph);
2518 
2519   // List the resources and the nodes using them. Also collect the Merge nodes,
2520   // fed nodes, and primary inputs.
2521   absl::flat_hash_map<const NodeDef*,
2522                       std::pair<absl::flat_hash_set<const NodeDef*>,
2523                                 absl::flat_hash_set<const NodeDef*>>>
2524       resources;
2525   absl::flat_hash_set<const NodeDef*> merge_nodes;
2526   absl::flat_hash_set<const NodeDef*> fed_nodes;
2527   absl::flat_hash_set<const NodeDef*> primary_inputs;
2528   int num_loops = 0;
2529   for (const NodeDef& node : item_.graph.node()) {
2530     if (IsQueue(node)) {
2531       for (const GraphView::InputPort& fanout :
2532            graph_view.GetFanouts(node, false)) {
2533         if (IsEnter(*fanout.node)) {
2534           const NodeDef& enter = *fanout.node;
2535           for (const GraphView::InputPort& fanout :
2536                graph_view.GetFanouts(enter, false)) {
2537             if (IsEnqueue(*fanout.node)) {
2538               resources[&node].first.insert(fanout.node);
2539             } else if (IsDequeue(*fanout.node)) {
2540               resources[&node].second.insert(fanout.node);
2541             }
2542           }
2543         } else {
2544           if (IsEnqueue(*fanout.node)) {
2545             resources[&node].first.insert(fanout.node);
2546           } else if (IsDequeue(*fanout.node)) {
2547             resources[&node].second.insert(fanout.node);
2548           }
2549         }
2550       }
2551     }
2552     if (!HasRegularInputs(node)) {
2553       primary_inputs.insert(&node);
2554     } else if (IsMerge(node)) {
2555       merge_nodes.insert(&node);
2556     } else if (IsNextIteration(node)) {
2557       ++num_loops;
2558     }
2559     if (fed_ports.find(node.name()) != fed_ports.end()) {
2560       fed_nodes.insert(&node);
2561     }
2562   }
2563 
2564   absl::flat_hash_map<const NodeDef*, const NodeDef*> resource_handles;
2565   std::vector<TopologicalDependency> extra_deps;
2566   for (const auto& resource : resources) {
2567     for (const NodeDef* src : resource.second.first) {
2568       resource_handles[src] = resource.first;
2569       for (const NodeDef* dst : resource.second.second) {
2570         // Add control edges from enqueue to dequeue nodes to ensure they are
2571         // processed in their logical order.
2572         extra_deps.emplace_back(src, dst);
2573       }
2574     }
2575   }
2576 
2577   std::vector<const NodeDef*> topo_order;
2578   Status s = ComputeTopologicalOrder(item_.graph, extra_deps, &topo_order);
2579   if (!s.ok()) {
2580     if (extra_deps.empty()) {
2581       return s;
2582     } else {
2583       // There is a loop between queues: we'll just use the graph topological
2584       // order. This will make the shape inference less precise but since this
2585       // isn't common it's not worth to figure out where to break the loop and
2586       // do a proper relaxation.
2587       TF_RETURN_IF_ERROR(ComputeTopologicalOrder(item_.graph, &topo_order));
2588     }
2589   }
2590 
2591   // Heap-allocate SymbolicShapeRefiner in order to not consume a large amount
2592   // of stack space.
2593   auto refiner = absl::make_unique<SymbolicShapeRefiner>(
2594       graph_view, fed_ports, aggressive_shape_inference);
2595 
2596   TopoQueue new_shapes(topo_order);
2597   // Also seed the propagation of shapes in the fanout of primary inputs.
2598   for (const NodeDef* node : primary_inputs) {
2599     new_shapes.push(node);
2600   }
2601   // Also seed the propagation of shapes in the fanout of fed nodes.
2602   for (const NodeDef* node : fed_nodes) {
2603     new_shapes.push(node);
2604   }
2605   // Propagate shapes normally.
2606   TF_RETURN_IF_ERROR(
2607       PropagateShapes(refiner.get(), &new_shapes, resource_handles, num_loops));
2608 
2609   // Track shapes globally across the graph.
2610   std::unique_ptr<SymbolicShapeManager> shape_manager =
2611       absl::make_unique<SymbolicShapeManager>();
2612   bool found_error = false;
2613   for (const NodeDef& node : item_.graph.node()) {
2614     auto node_ctx = refiner->GetContext(&node);
2615     if (!node_ctx) {
2616       continue;
2617     }
2618     // Skip any information that comes from fed nodes.
2619     if (fed_ports.find(node.name()) != fed_ports.end()) {
2620       VLOG(2) << "Skipping feed node shape: " << node.name();
2621       continue;
2622     }
2623     for (const auto& merged_shapes : node_ctx->MergedShapes()) {
2624       if (!shape_manager->Merge(merged_shapes.first, merged_shapes.second)
2625                .ok()) {
2626         found_error = true;
2627         break;
2628       }
2629     }
2630     for (const auto& merged_dims : node_ctx->MergedDims()) {
2631       if (!shape_manager->Merge(merged_dims.first, merged_dims.second).ok()) {
2632         found_error = true;
2633         break;
2634       }
2635     }
2636     if (found_error) {
2637       // The shapes aren't consistent, we can't infer safely: discard all the
2638       // information discovered so far.
2639       shape_manager = absl::make_unique<SymbolicShapeManager>();
2640       break;
2641     }
2642   }
2643 
2644   TF_RETURN_IF_ERROR(ValidateSymbolicShapeManager(item_.graph, refiner.get(),
2645                                                   shape_manager.get()));
2646 
2647   for (const NodeDef& node : item_.graph.node()) {
2648     VLOG(4) << "Filling in graph properties for node: " << node.name();
2649     auto ctx = refiner->GetNodeContext(&node);
2650     if (!ctx) {
2651       continue;
2652     }
2653 
2654     auto* ic = ctx->inference_context.get();
2655 
2656     // Fill input properties.
2657     {
2658       auto& input_properties = input_properties_[node.name()];
2659 
2660       // Should always be empty, node names in graph are supposed to be unique.
2661       CHECK_EQ(input_properties.size(), 0);
2662 
2663       input_properties.resize(ic->num_inputs());
2664       GraphView::InputPort input(&node, -1);
2665       for (int i = 0; i < ic->num_inputs(); ++i) {
2666         shape_manager->AsTensorProperties(ic->input(i), ctx->input_types[i],
2667                                           &input_properties[i]);
2668         input.port_id = i;
2669         GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
2670         if (include_input_tensor_values) {
2671           // Export tensor value to input_properties.value.
2672           if (IsConstant(*fanin.node)) {
2673             const TensorProto& raw_val =
2674                 fanin.node->attr().at("value").tensor();
2675             *input_properties[i].mutable_value() = raw_val;
2676           } else if (static_cast<int>(ctx->input_tensor_protos.size()) > i &&
2677                      ctx->input_tensor_protos[i] != nullptr) {
2678             *input_properties[i].mutable_value() = *ctx->input_tensor_protos[i];
2679           } else if (static_cast<int>(ic->input_tensors_as_shapes().size()) >
2680                          i &&
2681                      IsShapeFullyDefinedIntegerVectorOrScalar(
2682                          ic, ic->input(i), ic->input_tensors_as_shapes()[i],
2683                          ctx->input_types[i])) {
2684             *input_properties[i].mutable_value() = MakeTensorProtoFromShape(
2685                 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
2686                 ctx->input_types[i]);
2687           }
2688         }
2689       }
2690     }
2691 
2692     // Fill output properties.
2693     {
2694       auto& output_properties = output_properties_[node.name()];
2695 
2696       // Should always be empty, node names in graph are supposed to be unique.
2697       CHECK_EQ(output_properties.size(), 0);
2698 
2699       output_properties.resize(ic->num_outputs());
2700       for (int i = 0; i < ic->num_outputs(); ++i) {
2701         shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i],
2702                                           &output_properties[i]);
2703         auto converted_output_tensors_as_shapes =
2704             ReplaceUnknownDimFromConstWithUnknownDim(
2705                 ic, ctx->output_tensors_as_shapes);
2706         if (include_output_tensor_values) {
2707           // Export tensor value to output_properties.value.
2708           if (IsConstant(node)) {
2709             // TODO(rmlarsen): Eliminate this copy.
2710             const TensorProto& raw_val = node.attr().at("value").tensor();
2711             *output_properties[i].mutable_value() = raw_val;
2712           } else if (static_cast<int>(ctx->output_tensor_protos.size()) > i &&
2713                      ctx->output_tensor_protos[i] != nullptr) {
2714             *output_properties[i].mutable_value() =
2715                 *ctx->output_tensor_protos[i];
2716           } else if (static_cast<int>(
2717                          converted_output_tensors_as_shapes.size()) > i &&
2718                      IsShapeFullyDefinedIntegerVectorOrScalar(
2719                          ic, ic->output(i),
2720                          converted_output_tensors_as_shapes[i],
2721                          ctx->output_types[i])) {
2722             *output_properties[i].mutable_value() = MakeTensorProtoFromShape(
2723                 ic, ic->output(i), converted_output_tensors_as_shapes[i],
2724                 ctx->output_types[i]);
2725           }
2726         }
2727       }
2728     }
2729 
2730     if (aggressive_shape_inference && ctx->shape_incompatible)
2731       incompatible_shape_nodes_.insert(node.name());
2732   }
2733 
2734   if (aggressive_shape_inference && !incompatible_shape_nodes_.empty())
2735     LOG(WARNING) << incompatible_shape_nodes_.size()
2736                  << " nodes have incompatible output shapes.";
2737 
2738   // Help trace the unknown dimensions to their origins.
2739   VerboseLogUnknownDimensionSources(item_.graph, input_properties_,
2740                                     output_properties_);
2741 
2742   TF_RETURN_IF_ERROR(VerboseShapeInferenceLogging(item_.graph, refiner.get(),
2743                                                   shape_manager.get()));
2744 
2745   return Status::OK();
2746 }
2747 
InferDynamically(Cluster * cluster)2748 Status GraphProperties::InferDynamically(Cluster* cluster) {
2749   TF_RETURN_IF_ERROR(cluster->Initialize(item_));
2750 
2751   // Runs the model once to collect the shapes in the cost model.
2752   RunMetadata metadata;
2753   TF_RETURN_IF_ERROR(
2754       cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata));
2755 
2756   return InferFromCostGraph(metadata.cost_graph());
2757 }
2758 
AnnotateOutputShapes(GraphDef * output_graph_def,bool allow_symbolic_shapes) const2759 Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def,
2760                                              bool allow_symbolic_shapes) const {
2761   *output_graph_def = item_.graph;
2762   for (int i = 0; i < output_graph_def->node_size(); i++) {
2763     auto node = output_graph_def->mutable_node(i);
2764     AttrValue attr_output_shape;
2765     auto tensor_properties = GetOutputProperties(node->name());
2766     for (const auto& tensor_property : tensor_properties) {
2767       TensorShapeProto* proto = attr_output_shape.mutable_list()->add_shape();
2768       *proto = tensor_property.shape();
2769       if (!allow_symbolic_shapes) {
2770         NormalizeShapeForOutput(proto);
2771       }
2772     }
2773     (*node->mutable_attr())["_output_shapes"] = attr_output_shape;
2774   }
2775   return Status::OK();
2776 }
2777 
InferFromCostGraph(const CostGraphDef & cost_graph)2778 Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) {
2779   if (cost_graph.node_size() == 0) {
2780     LOG(WARNING) << "cost_graph is empty: nothing can be inferred!";
2781   }
2782   std::unordered_map<string, const CostGraphDef::Node*> name_to_cost;
2783   std::unordered_map<string, const NodeDef*> name_to_node;  // Empty
2784   for (auto& node : cost_graph.node()) {
2785     name_to_cost[node.name()] = &node;
2786 
2787     std::vector<OpInfo::TensorProperties> output_properties;
2788     for (const auto& out : node.output_info()) {
2789       OpInfo::TensorProperties properties;
2790       properties.set_dtype(out.dtype());
2791       *properties.mutable_shape() = out.shape();
2792       output_properties.push_back(properties);
2793     }
2794     output_properties_[node.name()] = output_properties;
2795   }
2796 
2797   for (const auto& node : item_.graph.node()) {
2798     // Skip the nodes that are not in the cost graph: these are nodes that
2799     // aren't run, because they aren't in the intersection of transitive fan-in
2800     // of a fetch node and the transitive fan-out of an input, or nodes that
2801     // were optimized away by the optimizer.
2802     auto it = name_to_cost.find(node.name());
2803     if (it == name_to_cost.end()) {
2804       continue;
2805     }
2806     std::vector<OpInfo::TensorProperties> inputs =
2807         FindInputFeatures(node, name_to_cost, name_to_node);
2808 
2809     input_properties_[node.name()] = inputs;
2810   }
2811   return Status::OK();
2812 }
2813 
HasInputProperties(const string & node_name) const2814 bool GraphProperties::HasInputProperties(const string& node_name) const {
2815   return input_properties_.find(node_name) != input_properties_.end();
2816 }
2817 
HasOutputProperties(const string & node_name) const2818 bool GraphProperties::HasOutputProperties(const string& node_name) const {
2819   return output_properties_.find(node_name) != output_properties_.end();
2820 }
2821 
2822 const std::vector<OpInfo::TensorProperties>&
GetInputProperties(const string & node_name) const2823 GraphProperties::GetInputProperties(const string& node_name) const {
2824   auto it = input_properties_.find(node_name);
2825   if (it != input_properties_.end()) {
2826     return it->second;
2827   }
2828   return missing_properties_;
2829 }
2830 
2831 const std::vector<OpInfo::TensorProperties>&
GetOutputProperties(const string & node_name) const2832 GraphProperties::GetOutputProperties(const string& node_name) const {
2833   auto it = output_properties_.find(node_name);
2834   if (it != output_properties_.end()) {
2835     return it->second;
2836   }
2837   return missing_properties_;
2838 }
2839 
ClearInputProperties(const string & node_name)2840 void GraphProperties::ClearInputProperties(const string& node_name) {
2841   input_properties_.erase(node_name);
2842 }
ClearOutputProperties(const string & node_name)2843 void GraphProperties::ClearOutputProperties(const string& node_name) {
2844   output_properties_.erase(node_name);
2845 }
2846 
2847 }  // end namespace grappler
2848 }  // end namespace tensorflow
2849