• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/graph_properties.h"
17 
18 #include "absl/types/optional.h"
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/framework/common_shape_fns.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/framework/versions.pb.h"
29 #include "tensorflow/core/graph/tensor_id.h"
30 #include "tensorflow/core/grappler/costs/utils.h"
31 #include "tensorflow/core/grappler/mutable_graph_view.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
34 #include "tensorflow/core/grappler/utils.h"
35 #include "tensorflow/core/grappler/utils/functions.h"
36 #include "tensorflow/core/grappler/utils/topological_sort.h"
37 #include "tensorflow/core/lib/gtl/cleanup.h"
38 #include "tensorflow/core/lib/gtl/flatset.h"
39 #include "tensorflow/core/lib/strings/str_util.h"
40 
41 namespace tensorflow {
42 namespace grappler {
43 
44 namespace {
45 
46 using shape_inference::DimensionHandle;
47 using shape_inference::InferenceContext;
48 using shape_inference::ShapeAndType;
49 using shape_inference::ShapeHandle;
50 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
51 
52 // A large value for UnknownDim from Const used as a dim value in shape.
53 // Some ops treat "-1" specially, different from UnknownDim:
54 // e.g., shape input to Reshape op.
55 const int64 kUnknownDimFromConst = INT64_MAX;
56 
57 template <typename Handle>
58 struct HashHandle {
operator ()tensorflow::grappler::__anonb793b30d0111::HashHandle59   std::size_t operator()(const Handle& h) const { return h.Handle(); }
60 };
61 template <typename Handle>
62 struct CompareHandle {
operator ()tensorflow::grappler::__anonb793b30d0111::CompareHandle63   bool operator()(const Handle& h1, const Handle& h2) const {
64     return h1.SameHandle(h2);
65   }
66 };
67 
68 template <typename Handle>
69 struct HandleToObject {};
70 template <>
71 struct HandleToObject<ShapeHandle> {
72   typedef ShapeHandle Object;
73 
Unknowntensorflow::grappler::__anonb793b30d0111::HandleToObject74   static ShapeHandle Unknown() { return ShapeHandle(); }
75 };
76 
77 template <>
78 struct HandleToObject<DimensionHandle> {
79   typedef int64 Object;
80 
Unknowntensorflow::grappler::__anonb793b30d0111::HandleToObject81   static int64 Unknown() { return -1; }
82 };
83 
84 template <typename Handle>
85 struct Processor {};
86 
87 template <>
88 struct Processor<ShapeHandle> {
89   // Extract the shape or dim denoted by the handle.
ExtractValuetensorflow::grappler::__anonb793b30d0111::Processor90   void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; }
91   // Merge the shapes or dims.
Mergetensorflow::grappler::__anonb793b30d0111::Processor92   Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) {
93     if (InferenceContext::RankKnown(*result)) {
94       // The result was initialized in a previous merge to a shape of known
95       // rank, make sure we preserve that information.
96       return Status::OK();
97     }
98     if (InferenceContext::RankKnown(h1)) {
99       *result = h1;
100     } else {
101       *result = h2;
102     }
103     return Status::OK();
104   }
105 };
106 
107 template <>
108 struct Processor<DimensionHandle> {
109   // Assign a negative id to unknown dimensions, starting at -2 (the -1 id
110   // reserved by TensorFlow).
ExtractValuetensorflow::grappler::__anonb793b30d0111::Processor111   void ExtractValue(DimensionHandle d, int64* result) {
112     if (!InferenceContext::ValueKnown(d)) {
113       *result = -counter;
114       counter++;
115     } else {
116       int64 val = InferenceContext::Value(d);
117       if (val >= 0) {
118         *result = val;
119       } else {
120         // A shape inference function generated an invalid dimension handle.
121         // Use a symbolic dimension to encode this.
122         *result = -counter;
123         counter++;
124       }
125     }
126   }
127 
128   // Merge the dimensions d1 and d2. Return the known shape if there is one,
129   // otherwise look for a symbolic shape. If there is no symbolic shape and no
130   // known shape, the shape if fully unknown so return -1.
Mergetensorflow::grappler::__anonb793b30d0111::Processor131   Status Merge(DimensionHandle d1, DimensionHandle d2, int64* result) {
132     const int64 dim1 = InferenceContext::Value(d1);
133     const int64 dim2 = InferenceContext::Value(d2);
134 
135     if (dim1 >= 0 && dim2 >= 0) {
136       CHECK_EQ(dim1, dim2);
137       return RefineDim(dim1, result);
138     } else if (dim1 >= 0 && dim2 < 0) {
139       return RefineDim(dim1, result);
140     } else if (dim1 < 0 && dim2 >= 0) {
141       return RefineDim(dim2, result);
142     } else if (dim1 < -1) {
143       return RefineDim(dim1, result);
144     } else if (dim2 < -1) {
145       return RefineDim(dim2, result);
146     } else {
147       CHECK_EQ(dim1, dim2);
148       CHECK_EQ(-1, dim1);
149       return RefineDim(-1, result);
150     }
151     return Status::OK();
152   }
153 
154  private:
RefineDimtensorflow::grappler::__anonb793b30d0111::Processor155   Status RefineDim(int64 dim, int64* result) {
156     if (*result >= 0) {
157       if (!(*result == dim || dim < 0)) {
158         return errors::InvalidArgument("Inconsistent dimensions detected");
159       }
160     } else if (dim >= 0) {
161       *result = dim;
162     } else if (dim < *result) {
163       *result = dim;
164     }
165     return Status::OK();
166   }
167 
168   int64 counter = 2;
169 };
170 
171 // Traditional Disjoint-Set datastructure with path compression.
172 // (https://en.wikipedia.org/wiki/Disjoint-set_data_structure)
173 template <typename Handle>
174 class DisjointSet {
175  public:
DisjointSet()176   DisjointSet() {}
~DisjointSet()177   ~DisjointSet() {
178     for (auto rep : nodes_) {
179       delete rep.second;
180     }
181   }
182 
183   Status Merge(Handle x, Handle y);
184   const typename HandleToObject<Handle>::Object GetMergedValue(Handle value);
185 
186  private:
187   // All the handles that belong to the same set are part of the same tree, and
188   // utimately represented by the root of that tree.
189   struct Rep {
190     // Parent in the tree used to encode the set.
191     Rep* parent;
192     // Rank in the tree, used to figure out how to compress the path to the root
193     // of the tree.
194     int rank;
195     // The handle.
196     typename HandleToObject<Handle>::Object value;
197   };
198 
199   // Create a new set for the value if none exists, or return its representative
200   // node otherwise.
201   Rep* Find(Handle value);
202 
203  private:
204   Processor<Handle> processor_;
205   absl::flat_hash_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>>
206       nodes_;
207 };
208 
209 template <typename Handle>
210 const typename HandleToObject<Handle>::Object
GetMergedValue(Handle value)211 DisjointSet<Handle>::GetMergedValue(Handle value) {
212   Rep* rep = Find(value);
213   if (!rep) {
214     // We don't know anything about this handle.
215     return HandleToObject<Handle>::Unknown();
216   }
217   return rep->value;
218 }
219 
220 template <typename Handle>
Merge(Handle x,Handle y)221 Status DisjointSet<Handle>::Merge(Handle x, Handle y) {
222   Rep* x_root = Find(x);
223   Rep* y_root = Find(y);
224 
225   // x and y are already in the same set
226   if (x_root == y_root) {
227     return Status::OK();
228   }
229   // x and y are not in same set, so we merge them
230   // Use the occasion to strengthen what we know about the handle by merging the
231   // information about the 2 subsets.
232   if (x_root->rank < y_root->rank) {
233     TF_RETURN_IF_ERROR(processor_.Merge(y, x, &y_root->value));
234     x_root->parent = y_root;
235   } else if (x_root->rank > y_root->rank) {
236     TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
237     y_root->parent = x_root;
238   } else {
239     TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
240     // Arbitrarily make one root the new parent
241     y_root->parent = x_root;
242     x_root->rank = x_root->rank + 1;
243   }
244   return Status::OK();
245 }
246 
247 template <typename Handle>
Find(Handle value)248 typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
249   auto it = nodes_.find(value);
250   if (it == nodes_.end()) {
251     // This is the first time we process this handle, create an entry for it.
252     Rep* node = new Rep;
253     node->parent = node;
254     node->rank = 0;
255     processor_.ExtractValue(value, &node->value);
256     nodes_[value] = node;
257     return node;
258   }
259   // Return the representative for the set, which is the root of the tree. Apply
260   // path compression to speedup future queries.
261   Rep* node = it->second;
262   Rep* root = node->parent;
263   while (root != root->parent) {
264     root = root->parent;
265   }
266   while (node->parent != root) {
267     Rep* next = node->parent;
268     node->parent = root;
269     node = next;
270   }
271   return root;
272 }
273 
274 // TODO(dyoon): Move many helper functions in this file (including those within
275 // SymbolicShapeRefiner class) to shared utils.
IsEnqueue(const NodeDef & n)276 bool IsEnqueue(const NodeDef& n) {
277   return (n.op().find("Enqueue") != string::npos &&
278           n.op().find("EnqueueMany") == string::npos);
279 }
280 
IsDequeue(const NodeDef & n)281 bool IsDequeue(const NodeDef& n) {
282   return (n.op().find("Dequeue") != string::npos &&
283           n.op().find("DequeueMany") == string::npos);
284 }
285 
HasAnyUnknownDimensions(const TensorShapeProto & proto)286 bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
287   if (proto.unknown_rank()) {
288     return true;
289   }
290   for (const auto& dim : proto.dim()) {
291     if (dim.size() < 0) {
292       return true;
293     }
294   }
295   return false;
296 }
297 
298 // This really should be done in an external debugging tool
VerboseLogUnknownDimensionSources(const GraphDef & graph,const absl::flat_hash_map<string,std::vector<OpInfo::TensorProperties>> & input_properties_map,const absl::flat_hash_map<string,std::vector<OpInfo::TensorProperties>> & output_properties_map)299 void VerboseLogUnknownDimensionSources(
300     const GraphDef& graph,
301     const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
302         input_properties_map,
303     const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
304         output_properties_map) {
305   if (!VLOG_IS_ON(2)) {
306     return;
307   }
308 
309   VLOG(2) << "Nodes with known inputs, but with unknown output dimensions:";
310 
311   // Find all nodes in the graph for which we
312   // do not have any unknown dimensions in their inputs, but
313   // we have some unknown dimensions in their outputs.
314   std::map<string, int> op_to_count;
315   for (const NodeDef& node : graph.node()) {
316     const auto& input_properties = input_properties_map.at(node.name());
317     const auto& output_properties = output_properties_map.at(node.name());
318 
319     bool has_unknown_inputs = false;
320     for (const auto& input_prop : input_properties) {
321       if (HasAnyUnknownDimensions(input_prop.shape())) {
322         has_unknown_inputs = true;
323         break;
324       }
325     }
326 
327     if (has_unknown_inputs) {
328       continue;
329     }
330 
331     for (const auto& output_prop : output_properties) {
332       if (HasAnyUnknownDimensions(output_prop.shape())) {
333         string inputs = "input_shapes=[";
334         for (const auto& input_prop : input_properties) {
335           inputs += PartialTensorShape::DebugString(input_prop.shape());
336         }
337         inputs += "]";
338 
339         string outputs = "output_shapes=[";
340         for (const auto& output_prop : output_properties) {
341           outputs += PartialTensorShape::DebugString(output_prop.shape());
342         }
343         outputs += "]";
344 
345         VLOG(2) << "Node: " << node.name() << ", Op: " << node.op() << ", "
346                 << inputs << ", " << outputs;
347 
348         op_to_count[node.op()]++;
349 
350         // don't log again for this node
351         break;
352       }
353     }
354   }
355   VLOG(2) << "Op types with known inputs, but with unknown output dimensions "
356           << "(format: <op_type> (<count>)):";
357   for (const auto& p : op_to_count) {
358     VLOG(2) << p.first << " (" << p.second << ")";
359   }
360 }
361 
362 // Helper function to convert kUnknownDimFromConst into UnknownDim.
ReplaceUnknownDimFromConstWithUnknownDim(InferenceContext * ic,const std::vector<ShapeHandle> & shapes)363 std::vector<ShapeHandle> ReplaceUnknownDimFromConstWithUnknownDim(
364     InferenceContext* ic, const std::vector<ShapeHandle>& shapes) {
365   std::vector<ShapeHandle> converted_shapes(shapes.size());
366   for (int i = 0, shapes_size = shapes.size(); i < shapes_size; i++) {
367     const auto& shape = shapes[i];
368     if (!ic->RankKnown(shape)) {
369       converted_shapes[i] = shape;
370       continue;
371     }
372     bool just_copy = true;
373     std::vector<DimensionHandle> dims;
374     for (int32 i = 0; i < ic->Rank(shape); ++i) {
375       DimensionHandle dim = ic->Dim(shape, i);
376       if (ic->ValueKnown(dim) && ic->Value(dim) == kUnknownDimFromConst) {
377         just_copy = false;
378         dims.push_back(ic->UnknownDim());
379       } else {
380         dims.push_back(dim);
381       }
382     }
383     if (just_copy) {
384       converted_shapes[i] = shape;
385       continue;
386     }
387     converted_shapes[i] = ic->MakeShape(dims);
388   }
389   return converted_shapes;
390 }
391 
392 // Returned tensor's shape is like `shape`, and its values and dtype are from
393 // `tensor_as_shape` and `dtype`.
MakeTensorProtoFromShape(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)394 TensorProto MakeTensorProtoFromShape(InferenceContext* ic,
395                                      const ShapeHandle& shape,
396                                      const ShapeHandle& tensor_as_shape,
397                                      const DataType& dtype) {
398   TensorProto tensor_proto;
399   tensor_proto.set_dtype(dtype);
400   auto* shape_proto = tensor_proto.mutable_tensor_shape();
401   if (ic->Rank(shape) == 1) {
402     shape_proto->add_dim()->set_size(ic->Rank(tensor_as_shape));
403   }
404   // For a scalar tensor, tensor_shape field will be left empty; no dim.
405   for (int i = 0; i < ic->Rank(tensor_as_shape); i++) {
406     int64 value = ic->Value(ic->Dim(tensor_as_shape, i));
407     if (dtype == DT_INT32) {
408       tensor_proto.add_int_val(value);
409     } else {
410       tensor_proto.add_int64_val(value);
411     }
412   }
413   return tensor_proto;
414 }
415 
416 // Returns a Const NodeDef with tensor `tensor_proto` and dtype = `dtype`.
MakeConstNodeDefFromTensorProto(InferenceContext * ic,const TensorProto & tensor_proto,const DataType & dtype)417 NodeDef MakeConstNodeDefFromTensorProto(InferenceContext* ic,
418                                         const TensorProto& tensor_proto,
419                                         const DataType& dtype) {
420   NodeDef const_node;
421   const_node.set_name("const_from_shape");
422   const_node.set_op("Const");
423   auto* attr = const_node.mutable_attr();
424   (*attr)["dtype"].set_type(dtype);
425   auto* tensor = (*attr)["value"].mutable_tensor();
426   *tensor = tensor_proto;
427   return const_node;
428 }
429 
430 // Returns a Const NodeDef with shape = `shape`, values = `tensor_as_shape`,
431 // and dtype = `dtype`.
MakeConstNodeDefFromShape(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)432 NodeDef MakeConstNodeDefFromShape(InferenceContext* ic,
433                                   const ShapeHandle& shape,
434                                   const ShapeHandle& tensor_as_shape,
435                                   const DataType& dtype) {
436   return MakeConstNodeDefFromTensorProto(
437       ic, MakeTensorProtoFromShape(ic, shape, tensor_as_shape, dtype), dtype);
438 }
439 
440 }  // namespace
441 
442 // Note that tensor_as_shape input should not include kUnknownDimFromConst.
443 // This function check kUnknownDimFromConst, but will log WARNING.
444 // If checking input_tensors_as_shape_to_propgate or output_tensors_as_shape,
445 // which may include kUnknownDimFromConst, run
446 // convert it using ReplaceUnknownDimFromConstWithUnknownDim() before.
IsShapeFullyDefinedIntegerVectorOrScalar(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)447 bool IsShapeFullyDefinedIntegerVectorOrScalar(
448     InferenceContext* ic, const ShapeHandle& shape,
449     const ShapeHandle& tensor_as_shape, const DataType& dtype) {
450   if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 ||
451       !ic->FullyDefined(tensor_as_shape) ||
452       (dtype != DT_INT32 && dtype != DT_INT64)) {
453     return false;
454   }
455   // Also check whether any dim in tensor_as_shape is kUnknownDimFromConst.
456   for (int32 i = 0; i < ic->Rank(tensor_as_shape); ++i) {
457     DimensionHandle dim = ic->Dim(tensor_as_shape, i);
458     if (ic->Value(dim) == kUnknownDimFromConst) {
459       LOG(WARNING) << "IsShapeFullyDefinedIntegerVectorOrScalar(): "
460                    << "tensor_as_shape input includes kUnknownDimFromConst -- "
461                    << ic->DebugString(tensor_as_shape);
462       return false;
463     }
464   }
465   return true;
466 }
467 
468 // Queue of nodes to process. Nodes can be enqueued in any order, but will be
469 // dequeued in (roughly) topological order. Propagating shapes following a
470 // topological ordering isn't required for correctness but helps speed things up
471 // since it avoids processing the same node multiple times as its inputs
472 // information is refined.
473 class TopoQueue {
474  public:
TopoQueue(const std::vector<const NodeDef * > & topo_order)475   explicit TopoQueue(const std::vector<const NodeDef*>& topo_order)
476       : topo_order_(TopoOrder(topo_order)) {}
477 
push(const NodeDef * n)478   void push(const NodeDef* n) { queue_.emplace(n, topo_order_.at(n)); }
479 
pop()480   const NodeDef* pop() {
481     CHECK(!empty());
482     auto it = queue_.begin();
483     const NodeDef* n = it->first;
484     queue_.erase(it);
485     return n;
486   }
487 
empty() const488   bool empty() const { return queue_.empty(); }
size() const489   std::size_t size() const { return queue_.size(); }
490 
491  private:
492   using NodeAndId = std::pair<const NodeDef*, int>;
493   // Graph nodes are created in (roughly) topological order. Therefore we can
494   // use their id to ensure they're sorted topologically.
495   struct OrderByIdAscending {
operator ()tensorflow::grappler::TopoQueue::OrderByIdAscending496     bool operator()(const NodeAndId& lhs, const NodeAndId& rhs) const {
497       return lhs.second < rhs.second;
498     }
499   };
500 
TopoOrder(const std::vector<const NodeDef * > & topo_order) const501   const absl::flat_hash_map<const NodeDef*, int> TopoOrder(
502       const std::vector<const NodeDef*>& topo_order) const {
503     absl::flat_hash_map<const NodeDef*, int> map;
504     map.reserve(topo_order.size());
505     for (int i = 0, topo_order_size = topo_order.size(); i < topo_order_size;
506          ++i) {
507       map.emplace(topo_order[i], i);
508     }
509     return map;
510   }
511 
512   const absl::flat_hash_map<const NodeDef*, int> topo_order_;
513   std::set<NodeAndId, OrderByIdAscending> queue_;
514 };
515 
IsNumericType(const DataType dtype)516 bool IsNumericType(const DataType dtype) {
517   static const gtl::FlatSet<DataType>* const kRealNumberTypes =
518       CHECK_NOTNULL((new gtl::FlatSet<DataType>{
519           // Floating point.
520           DT_BFLOAT16,
521           DT_HALF,
522           DT_FLOAT,
523           DT_DOUBLE,
524           // Int / UInt.
525           DT_INT8,
526           DT_INT16,
527           DT_INT32,
528           DT_INT64,
529           DT_UINT8,
530           DT_UINT16,
531           DT_UINT32,
532           DT_UINT64,
533           // Quantized Int.
534           DT_QINT8,
535           DT_QUINT8,
536           DT_QINT16,
537           DT_QUINT16,
538           DT_QINT32,
539           // Bool.
540           DT_BOOL,
541       }));
542   return kRealNumberTypes->find(dtype) != kRealNumberTypes->end();
543 }
544 
IsAllowListedOpTypeForEvaluateNode(const string & op_type)545 bool IsAllowListedOpTypeForEvaluateNode(const string& op_type) {
546   static const gtl::FlatSet<string>* const kOpTpeAllowlist =
547       CHECK_NOTNULL((new gtl::FlatSet<string>{
548           // Unary arithmetic ops
549           "Floor",
550           "Round",
551           "Sqrt",
552           "Square",
553           "Sign",
554           // Binary arithmetic ops
555           "Add",
556           "AddV2",
557           "Div",
558           "FloorDiv",
559           "FloorMod",
560           "Greater",
561           "GreaterEqual",
562           "Less",
563           "LessEqual",
564           "LogicalAnd",
565           "LogicalNot",
566           "LogicalOr",
567           "Maximum",
568           "Minimum",
569           "Mod",
570           "Mul",
571           "NotEqual",
572           "QuantizedAdd",
573           "QuantizedMul",
574           "SquareDifference",
575           "Sub",
576           "TruncateDiv",
577           "TruncateMod",
578           "RealDiv",
579           // N-ary arithmetic ops
580           "AddN",
581           // Others
582           "StridedSlice",
583           "OnesLike",
584           "ZerosLike",
585           "Concat",
586           "ConcatV2",
587           "Split",
588           "Range",
589           "Fill",
590           "Cast",
591       }));
592   return kOpTpeAllowlist->find(op_type) != kOpTpeAllowlist->end();
593 }
594 
595 // Negative shape size of '-1' represents unknown, while negative shape sizes
596 // less than -1 represent unknown symbolic shapes (e.g. the shape of [-5, 5, -1,
597 // -5] really means [x, 5, ?, x]). Before we can output the tensors as shapes,
598 // we need to normalize them: mark all values <-1 as "unknown" (-1).
NormalizeShapeForOutput(TensorShapeProto * shape)599 static void NormalizeShapeForOutput(TensorShapeProto* shape) {
600   for (int i = 0; i < shape->dim_size(); i++) {
601     if (shape->dim(i).size() < -1) {
602       VLOG(2) << "Normalizing dimension: " << i << " from "
603               << shape->dim(i).size() << " to -1";
604       shape->mutable_dim(i)->set_size(-1);
605     }
606   }
607 }
608 
609 // Processes symbolic shapes.
610 // Each symbolic shape or dimension is represented by a handle. Unlike the TF
611 // shape refiner which creates new handles every time it processes an unknown
612 // shape/dimension, the symbolic shape refiner assigns a specific handle to each
613 // unknown shape/dimension of a given node.
614 class SymbolicShapeRefiner {
615  public:
SymbolicShapeRefiner(const GraphView & graph,const absl::flat_hash_map<string,absl::flat_hash_set<int>> & fed_ports,const bool aggressive_shape_inference)616   explicit SymbolicShapeRefiner(
617       const GraphView& graph,
618       const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports,
619       const bool aggressive_shape_inference)
620       : graph_(graph),
621         function_library_(OpRegistry::Global(), graph.graph()->library()),
622         fed_ports_(fed_ports),
623         aggressive_shape_inference_(aggressive_shape_inference) {
624     graph_def_version_ = graph.graph()->versions().producer();
625     node_to_context_.reserve(graph.graph()->node_size());
626   }
627 
graph() const628   const GraphView& graph() const { return graph_; }
629 
630   struct NodeContext {
631     const OpRegistrationData* op_data;
632     DataTypeVector input_types;
633     DataTypeVector output_types;
634     std::unique_ptr<InferenceContext> inference_context;
635     // Additional info for propagating tensor values and tensor shapes.
636     std::vector<const TensorProto*> input_tensor_protos;
637     std::vector<const TensorProto*> output_tensor_protos;
638     // This is the same to inference_context->input_tensors_as_shapes, except
639     // that some UnknownDims (-1) can be kUnknownDimFromConst.
640     std::vector<ShapeHandle> input_tensors_as_shapes_to_propagate;
641     std::vector<ShapeHandle> output_tensors_as_shapes;
642 
643     // Output shapes incompatible between annotation and shape inference.
644     bool shape_incompatible = false;
645 
646     // Similar to DebugString() in InferenceContext, but prints out
647     // kUnknownDimFromConst properly.
StringifyShapeHandletensorflow::grappler::SymbolicShapeRefiner::NodeContext648     std::string StringifyShapeHandle(ShapeHandle s) {
649       auto* ic = inference_context.get();
650       if (ic->RankKnown(s)) {
651         std::vector<std::string> vals;
652         for (int i = 0; i < ic->Rank(s); i++) {
653           DimensionHandle d = ic->Dim(s, i);
654           if (ic->ValueKnown(d) && ic->Value(d) == kUnknownDimFromConst) {
655             vals.push_back("?(Const)");
656           } else {
657             vals.push_back(ic->DebugString(d));
658           }
659         }
660         return strings::StrCat("[", absl::StrJoin(vals, ","), "]");
661       } else {
662         return "?";
663       }
664     }
665 
DebugStringtensorflow::grappler::SymbolicShapeRefiner::NodeContext666     std::string DebugString(const NodeDef& node) {
667       std::string output;
668       auto* ic = inference_context.get();
669       absl::StrAppend(
670           &output, node.name(), " [", node.op(), "] has ", ic->num_inputs(),
671           (ic->num_inputs() > 1 ? " inputs and " : " input and "),
672           ic->num_outputs(), (ic->num_outputs() > 1 ? " outputs" : " output"));
673       if (op_data->is_function_op) {
674         absl::StrAppend(&output, " (function op)");
675       }
676       absl::StrAppend(&output, ": \n");
677 
678       for (int i = 0; i < ic->num_inputs(); i++) {
679         absl::StrAppend(&output, " input [", i, "] ", node.input(i),
680                         " -- type: ", DataTypeString(input_types.at(i)),
681                         ", shape: ", ic->DebugString(ic->input(i)),
682                         ", tensor: ");
683         Tensor t1;
684         int input_tensor_protos_size = input_tensor_protos.size();
685         if (input_tensor_protos_size > i &&
686             input_tensor_protos.at(i) != nullptr &&
687             t1.FromProto(*input_tensor_protos.at(i))) {
688           absl::StrAppend(&output, t1.DebugString(), ", tensor_as_shape: ");
689         } else {
690           absl::StrAppend(&output, " null, tensor_as_shape: ");
691         }
692         int input_tensors_as_shapes_to_propagate_size =
693             input_tensors_as_shapes_to_propagate.size();
694         if (input_tensors_as_shapes_to_propagate_size > i) {
695           absl::StrAppend(
696               &output,
697               StringifyShapeHandle(input_tensors_as_shapes_to_propagate.at(i)),
698               "\n");
699         } else {
700           absl::StrAppend(&output, " null\n");
701         }
702       }
703       for (int i = 0; i < ic->num_outputs(); i++) {
704         absl::StrAppend(&output, " output [", i,
705                         "] -- type: ", DataTypeString(output_types.at(i)),
706                         ", shape: ", ic->DebugString(ic->output(i)),
707                         ", tensor: ");
708         Tensor t2;
709         int output_tensor_protos_size = output_tensor_protos.size();
710         if (output_tensor_protos_size > i &&
711             output_tensor_protos.at(i) != nullptr &&
712             t2.FromProto(*output_tensor_protos.at(i))) {
713           absl::StrAppend(&output, t2.DebugString(), ", tensor_as_shape: ");
714         } else {
715           absl::StrAppend(&output, " null, tensor_as_shape: ");
716         }
717         int output_tensors_as_shapes_size = output_tensors_as_shapes.size();
718         if (output_tensors_as_shapes_size > i) {
719           absl::StrAppend(&output,
720                           StringifyShapeHandle(output_tensors_as_shapes.at(i)),
721                           "\n");
722         } else {
723           absl::StrAppend(&output, " null\n");
724         }
725       }
726       return output;
727     }
728   };
729 
GetNodeContext(const NodeDef * node)730   NodeContext* GetNodeContext(const NodeDef* node) {
731     auto it = node_to_context_.find(node);
732     if (it == node_to_context_.end()) {
733       return nullptr;
734     }
735     return &it->second;
736   }
737 
GetContext(const NodeDef * node)738   InferenceContext* GetContext(const NodeDef* node) {
739     auto it = node_to_context_.find(node);
740     if (it == node_to_context_.end()) {
741       return nullptr;
742     }
743     return it->second.inference_context.get();
744   }
745 
746   // Forward the shapes from the function input nodes, PartitionedCalls or
747   // StatefulPartitionedCall to
748   // the argument nodes (which are Placeholder nodes), then
749   // perform shape inference on the function body.
750   //
751   // Propagate shape information of final function body node
752   // to function node `function_node`.
753   //
754   // In the event of an error, UpdateNode will simply set `function_node`'s
755   // output shape to be Unknown.
UpdateFunction(const NodeDef * function_node)756   Status UpdateFunction(const NodeDef* function_node) {
757     NameAttrList function;
758     TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*function_node, &function));
759     auto it = fun_to_grappler_function_item_.find(function.name());
760     if (it == fun_to_grappler_function_item_.end()) {
761       return errors::InvalidArgument(
762           function.name(),
763           " was not previously added to SymbolicShapeRefiner.");
764     }
765 
766     const absl::optional<GrapplerFunctionItem>& maybe_grappler_function_item =
767         it->second;
768     if (!maybe_grappler_function_item.has_value()) {
769       VLOG(3) << "Skip failed to instantiate function call: function_name="
770               << function.name();
771 
772       auto* ctx = GetNodeContext(function_node);
773       auto* ic = ctx->inference_context.get();
774       for (int i = 0; i < ic->num_outputs(); ++i) {
775         TF_RETURN_IF_ERROR(SetUnknownShape(function_node, i));
776       }
777 
778       return Status::OK();
779     }
780 
781     // Copy (not reference) so that changes we make here (e.g., replacing
782     // _Arg with Const and _Retval with Identity) don't affect one in
783     // fun_to_grappler_function_item_.
784     GrapplerFunctionItem grappler_function_item = *maybe_grappler_function_item;
785     MutableGraphView gv(&grappler_function_item.graph);
786 
787     // Forward shapes from function input nodes to argument nodes.
788     for (int i = 0, end = grappler_function_item.inputs().size(); i < end;
789          ++i) {
790       auto& fun_input = grappler_function_item.input(i);
791       NodeDef* fun_node = gv.GetNode(fun_input.node_name);
792       const TensorId input_tensor = ParseTensorName(function_node->input(i));
793 
794       if (IsControlInput(input_tensor)) {
795         return errors::FailedPrecondition(
796             "Function inputs should not contain control nodes.");
797       }
798 
799       const NodeDef* input_node = graph_.GetNode(input_tensor.node());
800       if (input_node == nullptr) {
801         return errors::FailedPrecondition(input_tensor.node(),
802                                           " was not found in the graph.");
803       }
804 
805       InferenceContext* input_ic = GetContext(input_node);
806       if (input_ic == nullptr) {
807         return errors::FailedPrecondition(
808             "Inference context has not been created for ", input_tensor.node());
809       }
810 
811       int output_port_num = input_tensor.index();
812       AttrValue attr_output_shape;
813       TensorShapeProto proto;
814       const auto handle = input_ic->output(output_port_num);
815       input_ic->ShapeHandleToProto(handle, &proto);
816       // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
817       NormalizeShapeForOutput(&proto);
818       // _Arg op's output shape uses _output_shapes attr.
819       AttrValue output_attr;
820       output_attr.mutable_list()->add_shape()->Swap(&proto);
821       (*fun_node->mutable_attr())["_output_shapes"] = output_attr;
822 
823       // If dtype is DT_RESOURCE, ops that read _Arg op use _handle_dtypes and
824       // _handle_shapes attr for its shapes and dtypes.
825       if (fun_input.data_type == DT_RESOURCE) {
826         auto* shapes_and_types =
827             input_ic->output_handle_shapes_and_types(output_port_num);
828         if (shapes_and_types != nullptr && !shapes_and_types->empty()) {
829           AttrValue dtype_attr;
830           AttrValue shape_attr;
831           for (const auto& shape_and_type : *shapes_and_types) {
832             const auto& dtype = shape_and_type.dtype;
833             const auto& shape_handle = shape_and_type.shape;
834             dtype_attr.mutable_list()->add_type(dtype);
835             input_ic->ShapeHandleToProto(
836                 shape_handle, shape_attr.mutable_list()->add_shape());
837           }
838           (*fun_node->mutable_attr())["_handle_dtypes"] = dtype_attr;
839           (*fun_node->mutable_attr())["_handle_shapes"] = shape_attr;
840         } else {
841           // Note that we do not return error here, even if the input node does
842           // not have shapes_and_types. Within the function, we cannot infer the
843           // output shape of the DT_RESOURCE input; hence, potentially unknown
844           // shapes/dims in the function output shapes.
845           VLOG(2)
846               << "A function node (" << function_node->name()
847               << ") has input with DT_RESOURCE, but the input node does not "
848               << "have shapes_and_types information: \n"
849               << "function_node: " << function_node->ShortDebugString() << "\n"
850               << "function input: " << i
851               << ", input node's output: " << output_port_num << "\n"
852               << "input node: " << input_node->ShortDebugString();
853         }
854       }
855     }
856 
857     // ReplaceInputWithConst() may break GraphView's internal node mapping
858     // structure; hence, we separately build node name to NodeDef* map, for the
859     // output nodes (before GraphView becomes invalid). Note that we use string,
860     // not string_view.
861     absl::flat_hash_map<std::string, NodeDef*> output_nodes;
862     for (const auto& output_arg : grappler_function_item.outputs()) {
863       output_nodes[output_arg.node_name] = gv.GetNode(output_arg.node_name);
864     }
865 
866     // Replace input nodes with Consts, if values are known. Note that
867     // we don't check exceptions here as it's done in the above loop.
868     auto* ctx = GetNodeContext(function_node);
869     auto* ic = ctx->inference_context.get();
870     for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
871       const string& input = function_node->input(i);
872       const string node_name = NodeName(input);
873       const NodeDef* input_node = graph_.GetNode(node_name);
874       if (IsConstant(*input_node)) {
875         TF_CHECK_OK(
876             ReplaceInputWithConst(*input_node, i, &grappler_function_item));
877       } else if (static_cast<int>(ctx->input_tensor_protos.size()) > i &&
878                  ctx->input_tensor_protos[i] != nullptr) {
879         NodeDef const_input_node = MakeConstNodeDefFromTensorProto(
880             ic, *ctx->input_tensor_protos[i], ctx->input_types[i]);
881         TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
882                                           &grappler_function_item));
883       } else if (static_cast<int>(ic->input_tensors_as_shapes().size()) > i &&
884                  IsShapeFullyDefinedIntegerVectorOrScalar(
885                      ic, ic->input(i), ic->input_tensors_as_shapes()[i],
886                      ctx->input_types[i])) {
887         // We have fully defined input_tensors_as_shapes for this input; use it
888         // as a const input to the function node.
889         NodeDef const_input_node = MakeConstNodeDefFromShape(
890             ic, ic->input(i), ic->input_tensors_as_shapes()[i],
891             ctx->input_types[i]);
892         TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
893                                           &grappler_function_item));
894       }
895     }
896     // node_name to NodeDef* map in GraphView gv can be broken due to
897     // ReplaceInputWithConst(). gv should not be used after this.
898 
899     // Replace output _Retval nodes with Identity nodes. _Retval is a system op
900     // without outputs and registered shape function.
901     for (const auto& output_arg : grappler_function_item.outputs()) {
902       NodeDef* output_node = output_nodes[output_arg.node_name];
903       DCHECK_EQ(output_node->op(), "_Retval");
904       output_node->set_op("Identity");
905       output_node->mutable_attr()->erase("index");
906     }
907 
908     // Perform inference on function body.
909     GraphProperties gp(grappler_function_item);
910     TF_RETURN_IF_ERROR(gp.InferStatically(
911         /*assume_valid_feeds=*/true,
912         /*aggressive_shape_inference=*/aggressive_shape_inference_,
913         /*include_tensor_values=*/true));
914 
915     // Add return nodes for output shapes.
916     int output = 0;
917     ctx->output_tensors_as_shapes.resize(grappler_function_item.output_size());
918     ctx->output_tensor_protos.resize(grappler_function_item.output_size(),
919                                      nullptr);
920     for (auto const& out_arg : grappler_function_item.outputs()) {
921       // It is guaranteed that output_tensors does not contain any control
922       // inputs, so port_id >= 0.
923       TensorId out_tensor = ParseTensorName(out_arg.node_name);
924 
925       if (output_nodes.count(out_tensor.node()) <= 0) {
926         return errors::FailedPrecondition(
927             "Unable to find return function_node ", out_tensor.node(), " for ",
928             function_node->name());
929       }
930       const NodeDef* retnode = output_nodes[out_tensor.node()];
931 
932       auto output_properties = gp.GetOutputProperties(retnode->name());
933       int output_properties_size = output_properties.size();
934       if (out_tensor.index() >= output_properties_size) {
935         return errors::InvalidArgument(
936             out_tensor.ToString(), " has invalid position ", out_tensor.index(),
937             " (output_properties.size() = ", output_properties.size(), ").");
938       }
939       auto& outprop = output_properties[out_tensor.index()];
940       TensorShapeProto shape = outprop.shape();
941       NormalizeShapeForOutput(&shape);
942       ShapeHandle out;
943       TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
944       ic->set_output(output, out);
945       if (outprop.has_value()) {
946         // Forward tensor value to output_tensors_as_shape.
947         MaybeTensorProtoToShape(ic, outprop.value(),
948                                 &ctx->output_tensors_as_shapes[output]);
949         const_tensors_to_propagate_.push_back(outprop.value());
950         ctx->output_tensor_protos[output] = &const_tensors_to_propagate_.back();
951       }
952       output++;
953     }
954 
955     return Status::OK();
956   }
957 
958   // Prepares input shapes/values/handles, then runs shape inference, and
959   // finally sets output shapes/values/handles.
UpdateNode(const NodeDef * node,bool * refined)960   Status UpdateNode(const NodeDef* node, bool* refined) {
961     NodeContext* ctx = GetNodeContext(node);
962     if (ctx == nullptr) {
963       TF_RETURN_IF_ERROR(AddNode(node));
964       ctx = CHECK_NOTNULL(GetNodeContext(node));
965       *refined = true;
966     }
967 
968     // Check if the shapes of the nodes in the fan-in of this node have changed,
969     // and if they have, update the node input shapes.
970     InferenceContext* ic = ctx->inference_context.get();
971     ctx->input_tensors_as_shapes_to_propagate.resize(ic->num_inputs());
972     ctx->input_tensor_protos.resize(ic->num_inputs(), nullptr);
973 
974     for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) {
975       const GraphView::InputPort port(node, dst_input);
976       const GraphView::OutputPort fanin = graph_.GetRegularFanin(port);
977       int src_output = fanin.port_id;
978       const NodeDef* src = fanin.node;
979       NodeContext* src_ctx = GetNodeContext(src);
980       if (src_ctx == nullptr) {
981         return errors::FailedPrecondition(
982             "Input ", dst_input, " for '", node->name(),
983             "' was not previously added to SymbolicShapeRefiner.");
984       }
985 
986       InferenceContext* src_ic = src_ctx->inference_context.get();
987       if (src_output >= src_ic->num_outputs()) {
988         return errors::OutOfRange("src_output = ", src_output,
989                                   ", but num_outputs is only ",
990                                   src_ic->num_outputs());
991       }
992 
993       // Propagate input node's NodeContext info to the current node's
994       // NodeContext:
995       // output_tensor_protos to input_tensor_protos and input_tensors, and
996       // output_tensors_as_shapes to input_tensors_as_shapes.
997       if (static_cast<int>(src_ctx->output_tensors_as_shapes.size()) >
998           src_output) {
999         ctx->input_tensors_as_shapes_to_propagate[dst_input] =
1000             src_ctx->output_tensors_as_shapes[src_output];
1001       }
1002 
1003       if (static_cast<int>(src_ctx->output_tensor_protos.size()) > src_output) {
1004         const auto* tensor_proto = src_ctx->output_tensor_protos[src_output];
1005         if (tensor_proto != nullptr) {
1006           ctx->input_tensor_protos[dst_input] = tensor_proto;
1007         }
1008       }
1009 
1010       // NOTE: we check only shape is refined; we do not (yet) check whether
1011       // tensor value is refined.
1012       if (!*refined &&
1013           !ic->input(dst_input).SameHandle(src_ic->output(src_output))) {
1014         *refined = true;
1015       }
1016       ic->SetInput(dst_input, src_ic->output(src_output));
1017 
1018       if (!*refined && ic->requested_input_tensor_as_partial_shape(dst_input)) {
1019         // The input value may have changed. Since we have no way to know if
1020         // that's indeed the case, err on the safe side.
1021         *refined = true;
1022       }
1023 
1024       // Also propagate handle shape and dtype of edges which are carrying
1025       // resource handles.
1026       if (ctx->input_types[dst_input] == DT_RESOURCE) {
1027         auto* outputs = src_ic->output_handle_shapes_and_types(src_output);
1028         if (!outputs) continue;
1029         auto* inputs = ic->input_handle_shapes_and_types(dst_input);
1030 
1031         if (!inputs || !EquivalentShapesAndTypes(*outputs, *inputs))
1032           *refined = true;
1033         ic->set_input_handle_shapes_and_types(dst_input, *outputs);
1034       }
1035     }
1036 
1037     // Make sure we schedule the fanout of resources (which have no input)
1038     // whenever the resources are updated.
1039     *refined |= ic->num_inputs() == 0;
1040 
1041     if (!*refined) {
1042       // No input shape has changed, we're done.
1043       return Status::OK();
1044     }
1045 
1046     // Convert all kUnknownDimFromConst to -1 for shape inference.
1047     ic->set_input_tensors_as_shapes(ReplaceUnknownDimFromConstWithUnknownDim(
1048         ic, ctx->input_tensors_as_shapes_to_propagate));
1049     // Note: UpdateFunction uses input_tensors_as_shapes and
1050     // input_tensor_protos (not the Tensor object) for input values.
1051     // so for function nodes, we don't need to convert TensorProtos
1052     // to Tensors here. If the current op is not a function op, we convert
1053     // TensorProtos to Tensors before calling InferShapes.
1054 
1055     // Properly handle function nodes.
1056     if (ctx->op_data && ctx->op_data->is_function_op) {
1057       // TODO(jmdecker): Detect if the input shapes have changed for this
1058       // function. Note that when we hit a function call node, refined will be
1059       // true, as the updates to the call node will have changed, even if it's
1060       // the same function being called twice with the same input shapes.
1061       // Example: simple_function.pbtxt
1062       if (aggressive_shape_inference_) {
1063         // If output shapes are annotated, use it and skip UpdateFunction();
1064         // it can be very expensive when a function node has nested function
1065         // nodes internally. One downside with this approach is that we do not
1066         // get output values or output shapes as tensor from function node.
1067         auto s = UpdateOutputShapesUsingAnnotatedInformation(*node, ctx);
1068         if (s.ok() && AllOutputShapesKnown(ctx)) {
1069           return Status::OK();
1070         }
1071         // If shape annotation was not available, incomplete, or incompatible,
1072         // fall through to call UpdateFunction().
1073       }
1074       auto s = UpdateFunction(node);
1075       if (s.ok()) {
1076         return Status::OK();
1077       } else {
1078         VLOG(1) << "UpdateFunction failed for " << node->op()
1079                 << ". Defaulting to ShapeUnknown.\n"
1080                 << s.ToString();
1081       }
1082     }
1083 
1084     //  Construct Tensors for constant inputs used by shape functions.
1085     std::vector<Tensor> const_values(ic->num_inputs());
1086     std::vector<const Tensor*> input_tensors(ic->num_inputs(), nullptr);
1087     for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) {
1088       const TensorProto* tensor_proto = ctx->input_tensor_protos[dst_input];
1089       if (tensor_proto != nullptr &&
1090           const_values[dst_input].FromProto(*tensor_proto)) {
1091         input_tensors[dst_input] = &const_values[dst_input];
1092       }
1093     }
1094     ic->set_input_tensors(input_tensors);
1095 
1096     // Update the shapes of the outputs.
1097     return InferShapes(*node, ctx);
1098   }
1099 
SetUnknownShape(const NodeDef * node,int output_port)1100   Status SetUnknownShape(const NodeDef* node, int output_port) {
1101     shape_inference::ShapeHandle shape =
1102         GetUnknownOutputShape(node, output_port);
1103     InferenceContext* ctx = GetContext(node);
1104     if (ctx == nullptr) {
1105       return errors::InvalidArgument("Missing context");
1106     }
1107     ctx->set_output(output_port, shape);
1108     return Status::OK();
1109   }
1110 
1111   struct ShapeId {
1112     const NodeDef* node;
1113     int port_id;
operator ==tensorflow::grappler::SymbolicShapeRefiner::ShapeId1114     bool operator==(const ShapeId& other) const {
1115       return node == other.node && port_id == other.port_id;
1116     }
1117   };
1118   struct HashShapeId {
operator ()tensorflow::grappler::SymbolicShapeRefiner::HashShapeId1119     std::size_t operator()(const ShapeId& shp) const {
1120       return std::hash<const NodeDef*>{}(shp.node) + shp.port_id;
1121     }
1122   };
1123 
1124   struct DimId {
1125     const NodeDef* node;
1126     int port_id;
1127     int dim_index;
operator ==tensorflow::grappler::SymbolicShapeRefiner::DimId1128     bool operator==(const DimId& other) const {
1129       return node == other.node && port_id == other.port_id &&
1130              dim_index == other.dim_index;
1131     }
1132   };
1133 
1134   struct HashDimId {
operator ()tensorflow::grappler::SymbolicShapeRefiner::HashDimId1135     std::size_t operator()(const DimId& dim) const {
1136       return std::hash<const NodeDef*>{}(dim.node) + dim.port_id +
1137              dim.dim_index;
1138     }
1139   };
1140 
1141   // 'port_index' as the union of shape1 and shape2.
OutputAsUnion(const NodeDef * node,int port_index,ShapeHandle shape1,ShapeHandle shape2)1142   ShapeHandle OutputAsUnion(const NodeDef* node, int port_index,
1143                             ShapeHandle shape1, ShapeHandle shape2) {
1144     if (shape1.SameHandle(shape2)) {
1145       return shape1;
1146     }
1147     InferenceContext* ctx = GetContext(node);
1148     ShapeHandle relaxed = shape1;
1149     const int rank = ctx->Rank(shape1);
1150     if (!ctx->RankKnown(shape2) || ctx->Rank(shape2) != rank) {
1151       relaxed = GetUnknownOutputShape(node, port_index);
1152     } else {
1153       for (int d = 0; d < rank; ++d) {
1154         if (!ctx->Dim(shape1, d).SameHandle(ctx->Dim(shape2, d))) {
1155           int64 val1 = ctx->Value(ctx->Dim(shape1, d));
1156           int64 val2 = ctx->Value(ctx->Dim(shape2, d));
1157           if (val1 != val2 || (val1 < 0 && val2 < 0)) {
1158             DimensionHandle new_dim = GetUnknownOutputDim(node, port_index, d);
1159             TF_CHECK_OK(ctx->ReplaceDim(relaxed, d, new_dim, &relaxed));
1160           }
1161         }
1162       }
1163     }
1164     return relaxed;
1165   }
1166 
EquivalentShapes(ShapeHandle s1,ShapeHandle s2) const1167   bool EquivalentShapes(ShapeHandle s1, ShapeHandle s2) const {
1168     if (s1.SameHandle(s2)) {
1169       return true;
1170     }
1171     if (InferenceContext::Rank(s1) != InferenceContext::Rank(s2)) {
1172       return false;
1173     }
1174     if (!InferenceContext::RankKnown(s1) && !InferenceContext::RankKnown(s2)) {
1175       return true;
1176     }
1177     const int rank = InferenceContext::Rank(s1);
1178     for (int i = 0; i < rank; ++i) {
1179       if (!InferenceContext::DimKnownRank(s1, i).SameHandle(
1180               InferenceContext::DimKnownRank(s2, i))) {
1181         int64 val1 =
1182             InferenceContext::Value(InferenceContext::DimKnownRank(s1, i));
1183         int64 val2 =
1184             InferenceContext::Value(InferenceContext::DimKnownRank(s2, i));
1185         if (val1 >= 0 && val2 >= 0 && val1 == val2) {
1186           continue;
1187         }
1188         return false;
1189       }
1190     }
1191     return true;
1192   }
1193 
1194   // Return true if the annotated shape is compatible with shape inference
1195   // result. Examples:
1196   // Inferred shape: ?, annotated shape: [10, 10] -> true;
1197   // Inferred shape: [-1, 10], annotated shape: [10, 10] -> true;
1198   // Inferred shape: [-1, 100], annotated shape: [10, 10] -> false;
1199   // Inferred shape: [-1, 10, 10], annotated shape: [10, 10] -> false.
CompatibleShapes(ShapeHandle inferred_shape,ShapeHandle annotated_shape) const1200   bool CompatibleShapes(ShapeHandle inferred_shape,
1201                         ShapeHandle annotated_shape) const {
1202     if (inferred_shape.SameHandle(annotated_shape)) {
1203       return true;
1204     }
1205     if (!InferenceContext::RankKnown(inferred_shape)) {
1206       return true;
1207     }
1208     if (InferenceContext::Rank(inferred_shape) !=
1209         InferenceContext::Rank(annotated_shape)) {
1210       return false;
1211     }
1212     const int rank = InferenceContext::Rank(inferred_shape);
1213     for (int i = 0; i < rank; ++i) {
1214       if (!InferenceContext::DimKnownRank(inferred_shape, i)
1215                .SameHandle(
1216                    InferenceContext::DimKnownRank(annotated_shape, i))) {
1217         int64 val1 = InferenceContext::Value(
1218             InferenceContext::DimKnownRank(inferred_shape, i));
1219         int64 val2 = InferenceContext::Value(
1220             InferenceContext::DimKnownRank(annotated_shape, i));
1221         if (val1 >= 0 && val1 != val2) {
1222           return false;
1223         }
1224       }
1225     }
1226     return true;
1227   }
1228 
SameShapes(ShapeHandle inferred_shape,ShapeHandle annotated_shape) const1229   bool SameShapes(ShapeHandle inferred_shape,
1230                   ShapeHandle annotated_shape) const {
1231     if (inferred_shape.SameHandle(annotated_shape)) {
1232       return true;
1233     }
1234     if (InferenceContext::Rank(inferred_shape) !=
1235         InferenceContext::Rank(annotated_shape)) {
1236       return false;
1237     }
1238     const int rank = InferenceContext::Rank(inferred_shape);
1239     for (int i = 0; i < rank; ++i) {
1240       int64 val1 = InferenceContext::Value(
1241           InferenceContext::DimKnownRank(inferred_shape, i));
1242       int64 val2 = InferenceContext::Value(
1243           InferenceContext::DimKnownRank(annotated_shape, i));
1244       if (val1 != val2) {
1245         return false;
1246       }
1247     }
1248     return true;
1249   }
1250 
EquivalentShapesAndTypes(const std::vector<ShapeAndType> & st1,const std::vector<ShapeAndType> & st2) const1251   bool EquivalentShapesAndTypes(const std::vector<ShapeAndType>& st1,
1252                                 const std::vector<ShapeAndType>& st2) const {
1253     if (st1.size() != st2.size()) {
1254       return false;
1255     }
1256     for (int i = 0, st1_size = st1.size(); i < st1_size; ++i) {
1257       const ShapeAndType& s1 = st1[i];
1258       const ShapeAndType& s2 = st2[i];
1259       if (s1.dtype != s2.dtype) {
1260         return false;
1261       }
1262       if (!EquivalentShapes(s1.shape, s2.shape)) {
1263         return false;
1264       }
1265     }
1266     return true;
1267   }
1268 
AddFunction(const NodeDef * function_node,NameAttrList function)1269   Status AddFunction(const NodeDef* function_node, NameAttrList function) {
1270     auto it = fun_to_grappler_function_item_.find(function.name());
1271     if (it != fun_to_grappler_function_item_.end()) {
1272       return Status::OK();
1273     }
1274 
1275     const FunctionDef* function_def =
1276         CHECK_NOTNULL(function_library_.Find(function.name()));
1277     GrapplerFunctionItem grappler_function_item;
1278     Status function_instantiated =
1279         MakeGrapplerFunctionItem(*function_def, function_library_,
1280                                  graph_def_version_, &grappler_function_item);
1281 
1282     // If function instantiation failed we will skip it during shape inference.
1283     if (!function_instantiated.ok()) {
1284       VLOG(3) << "Failed to instantiate a function. Error: "
1285               << function_instantiated.error_message();
1286       fun_to_grappler_function_item_[function_def->signature().name()] =
1287           absl::nullopt;
1288       return Status::OK();
1289     }
1290 
1291     if (static_cast<int>(grappler_function_item.inputs().size()) >
1292         function_node->input_size()) {
1293       return errors::FailedPrecondition(
1294           "Function input size should be smaller than node input size.");
1295     }
1296 
1297     for (int i = grappler_function_item.inputs().size(),
1298              end = function_node->input_size();
1299          i < end; ++i) {
1300       const string& input = function_node->input(i);
1301       if (!IsControlInput(input)) {
1302         return errors::FailedPrecondition(
1303             "Found regular input (", input,
1304             ") instead of control nodes for node ", function_node->name());
1305       }
1306     }
1307 
1308     fun_to_grappler_function_item_[function_def->signature().name()] =
1309         grappler_function_item;
1310 
1311     return Status::OK();
1312   }
1313 
AddNode(const NodeDef * node)1314   Status AddNode(const NodeDef* node) {
1315     NodeContext& node_ctx = node_to_context_[node];
1316     NameAttrList function;
1317     TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*node, &function));
1318 
1319     // For PartitionedCall, op_data represents the function info.
1320     TF_RETURN_IF_ERROR(
1321         function_library_.LookUp(function.name(), &node_ctx.op_data));
1322 
1323     if (node_ctx.op_data->is_function_op) {
1324       TF_RETURN_IF_ERROR(AddFunction(node, function));
1325     }
1326 
1327     TF_RETURN_IF_ERROR(InOutTypesForNode(*node, node_ctx.op_data->op_def,
1328                                          &node_ctx.input_types,
1329                                          &node_ctx.output_types));
1330 
1331     // Create the inference context for this node.
1332     const int num_inputs = node_ctx.input_types.size();
1333     std::vector<ShapeHandle> input_shapes(num_inputs);
1334     std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
1335         input_handle_shapes_and_types(num_inputs);
1336     std::vector<const Tensor*> input_tensors(num_inputs, nullptr);
1337     std::vector<ShapeHandle> input_tensors_as_shapes;
1338 
1339     node_ctx.inference_context.reset(new InferenceContext(
1340         graph_def_version_, *node, node_ctx.op_data->op_def, input_shapes,
1341         input_tensors, input_tensors_as_shapes,
1342         std::move(input_handle_shapes_and_types)));
1343     const Status s = node_ctx.inference_context->construction_status();
1344     if (!s.ok()) {
1345       node_ctx.inference_context.reset(nullptr);
1346     }
1347     return s;
1348   }
1349 
1350  private:
1351   // Return the one ShapeHandle used to denote a fully unknown shape for a node
1352   // output.
GetUnknownOutputShape(const NodeDef * node,int index)1353   ShapeHandle GetUnknownOutputShape(const NodeDef* node, int index) {
1354     ShapeId id{node, index};
1355     auto it = unknown_shapes_.find(id);
1356     if (it != unknown_shapes_.end()) {
1357       return it->second;
1358     }
1359     InferenceContext* c = GetContext(node);
1360     ShapeHandle shp = c->UnknownShape();
1361     unknown_shapes_[id] = shp;
1362     return shp;
1363   }
1364   // Return the one ShapeHandle used to denote a fully unknown dimension for a
1365   // node output.
GetUnknownOutputDim(const NodeDef * node,int index,int dim_id)1366   DimensionHandle GetUnknownOutputDim(const NodeDef* node, int index,
1367                                       int dim_id) {
1368     DimId id{node, index, dim_id};
1369     auto it = unknown_dims_.find(id);
1370     if (it != unknown_dims_.end()) {
1371       return it->second;
1372     }
1373     InferenceContext* c = GetContext(node);
1374     DimensionHandle dim = c->UnknownDim();
1375     unknown_dims_[id] = dim;
1376     return dim;
1377   }
1378 
1379   // Returns true if all the output tensors have known values.
AllOutputValuesKnown(NodeContext * c)1380   bool AllOutputValuesKnown(NodeContext* c) {
1381     InferenceContext* ic = c->inference_context.get();
1382     int c_output_tensors_as_shapes_size = c->output_tensors_as_shapes.size();
1383     int c_output_tensor_protos_size = c->output_tensor_protos.size();
1384     if (c_output_tensors_as_shapes_size < ic->num_outputs() &&
1385         c_output_tensor_protos_size < ic->num_outputs()) {
1386       return false;
1387     } else {
1388       // Checks if we can get output value via either output_tensor_proto or
1389       // output_tensors_as_shapes.
1390       for (int i = 0; i < ic->num_outputs(); i++) {
1391         if (c_output_tensor_protos_size > i &&
1392             c->output_tensor_protos[i] != nullptr) {
1393           continue;
1394         }
1395         if (c_output_tensors_as_shapes_size > i &&
1396             ic->FullyDefined(c->output_tensors_as_shapes[i])) {
1397           bool no_unknown_dim_from_const = true;
1398           for (int32 j = 0; j < ic->Rank(c->output_tensors_as_shapes[i]); ++j) {
1399             const auto dim = ic->Dim(c->output_tensors_as_shapes[i], j);
1400             if (ic->ValueKnown(dim) && ic->Value(dim) == kUnknownDimFromConst) {
1401               no_unknown_dim_from_const = false;
1402               break;
1403             }
1404           }
1405           if (no_unknown_dim_from_const) {
1406             continue;
1407           }
1408         }
1409         return false;
1410       }
1411     }
1412     return true;
1413   }
1414 
1415   // Returns true if all the output shapes are known.
AllOutputShapesKnown(NodeContext * c)1416   bool AllOutputShapesKnown(NodeContext* c) {
1417     InferenceContext* ic = c->inference_context.get();
1418     // Checks if all the output shapes are fully defined.
1419     for (int i = 0; i < ic->num_outputs(); i++) {
1420       if (!ic->FullyDefined(ic->output(i))) {
1421         return false;
1422       }
1423     }
1424     return true;
1425   }
1426 
1427   // Returns true if we can infer output tensors' values -- we know values of
1428   // all the input tensors.
AllInputValuesKnown(NodeContext * c)1429   bool AllInputValuesKnown(NodeContext* c) {
1430     InferenceContext* ic = c->inference_context.get();
1431 
1432     // Check inputs are fully defined and values are known.
1433     for (int i = 0; i < ic->num_inputs(); i++) {
1434       const Tensor* tensor = ic->input_tensor(i);
1435       // Note that we don't check c->input_tensor_protos[i], as UpdateNode()
1436       // already converted it to ic->input_tensor(i);
1437       const ShapeHandle& input_tensors_as_shape =
1438           ic->input_tensors_as_shapes()[i];
1439       // Either input_tensor is valid or input_tensors_as_shape, which has
1440       // value of input tensors as shape format, should be fully defined.
1441       if (tensor == nullptr && !ic->FullyDefined(input_tensors_as_shape)) {
1442         return false;
1443       }
1444     }
1445     return true;
1446   }
1447 
1448   // Returns true if we want to update output shapes and values with running
1449   // EvaluateNode() for this op, based on op type, data type, and size.
ShouldUpdateOutputShapesAndValues(NodeContext * c,int64 max_size)1450   bool ShouldUpdateOutputShapesAndValues(NodeContext* c, int64 max_size) {
1451     InferenceContext* ic = c->inference_context.get();
1452 
1453     // Due to the cost of running EvaluateNode(), we limit only to white listed
1454     // op types.
1455     if (!IsAllowListedOpTypeForEvaluateNode(c->op_data->op_def.name())) {
1456       return false;
1457     }
1458 
1459     // Check input dtypes are number types.
1460     for (const auto& input_type : c->input_types) {
1461       if (!IsNumericType(input_type)) {
1462         return false;
1463       }
1464     }
1465 
1466     // Check output dtypes are number types.
1467     for (const auto& output_type : c->output_types) {
1468       if (!IsNumericType(output_type)) {
1469         return false;
1470       }
1471     }
1472 
1473     // Check if the number of elements of each of input tensor is no larger than
1474     // the given max size.
1475     for (int i = 0; i < ic->num_inputs(); i++) {
1476       const Tensor* tensor = ic->input_tensor(i);
1477       const ShapeHandle& input_shape_handle = ic->input(i);
1478       if (tensor != nullptr) {
1479         if (tensor->NumElements() > max_size) {
1480           return false;
1481         }
1482       } else if (ic->Value(ic->NumElements(input_shape_handle)) > max_size) {
1483         return false;
1484       }
1485     }
1486 
1487     // Check if we know the shape of each output tensor, and the number of
1488     // elements is larger than the given max size.
1489     for (int i = 0; i < ic->num_outputs(); i++) {
1490       const ShapeHandle& shape_handle = ic->output(i);
1491       if (!ic->FullyDefined(shape_handle) ||
1492           ic->Value(ic->NumElements(shape_handle)) > max_size) {
1493         return false;
1494       }
1495     }
1496     return true;
1497   }
1498 
1499   // Create input tensors from the NodeContext.
CreateInputTensors(NodeContext * c,std::vector<Tensor> * input_tensor_vector,TensorVector * inputs)1500   void CreateInputTensors(NodeContext* c,
1501                           std::vector<Tensor>* input_tensor_vector,
1502                           TensorVector* inputs) {
1503     InferenceContext* ic = c->inference_context.get();
1504     for (int i = 0; i < ic->num_inputs(); i++) {
1505       if (ic->input_tensor(i)) {
1506         input_tensor_vector->at(i) = *ic->input_tensor(i);
1507         inputs->emplace_back(&input_tensor_vector->at(i));
1508         // Note that we don't check c->input_tensor_protos[i], as UpdateNode()
1509         // already converted it to ic->input_tensor(i);
1510       } else {
1511         // Create Tensor from input_tensors_as_shapes, and then emplace it
1512         // back to inputs.
1513         // Note that input_tensors_as_shapes is scalar or vector.
1514         const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
1515         const DataType& data_type = c->input_types[i];
1516         int32 rank = ic->Rank(shape_handle);
1517         if (rank < 1) {
1518           input_tensor_vector->at(i) = Tensor(data_type, {});
1519         } else {
1520           input_tensor_vector->at(i) = Tensor(data_type, {rank});
1521         }
1522         auto* tensor = &input_tensor_vector->at(i);
1523         if (data_type == DT_INT32) {
1524           auto flat = tensor->flat<int32>();
1525           for (int j = 0; j < rank; j++) {
1526             int32 dim = ic->Value(ic->Dim(shape_handle, j));
1527             flat(j) = dim;
1528           }
1529         } else {
1530           auto flat = tensor->flat<int64>();
1531           for (int j = 0; j < rank; j++) {
1532             int64 dim = ic->Value(ic->Dim(shape_handle, j));
1533             flat(j) = dim;
1534           }
1535         }
1536         inputs->emplace_back(tensor);
1537       }
1538     }
1539   }
1540 
1541   // Run a node to infer output shapes and values, and add it to the
1542   // NodeContext.
UpdateOutputShapesAndValues(const NodeDef & node,NodeContext * c)1543   Status UpdateOutputShapesAndValues(const NodeDef& node, NodeContext* c) {
1544     InferenceContext* ic = c->inference_context.get();
1545 
1546     // Input to EvaluateNode()
1547     TensorVector inputs;
1548     // Container for temporarily created tensor object.
1549     std::vector<Tensor> input_tensor_vector(ic->num_inputs());
1550     CreateInputTensors(c, &input_tensor_vector, &inputs);
1551 
1552     // Output for EvaluateNode() and output tensor clean up object.
1553     TensorVector outputs;
1554     auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
1555       for (const auto& output : outputs) {
1556         if (output.tensor) {
1557           delete output.tensor;
1558         }
1559       }
1560     });
1561 
1562     TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, /*cpu_device=*/nullptr,
1563                                     &resource_mgr_, &outputs));
1564     c->output_tensors_as_shapes.resize(outputs.size());
1565     c->output_tensor_protos.resize(outputs.size(), nullptr);
1566     for (int k = 0, outputs_size = outputs.size(); k < outputs_size; k++) {
1567       const auto& t = outputs[k];
1568       // Override output shape.
1569       ShapeHandle output_shape;
1570       TF_RETURN_IF_ERROR(
1571           ic->MakeShapeFromTensorShape(t->shape(), &output_shape));
1572       if (ic->FullyDefined(ic->output(k)) &&
1573           !EquivalentShapes(ic->output(k), output_shape)) {
1574         LOG(WARNING) << "UpdateOutputShapesAndValues() -- node: " << node.name()
1575                      << ", inferred output shape "
1576                      << "doesn't match for k=" << k << ": "
1577                      << "ic->output(k): " << ic->DebugString(ic->output(k))
1578                      << ", output_shape: " << ic->DebugString(output_shape)
1579                      << " -- " << node.DebugString();
1580       }
1581       ic->set_output(k, output_shape);
1582       // Set output_tensors_as_shape.
1583       MaybeTensorValueToShape(ic, *t.tensor, &c->output_tensors_as_shapes[k]);
1584 
1585       // Set output_tensor_protos.
1586       TensorProto tensor_proto;
1587       t->AsProtoTensorContent(&tensor_proto);
1588       const_tensors_to_propagate_.push_back(tensor_proto);
1589       c->output_tensor_protos[k] = &const_tensors_to_propagate_.back();
1590     }
1591     return Status::OK();
1592   }
1593 
1594   // Update output shapes with annotated information.
1595   // Currently only handle nodes with static shapes, i.e. shapes do not change
1596   // during execution.
1597   // TODO(andiryxu): Use annotated shapes in Enter/Merge etc as well.
UpdateOutputShapesUsingAnnotatedInformation(const NodeDef & node,NodeContext * c) const1598   Status UpdateOutputShapesUsingAnnotatedInformation(const NodeDef& node,
1599                                                      NodeContext* c) const {
1600     const auto& attr = node.attr();
1601     if (attr.count(kOutputSame) == 0 || !attr.at(kOutputSame).b() ||
1602         attr.count(kOutputShapes) == 0)
1603       return Status::OK();
1604 
1605     InferenceContext* ic = c->inference_context.get();
1606     int output_size = attr.at(kOutputShapes).list().shape_size();
1607 
1608     for (int i = 0; i < ic->num_outputs(); i++) {
1609       // Annotated Switch node has only one output. Propagate the shape to all
1610       // the outputs.
1611       int shape_index = IsSwitch(node) ? 0 : i;
1612       if (shape_index >= output_size) {
1613         LOG(WARNING)
1614             << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1615             << node.name() << ", inferred output shape size "
1616             << ic->num_outputs() << ", annotated output shape size "
1617             << output_size;
1618         break;
1619       }
1620 
1621       const TensorShapeProto& shape =
1622           attr.at(kOutputShapes).list().shape(shape_index);
1623       if (shape.dim().empty()) continue;
1624 
1625       ShapeHandle output_shape;
1626       TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &output_shape));
1627 
1628       // Check if annotated shapes are incompatible with inferred shapes.
1629       if ((ic->FullyDefined(ic->output(i)) &&
1630            !SameShapes(ic->output(i), output_shape)) ||
1631           (!ic->FullyDefined(ic->output(i)) &&
1632            !CompatibleShapes(ic->output(i), output_shape))) {
1633         LOG(WARNING)
1634             << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1635             << node.name() << ", inferred output shape "
1636             << "doesn't match for i=" << i << ": "
1637             << "ic->output(k): " << ic->DebugString(ic->output(i))
1638             << ", annotated output shape: " << ic->DebugString(output_shape)
1639             << " -- " << node.DebugString();
1640         c->shape_incompatible = true;
1641       }
1642 
1643       // Only use annotated shapes if the inference shape is unknown and
1644       // compatible with annotated shapes.
1645       if (!ic->FullyDefined(ic->output(i)) &&
1646           CompatibleShapes(ic->output(i), output_shape)) {
1647         VLOG(3) << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1648                 << node.name() << ", inferred output shape " << i << ": "
1649                 << "ic->output(i): " << ic->DebugString(ic->output(i))
1650                 << ", annotated output shape: " << ic->DebugString(output_shape)
1651                 << " -- " << node.ShortDebugString();
1652         ic->set_output(i, output_shape);
1653       }
1654     }
1655 
1656     return Status::OK();
1657   }
1658 
MaybeUpdateNodeContextOutput(const NodeDef & node,const bool is_fed,NodeContext * c)1659   Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed,
1660                                       NodeContext* c) {
1661     // Propagate tensors and shape tensors unless the node is fed.
1662     // TODO(bsteiner) We should still propagate the shapes to the ports that
1663     // aren't fed in the case of a ShapeN node.
1664 
1665     // Note that when propagating tensors_as_shapes, we use
1666     // c->input_tensors_as_shapes_to_progate instead of
1667     // ic->input_tensors_as_shapes. The former uses kUnknownDimFromConst if
1668     // UnknownDim is from Const tensor, and it is propagated through shape
1669     // inference. Before calling shape functions, we convert it to UnknownDim,
1670     // but instantiate a new UnknownDim to prevent incorrect symbolic shape
1671     // inference through UnknownDim from Const.
1672     InferenceContext* ic = c->inference_context.get();
1673     if (!is_fed) {
1674       if (IsConstant(node)) {
1675         c->output_tensor_protos.resize(1);
1676         const TensorProto& tensor_proto = node.attr().at("value").tensor();
1677         c->output_tensor_protos[0] = &tensor_proto;
1678         c->output_tensors_as_shapes.resize(1);
1679         MaybeTensorProtoToShape(ic, tensor_proto,
1680                                 &c->output_tensors_as_shapes[0]);
1681       } else if (IsRank(node)) {
1682         if (ic->RankKnown(ic->input(0))) {
1683           // Propagate rank value.
1684           int32 rank = ic->Rank(ic->input(0));
1685           const_tensors_to_propagate_.push_back(
1686               MakeIntegerScalarTensorProto(DT_INT32, rank));
1687           c->output_tensor_protos.resize(1);
1688           c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
1689         }
1690       } else if (IsSize(node)) {
1691         DimensionHandle size = ic->NumElements(ic->input(0));
1692         if (ic->ValueKnown(size)) {
1693           // Propagate size value.
1694           int64 sz = ic->Value(size);
1695           bool valid = false;
1696           if (node.attr().at("out_type").type() == DT_INT32) {
1697             if (sz < std::numeric_limits<int32>::max()) {
1698               const_tensors_to_propagate_.push_back(
1699                   MakeIntegerScalarTensorProto(DT_INT32, sz));
1700               valid = true;
1701             }
1702           } else {
1703             const_tensors_to_propagate_.push_back(
1704                 MakeIntegerScalarTensorProto(DT_INT64, sz));
1705             valid = true;
1706           }
1707           if (valid) {
1708             c->output_tensor_protos.resize(1);
1709             c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
1710           }
1711         }
1712       } else if (IsShape(node)) {
1713         c->output_tensors_as_shapes.resize(1);
1714         c->output_tensors_as_shapes[0] = c->inference_context->input(0);
1715       } else if (IsShapeN(node)) {
1716         c->output_tensors_as_shapes.resize(c->inference_context->num_inputs());
1717         for (int i = 0; i < c->inference_context->num_inputs(); ++i) {
1718           c->output_tensors_as_shapes[i] = c->inference_context->input(i);
1719         }
1720       } else if (node.op() == "ConcatV2") {
1721         bool valid = true;
1722         ShapeHandle result;
1723         for (int i = 0; i < ic->num_inputs() - 1; ++i) {
1724           ShapeHandle input = c->input_tensors_as_shapes_to_propagate[i];
1725           if (!ic->RankKnown(input)) {
1726             valid = false;
1727             break;
1728           } else if (i == 0) {
1729             result = input;
1730           } else {
1731             TF_RETURN_IF_ERROR(ic->Concatenate(result, input, &result));
1732           }
1733         }
1734         if (valid) {
1735           c->output_tensors_as_shapes.resize(1);
1736           c->output_tensors_as_shapes[0] = result;
1737         }
1738       } else if (IsPack(node)) {
1739         // A Pack node concatenating scalars is often used to generate a shape.
1740         std::vector<DimensionHandle> dims;
1741         bool valid = true;
1742         for (int i = 0; i < ic->num_inputs(); ++i) {
1743           const Tensor* t = ic->input_tensor(i);
1744           if (t) {
1745             if (t->dims() != 0 ||
1746                 (t->dtype() != DT_INT32 && t->dtype() != DT_INT64)) {
1747               valid = false;
1748               break;
1749             }
1750             int64 size = t->dtype() == DT_INT32 ? t->scalar<int32>()()
1751                                                 : t->scalar<int64>()();
1752             dims.push_back(size < 0 ? ic->MakeDim(kUnknownDimFromConst)
1753                                     : ic->MakeDim(size));
1754           } else {
1755             // Don't have tensor value, but use input_tensors_as_shapes, if
1756             // possible.
1757             const ShapeHandle& shape_handle =
1758                 c->input_tensors_as_shapes_to_propagate[i];
1759             if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 &&
1760                 ic->ValueKnown(ic->Dim(shape_handle, 0))) {
1761               dims.push_back(ic->Dim(shape_handle, 0));
1762             } else {
1763               // This is not from Const, but as it shouldn'be used as symbolic
1764               // unknown dim for different ops, we use kUnknownDimFromConst.
1765               dims.push_back(ic->MakeDim(kUnknownDimFromConst));
1766             }
1767           }
1768         }
1769         if (valid) {
1770           c->output_tensors_as_shapes.resize(1);
1771           c->output_tensors_as_shapes[0] = ic->MakeShape(dims);
1772         }
1773       } else if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
1774         c->output_tensors_as_shapes.resize(1);
1775         c->output_tensors_as_shapes[0] =
1776             c->input_tensors_as_shapes_to_propagate[0];
1777         if (c->input_tensor_protos[0] != nullptr) {
1778           c->output_tensor_protos.resize(1);
1779           c->output_tensor_protos[0] = c->input_tensor_protos[0];
1780         }
1781       } else if (IsSlice(node)) {
1782         ShapeHandle input = c->input_tensors_as_shapes_to_propagate[0];
1783         bool valid = ic->RankKnown(input);
1784         const Tensor* slice_offset = ic->input_tensor(1);
1785         valid &= slice_offset != nullptr && slice_offset->NumElements() == 1;
1786         const Tensor* slice_size = ic->input_tensor(2);
1787         valid &= slice_size != nullptr && slice_size->NumElements() == 1;
1788         if (valid) {
1789           int64 start = slice_offset->dtype() == DT_INT32
1790                             ? slice_offset->flat<int32>()(0)
1791                             : slice_offset->flat<int64>()(0);
1792           int64 size =
1793               (slice_size->dtype() == DT_INT32 ? slice_size->flat<int32>()(0)
1794                                                : slice_size->flat<int64>()(0));
1795           ShapeHandle result;
1796           if (size == -1) {
1797             TF_RETURN_IF_ERROR(ic->Subshape(input, start, &result));
1798           } else {
1799             int64 end = start + size;
1800             TF_RETURN_IF_ERROR(ic->Subshape(input, start, end, &result));
1801           }
1802           c->output_tensors_as_shapes.resize(1);
1803           c->output_tensors_as_shapes[0] = result;
1804         }
1805       } else if (IsStridedSlice(node)) {
1806         ShapeHandle input = c->input_tensors_as_shapes_to_propagate[0];
1807         bool valid = ic->RankKnown(input);
1808         const Tensor* slice_begin = ic->input_tensor(1);
1809         valid &= slice_begin != nullptr && slice_begin->NumElements() == 1;
1810         const Tensor* slice_end = ic->input_tensor(2);
1811         valid &= slice_end != nullptr && slice_end->NumElements() == 1;
1812         const Tensor* slice_stride = ic->input_tensor(3);
1813         valid &= slice_stride != nullptr && slice_stride->NumElements() == 1;
1814 
1815         if (node.attr().count("ellipsis_mask") > 0 &&
1816             node.attr().at("ellipsis_mask").i() != 0) {
1817           valid = false;
1818         }
1819         if (node.attr().count("new_axis_mask") > 0 &&
1820             node.attr().at("new_axis_mask").i() != 0) {
1821           valid = false;
1822         }
1823         if (node.attr().count("shrink_axis_mask") > 0 &&
1824             node.attr().at("shrink_axis_mask").i() != 0) {
1825           valid = false;
1826         }
1827         int begin_mask = 0;
1828         if (node.attr().count("begin_mask") > 0) {
1829           begin_mask = node.attr().at("begin_mask").i();
1830         }
1831         int end_mask = 0;
1832         if (node.attr().count("end_mask") > 0) {
1833           end_mask = node.attr().at("end_mask").i();
1834         }
1835         if (begin_mask < 0 || begin_mask > 1 || end_mask < 0 || end_mask > 1) {
1836           valid = false;
1837         }
1838         if (valid) {
1839           int64 begin = 0;
1840           if (begin_mask == 0) {
1841             begin = slice_begin->dtype() == DT_INT32
1842                         ? slice_begin->flat<int32>()(0)
1843                         : slice_begin->flat<int64>()(0);
1844           }
1845           int64 end = std::numeric_limits<int64>::max();
1846           if (end_mask == 0) {
1847             end =
1848                 (slice_end->dtype() == DT_INT32 ? slice_end->flat<int32>()(0)
1849                                                 : slice_end->flat<int64>()(0));
1850           }
1851           int64 stride = slice_stride->dtype() == DT_INT32
1852                              ? slice_stride->flat<int32>()(0)
1853                              : slice_stride->flat<int64>()(0);
1854           ShapeHandle result;
1855           TF_RETURN_IF_ERROR(ic->Subshape(input, begin, end, stride, &result));
1856           c->output_tensors_as_shapes.resize(1);
1857           c->output_tensors_as_shapes[0] = result;
1858         }
1859       }
1860     }
1861 
1862     if (aggressive_shape_inference_) {
1863       // Update output shapes with annotated information. This is optional.
1864       UpdateOutputShapesUsingAnnotatedInformation(node, c).IgnoreError();
1865 
1866       // Update output tensor values using EvaluateNode() if we can.
1867       // Due to the cost of EvaluateNode(), we run it only for certain op types
1868       // (white listed) and small integer tensors.
1869 
1870       const int max_element_size = 17;  // Max up to 4x4 matrix or similar.
1871       if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) ||
1872           !ShouldUpdateOutputShapesAndValues(c, max_element_size)) {
1873         return Status::OK();
1874       }
1875       UpdateOutputShapesAndValues(node, c).IgnoreError();  // This is optional.
1876     }
1877     return Status::OK();
1878   }
1879 
InferShapes(const NodeDef & node,NodeContext * c)1880   Status InferShapes(const NodeDef& node, NodeContext* c) {
1881     // Infer the shapes of output tensors.
1882     if (!c->op_data || c->op_data->shape_inference_fn == nullptr ||
1883         !c->inference_context->Run(c->op_data->shape_inference_fn).ok()) {
1884       // Annotate outputs with unknown shapes. Update output shapes with
1885       // annotated information later on if available.
1886       // Note that shape inference function may return an error, but we ignore
1887       // it, and use UnknownShape in that case.
1888       TF_RETURN_IF_ERROR(
1889           c->inference_context->Run(shape_inference::UnknownShape));
1890     }
1891     Status status = Status::OK();
1892     auto it = fed_ports_.find(node.name());
1893     const bool is_fed = it != fed_ports_.end();
1894     if (is_fed) {
1895       // It is possible to feed node output ports with tensors of any shape: as
1896       // a result, the shape of a fed port is completely unknown.
1897       for (const int output_port : it->second) {
1898         status.Update(SetUnknownShape(&node, output_port));
1899       }
1900     }
1901 
1902     // Update NodeContext output fields after shape inference function runs.
1903     status.Update(MaybeUpdateNodeContextOutput(node, is_fed, c));
1904 
1905     return status;
1906   }
1907 
1908  private:
IsIntegerVector(const Tensor & tensor)1909   bool IsIntegerVector(const Tensor& tensor) {
1910     if (tensor.dims() == 1 &&
1911         (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64)) {
1912       return true;
1913     }
1914     return false;
1915   }
1916 
IsIntegerScalar(const Tensor & tensor)1917   bool IsIntegerScalar(const Tensor& tensor) {
1918     if (tensor.dims() == 0 &&
1919         (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64) &&
1920         tensor.NumElements() == 1) {
1921       return true;
1922     }
1923     return false;
1924   }
1925 
MakeIntegerScalarTensorProto(const DataType dtype,const int64 val)1926   TensorProto MakeIntegerScalarTensorProto(const DataType dtype,
1927                                            const int64 val) {
1928     TensorProto tensor_proto;
1929     tensor_proto.set_dtype(dtype);
1930     // Scalar TensorProto has an empty tensor_shape; no dim, no dim.size.
1931     tensor_proto.mutable_tensor_shape();
1932     if (dtype == DT_INT32) {
1933       tensor_proto.add_int_val(val);
1934     } else if (dtype == DT_INT64) {
1935       tensor_proto.add_int64_val(val);
1936     }
1937     return tensor_proto;
1938   }
1939 
MaybeTensorProtoToShape(InferenceContext * ic,const TensorProto & tensor_proto,ShapeHandle * tensors_as_shapes)1940   bool MaybeTensorProtoToShape(InferenceContext* ic,
1941                                const TensorProto& tensor_proto,
1942                                ShapeHandle* tensors_as_shapes) {
1943     // Skip if dtype is not integer.
1944     if (tensor_proto.dtype() != DT_INT32 && tensor_proto.dtype() != DT_INT64) {
1945       return false;
1946     }
1947     // Skip if shape is neither scalar nor vector.
1948     if (tensor_proto.tensor_shape().unknown_rank() ||
1949         tensor_proto.tensor_shape().dim_size() > 1) {
1950       return false;
1951     }
1952     Tensor tensor;
1953     if (!tensor.FromProto(tensor_proto)) {
1954       return false;
1955     }
1956     return MaybeTensorValueToShape(ic, tensor, tensors_as_shapes);
1957   }
1958 
MaybeTensorValueToShape(InferenceContext * ic,const Tensor & tensor,ShapeHandle * tensors_as_shapes)1959   bool MaybeTensorValueToShape(InferenceContext* ic, const Tensor& tensor,
1960                                ShapeHandle* tensors_as_shapes) {
1961     // Integer tensors of rank one can also be interpreted as a shape
1962     // provided all their values are >= -1.
1963 
1964     if (IsIntegerVector(tensor)) {
1965       bool has_values_smaller_than_minus_1 = false;
1966       std::vector<DimensionHandle> dims;
1967       for (int i = 0; i < tensor.NumElements(); i++) {
1968         int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(i)
1969                                                  : tensor.flat<int64>()(i);
1970         has_values_smaller_than_minus_1 |= (value < -1);
1971         // Mark this as UnknownDim from Const.
1972         dims.push_back(value < 0 ? ic->MakeDim(kUnknownDimFromConst)
1973                                  : ic->MakeDim(value));
1974       }
1975 
1976       if (!has_values_smaller_than_minus_1) {
1977         *tensors_as_shapes = ic->MakeShape(dims);
1978         return true;
1979       }
1980     } else if (IsIntegerScalar(tensor)) {
1981       // Scalar constant.
1982       int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0)
1983                                                : tensor.flat<int64>()(0);
1984       if (value == -1) {
1985         // Scalar value -1 represents an unknown shape. If we would try to
1986         // MakeShape(MakeDim) with it, we would get vector of unknown size.
1987         *tensors_as_shapes = ic->UnknownShape();
1988         return true;
1989       } else if (value >= 0) {
1990         // Ideally, values can be < -1, but MakeDim() fails with a value < -1.
1991         // It's a limitation as we use ShapeHandle as a means to pass values.
1992         *tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)});
1993         return true;
1994       }
1995     }
1996     return false;
1997   }
1998 
1999   const GraphView& graph_;
2000   int graph_def_version_;
2001   absl::flat_hash_map<const NodeDef*, NodeContext> node_to_context_;
2002   absl::flat_hash_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
2003   absl::flat_hash_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
2004   // Store function instantiations only for valid function. If function
2005   // instantiation failed it will have an `absl::nullopt`.
2006   absl::flat_hash_map<string, absl::optional<GrapplerFunctionItem>>
2007       fun_to_grappler_function_item_;
2008   FunctionLibraryDefinition function_library_;
2009   const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports_;
2010   // Store TensorProtos for tensor value propagation. Note that we use deque,
2011   // not vector, as we use pointers to the TensorProtos in this container.
2012   // Vector may resize and copy the objects into a new buffer, then the existing
2013   // pointers become dangling pointers.
2014   std::deque<TensorProto> const_tensors_to_propagate_;
2015 
2016   // For more aggressive shape and value inference.
2017   bool aggressive_shape_inference_;
2018   ResourceMgr resource_mgr_;
2019 };
2020 
2021 // Keep track of shapes and dimensions in a graph.
2022 // In particular, use disjoint sets to track equivalence between shapes and
2023 // dims, and consolidate the information globally.
2024 class SymbolicShapeManager {
2025  public:
SymbolicShapeManager()2026   SymbolicShapeManager() {}
2027 
Merge(ShapeHandle s1,ShapeHandle s2)2028   Status Merge(ShapeHandle s1, ShapeHandle s2) {
2029     if (!s1.IsSet() || !s2.IsSet()) {
2030       return Status::OK();
2031     }
2032     TF_RETURN_IF_ERROR(shapes_.Merge(s1, s2));
2033     if (InferenceContext::Rank(s1) > 0 && InferenceContext::Rank(s2) > 0) {
2034       CHECK_EQ(InferenceContext::Rank(s1), InferenceContext::Rank(s2));
2035       for (int i = 0; i < InferenceContext::Rank(s1); ++i) {
2036         TF_RETURN_IF_ERROR(dims_.Merge(InferenceContext::DimKnownRank(s1, i),
2037                                        InferenceContext::DimKnownRank(s2, i)));
2038       }
2039     }
2040     return Status::OK();
2041   }
Merge(DimensionHandle d1,DimensionHandle d2)2042   Status Merge(DimensionHandle d1, DimensionHandle d2) {
2043     if (!d1.IsSet() || !d2.IsSet()) {
2044       return Status::OK();
2045     }
2046     return dims_.Merge(d1, d2);
2047   }
2048 
AsTensorProperties(const ShapeHandle & shape,const DataType & type,OpInfo::TensorProperties * properties)2049   void AsTensorProperties(const ShapeHandle& shape, const DataType& type,
2050                           OpInfo::TensorProperties* properties) {
2051     properties->set_dtype(type);
2052     ShapeHandle actual_shape = shapes_.GetMergedValue(shape);
2053     if (!InferenceContext::RankKnown(actual_shape)) {
2054       properties->mutable_shape()->set_unknown_rank(true);
2055     } else {
2056       for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
2057         shape_inference::DimensionHandle dim =
2058             InferenceContext::DimKnownRank(actual_shape, j);
2059         int64 d = dims_.GetMergedValue(dim);
2060         properties->mutable_shape()->add_dim()->set_size(d);
2061       }
2062     }
2063   }
2064 
2065   // Returns merged shape with merged dimensions.
GetMergedShape(InferenceContext * ic,ShapeHandle s)2066   ShapeHandle GetMergedShape(InferenceContext* ic, ShapeHandle s) {
2067     const auto& actual_shape = shapes_.GetMergedValue(s);
2068     if (!InferenceContext::RankKnown(actual_shape)) {
2069       return ic->UnknownShape();
2070     } else {
2071       std::vector<DimensionHandle> dims;
2072       for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
2073         shape_inference::DimensionHandle dim =
2074             InferenceContext::DimKnownRank(actual_shape, j);
2075         int64 d = dims_.GetMergedValue(dim);
2076         // Symbolic shape manager may made some dims < -1, which causes errors
2077         // in creating Dimension.
2078         if (d < -1) {
2079           d = -1;
2080         }
2081         dims.push_back(ic->MakeDim(d));
2082       }
2083       return ic->MakeShape(dims);
2084     }
2085   }
2086 
2087  private:
2088   DisjointSet<shape_inference::ShapeHandle> shapes_;
2089   DisjointSet<shape_inference::DimensionHandle> dims_;
2090 };
2091 
2092 // Checks whether there is any conflict in merged shapes and dims in
2093 // SymbolicShapeManager.
ValidateSymbolicShapeManager(const GraphDef & graph_def,SymbolicShapeRefiner * refiner,SymbolicShapeManager * shape_manager)2094 Status ValidateSymbolicShapeManager(const GraphDef& graph_def,
2095                                     SymbolicShapeRefiner* refiner,
2096                                     SymbolicShapeManager* shape_manager) {
2097   if (!VLOG_IS_ON(1)) {
2098     return Status::OK();
2099   }
2100 
2101   VLOG(1) << "Checking any conflics in shapes and dimensions ...";
2102   int64 num_incompatible_shapes = 0;
2103   for (const NodeDef& node : graph_def.node()) {
2104     auto ctx = refiner->GetNodeContext(&node);
2105     if (!ctx) {
2106       continue;
2107     }
2108     auto* ic = ctx->inference_context.get();
2109     for (int i = 0; i < ic->num_inputs(); ++i) {
2110       const auto& shape = ic->input(i);
2111       const auto& merged_shape = shape_manager->GetMergedShape(ic, shape);
2112       if (!refiner->CompatibleShapes(shape, merged_shape)) {
2113         num_incompatible_shapes++;
2114         VLOG(1) << "**** Incompatible shape from SymbolicShapeManager "
2115                 << "for node " << node.name() << " input (" << i << ") "
2116                 << ic->DebugString(shape)
2117                 << " vs. merged: " << ic->DebugString(merged_shape);
2118       }
2119     }
2120     for (int i = 0; i < ic->num_outputs(); ++i) {
2121       const auto& shape = ic->output(i);
2122       const auto& merged_shape = shape_manager->GetMergedShape(ic, shape);
2123       if (!refiner->CompatibleShapes(shape, merged_shape)) {
2124         num_incompatible_shapes++;
2125         VLOG(1) << "**** Incompatible shape from SymbolicShapeManager "
2126                 << "for node " << node.name() << " output (" << i << ") "
2127                 << ic->DebugString(shape)
2128                 << " vs. merged: " << ic->DebugString(merged_shape);
2129       }
2130     }
2131   }
2132   if (num_incompatible_shapes > 0) {
2133     VLOG(1) << "**** WARNING: " << num_incompatible_shapes
2134             << " incompatible shapes from SymbolicShapeManager.";
2135   } else {
2136     VLOG(1) << "**** No incompatible shape found from SymbolicShapeManager.";
2137   }
2138 
2139   return Status::OK();
2140 }
2141 
2142 // Log shape inference and its merged shapes.
VerboseShapeInferenceLogging(const GraphDef & graph_def,SymbolicShapeRefiner * refiner,SymbolicShapeManager * shape_manager)2143 Status VerboseShapeInferenceLogging(const GraphDef& graph_def,
2144                                     SymbolicShapeRefiner* refiner,
2145                                     SymbolicShapeManager* shape_manager) {
2146   // As logging all the nodes would generate too many lines, we by default
2147   // skip this detailed logging. Users may add nodes of interest to
2148   // node_names_for_logging to enable detailed logging.
2149   absl::flat_hash_set<std::string> node_names_for_logging = {};
2150   if (!VLOG_IS_ON(3) || node_names_for_logging.empty()) {
2151     return Status::OK();
2152   }
2153 
2154   auto should_log = [&node_names_for_logging](std::string node_name) {
2155     return node_names_for_logging.find(node_name) !=
2156            node_names_for_logging.end();
2157   };
2158 
2159   for (const NodeDef& node : graph_def.node()) {
2160     if (!should_log(node.name())) {
2161       continue;
2162     }
2163     auto ctx = refiner->GetNodeContext(&node);
2164     if (!ctx) {
2165       continue;
2166     }
2167     auto* ic = ctx->inference_context.get();
2168     VLOG(3) << "Shape inference for node : " << node.name();
2169     VLOG(3) << ctx->DebugString(node);
2170     std::string merged_shapes = "Merged shapes from SymbolicShapManager:\n";
2171     for (int i = 0; i < ic->num_inputs(); ++i) {
2172       absl::StrAppend(
2173           &merged_shapes, " input[", i, "] -- ",
2174           ic->DebugString(shape_manager->GetMergedShape(ic, ic->input(i))),
2175           "\n");
2176     }
2177     for (int i = 0; i < ic->num_outputs(); ++i) {
2178       absl::StrAppend(
2179           &merged_shapes, " output[", i, "] -- ",
2180           ic->DebugString(shape_manager->GetMergedShape(ic, ic->output(i))),
2181           "\n");
2182     }
2183     VLOG(3) << merged_shapes;
2184     VLOG(3) << "--------------------------------";
2185     VLOG(3) << "";
2186   }
2187 
2188   return Status::OK();
2189 }
2190 
RelaxEnqueueShapesAndMergeTypes(SymbolicShapeRefiner * shape_refiner,const NodeDef * qnode,const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * queue_shapes_and_types)2191 Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
2192     SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
2193     const std::vector<ShapeAndType>& shapes_and_types,
2194     std::vector<ShapeAndType>* queue_shapes_and_types) {
2195   if (shapes_and_types.size() != queue_shapes_and_types->size()) {
2196     return errors::InvalidArgument(
2197         "Enqueue nodes mixed number of tensors: ", shapes_and_types.size(),
2198         "  vs ", queue_shapes_and_types->size());
2199   }
2200   for (size_t i = 0; i < shapes_and_types.size(); ++i) {
2201     const ShapeAndType& a = shapes_and_types[i];
2202     ShapeAndType& b = (*queue_shapes_and_types)[i];
2203     if (a.dtype != b.dtype) {
2204       return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ",
2205                                      i, ": ", DataTypeString(a.dtype), " vs ",
2206                                      DataTypeString(b.dtype));
2207     }
2208 
2209     b.shape = shape_refiner->OutputAsUnion(qnode, i, a.shape, b.shape);
2210   }
2211   return Status::OK();
2212 }
2213 
2214 // Compute the output shape of the merge node as the union of the available
2215 // input shapes.
UpdateMerge(SymbolicShapeRefiner * shape_refiner,const NodeDef * node,bool * new_shapes) const2216 Status GraphProperties::UpdateMerge(SymbolicShapeRefiner* shape_refiner,
2217                                     const NodeDef* node,
2218                                     bool* new_shapes) const {
2219   InferenceContext* ic = shape_refiner->GetContext(node);
2220   if (!ic) {
2221     // Now we can run shape inference
2222     TF_RETURN_IF_ERROR(shape_refiner->AddNode(node));
2223     ic = CHECK_NOTNULL(shape_refiner->GetContext(node));
2224     *new_shapes = true;
2225 
2226     // Infer the shape of the second output once and for all since it never
2227     // changes.
2228     ShapeHandle out1 = ic->Scalar();
2229     if (ic->num_outputs() >= 2) ic->set_output(1, out1);
2230   }
2231 
2232   ShapeHandle out;
2233   const std::vector<ShapeAndType>* out_handle = nullptr;
2234   bool out_initialized = false;
2235   for (const GraphView::Edge fanin : shape_refiner->graph().GetFaninEdges(
2236            *node, /*include_controlling_edges=*/false)) {
2237     InferenceContext* src_ic = shape_refiner->GetContext(fanin.src.node);
2238     if (!src_ic) {
2239       // Handling a loop for the first time, the back edge won't have any shape
2240       // info.
2241       continue;
2242     }
2243     ShapeHandle input = src_ic->output(fanin.src.port_id);
2244     ic->SetInput(fanin.dst.port_id, input);
2245     auto* input_handle =
2246         src_ic->output_handle_shapes_and_types(fanin.src.port_id);
2247     if (input_handle)
2248       ic->set_input_handle_shapes_and_types(fanin.dst.port_id, *input_handle);
2249     if (!out_initialized) {
2250       out_initialized = true;
2251       out = input;
2252       out_handle = input_handle;
2253     } else {
2254       // Note here only out, not out_handle, is modified.
2255       out = shape_refiner->OutputAsUnion(node, 0, input, out);
2256     }
2257   }
2258 
2259   if (*new_shapes || !shape_refiner->EquivalentShapes(out, ic->output(0))) {
2260     ic->set_output(0, out);
2261     if (out_handle) ic->set_output_handle_shapes_and_types(0, *out_handle);
2262     *new_shapes = true;
2263   }
2264 
2265   return Status::OK();
2266 }
2267 
2268 // Manually propagate the input shape for Enter nodes.
UpdateEnter(SymbolicShapeRefiner * shape_refiner,const NodeDef * node,bool * new_shapes)2269 Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
2270                                     const NodeDef* node, bool* new_shapes) {
2271   InferenceContext* ic = shape_refiner->GetContext(node);
2272   if (!ic) {
2273     TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(node, new_shapes));
2274     ic = shape_refiner->GetContext(node);
2275   }
2276 
2277   GraphView::InputPort port(node, 0);
2278   GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(port);
2279 
2280   InferenceContext* src_ic = shape_refiner->GetContext(fanin.node);
2281   ShapeHandle input = src_ic->output(fanin.port_id);
2282   if (!ic->output(0).SameHandle(input)) {
2283     ic->SetInput(0, input);
2284     ic->set_output(0, input);
2285     *new_shapes = true;
2286   }
2287   auto* outputs = src_ic->output_handle_shapes_and_types(fanin.port_id);
2288   if (outputs) {
2289     ic->set_input_handle_shapes_and_types(0, *outputs);
2290     ic->set_output_handle_shapes_and_types(0, *outputs);
2291     *new_shapes = true;
2292   }
2293   return Status::OK();
2294 }
2295 
UpdateShapes(SymbolicShapeRefiner * shape_refiner,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,const NodeDef * n,bool * new_shapes) const2296 Status GraphProperties::UpdateShapes(
2297     SymbolicShapeRefiner* shape_refiner,
2298     const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2299     const NodeDef* n, bool* new_shapes) const {
2300   if (IsEnter(*n)) {
2301     // The Enter shape function always forwards an UnknownShape, so do the right
2302     // thing here.
2303     TF_RETURN_IF_ERROR(UpdateEnter(shape_refiner, n, new_shapes));
2304   } else if (IsMerge(*n)) {
2305     // Properly handle merge nodes.
2306     TF_RETURN_IF_ERROR(UpdateMerge(shape_refiner, n, new_shapes));
2307   } else if (IsEnqueue(*n)) {
2308     // Make sure the shapes of enqueued tensors are propagated to the queue
2309     // itself.
2310     TF_RETURN_IF_ERROR(
2311         UpdateEnqueue(n, resource_handles, shape_refiner, new_shapes));
2312   } else if (IsQueue(*n)) {
2313     // Set shapes and types of Queue ops, if needed.
2314     TF_RETURN_IF_ERROR(UpdateQueue(n, shape_refiner, new_shapes));
2315   } else {
2316     // Rely on regular TF shape refinement for all the other nodes.
2317     // UpdateNode calls UpdateFunction if a function node is detected.
2318     TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
2319   }
2320 
2321   return Status::OK();
2322 }
2323 
2324 // Propagates the shapes in the transitive fan-out of <new_shapes>.
PropagateShapes(SymbolicShapeRefiner * shape_refiner,TopoQueue * new_shapes,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,int num_loops) const2325 Status GraphProperties::PropagateShapes(
2326     SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
2327     const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2328     int num_loops) const {
2329   // Limit the number of iterations to prevent infinite loops in the presence of
2330   // incorrect shape functions. The algorithm should converge in at most
2331   // num_nested_loops^2 * max_rank. We approximate max_rank with the constant 4.
2332   // The same applies to resources.
2333   VLOG(1) << "Propagating " << new_shapes->size() << " new shapes through "
2334           << num_loops << " loops and " << resource_handles.size()
2335           << " resources" << std::endl;
2336 
2337   const int64 max_loop_length = item_.graph.node_size();
2338   const int64 max_rank = 4;
2339   const int64 max_loop_iterations =
2340       max_rank * max_loop_length * std::max<int64>(1, num_loops * num_loops);
2341   const int64 num_queues = resource_handles.size();
2342   const int64 max_resource_iterations = num_queues * num_queues * max_rank;
2343 
2344   int64 num_resource_iterations = 0;
2345   do {
2346     int64 num_loop_iterations = 0;
2347     while (!new_shapes->empty() &&
2348            num_loop_iterations++ < max_loop_iterations) {
2349       const NodeDef* n = new_shapes->pop();
2350       bool updated = false;
2351       TF_RETURN_IF_ERROR(
2352           UpdateShapes(shape_refiner, resource_handles, n, &updated));
2353       if (updated) {
2354         for (const auto& fanout : shape_refiner->graph().GetFanouts(
2355                  *n, /*include_controlled_nodes=*/false)) {
2356           new_shapes->push(fanout.node);
2357         }
2358         // Make sure the corresponding queue nodes are (re)processed.
2359         if (IsEnqueue(*n)) {
2360           auto it = resource_handles.find(n);
2361           if (it != resource_handles.end()) {
2362             new_shapes->push(it->second);
2363           }
2364         }
2365       }
2366     }
2367   } while (!new_shapes->empty() &&
2368            num_resource_iterations++ < max_resource_iterations);
2369 
2370   if (!new_shapes->empty()) {
2371     return errors::Internal("Shape inference failed to converge");
2372   }
2373 
2374   return Status::OK();
2375 }
2376 
UpdateQueue(const NodeDef * queue_node,SymbolicShapeRefiner * shape_refiner,bool * new_shapes)2377 Status GraphProperties::UpdateQueue(const NodeDef* queue_node,
2378                                     SymbolicShapeRefiner* shape_refiner,
2379                                     bool* new_shapes) {
2380   auto* ctx = shape_refiner->GetNodeContext(queue_node);
2381   if (!ctx) {
2382     TF_RETURN_IF_ERROR(shape_refiner->AddNode(queue_node));
2383     ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(queue_node));
2384   }
2385   auto* ic = ctx->inference_context.get();
2386 
2387   auto* outputs = ic->output_handle_shapes_and_types(0);
2388   if (outputs) {
2389     // Shapes and types are already set, presumably by Enqueue ops.
2390     return shape_refiner->UpdateNode(queue_node, new_shapes);
2391   }
2392 
2393   if (queue_node->attr().count("shapes") <= 0 ||
2394       queue_node->attr().count("component_types") <= 0 ||
2395       queue_node->attr().at("shapes").list().shape_size() !=
2396           queue_node->attr().at("component_types").list().type_size()) {
2397     // Errors in shapes and component_types attr.
2398     return shape_refiner->UpdateNode(queue_node, new_shapes);
2399   }
2400 
2401   // Extract types and shapes from Queue attr.
2402   const auto& shapes = queue_node->attr().at("shapes").list().shape();
2403   const auto& types = queue_node->attr().at("component_types").list().type();
2404   std::vector<ShapeAndType> shapes_and_types;
2405   for (int i = 0; i < types.size(); i++) {
2406     const auto& shape = shapes[i];
2407     ShapeHandle shape_handle;
2408     TF_RETURN_IF_ERROR(
2409         ic->MakeShapeFromPartialTensorShape(shape, &shape_handle));
2410     DataType data_type =
2411         queue_node->attr().at("component_types").list().type(i);
2412     ShapeAndType shape_and_type(shape_handle, data_type);
2413     shapes_and_types.push_back(shape_and_type);
2414   }
2415   ic->set_output_handle_shapes_and_types(0, shapes_and_types);
2416 
2417   // Queue node is updated with output_handle_shapes_and_types, so set
2418   // new_shapes and ignore it from UpdateNoe().
2419   *new_shapes = true;
2420   bool dummy_new_shapes = false;
2421   return shape_refiner->UpdateNode(queue_node, &dummy_new_shapes);
2422 }
2423 
UpdateEnqueue(const NodeDef * enqueue_node,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,SymbolicShapeRefiner * shape_refiner,bool * new_shapes)2424 Status GraphProperties::UpdateEnqueue(
2425     const NodeDef* enqueue_node,
2426     const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2427     SymbolicShapeRefiner* shape_refiner, bool* new_shapes) {
2428   auto ctx = shape_refiner->GetNodeContext(enqueue_node);
2429   if (!ctx) {
2430     TF_RETURN_IF_ERROR(shape_refiner->AddNode(enqueue_node));
2431     ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(enqueue_node));
2432   }
2433 
2434   auto it = resource_handles.find(enqueue_node);
2435   if (it == resource_handles.end()) {
2436     // The corresponding queue was not found, there isn't much we can do.
2437     return Status::OK();
2438   }
2439   const NodeDef* qnode = it->second;
2440   auto qctx = shape_refiner->GetContext(qnode);
2441   if (!qctx) {
2442     return Status::OK();
2443   }
2444   auto* queue_handle_data = qctx->output_handle_shapes_and_types(0);
2445 
2446   // TODO(bsteiner): handle EnqueueMany as well.
2447   std::vector<ShapeAndType> shapes_and_types;
2448   for (int i = 1, end = ctx->input_types.size(); i < end; ++i) {
2449     GraphView::InputPort inp(enqueue_node, i);
2450     GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(inp);
2451     InferenceContext* in = shape_refiner->GetContext(fanin.node);
2452     ShapeHandle input = in->output(fanin.port_id);
2453     ctx->inference_context->SetInput(i, input);
2454     shapes_and_types.push_back({input, ctx->input_types[i]});
2455   }
2456 
2457   if (queue_handle_data == nullptr) {
2458     qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
2459     *new_shapes = true;
2460   } else {
2461     TF_RETURN_IF_ERROR(RelaxEnqueueShapesAndMergeTypes(
2462         shape_refiner, qnode, *queue_handle_data, &shapes_and_types));
2463     *new_shapes |= !shape_refiner->EquivalentShapesAndTypes(*queue_handle_data,
2464                                                             shapes_and_types);
2465     qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
2466   }
2467 
2468   return Status::OK();
2469 }
2470 
InferStatically(bool assume_valid_feeds,bool aggressive_shape_inference,bool include_input_tensor_values,bool include_output_tensor_values)2471 Status GraphProperties::InferStatically(bool assume_valid_feeds,
2472                                         bool aggressive_shape_inference,
2473                                         bool include_input_tensor_values,
2474                                         bool include_output_tensor_values) {
2475   FunctionLibraryDefinition function_library(OpRegistry::Global(),
2476                                              item_.graph.library());
2477   absl::flat_hash_map<string, absl::flat_hash_set<int>> fed_ports;
2478   if (!assume_valid_feeds) {
2479     for (const auto& feed : item_.feed) {
2480       SafeTensorId tensor_id = ParseTensorName(feed.first);
2481       fed_ports[tensor_id.node()].insert(tensor_id.index());
2482     }
2483   }
2484 
2485   GraphView graph_view(&item_.graph);
2486 
2487   // List the resources and the nodes using them. Also collect the Merge nodes,
2488   // fed nodes, and primary inputs.
2489   absl::flat_hash_map<const NodeDef*,
2490                       std::pair<absl::flat_hash_set<const NodeDef*>,
2491                                 absl::flat_hash_set<const NodeDef*>>>
2492       resources;
2493   absl::flat_hash_set<const NodeDef*> merge_nodes;
2494   absl::flat_hash_set<const NodeDef*> fed_nodes;
2495   absl::flat_hash_set<const NodeDef*> primary_inputs;
2496   int num_loops = 0;
2497   for (const NodeDef& node : item_.graph.node()) {
2498     if (IsQueue(node)) {
2499       for (const GraphView::InputPort& fanout :
2500            graph_view.GetFanouts(node, false)) {
2501         if (IsEnter(*fanout.node)) {
2502           const NodeDef& enter = *fanout.node;
2503           for (const GraphView::InputPort& fanout :
2504                graph_view.GetFanouts(enter, false)) {
2505             if (IsEnqueue(*fanout.node)) {
2506               resources[&node].first.insert(fanout.node);
2507             } else if (IsDequeue(*fanout.node)) {
2508               resources[&node].second.insert(fanout.node);
2509             }
2510           }
2511         } else {
2512           if (IsEnqueue(*fanout.node)) {
2513             resources[&node].first.insert(fanout.node);
2514           } else if (IsDequeue(*fanout.node)) {
2515             resources[&node].second.insert(fanout.node);
2516           }
2517         }
2518       }
2519     }
2520     if (!HasRegularInputs(node)) {
2521       primary_inputs.insert(&node);
2522     } else if (IsMerge(node)) {
2523       merge_nodes.insert(&node);
2524     } else if (IsNextIteration(node)) {
2525       ++num_loops;
2526     }
2527     if (fed_ports.find(node.name()) != fed_ports.end()) {
2528       fed_nodes.insert(&node);
2529     }
2530   }
2531 
2532   absl::flat_hash_map<const NodeDef*, const NodeDef*> resource_handles;
2533   std::vector<TopologicalDependency> extra_deps;
2534   for (const auto& resource : resources) {
2535     for (const NodeDef* src : resource.second.first) {
2536       resource_handles[src] = resource.first;
2537       for (const NodeDef* dst : resource.second.second) {
2538         // Add control edges from enqueue to dequeue nodes to ensure they are
2539         // processed in their logical order.
2540         extra_deps.emplace_back(src, dst);
2541       }
2542     }
2543   }
2544 
2545   std::vector<const NodeDef*> topo_order;
2546   Status s = ComputeTopologicalOrder(item_.graph, extra_deps, &topo_order);
2547   if (!s.ok()) {
2548     if (extra_deps.empty()) {
2549       return s;
2550     } else {
2551       // There is a loop between queues: we'll just use the graph topological
2552       // order. This will make the shape inference less precise but since this
2553       // isn't common it's not worth to figure out where to break the loop and
2554       // do a proper relaxation.
2555       TF_RETURN_IF_ERROR(ComputeTopologicalOrder(item_.graph, &topo_order));
2556     }
2557   }
2558 
2559   // Heap-allocate SymbolicShapeRefiner in order to not consume a large amount
2560   // of stack space.
2561   auto refiner = absl::make_unique<SymbolicShapeRefiner>(
2562       graph_view, fed_ports, aggressive_shape_inference);
2563 
2564   TopoQueue new_shapes(topo_order);
2565   // Also seed the propagation of shapes in the fanout of primary inputs.
2566   for (const NodeDef* node : primary_inputs) {
2567     new_shapes.push(node);
2568   }
2569   // Also seed the propagation of shapes in the fanout of fed nodes.
2570   for (const NodeDef* node : fed_nodes) {
2571     new_shapes.push(node);
2572   }
2573   // Propagate shapes normally.
2574   TF_RETURN_IF_ERROR(
2575       PropagateShapes(refiner.get(), &new_shapes, resource_handles, num_loops));
2576 
2577   // Track shapes globally across the graph.
2578   std::unique_ptr<SymbolicShapeManager> shape_manager =
2579       absl::make_unique<SymbolicShapeManager>();
2580   bool found_error = false;
2581   for (const NodeDef& node : item_.graph.node()) {
2582     auto node_ctx = refiner->GetContext(&node);
2583     if (!node_ctx) {
2584       continue;
2585     }
2586     // Skip any information that comes from fed nodes.
2587     if (fed_ports.find(node.name()) != fed_ports.end()) {
2588       VLOG(2) << "Skipping feed node shape: " << node.name();
2589       continue;
2590     }
2591     for (const auto& merged_shapes : node_ctx->MergedShapes()) {
2592       if (!shape_manager->Merge(merged_shapes.first, merged_shapes.second)
2593                .ok()) {
2594         found_error = true;
2595         break;
2596       }
2597     }
2598     for (const auto& merged_dims : node_ctx->MergedDims()) {
2599       if (!shape_manager->Merge(merged_dims.first, merged_dims.second).ok()) {
2600         found_error = true;
2601         break;
2602       }
2603     }
2604     if (found_error) {
2605       // The shapes aren't consistent, we can't infer safely: discard all the
2606       // information discovered so far.
2607       shape_manager = absl::make_unique<SymbolicShapeManager>();
2608       break;
2609     }
2610   }
2611 
2612   TF_RETURN_IF_ERROR(ValidateSymbolicShapeManager(item_.graph, refiner.get(),
2613                                                   shape_manager.get()));
2614 
2615   for (const NodeDef& node : item_.graph.node()) {
2616     VLOG(4) << "Filling in graph properties for node: " << node.name();
2617     auto ctx = refiner->GetNodeContext(&node);
2618     if (!ctx) {
2619       continue;
2620     }
2621 
2622     auto* ic = ctx->inference_context.get();
2623 
2624     // Fill input properties.
2625     {
2626       auto& input_properties = input_properties_[node.name()];
2627 
2628       // Should always be empty, node names in graph are supposed to be unique.
2629       CHECK_EQ(input_properties.size(), 0);
2630 
2631       input_properties.resize(ic->num_inputs());
2632       GraphView::InputPort input(&node, -1);
2633       for (int i = 0; i < ic->num_inputs(); ++i) {
2634         shape_manager->AsTensorProperties(ic->input(i), ctx->input_types[i],
2635                                           &input_properties[i]);
2636         input.port_id = i;
2637         GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
2638         if (include_input_tensor_values) {
2639           // Export tensor value to input_properties.value.
2640           if (IsConstant(*fanin.node)) {
2641             const TensorProto& raw_val =
2642                 fanin.node->attr().at("value").tensor();
2643             *input_properties[i].mutable_value() = raw_val;
2644           } else if (static_cast<int>(ctx->input_tensor_protos.size()) > i &&
2645                      ctx->input_tensor_protos[i] != nullptr) {
2646             *input_properties[i].mutable_value() = *ctx->input_tensor_protos[i];
2647           } else if (static_cast<int>(ic->input_tensors_as_shapes().size()) >
2648                          i &&
2649                      IsShapeFullyDefinedIntegerVectorOrScalar(
2650                          ic, ic->input(i), ic->input_tensors_as_shapes()[i],
2651                          ctx->input_types[i])) {
2652             *input_properties[i].mutable_value() = MakeTensorProtoFromShape(
2653                 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
2654                 ctx->input_types[i]);
2655           }
2656         }
2657       }
2658     }
2659 
2660     // Fill output properties.
2661     {
2662       auto& output_properties = output_properties_[node.name()];
2663 
2664       // Should always be empty, node names in graph are supposed to be unique.
2665       CHECK_EQ(output_properties.size(), 0);
2666 
2667       output_properties.resize(ic->num_outputs());
2668       for (int i = 0; i < ic->num_outputs(); ++i) {
2669         shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i],
2670                                           &output_properties[i]);
2671         auto converted_output_tensors_as_shapes =
2672             ReplaceUnknownDimFromConstWithUnknownDim(
2673                 ic, ctx->output_tensors_as_shapes);
2674         if (include_output_tensor_values) {
2675           // Export tensor value to output_properties.value.
2676           if (IsConstant(node)) {
2677             // TODO(rmlarsen): Eliminate this copy.
2678             const TensorProto& raw_val = node.attr().at("value").tensor();
2679             *output_properties[i].mutable_value() = raw_val;
2680           } else if (static_cast<int>(ctx->output_tensor_protos.size()) > i &&
2681                      ctx->output_tensor_protos[i] != nullptr) {
2682             *output_properties[i].mutable_value() =
2683                 *ctx->output_tensor_protos[i];
2684           } else if (static_cast<int>(
2685                          converted_output_tensors_as_shapes.size()) > i &&
2686                      IsShapeFullyDefinedIntegerVectorOrScalar(
2687                          ic, ic->output(i),
2688                          converted_output_tensors_as_shapes[i],
2689                          ctx->output_types[i])) {
2690             *output_properties[i].mutable_value() = MakeTensorProtoFromShape(
2691                 ic, ic->output(i), converted_output_tensors_as_shapes[i],
2692                 ctx->output_types[i]);
2693           }
2694         }
2695       }
2696     }
2697 
2698     if (aggressive_shape_inference && ctx->shape_incompatible)
2699       incompatible_shape_nodes_.insert(node.name());
2700   }
2701 
2702   if (aggressive_shape_inference && !incompatible_shape_nodes_.empty())
2703     LOG(WARNING) << incompatible_shape_nodes_.size()
2704                  << " nodes have incompatible output shapes.";
2705 
2706   // Help trace the unknown dimensions to their origins.
2707   VerboseLogUnknownDimensionSources(item_.graph, input_properties_,
2708                                     output_properties_);
2709 
2710   TF_RETURN_IF_ERROR(VerboseShapeInferenceLogging(item_.graph, refiner.get(),
2711                                                   shape_manager.get()));
2712 
2713   return Status::OK();
2714 }
2715 
InferDynamically(Cluster * cluster)2716 Status GraphProperties::InferDynamically(Cluster* cluster) {
2717   TF_RETURN_IF_ERROR(cluster->Initialize(item_));
2718 
2719   // Runs the model once to collect the shapes in the cost model.
2720   RunMetadata metadata;
2721   TF_RETURN_IF_ERROR(
2722       cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata));
2723 
2724   return InferFromCostGraph(metadata.cost_graph());
2725 }
2726 
AnnotateOutputShapes(GraphDef * output_graph_def,bool allow_symbolic_shapes) const2727 Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def,
2728                                              bool allow_symbolic_shapes) const {
2729   *output_graph_def = item_.graph;
2730   for (int i = 0; i < output_graph_def->node_size(); i++) {
2731     auto node = output_graph_def->mutable_node(i);
2732     AttrValue attr_output_shape;
2733     auto tensor_properties = GetOutputProperties(node->name());
2734     for (const auto& tensor_property : tensor_properties) {
2735       TensorShapeProto* proto = attr_output_shape.mutable_list()->add_shape();
2736       *proto = tensor_property.shape();
2737       if (!allow_symbolic_shapes) {
2738         NormalizeShapeForOutput(proto);
2739       }
2740     }
2741     (*node->mutable_attr())["_output_shapes"] = attr_output_shape;
2742   }
2743   return Status::OK();
2744 }
2745 
InferFromCostGraph(const CostGraphDef & cost_graph)2746 Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) {
2747   if (cost_graph.node_size() == 0) {
2748     LOG(WARNING) << "cost_graph is empty: nothing can be inferred!";
2749   }
2750   std::unordered_map<string, const CostGraphDef::Node*> name_to_cost;
2751   std::unordered_map<string, const NodeDef*> name_to_node;  // Empty
2752   for (auto& node : cost_graph.node()) {
2753     name_to_cost[node.name()] = &node;
2754 
2755     std::vector<OpInfo::TensorProperties> output_properties;
2756     for (const auto& out : node.output_info()) {
2757       OpInfo::TensorProperties properties;
2758       properties.set_dtype(out.dtype());
2759       *properties.mutable_shape() = out.shape();
2760       output_properties.push_back(properties);
2761     }
2762     output_properties_[node.name()] = output_properties;
2763   }
2764 
2765   for (const auto& node : item_.graph.node()) {
2766     // Skip the nodes that are not in the cost graph: these are nodes that
2767     // aren't run, because they aren't in the intersection of transitive fan-in
2768     // of a fetch node and the transitive fan-out of an input, or nodes that
2769     // were optimized away by the optimizer.
2770     auto it = name_to_cost.find(node.name());
2771     if (it == name_to_cost.end()) {
2772       continue;
2773     }
2774     std::vector<OpInfo::TensorProperties> inputs =
2775         FindInputFeatures(node, name_to_cost, name_to_node);
2776 
2777     input_properties_[node.name()] = inputs;
2778   }
2779   return Status::OK();
2780 }
2781 
HasInputProperties(const string & node_name) const2782 bool GraphProperties::HasInputProperties(const string& node_name) const {
2783   return input_properties_.find(node_name) != input_properties_.end();
2784 }
2785 
HasOutputProperties(const string & node_name) const2786 bool GraphProperties::HasOutputProperties(const string& node_name) const {
2787   return output_properties_.find(node_name) != output_properties_.end();
2788 }
2789 
2790 const std::vector<OpInfo::TensorProperties>&
GetInputProperties(const string & node_name) const2791 GraphProperties::GetInputProperties(const string& node_name) const {
2792   auto it = input_properties_.find(node_name);
2793   if (it != input_properties_.end()) {
2794     return it->second;
2795   }
2796   return missing_properties_;
2797 }
2798 
2799 const std::vector<OpInfo::TensorProperties>&
GetOutputProperties(const string & node_name) const2800 GraphProperties::GetOutputProperties(const string& node_name) const {
2801   auto it = output_properties_.find(node_name);
2802   if (it != output_properties_.end()) {
2803     return it->second;
2804   }
2805   return missing_properties_;
2806 }
2807 
ClearInputProperties(const string & node_name)2808 void GraphProperties::ClearInputProperties(const string& node_name) {
2809   input_properties_.erase(node_name);
2810 }
ClearOutputProperties(const string & node_name)2811 void GraphProperties::ClearOutputProperties(const string& node_name) {
2812   output_properties_.erase(node_name);
2813 }
2814 
2815 }  // end namespace grappler
2816 }  // end namespace tensorflow
2817