• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/graph_properties.h"
17 
18 #include "tensorflow/core/framework/common_shape_fns.h"
19 #include "tensorflow/core/framework/function.pb.h"
20 #include "tensorflow/core/framework/node_def_util.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/framework/tensor_shape.pb.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/graph/graph_constructor.h"
27 #include "tensorflow/core/graph/tensor_id.h"
28 #include "tensorflow/core/grappler/costs/utils.h"
29 #include "tensorflow/core/grappler/mutable_graph_view.h"
30 #include "tensorflow/core/grappler/op_types.h"
31 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
32 #include "tensorflow/core/grappler/utils.h"
33 #include "tensorflow/core/grappler/utils/functions.h"
34 #include "tensorflow/core/grappler/utils/topological_sort.h"
35 #include "tensorflow/core/lib/gtl/cleanup.h"
36 #include "tensorflow/core/lib/gtl/flatset.h"
37 #include "tensorflow/core/lib/strings/str_util.h"
38 
39 namespace tensorflow {
40 namespace grappler {
41 
42 namespace {
43 
44 using shape_inference::DimensionHandle;
45 using shape_inference::InferenceContext;
46 using shape_inference::ShapeAndType;
47 using shape_inference::ShapeHandle;
48 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
49 
50 template <typename Handle>
51 struct HashHandle {
operator ()tensorflow::grappler::__anon949966ef0111::HashHandle52   std::size_t operator()(const Handle& h) const { return h.Handle(); }
53 };
54 template <typename Handle>
55 struct CompareHandle {
operator ()tensorflow::grappler::__anon949966ef0111::CompareHandle56   bool operator()(const Handle& h1, const Handle& h2) const {
57     return h1.SameHandle(h2);
58   }
59 };
60 
61 template <typename Handle>
62 struct HandleToObject {};
63 template <>
64 struct HandleToObject<ShapeHandle> {
65   typedef ShapeHandle Object;
66 
Unknowntensorflow::grappler::__anon949966ef0111::HandleToObject67   static ShapeHandle Unknown() { return ShapeHandle(); }
68 };
69 
70 template <>
71 struct HandleToObject<DimensionHandle> {
72   typedef int64 Object;
73 
Unknowntensorflow::grappler::__anon949966ef0111::HandleToObject74   static int64 Unknown() { return -1; }
75 };
76 
77 template <typename Handle>
78 struct Processor {};
79 
80 template <>
81 struct Processor<ShapeHandle> {
82   // Extract the shape or dim denoted by the handle.
ExtractValuetensorflow::grappler::__anon949966ef0111::Processor83   void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; }
84   // Merge the shapes or dims.
Mergetensorflow::grappler::__anon949966ef0111::Processor85   Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) {
86     if (InferenceContext::RankKnown(*result)) {
87       // The result was initialized in a previous merge to a shape of known
88       // rank, make sure we preserve that information.
89       return Status::OK();
90     }
91     if (InferenceContext::RankKnown(h1)) {
92       *result = h1;
93     } else {
94       *result = h2;
95     }
96     return Status::OK();
97   }
98 };
99 
100 template <>
101 struct Processor<DimensionHandle> {
102   // Assign a negative id to unknown dimensions, starting at -2 (the -1 id
103   // reserved by TensorFlow).
ExtractValuetensorflow::grappler::__anon949966ef0111::Processor104   void ExtractValue(DimensionHandle d, int64* result) {
105     if (!InferenceContext::ValueKnown(d)) {
106       *result = -counter;
107       counter++;
108     } else {
109       int64 val = InferenceContext::Value(d);
110       if (val >= 0) {
111         *result = val;
112       } else {
113         // A shape inference function generated an invalid dimension handle.
114         // Use a symbolic dimension to encode this.
115         *result = -counter;
116         counter++;
117       }
118     }
119   }
120 
121   // Merge the dimensions d1 and d2. Return the known shape if there is one,
122   // otherwise look for a symbolic shape. If there is no symbolic shape and no
123   // known shape, the shape if fully unknown so return -1.
Mergetensorflow::grappler::__anon949966ef0111::Processor124   Status Merge(DimensionHandle d1, DimensionHandle d2, int64* result) {
125     const int64 dim1 = InferenceContext::Value(d1);
126     const int64 dim2 = InferenceContext::Value(d2);
127 
128     if (dim1 >= 0 && dim2 >= 0) {
129       CHECK_EQ(dim1, dim2);
130       return RefineDim(dim1, result);
131     } else if (dim1 >= 0 && dim2 < 0) {
132       return RefineDim(dim1, result);
133     } else if (dim1 < 0 && dim2 >= 0) {
134       return RefineDim(dim2, result);
135     } else if (dim1 < -1) {
136       return RefineDim(dim1, result);
137     } else if (dim2 < -1) {
138       return RefineDim(dim2, result);
139     } else {
140       CHECK_EQ(dim1, dim2);
141       CHECK_EQ(-1, dim1);
142       return RefineDim(-1, result);
143     }
144     return Status::OK();
145   }
146 
147  private:
RefineDimtensorflow::grappler::__anon949966ef0111::Processor148   Status RefineDim(int64 dim, int64* result) {
149     if (*result >= 0) {
150       if (!(*result == dim || dim < 0)) {
151         return errors::InvalidArgument("Inconsistent dimensions detected");
152       }
153     } else if (dim >= 0) {
154       *result = dim;
155     } else if (dim < *result) {
156       *result = dim;
157     }
158     return Status::OK();
159   }
160 
161   int64 counter = 2;
162 };
163 
164 // Traditional Disjoint-Set datastructure with path compression.
165 // (https://en.wikipedia.org/wiki/Disjoint-set_data_structure)
166 template <typename Handle>
167 class DisjointSet {
168  public:
DisjointSet()169   DisjointSet() {}
~DisjointSet()170   ~DisjointSet() {
171     for (auto rep : nodes_) {
172       delete rep.second;
173     }
174   }
175 
176   Status Merge(Handle x, Handle y);
177   const typename HandleToObject<Handle>::Object GetMergedValue(Handle value);
178 
179  private:
180   // All the handles that belong to the same set are part of the same tree, and
181   // utimately represented by the root of that tree.
182   struct Rep {
183     // Parent in the tree used to encode the set.
184     Rep* parent;
185     // Rank in the tree, used to figure out how to compress the path to the root
186     // of the tree.
187     int rank;
188     // The handle.
189     typename HandleToObject<Handle>::Object value;
190   };
191 
192   // Create a new set for the value if none exists, or return its representative
193   // node otherwise.
194   Rep* Find(Handle value);
195 
196  private:
197   Processor<Handle> processor_;
198   std::unordered_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>>
199       nodes_;
200 };
201 
202 template <typename Handle>
203 const typename HandleToObject<Handle>::Object
GetMergedValue(Handle value)204 DisjointSet<Handle>::GetMergedValue(Handle value) {
205   Rep* rep = Find(value);
206   if (!rep) {
207     // We don't know anything about this handle.
208     return HandleToObject<Handle>::Unknown();
209   }
210   return rep->value;
211 }
212 
213 template <typename Handle>
Merge(Handle x,Handle y)214 Status DisjointSet<Handle>::Merge(Handle x, Handle y) {
215   Rep* x_root = Find(x);
216   Rep* y_root = Find(y);
217 
218   // x and y are already in the same set
219   if (x_root == y_root) {
220     return Status::OK();
221   }
222   // x and y are not in same set, so we merge them
223   // Use the occasion to strengthen what we know about the handle by merging the
224   // information about the 2 subsets.
225   if (x_root->rank < y_root->rank) {
226     TF_RETURN_IF_ERROR(processor_.Merge(y, x, &y_root->value));
227     x_root->parent = y_root;
228   } else if (x_root->rank > y_root->rank) {
229     TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
230     y_root->parent = x_root;
231   } else {
232     TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
233     // Arbitrarily make one root the new parent
234     y_root->parent = x_root;
235     x_root->rank = x_root->rank + 1;
236   }
237   return Status::OK();
238 }
239 
240 template <typename Handle>
Find(Handle value)241 typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
242   auto it = nodes_.find(value);
243   if (it == nodes_.end()) {
244     // This is the first time we process this handle, create an entry for it.
245     Rep* node = new Rep;
246     node->parent = node;
247     node->rank = 0;
248     processor_.ExtractValue(value, &node->value);
249     nodes_[value] = node;
250     return node;
251   }
252   // Return the representative for the set, which is the root of the tree. Apply
253   // path compression to speedup future queries.
254   Rep* node = it->second;
255   Rep* root = node->parent;
256   while (root != root->parent) {
257     root = root->parent;
258   }
259   while (node->parent != root) {
260     Rep* next = node->parent;
261     node->parent = root;
262     node = next;
263   }
264   return root;
265 }
266 
267 // TODO(dyoon): Move many helper functions in this file (including those within
268 // SymbolicShapeRefiner class) to shared utils.
IsEnqueue(const NodeDef & n)269 bool IsEnqueue(const NodeDef& n) {
270   return (n.op().find("Enqueue") != string::npos &&
271           n.op().find("EnqueueMany") == string::npos);
272 }
273 
IsDequeue(const NodeDef & n)274 bool IsDequeue(const NodeDef& n) {
275   return (n.op().find("Dequeue") != string::npos &&
276           n.op().find("DequeueMany") == string::npos);
277 }
278 
HasAnyUnknownDimensions(const TensorShapeProto & proto)279 bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
280   if (proto.unknown_rank()) {
281     return true;
282   }
283   for (const auto& dim : proto.dim()) {
284     if (dim.size() < 0) {
285       return true;
286     }
287   }
288   return false;
289 }
290 
291 // This really should be done in an external debugging tool
VerboseLogUnknownDimensionSources(const GraphDef & graph,const std::unordered_map<string,std::vector<OpInfo::TensorProperties>> & input_properties_map,const std::unordered_map<string,std::vector<OpInfo::TensorProperties>> & output_properties_map)292 void VerboseLogUnknownDimensionSources(
293     const GraphDef& graph,
294     const std::unordered_map<string, std::vector<OpInfo::TensorProperties>>&
295         input_properties_map,
296     const std::unordered_map<string, std::vector<OpInfo::TensorProperties>>&
297         output_properties_map) {
298   if (!VLOG_IS_ON(2)) {
299     return;
300   }
301 
302   VLOG(2) << "Nodes with known inputs, but with unknown output dimensions:";
303 
304   // Find all nodes in the graph for which we
305   // do not have any unknown dimensions in their inputs, but
306   // we have some unknown dimensions in their outputs.
307   std::map<string, int> op_to_count;
308   for (const NodeDef& node : graph.node()) {
309     const auto& input_properties = input_properties_map.at(node.name());
310     const auto& output_properties = output_properties_map.at(node.name());
311 
312     bool has_unknown_inputs = false;
313     for (const auto& input_prop : input_properties) {
314       if (HasAnyUnknownDimensions(input_prop.shape())) {
315         has_unknown_inputs = true;
316         break;
317       }
318     }
319 
320     if (has_unknown_inputs) {
321       continue;
322     }
323 
324     for (const auto& output_prop : output_properties) {
325       if (HasAnyUnknownDimensions(output_prop.shape())) {
326         string inputs = "input_shapes=[";
327         for (const auto& input_prop : input_properties) {
328           inputs += PartialTensorShape::DebugString(input_prop.shape());
329         }
330         inputs += "]";
331 
332         string outputs = "output_shapes=[";
333         for (const auto& output_prop : output_properties) {
334           outputs += PartialTensorShape::DebugString(output_prop.shape());
335         }
336         outputs += "]";
337 
338         VLOG(2) << "Node: " << node.name() << ", Op: " << node.op() << ", "
339                 << inputs << ", " << outputs;
340 
341         op_to_count[node.op()]++;
342 
343         // don't log again for this node
344         break;
345       }
346     }
347   }
348   VLOG(2) << "Op types with known inputs, but with unknown output dimensions "
349           << "(format: <op_type> (<count>)):";
350   for (const auto& p : op_to_count) {
351     VLOG(2) << p.first << " (" << p.second << ")";
352   }
353 }
354 
IsShapeFullyDefinedIntegerVectorOrScalar(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)355 bool IsShapeFullyDefinedIntegerVectorOrScalar(
356     InferenceContext* ic, const ShapeHandle& shape,
357     const ShapeHandle& tensor_as_shape, const DataType& dtype) {
358   if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 ||
359       !ic->FullyDefined(tensor_as_shape) ||
360       (dtype != DT_INT32 && dtype != DT_INT64)) {
361     return false;
362   }
363   return true;
364 }
365 
366 // Returned tensor's shape is like `shape`, and its values and dtype are from
367 // `tensor_as_shape` and `dtype`.
MakeTensorProtoFromShape(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)368 TensorProto MakeTensorProtoFromShape(InferenceContext* ic,
369                                      const ShapeHandle& shape,
370                                      const ShapeHandle& tensor_as_shape,
371                                      const DataType& dtype) {
372   TensorProto tensor_proto;
373   tensor_proto.set_dtype(dtype);
374   auto* shape_proto = tensor_proto.mutable_tensor_shape();
375   if (ic->Rank(shape) == 1) {
376     shape_proto->add_dim()->set_size(ic->Rank(tensor_as_shape));
377   }
378   // For a scalar tensor, tensor_shape field will be left empty; no dim.
379   for (int i = 0; i < ic->Rank(tensor_as_shape); i++) {
380     int64 value = ic->Value(ic->Dim(tensor_as_shape, i));
381     if (dtype == DT_INT32) {
382       tensor_proto.add_int_val(value);
383     } else {
384       tensor_proto.add_int64_val(value);
385     }
386   }
387   return tensor_proto;
388 }
389 
390 // Returns a Const NodeDef with tensor `tensor_proto` and dtype = `dtype`.
MakeConstNodeDefFromTensorProto(InferenceContext * ic,const TensorProto & tensor_proto,const DataType & dtype)391 NodeDef MakeConstNodeDefFromTensorProto(InferenceContext* ic,
392                                         const TensorProto& tensor_proto,
393                                         const DataType& dtype) {
394   NodeDef const_node;
395   const_node.set_name("const_from_shape");
396   const_node.set_op("Const");
397   auto* attr = const_node.mutable_attr();
398   (*attr)["dtype"].set_type(dtype);
399   auto* tensor = (*attr)["value"].mutable_tensor();
400   *tensor = tensor_proto;
401   return const_node;
402 }
403 
404 // Returns a Const NodeDef with shape = `shape`, values = `tensor_as_shape`,
405 // and dtype = `dtype`.
MakeConstNodeDefFromShape(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)406 NodeDef MakeConstNodeDefFromShape(InferenceContext* ic,
407                                   const ShapeHandle& shape,
408                                   const ShapeHandle& tensor_as_shape,
409                                   const DataType& dtype) {
410   return MakeConstNodeDefFromTensorProto(
411       ic, MakeTensorProtoFromShape(ic, shape, tensor_as_shape, dtype), dtype);
412 }
413 
414 }  // namespace
415 
416 // Queue of nodes to process. Nodes can be enqueued in any order, but will be
417 // dequeued in (roughly) topological order. Propagating shapes following a
418 // topological ordering isn't required for correctness but helps speed things up
419 // since it avoids processing the same node multiple times as its inputs
420 // information is refined.
421 class TopoQueue {
422  public:
TopoQueue(const std::vector<const NodeDef * > & topo_order)423   explicit TopoQueue(const std::vector<const NodeDef*>& topo_order)
424       : topo_order_(TopoOrder(topo_order)) {}
425 
push(const NodeDef * n)426   void push(const NodeDef* n) { queue_.emplace(n, topo_order_.at(n)); }
427 
pop()428   const NodeDef* pop() {
429     CHECK(!empty());
430     auto it = queue_.begin();
431     const NodeDef* n = it->first;
432     queue_.erase(it);
433     return n;
434   }
435 
empty() const436   bool empty() const { return queue_.empty(); }
size() const437   std::size_t size() const { return queue_.size(); }
438 
439  private:
440   using NodeAndId = std::pair<const NodeDef*, int>;
441   // Graph nodes are created in (roughly) topological order. Therefore we can
442   // use their id to ensure they're sorted topologically.
443   struct OrderByIdAscending {
operator ()tensorflow::grappler::TopoQueue::OrderByIdAscending444     bool operator()(const NodeAndId& lhs, const NodeAndId& rhs) const {
445       return lhs.second < rhs.second;
446     }
447   };
448 
TopoOrder(const std::vector<const NodeDef * > & topo_order) const449   const std::unordered_map<const NodeDef*, int> TopoOrder(
450       const std::vector<const NodeDef*>& topo_order) const {
451     std::unordered_map<const NodeDef*, int> map;
452     map.reserve(topo_order.size());
453     for (int i = 0; i < topo_order.size(); ++i) {
454       map.emplace(topo_order[i], i);
455     }
456     return map;
457   }
458 
459   const std::unordered_map<const NodeDef*, int> topo_order_;
460   std::set<NodeAndId, OrderByIdAscending> queue_;
461 };
462 
IsNumericType(const DataType dtype)463 bool IsNumericType(const DataType dtype) {
464   static const gtl::FlatSet<DataType>* const kRealNumberTypes =
465       CHECK_NOTNULL((new gtl::FlatSet<DataType>{
466           // Floating point.
467           DT_BFLOAT16,
468           DT_HALF,
469           DT_FLOAT,
470           DT_DOUBLE,
471           // Int / UInt.
472           DT_INT8,
473           DT_INT16,
474           DT_INT32,
475           DT_INT64,
476           DT_UINT8,
477           DT_UINT16,
478           DT_UINT32,
479           DT_UINT64,
480           // Quantized Int.
481           DT_QINT8,
482           DT_QUINT8,
483           DT_QINT16,
484           DT_QUINT16,
485           DT_QINT32,
486           // Bool.
487           DT_BOOL,
488       }));
489   return kRealNumberTypes->find(dtype) != kRealNumberTypes->end();
490 }
491 
IsWhiteListedOpTypeForEvaluateNode(const string & op_type)492 bool IsWhiteListedOpTypeForEvaluateNode(const string& op_type) {
493   static const gtl::FlatSet<string>* const kOpTpeWhitelist =
494       CHECK_NOTNULL((new gtl::FlatSet<string>{
495           // Unary arithmetic ops
496           "Floor",
497           "Round",
498           "Sqrt",
499           "Square",
500           "Sign",
501           // Binary arithmetic ops
502           "Add",
503           "Div",
504           "FloorDiv",
505           "FloorMod",
506           "Greater",
507           "GreaterEqual",
508           "Less",
509           "LessEqual",
510           "LogicalAnd",
511           "LogicalNot",
512           "LogicalOr",
513           "Maximum",
514           "Minimum",
515           "Mod",
516           "Mul",
517           "NotEqual",
518           "QuantizedAdd",
519           "QuantizedMul",
520           "SquareDifference",
521           "Sub",
522           "TruncateDiv",
523           "TruncateMod",
524           "RealDiv",
525           // N-ary arithemtic ops
526           "AddN",
527           // Others
528           "StridedSlice",
529           "OnesLike",
530           "ZerosLike",
531           "Concat",
532           "ConcatV2",
533           "Split",
534           "Range",
535           "Fill",
536           "Cast",
537       }));
538   return kOpTpeWhitelist->find(op_type) != kOpTpeWhitelist->end();
539 }
540 
541 // Processes symbolic shapes.
542 // Each symbolic shape or dimension is represented by a handle. Unlike the TF
543 // shape refiner which creates new handles every time it processes an unknown
544 // shape/dimension, the symbolic shape refiner assigns a specific handle to each
545 // unknown shape/dimension of a given node.
546 class SymbolicShapeRefiner {
547  public:
SymbolicShapeRefiner(const GraphView & graph,const std::unordered_map<string,std::unordered_set<int>> & fed_ports,const bool aggressive_shape_inference)548   explicit SymbolicShapeRefiner(
549       const GraphView& graph,
550       const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
551       const bool aggressive_shape_inference)
552       : graph_(graph),
553         function_library_(OpRegistry::Global(), graph.graph()->library()),
554         fed_ports_(fed_ports),
555         aggressive_shape_inference_(aggressive_shape_inference) {
556     graph_def_version_ = graph.graph()->versions().producer();
557     node_to_context_.reserve(graph.graph()->node_size());
558   }
559 
graph() const560   const GraphView& graph() const { return graph_; }
561 
562   struct NodeContext {
563     const OpRegistrationData* op_data;
564     DataTypeVector input_types;
565     DataTypeVector output_types;
566     std::unique_ptr<InferenceContext> inference_context;
567     // Additional info for propagating tensor values and tensor shapes.
568     std::vector<const TensorProto*> input_tensor_protos;
569     std::vector<const TensorProto*> output_tensor_protos;
570     std::vector<ShapeHandle> output_tensors_as_shapes;
571   };
572 
GetNodeContext(const NodeDef * node)573   NodeContext* GetNodeContext(const NodeDef* node) {
574     auto it = node_to_context_.find(node);
575     if (it == node_to_context_.end()) {
576       return nullptr;
577     }
578     return &it->second;
579   }
580 
GetContext(const NodeDef * node)581   InferenceContext* GetContext(const NodeDef* node) {
582     auto it = node_to_context_.find(node);
583     if (it == node_to_context_.end()) {
584       return nullptr;
585     }
586     return it->second.inference_context.get();
587   }
588 
589   // Forward the shapes from the function input nodes to
590   // the argument nodes (which are Placeholder nodes), then
591   // perform shape inference on the function body.
592   //
593   // Propagate shape information of final function body node
594   // to function node `function_node`.
595   //
596   // In the event of an error, UpdateNode will simply set `function_node`'s
597   // output shape to be Unknown.
UpdateFunction(const NodeDef * function_node)598   Status UpdateFunction(const NodeDef* function_node) {
599     auto it = fun_to_grappler_function_item_.find(function_node->op());
600     if (it == fun_to_grappler_function_item_.end()) {
601       return errors::InvalidArgument(
602           function_node->op(),
603           " was not previously added to SymbolicShapeRefiner.");
604     }
605 
606     // Copy (not reference) so that changes we make here (e.g., replacing
607     // Placeholder with Const) don't affect one in
608     // fun_to_grappler_function_item_.
609     GrapplerFunctionItem grappler_function_item = it->second;
610     MutableGraphView gv(&grappler_function_item.graph);
611 
612     // Forward shapes from function input nodes to argument nodes.
613     for (int i = 0; i < grappler_function_item.inputs().size(); ++i) {
614       auto& fun_input = grappler_function_item.input(i);
615       if (fun_input.placeholders.size() > 1) {
616         // TODO(jmdecker): Handle case with multiple input placeholders
617         return errors::Unimplemented(
618             "Input arguments with multiple placeholders are not yet "
619             "supported.");
620       }
621       NodeDef* fun_node = gv.GetNode(fun_input.input_name);
622       const TensorId input_tensor = ParseTensorName(function_node->input(i));
623 
624       if (IsControlInput(input_tensor)) {
625         return errors::FailedPrecondition(
626             "Function inputs should not contain control nodes.");
627       }
628 
629       const NodeDef* input_node = graph_.GetNode(input_tensor.node());
630       if (input_node == nullptr) {
631         return errors::FailedPrecondition(input_tensor.node(),
632                                           " was not found in the graph.");
633       }
634 
635       InferenceContext* input_ic = GetContext(input_node);
636       if (input_ic == nullptr) {
637         return errors::FailedPrecondition(
638             "Inference context has not been created for ", input_tensor.node());
639       }
640 
641       int output_port_num = input_tensor.index();
642       AttrValue attr_output_shape;
643       TensorShapeProto proto;
644       const auto& handle = input_ic->output(output_port_num);
645       input_ic->ShapeHandleToProto(handle, &proto);
646       // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
647       for (int i = 0; i < proto.dim_size(); i++) {
648         if (proto.dim(i).size() < -1) {
649           proto.mutable_dim(i)->set_size(-1);
650         }
651       }
652       *attr_output_shape.mutable_shape() = proto;
653       (*fun_node->mutable_attr())["shape"] = attr_output_shape;
654     }
655 
656     // Replace input Placeholders with Consts, if values are known. Note that
657     // we don't check exceptions here as it's done in the above loop.
658     auto* ctx = GetNodeContext(function_node);
659     auto* ic = ctx->inference_context.get();
660     for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
661       const string& input = function_node->input(i);
662       const string& node_name = NodeName(input);
663       const NodeDef* input_node = graph_.GetNode(node_name);
664       if (IsConstant(*input_node)) {
665         TF_CHECK_OK(
666             ReplaceInputWithConst(*input_node, i, &grappler_function_item));
667       } else if (ctx->input_tensor_protos.size() > i &&
668                  ctx->input_tensor_protos[i] != nullptr) {
669         NodeDef const_input_node = MakeConstNodeDefFromTensorProto(
670             ic, *ctx->input_tensor_protos[i], ctx->input_types[i]);
671         TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
672                                           &grappler_function_item));
673       } else if (ic->input_tensors_as_shapes().size() > i &&
674                  IsShapeFullyDefinedIntegerVectorOrScalar(
675                      ic, ic->input(i), ic->input_tensors_as_shapes()[i],
676                      ctx->input_types[i])) {
677         // We have fully defined input_tensors_as_shapes for this input; use it
678         // as a const input to the function node.
679         NodeDef const_input_node = MakeConstNodeDefFromShape(
680             ic, ic->input(i), ic->input_tensors_as_shapes()[i],
681             ctx->input_types[i]);
682         TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
683                                           &grappler_function_item));
684       }
685     }
686 
687     // Perform inference on function body.
688     GraphProperties gp(grappler_function_item);
689     TF_RETURN_IF_ERROR(gp.InferStatically(true, aggressive_shape_inference_));
690 
691     // Add return nodes for output shapes.
692     int output = 0;
693     ctx->output_tensors_as_shapes.resize(grappler_function_item.output_size());
694     ctx->output_tensor_protos.resize(grappler_function_item.output_size(),
695                                      nullptr);
696     for (auto const& out_arg : grappler_function_item.outputs()) {
697       if (out_arg.output_nodes.size() > 1) {
698         // TODO(jmdecker): Handle case of multiple output tensors
699         return errors::Unimplemented(
700             "Output arguments with multiple output tensors are not yet "
701             "supported.");
702       }
703 
704       // It is guaranteed that output_tensors does not contain any control
705       // inputs, so port_id >= 0.
706       TensorId out_tensor = ParseTensorName(out_arg.output_nodes[0]);
707 
708       const NodeDef* retnode = gv.GetNode(out_tensor.node());
709       if (retnode == nullptr) {
710         return errors::FailedPrecondition(
711             "Unable to find return function_node ", out_tensor.node(), " for ",
712             function_node->name());
713       }
714 
715       auto output_properties = gp.GetOutputProperties(retnode->name());
716       if (out_tensor.index() >= output_properties.size()) {
717         return errors::InvalidArgument(
718             out_tensor.ToString(), " has invalid position ", out_tensor.index(),
719             " (output_properties.size() = ", output_properties.size(), ").");
720       }
721       auto const& outprop = output_properties[out_tensor.index()];
722       const TensorShapeProto& shape = outprop.shape();
723       ShapeHandle out;
724       TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
725       ic->set_output(output, out);
726       if (outprop.has_value()) {
727         // Forward tensor value to output_tensors_as_shape.
728         Tensor tensor;
729         if (tensor.FromProto(outprop.value())) {
730           MaybeTensorValueToShape(ic, tensor,
731                                   &ctx->output_tensors_as_shapes[output]);
732           const_tensors_to_propagate_.push_back(outprop.value());
733           ctx->output_tensor_protos[output] =
734               &const_tensors_to_propagate_.back();
735         }
736       }
737       output++;
738     }
739 
740     return Status::OK();
741   }
742 
743   // Prepares input shapes/values/handles, then runs shape inference, and
744   // finally sets output shapes/values/handles.
UpdateNode(const NodeDef * node,bool * refined)745   Status UpdateNode(const NodeDef* node, bool* refined) {
746     NodeContext* ctx = GetNodeContext(node);
747     if (ctx == nullptr) {
748       TF_RETURN_IF_ERROR(AddNode(node));
749       ctx = CHECK_NOTNULL(GetNodeContext(node));
750       *refined = true;
751     }
752 
753     // Check if the shapes of the nodes in the fan-in of this node have changed,
754     // and if they have, update the node input shapes.
755     InferenceContext* ic = ctx->inference_context.get();
756     std::vector<Tensor> const_values(ic->num_inputs());
757     std::vector<const Tensor*> input_tensors(ic->num_inputs(), nullptr);
758     std::vector<ShapeHandle> input_tensors_as_shapes(ic->num_inputs());
759     ctx->input_tensor_protos.resize(ic->num_inputs(), nullptr);
760 
761     for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) {
762       const GraphView::InputPort port(node, dst_input);
763       const GraphView::OutputPort fanin = graph_.GetRegularFanin(port);
764       int src_output = fanin.port_id;
765       const NodeDef* src = fanin.node;
766       NodeContext* src_ctx = GetNodeContext(src);
767       InferenceContext* src_ic = src_ctx->inference_context.get();
768       if (src_ctx == nullptr) {
769         return errors::FailedPrecondition(
770             "Input ", dst_input, " ('", src->name(), "') for '", node->name(),
771             "' was not previously added to SymbolicShapeRefiner.");
772       }
773 
774       if (src_output >= src_ic->num_outputs()) {
775         return errors::OutOfRange("src_output = ", src_output,
776                                   ", but num_outputs is only ",
777                                   src_ic->num_outputs());
778       }
779 
780       // Propagate input node's NodeContext info to the current node's
781       // NodeContext:
782       // output_tensor_protos to input_tensor_protos and input_tensors, and
783       // output_tensors_as_shapes to input_tensors_as_shapes.
784 
785       if (src_ctx->output_tensors_as_shapes.size() > src_output) {
786         input_tensors_as_shapes[dst_input] =
787             src_ctx->output_tensors_as_shapes[src_output];
788       }
789 
790       if (src_ctx->output_tensor_protos.size() > src_output) {
791         auto* tensor_proto = src_ctx->output_tensor_protos[src_output];
792         if (tensor_proto != nullptr &&
793             const_values[dst_input].FromProto(*tensor_proto)) {
794           input_tensors[dst_input] = &const_values[dst_input];
795           ctx->input_tensor_protos[dst_input] = tensor_proto;
796 
797           if (!ic->FullyDefined(input_tensors_as_shapes[dst_input])) {
798             // Shape from a Const is not fully defined when the Const has
799             // value -1 (e.g., Reshape(x, Const(-1)) to reshape an arbitrary
800             // tensor x to a vector).
801             // It's possible that the same Const with -1 is used in many
802             // places, but that doesn't mean the resultant shapes are
803             // identical. e.g., x1 = Reshape(x, c) and y1 = Reshape(y, c),
804             // where c is -1. In this case, shape inference yields both x1 and
805             // y1 as rank 1, size unknown, but still the shapes of x1 and y1
806             // can be different. (even if we use different Const(-1) for x1
807             // and x2, graph optimzier may merge them to single Const through
808             // duplicate removal.)
809             // If we reuse output_tensors_as_shapes to input_tensors_as_shapes
810             // by copying ShapeHandle, they share the same Shape object, and
811             // SymbolicShapeManager, later in InferStatically(), assigns the
812             // same symbolic dim value (unique value < -1); in the above
813             // Reshape example, the shapes of x1 and y1 become, for example,
814             // [-278] and graph optimizer may yield incorrect output 'cause it
815             // assumes x1 and y1 have the same shape.
816             // To prevent this, we re-create a ShapeHandle from the Const
817             // tensor, instead of reusing output_tensors_as_shapes (so that
818             // ShapeHandles of the const fanouts have the same values,
819             // but different Shape objects -- SymbolicShapeManager assigns
820             // different symbol id to each fanout shape).
821             // TODO(dyoon): clean up the way values are propagated.
822             MaybeTensorValueToShape(ic, const_values[dst_input],
823                                     &input_tensors_as_shapes[dst_input]);
824           }
825         }
826       }
827 
828       // NOTE: we check only shape is refined; we do not (yet) check whether
829       // tensor value is refined.
830       if (!*refined &&
831           !ic->input(dst_input).SameHandle(src_ic->output(src_output))) {
832         *refined = true;
833       }
834       ic->SetInput(dst_input, src_ic->output(src_output));
835 
836       if (!*refined && ic->requested_input_tensor_as_partial_shape(dst_input)) {
837         // The input value may have changed. Since we have no way to know if
838         // that's indeed the case, err on the safe side.
839         *refined = true;
840       }
841 
842       // Also propagate handle shape and dtype of edges which are carrying
843       // resource handles.
844       if (ctx->input_types[dst_input] == DT_RESOURCE) {
845         auto* outputs = src_ic->output_handle_shapes_and_types(src_output);
846         if (!outputs) continue;
847         auto* inputs = ic->input_handle_shapes_and_types(dst_input);
848 
849         if (!inputs || !EquivalentShapesAndTypes(*outputs, *inputs))
850           *refined = true;
851         ic->set_input_handle_shapes_and_types(dst_input, *outputs);
852       }
853     }
854 
855     // Make sure we schedule the fanout of resources (which have no input)
856     // whenever the resources are updated.
857     *refined |= ic->num_inputs() == 0;
858 
859     if (!*refined) {
860       // No input shape has changed, we're done.
861       return Status::OK();
862     }
863 
864     ic->set_input_tensors(input_tensors);
865     ic->set_input_tensors_as_shapes(input_tensors_as_shapes);
866 
867     // Properly handle function nodes.
868     if (ctx->op_data && ctx->op_data->is_function_op) {
869       // TODO(jmdecker): Detect if the input shapes have changed for this
870       // function. Note that when we hit a function call node, refined will be
871       // true, as the updates to the call node will have changed, even if it's
872       // the same function being called twice with the same input shapes.
873       // Example: simple_function.pbtxt
874       auto s = UpdateFunction(node);
875       if (s.ok()) {
876         return Status::OK();
877       } else {
878         VLOG(1) << "UpdateFunction failed for " << node->op()
879                 << ". Defaulting to ShapeUnknown.\n"
880                 << s.ToString();
881       }
882     }
883 
884     // Update the shapes of the outputs.
885     return InferShapes(*node, ctx);
886   }
887 
SetUnknownShape(const NodeDef * node,int output_port)888   Status SetUnknownShape(const NodeDef* node, int output_port) {
889     shape_inference::ShapeHandle shape =
890         GetUnknownOutputShape(node, output_port);
891     InferenceContext* ctx = GetContext(node);
892     if (ctx == nullptr) {
893       return errors::InvalidArgument("Missing context");
894     }
895     ctx->set_output(output_port, shape);
896     return Status::OK();
897   }
898 
899   struct ShapeId {
900     const NodeDef* node;
901     int port_id;
operator ==tensorflow::grappler::SymbolicShapeRefiner::ShapeId902     bool operator==(const ShapeId& other) const {
903       return node == other.node && port_id == other.port_id;
904     }
905   };
906   struct HashShapeId {
operator ()tensorflow::grappler::SymbolicShapeRefiner::HashShapeId907     std::size_t operator()(const ShapeId& shp) const {
908       return std::hash<const NodeDef*>{}(shp.node) + shp.port_id;
909     }
910   };
911 
912   struct DimId {
913     const NodeDef* node;
914     int port_id;
915     int dim_index;
operator ==tensorflow::grappler::SymbolicShapeRefiner::DimId916     bool operator==(const DimId& other) const {
917       return node == other.node && port_id == other.port_id &&
918              dim_index == other.dim_index;
919     }
920   };
921 
922   struct HashDimId {
operator ()tensorflow::grappler::SymbolicShapeRefiner::HashDimId923     std::size_t operator()(const DimId& dim) const {
924       return std::hash<const NodeDef*>{}(dim.node) + dim.port_id +
925              dim.dim_index;
926     }
927   };
928 
929   // 'port_index' as the union of shape1 and shape2.
OutputAsUnion(const NodeDef * node,int port_index,ShapeHandle shape1,ShapeHandle shape2)930   ShapeHandle OutputAsUnion(const NodeDef* node, int port_index,
931                             ShapeHandle shape1, ShapeHandle shape2) {
932     if (shape1.SameHandle(shape2)) {
933       return shape1;
934     }
935     InferenceContext* ctx = GetContext(node);
936     ShapeHandle relaxed = shape1;
937     const int rank = ctx->Rank(shape1);
938     if (!ctx->RankKnown(shape2) || ctx->Rank(shape2) != rank) {
939       relaxed = GetUnknownOutputShape(node, port_index);
940     } else {
941       for (int d = 0; d < rank; ++d) {
942         if (!ctx->Dim(shape1, d).SameHandle(ctx->Dim(shape2, d))) {
943           int64 val1 = ctx->Value(ctx->Dim(shape1, d));
944           int64 val2 = ctx->Value(ctx->Dim(shape2, d));
945           if (val1 != val2 || (val1 < 0 && val2 < 0)) {
946             DimensionHandle new_dim = GetUnknownOutputDim(node, port_index, d);
947             TF_CHECK_OK(ctx->ReplaceDim(relaxed, d, new_dim, &relaxed));
948           }
949         }
950       }
951     }
952     return relaxed;
953   }
954 
EquivalentShapes(ShapeHandle s1,ShapeHandle s2) const955   bool EquivalentShapes(ShapeHandle s1, ShapeHandle s2) const {
956     if (s1.SameHandle(s2)) {
957       return true;
958     }
959     if (InferenceContext::Rank(s1) != InferenceContext::Rank(s2)) {
960       return false;
961     }
962     if (!InferenceContext::RankKnown(s1) && !InferenceContext::RankKnown(s2)) {
963       return true;
964     }
965     const int rank = InferenceContext::Rank(s1);
966     for (int i = 0; i < rank; ++i) {
967       if (!InferenceContext::DimKnownRank(s1, i).SameHandle(
968               InferenceContext::DimKnownRank(s2, i))) {
969         int64 val1 =
970             InferenceContext::Value(InferenceContext::DimKnownRank(s1, i));
971         int64 val2 =
972             InferenceContext::Value(InferenceContext::DimKnownRank(s2, i));
973         if (val1 >= 0 && val2 >= 0 && val1 == val2) {
974           continue;
975         }
976         return false;
977       }
978     }
979     return true;
980   }
981 
982   // Return true if the annotated shape is compatible with shape inference
983   // result. Examples:
984   // Inferred shape: ?, annotated shape: [10, 10] -> true;
985   // Inferred shape: [-1, 10], annotated shape: [10, 10] -> true;
986   // Inferred shape: [-1, 100], annotated shape: [10, 10] -> false;
987   // Inferred shape: [-1, 10, 10], annotated shape: [10, 10] -> false.
CompatibleShapes(ShapeHandle inferred_shape,ShapeHandle annotated_shape) const988   bool CompatibleShapes(ShapeHandle inferred_shape,
989                         ShapeHandle annotated_shape) const {
990     if (inferred_shape.SameHandle(annotated_shape)) {
991       return true;
992     }
993     if (!InferenceContext::RankKnown(inferred_shape)) {
994       return true;
995     }
996     if (InferenceContext::Rank(inferred_shape) !=
997         InferenceContext::Rank(annotated_shape)) {
998       return false;
999     }
1000     const int rank = InferenceContext::Rank(inferred_shape);
1001     for (int i = 0; i < rank; ++i) {
1002       if (!InferenceContext::DimKnownRank(inferred_shape, i)
1003                .SameHandle(
1004                    InferenceContext::DimKnownRank(annotated_shape, i))) {
1005         int64 val1 = InferenceContext::Value(
1006             InferenceContext::DimKnownRank(inferred_shape, i));
1007         int64 val2 = InferenceContext::Value(
1008             InferenceContext::DimKnownRank(annotated_shape, i));
1009         if (val1 >= 0 && val1 != val2) {
1010           return false;
1011         }
1012       }
1013     }
1014     return true;
1015   }
1016 
EquivalentShapesAndTypes(const std::vector<ShapeAndType> & st1,const std::vector<ShapeAndType> & st2) const1017   bool EquivalentShapesAndTypes(const std::vector<ShapeAndType>& st1,
1018                                 const std::vector<ShapeAndType>& st2) const {
1019     if (st1.size() != st2.size()) {
1020       return false;
1021     }
1022     for (int i = 0; i < st1.size(); ++i) {
1023       const ShapeAndType& s1 = st1[i];
1024       const ShapeAndType& s2 = st2[i];
1025       if (s1.dtype != s2.dtype) {
1026         return false;
1027       }
1028       if (!EquivalentShapes(s1.shape, s2.shape)) {
1029         return false;
1030       }
1031     }
1032     return true;
1033   }
1034 
AddFunction(const NodeDef * function_node)1035   Status AddFunction(const NodeDef* function_node) {
1036     auto it = fun_to_grappler_function_item_.find(function_node->op());
1037     if (it != fun_to_grappler_function_item_.end()) {
1038       return Status::OK();
1039     }
1040 
1041     const FunctionDef* function_def =
1042         CHECK_NOTNULL(function_library_.Find(function_node->op()));
1043 
1044     GrapplerFunctionItem grappler_function_item;
1045     TF_RETURN_IF_ERROR(
1046         MakeGrapplerFunctionItem(*function_def, function_library_,
1047                                  graph_def_version_, &grappler_function_item));
1048 
1049     if (grappler_function_item.inputs().size() > function_node->input_size()) {
1050       return errors::FailedPrecondition(
1051           "Function input size should be smaller than node input size.");
1052     }
1053 
1054     for (int i = grappler_function_item.inputs().size();
1055          i < function_node->input_size(); ++i) {
1056       const string& input = function_node->input(i);
1057       if (!IsControlInput(input)) {
1058         return errors::FailedPrecondition(
1059             "Found regular input (", input,
1060             ") instead of control nodes for node ", function_node->name());
1061       }
1062     }
1063 
1064     fun_to_grappler_function_item_[function_def->signature().name()] =
1065         grappler_function_item;
1066 
1067     return Status::OK();
1068   }
1069 
AddNode(const NodeDef * node)1070   Status AddNode(const NodeDef* node) {
1071     NodeContext& node_ctx = node_to_context_[node];
1072     TF_RETURN_IF_ERROR(function_library_.LookUp(node->op(), &node_ctx.op_data));
1073 
1074     if (node_ctx.op_data->is_function_op) {
1075       TF_RETURN_IF_ERROR(AddFunction(node));
1076     }
1077 
1078     TF_RETURN_IF_ERROR(InOutTypesForNode(*node, node_ctx.op_data->op_def,
1079                                          &node_ctx.input_types,
1080                                          &node_ctx.output_types));
1081 
1082     // Create the inference context for this node.
1083     const int num_inputs = node_ctx.input_types.size();
1084     std::vector<ShapeHandle> input_shapes(num_inputs);
1085     std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
1086         input_handle_shapes_and_types(num_inputs);
1087     std::vector<const Tensor*> input_tensors(num_inputs, nullptr);
1088     std::vector<ShapeHandle> input_tensors_as_shapes;
1089 
1090     node_ctx.inference_context.reset(new InferenceContext(
1091         graph_def_version_, node, node_ctx.op_data->op_def, input_shapes,
1092         input_tensors, input_tensors_as_shapes,
1093         std::move(input_handle_shapes_and_types)));
1094     const Status s = node_ctx.inference_context->construction_status();
1095     if (!s.ok()) {
1096       node_ctx.inference_context.reset(nullptr);
1097     }
1098     return s;
1099   }
1100 
1101  private:
1102   // Return the one ShapeHandle used to denote a fully unknown shape for a node
1103   // output.
GetUnknownOutputShape(const NodeDef * node,int index)1104   ShapeHandle GetUnknownOutputShape(const NodeDef* node, int index) {
1105     ShapeId id{node, index};
1106     auto it = unknown_shapes_.find(id);
1107     if (it != unknown_shapes_.end()) {
1108       return it->second;
1109     }
1110     InferenceContext* c = GetContext(node);
1111     ShapeHandle shp = c->UnknownShape();
1112     unknown_shapes_[id] = shp;
1113     return shp;
1114   }
1115   // Return the one ShapeHandle used to denote a fully unknown dimension for a
1116   // node output.
GetUnknownOutputDim(const NodeDef * node,int index,int dim_id)1117   DimensionHandle GetUnknownOutputDim(const NodeDef* node, int index,
1118                                       int dim_id) {
1119     DimId id{node, index, dim_id};
1120     auto it = unknown_dims_.find(id);
1121     if (it != unknown_dims_.end()) {
1122       return it->second;
1123     }
1124     InferenceContext* c = GetContext(node);
1125     DimensionHandle dim = c->UnknownDim();
1126     unknown_dims_[id] = dim;
1127     return dim;
1128   }
1129 
1130   // Returns true if all the output tensors have known values.
AllOutputValuesKnown(NodeContext * c)1131   bool AllOutputValuesKnown(NodeContext* c) {
1132     InferenceContext* ic = c->inference_context.get();
1133     if (c->output_tensors_as_shapes.size() < ic->num_outputs() &&
1134         c->output_tensor_protos.size() < ic->num_outputs()) {
1135       return false;
1136     } else {
1137       // Checks if we can get output value via either output_tensor_proto or
1138       // output_tensors_as_shapes.
1139       for (int i = 0; i < ic->num_outputs(); i++) {
1140         if (c->output_tensor_protos.size() > i &&
1141             c->output_tensor_protos[i] != nullptr) {
1142           continue;
1143         }
1144         if (c->output_tensors_as_shapes.size() > i &&
1145             ic->FullyDefined(c->output_tensors_as_shapes[i])) {
1146           continue;
1147         }
1148 
1149         // Unknown for output[i].
1150         return false;
1151       }
1152     }
1153     return true;
1154   }
1155 
1156   // Returns true if we can infer output tensors' values -- we know values of
1157   // all the input tensors.
AllInputValuesKnown(NodeContext * c)1158   bool AllInputValuesKnown(NodeContext* c) {
1159     InferenceContext* ic = c->inference_context.get();
1160 
1161     // Check inputs are fully defined and values are known.
1162     for (int i = 0; i < ic->num_inputs(); i++) {
1163       const Tensor* tensor = ic->input_tensor(i);
1164       // Note that we don't check c->input_tensor_protos[i], as UpdateNode()
1165       // already converted it to ic->input_tensor(i);
1166       const ShapeHandle& input_tensors_as_shape =
1167           ic->input_tensors_as_shapes()[i];
1168       // Either input_tensor is valid or input_tensors_as_shape, which has
1169       // value of input tensors as shape format, should be fully defined.
1170       if (tensor == nullptr && !ic->FullyDefined(input_tensors_as_shape)) {
1171         return false;
1172       }
1173     }
1174     return true;
1175   }
1176 
1177   // Returns true if we want to update output shapes and values with running
1178   // EvaluateNode() for this op, based on op type, data type, and size.
ShouldUpdateOutputShapesAndValues(NodeContext * c,int64 max_size)1179   bool ShouldUpdateOutputShapesAndValues(NodeContext* c, int64 max_size) {
1180     InferenceContext* ic = c->inference_context.get();
1181 
1182     // Due to the cost of running EvaluateNode(), we limit only to white listed
1183     // op types.
1184     if (!IsWhiteListedOpTypeForEvaluateNode(c->op_data->op_def.name())) {
1185       return false;
1186     }
1187 
1188     // Check input dtypes are number types.
1189     for (const auto& input_type : c->input_types) {
1190       if (!IsNumericType(input_type)) {
1191         return false;
1192       }
1193     }
1194 
1195     // Check output dtypes are number types.
1196     for (const auto& output_type : c->output_types) {
1197       if (!IsNumericType(output_type)) {
1198         return false;
1199       }
1200     }
1201 
1202     // Check if the number of elements of each of input tensor is no larger than
1203     // the given max size.
1204     for (int i = 0; i < ic->num_inputs(); i++) {
1205       const Tensor* tensor = ic->input_tensor(i);
1206       const ShapeHandle& input_shape_handle = ic->input(i);
1207       if (tensor != nullptr) {
1208         if (tensor->NumElements() > max_size) {
1209           return false;
1210         }
1211       } else if (ic->Value(ic->NumElements(input_shape_handle)) > max_size) {
1212         return false;
1213       }
1214     }
1215 
1216     // Check if we know the shape of each output tensor, and the number of
1217     // elements is larger than the given max size.
1218     for (int i = 0; i < ic->num_outputs(); i++) {
1219       const ShapeHandle& shape_handle = ic->output(i);
1220       if (!ic->FullyDefined(shape_handle) ||
1221           ic->Value(ic->NumElements(shape_handle)) > max_size) {
1222         return false;
1223       }
1224     }
1225     return true;
1226   }
1227 
1228   // Create input tensors from the NodeConext.
CreateInputTensors(NodeContext * c,std::vector<Tensor> * input_tensor_vector,TensorVector * inputs)1229   void CreateInputTensors(NodeContext* c,
1230                           std::vector<Tensor>* input_tensor_vector,
1231                           TensorVector* inputs) {
1232     InferenceContext* ic = c->inference_context.get();
1233     for (int i = 0; i < ic->num_inputs(); i++) {
1234       if (ic->input_tensor(i)) {
1235         input_tensor_vector->at(i) = *ic->input_tensor(i);
1236         inputs->emplace_back(&input_tensor_vector->at(i));
1237         // Note that we don't check c->input_tensor_protos[i], as UpdateNode()
1238         // already converted it to ic->input_tensor(i);
1239       } else {
1240         // Create Tensor from input_tensors_as_shapes, and then emplace it
1241         // back to inputs.
1242         // Note that input_tensors_as_shapes is scalar or vector.
1243         const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
1244         const DataType& data_type = c->input_types[i];
1245         int32 rank = ic->Rank(shape_handle);
1246         if (rank < 1) {
1247           input_tensor_vector->at(i) = Tensor(data_type, {});
1248         } else {
1249           input_tensor_vector->at(i) = Tensor(data_type, {rank});
1250         }
1251         auto* tensor = &input_tensor_vector->at(i);
1252         if (data_type == DT_INT32) {
1253           auto flat = tensor->flat<int32>();
1254           for (int j = 0; j < rank; j++) {
1255             int32 dim = ic->Value(ic->Dim(shape_handle, j));
1256             flat(j) = dim;
1257           }
1258         } else {
1259           auto flat = tensor->flat<int64>();
1260           for (int j = 0; j < rank; j++) {
1261             int64 dim = ic->Value(ic->Dim(shape_handle, j));
1262             flat(j) = dim;
1263           }
1264         }
1265         inputs->emplace_back(tensor);
1266       }
1267     }
1268   }
1269 
1270   // Run a node to infer output shapes and values, and add it to the
1271   // NodeContext.
UpdateOutputShapesAndValues(const NodeDef & node,NodeContext * c)1272   Status UpdateOutputShapesAndValues(const NodeDef& node, NodeContext* c) {
1273     InferenceContext* ic = c->inference_context.get();
1274 
1275     // Input to EvaluateNode()
1276     TensorVector inputs;
1277     // Container for temporaily created tensor object.
1278     std::vector<Tensor> input_tensor_vector(ic->num_inputs());
1279     CreateInputTensors(c, &input_tensor_vector, &inputs);
1280 
1281     // Output for EvaluateNode() and output tensor clean up object.
1282     TensorVector outputs;
1283     auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
1284       for (const auto& output : outputs) {
1285         if (output.tensor) {
1286           delete output.tensor;
1287         }
1288       }
1289     });
1290 
1291     TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, /*cpu_device=*/nullptr,
1292                                     &resource_mgr_, &outputs));
1293     c->output_tensors_as_shapes.resize(outputs.size());
1294     c->output_tensor_protos.resize(outputs.size(), nullptr);
1295     for (int k = 0; k < outputs.size(); k++) {
1296       const auto& t = outputs[k];
1297       // Override output shape.
1298       ShapeHandle output_shape;
1299       TF_RETURN_IF_ERROR(
1300           ic->MakeShapeFromTensorShape(t->shape(), &output_shape));
1301       if (ic->FullyDefined(ic->output(k)) &&
1302           !EquivalentShapes(ic->output(k), output_shape)) {
1303         LOG(WARNING) << "UpdateOutputShapesAndValues() -- node: " << node.name()
1304                      << ", inferred output shape "
1305                      << "doesn't match for k=" << k << ": "
1306                      << "ic->output(k): " << ic->DebugString(ic->output(k))
1307                      << ", output_shape: " << ic->DebugString(output_shape)
1308                      << " -- " << node.DebugString();
1309       }
1310       ic->set_output(k, output_shape);
1311       // Set output_tensors_as_shape.
1312       MaybeTensorValueToShape(ic, *t.tensor, &c->output_tensors_as_shapes[k]);
1313 
1314       // Set output_tensor_protos.
1315       TensorProto tensor_proto;
1316       t->AsProtoTensorContent(&tensor_proto);
1317       const_tensors_to_propagate_.push_back(tensor_proto);
1318       c->output_tensor_protos[k] = &const_tensors_to_propagate_.back();
1319     }
1320     return Status::OK();
1321   }
1322 
1323   // Update output shapes with annotated information.
1324   // Currently only handle nodes with static shapes, i.e. shapes do not change
1325   // during execution.
1326   // TODO(andiryxu): Use annotated shapes in Enter/Merge etc as well.
UpdateOutputShapesUsingAnnotatedInformation(const NodeDef & node,NodeContext * c) const1327   Status UpdateOutputShapesUsingAnnotatedInformation(const NodeDef& node,
1328                                                      NodeContext* c) const {
1329     const auto& attr = node.attr();
1330     if (attr.count(kOutputSame) == 0 || !attr.at(kOutputSame).b() ||
1331         attr.count(kOutputShapes) == 0)
1332       return Status::OK();
1333 
1334     InferenceContext* ic = c->inference_context.get();
1335     int output_size = attr.at(kOutputShapes).list().shape_size();
1336 
1337     for (int i = 0; i < ic->num_outputs(); i++) {
1338       // Annotated Switch node has only one output. Propagate the shape to all
1339       // the outputs.
1340       int shape_index = IsSwitch(node) ? 0 : i;
1341       if (shape_index >= output_size) {
1342         LOG(WARNING)
1343             << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1344             << node.name() << ", inferred output shape size "
1345             << ic->num_outputs() << ", annotated output shape size "
1346             << output_size;
1347         break;
1348       }
1349 
1350       const TensorShapeProto& shape =
1351           attr.at(kOutputShapes).list().shape(shape_index);
1352       ShapeHandle output_shape;
1353       TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &output_shape));
1354 
1355       // Only use annotated shapes if the inference shape is unknown and
1356       // compatible with annotated shapes.
1357       if (!ic->FullyDefined(ic->output(i)) &&
1358           CompatibleShapes(ic->output(i), output_shape)) {
1359         VLOG(3) << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1360                 << node.name() << ", inferred output shape " << i << ": "
1361                 << "ic->output(i): " << ic->DebugString(ic->output(i))
1362                 << ", annotated output shape: " << ic->DebugString(output_shape)
1363                 << " -- " << node.ShortDebugString();
1364         ic->set_output(i, output_shape);
1365       }
1366     }
1367 
1368     return Status::OK();
1369   }
1370 
MaybeUpdateNodeContextOutput(const NodeDef & node,const bool is_fed,NodeContext * c)1371   Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed,
1372                                       NodeContext* c) {
1373     // Propagate tensors and shape tensors unless the node is fed.
1374     // TODO(bsteiner) We should still propagate the shapes to the ports that
1375     // aren't fed in the case of a ShapeN node.
1376 
1377     InferenceContext* ic = c->inference_context.get();
1378     if (!is_fed) {
1379       if (IsConstant(node)) {
1380         c->output_tensor_protos.resize(1);
1381         const TensorProto& tensor_proto = node.attr().at("value").tensor();
1382         c->output_tensor_protos[0] = &tensor_proto;
1383         c->output_tensors_as_shapes.resize(1);
1384         MaybeTensorProtoToShape(ic, tensor_proto,
1385                                 &c->output_tensors_as_shapes[0]);
1386       } else if (IsRank(node)) {
1387         if (ic->RankKnown(ic->input(0))) {
1388           // Propagate rank value.
1389           int32 rank = ic->Rank(ic->input(0));
1390           const_tensors_to_propagate_.push_back(
1391               MakeIntegerScalarTensorProto(DT_INT32, rank));
1392           c->output_tensor_protos.resize(1);
1393           c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
1394         }
1395       } else if (IsSize(node)) {
1396         DimensionHandle size = ic->NumElements(ic->input(0));
1397         if (ic->ValueKnown(size)) {
1398           // Propagate size value.
1399           int64 sz = ic->Value(size);
1400           bool valid = false;
1401           if (node.attr().at("out_type").type() == DT_INT32) {
1402             if (sz < std::numeric_limits<int32>::max()) {
1403               const_tensors_to_propagate_.push_back(
1404                   MakeIntegerScalarTensorProto(DT_INT32, sz));
1405               valid = true;
1406             }
1407           } else {
1408             const_tensors_to_propagate_.push_back(
1409                 MakeIntegerScalarTensorProto(DT_INT64, sz));
1410             valid = true;
1411           }
1412           if (valid) {
1413             c->output_tensor_protos.resize(1);
1414             c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
1415           }
1416         }
1417       } else if (IsShape(node)) {
1418         c->output_tensors_as_shapes.resize(1);
1419         c->output_tensors_as_shapes[0] = c->inference_context->input(0);
1420       } else if (IsShapeN(node)) {
1421         c->output_tensors_as_shapes.resize(c->inference_context->num_inputs());
1422         for (int i = 0; i < c->inference_context->num_inputs(); ++i) {
1423           c->output_tensors_as_shapes[i] = c->inference_context->input(i);
1424         }
1425       } else if (node.op() == "ConcatV2") {
1426         bool valid = true;
1427         ShapeHandle result;
1428         for (int i = 0; i < ic->num_inputs() - 1; ++i) {
1429           ShapeHandle input = ic->input_tensors_as_shapes()[i];
1430           if (!ic->RankKnown(input)) {
1431             valid = false;
1432             break;
1433           } else if (i == 0) {
1434             result = input;
1435           } else {
1436             TF_RETURN_IF_ERROR(ic->Concatenate(result, input, &result));
1437           }
1438         }
1439         if (valid) {
1440           c->output_tensors_as_shapes.resize(1);
1441           c->output_tensors_as_shapes[0] = result;
1442         }
1443       } else if (IsPack(node)) {
1444         // A Pack node concatenating scalars is often used to generate a shape.
1445         std::vector<DimensionHandle> dims;
1446         bool valid = true;
1447         for (int i = 0; i < ic->num_inputs(); ++i) {
1448           const Tensor* t = ic->input_tensor(i);
1449           if (t) {
1450             if (t->dims() != 0 ||
1451                 (t->dtype() != DT_INT32 && t->dtype() != DT_INT64)) {
1452               valid = false;
1453               break;
1454             }
1455             int64 size = t->dtype() == DT_INT32 ? t->scalar<int32>()()
1456                                                 : t->scalar<int64>()();
1457             dims.push_back(size < 0 ? ic->UnknownDim() : ic->MakeDim(size));
1458           } else {
1459             // Don't have tensor value, but use input_tensors_as_shapes, if
1460             // possible.
1461             const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
1462             if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 &&
1463                 ic->ValueKnown(ic->Dim(shape_handle, 0))) {
1464               dims.push_back(ic->Dim(shape_handle, 0));
1465             } else {
1466               dims.push_back(ic->UnknownDim());
1467             }
1468           }
1469         }
1470         if (valid) {
1471           c->output_tensors_as_shapes.resize(1);
1472           c->output_tensors_as_shapes[0] = ic->MakeShape(dims);
1473         }
1474       } else if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
1475         c->output_tensors_as_shapes.resize(1);
1476         c->output_tensors_as_shapes[0] = ic->input_tensors_as_shapes()[0];
1477         if (c->input_tensor_protos[0] != nullptr) {
1478           c->output_tensor_protos.resize(1);
1479           c->output_tensor_protos[0] = c->input_tensor_protos[0];
1480         }
1481       } else if (IsSlice(node)) {
1482         ShapeHandle input = ic->input_tensors_as_shapes()[0];
1483         bool valid = ic->RankKnown(input);
1484         const Tensor* slice_offset = ic->input_tensor(1);
1485         valid &= slice_offset != nullptr && slice_offset->NumElements() == 1;
1486         const Tensor* slice_size = ic->input_tensor(2);
1487         valid &= slice_size != nullptr && slice_size->NumElements() == 1;
1488         if (valid) {
1489           int64 start = slice_offset->dtype() == DT_INT32
1490                             ? slice_offset->flat<int32>()(0)
1491                             : slice_offset->flat<int64>()(0);
1492           int64 size =
1493               (slice_size->dtype() == DT_INT32 ? slice_size->flat<int32>()(0)
1494                                                : slice_size->flat<int64>()(0));
1495           ShapeHandle result;
1496           if (size == -1) {
1497             TF_RETURN_IF_ERROR(ic->Subshape(input, start, &result));
1498           } else {
1499             int64 end = start + size;
1500             TF_RETURN_IF_ERROR(ic->Subshape(input, start, end, &result));
1501           }
1502           c->output_tensors_as_shapes.resize(1);
1503           c->output_tensors_as_shapes[0] = result;
1504         }
1505       } else if (IsStridedSlice(node)) {
1506         ShapeHandle input = ic->input_tensors_as_shapes()[0];
1507         bool valid = ic->RankKnown(input);
1508         const Tensor* slice_begin = ic->input_tensor(1);
1509         valid &= slice_begin != nullptr && slice_begin->NumElements() == 1;
1510         const Tensor* slice_end = ic->input_tensor(2);
1511         valid &= slice_end != nullptr && slice_end->NumElements() == 1;
1512         const Tensor* slice_stride = ic->input_tensor(3);
1513         valid &= slice_stride != nullptr && slice_stride->NumElements() == 1;
1514 
1515         if (node.attr().count("ellipsis_mask") > 0 &&
1516             node.attr().at("ellipsis_mask").i() != 0) {
1517           valid = false;
1518         }
1519         if (node.attr().count("new_axis_mask") > 0 &&
1520             node.attr().at("new_axis_mask").i() != 0) {
1521           valid = false;
1522         }
1523         if (node.attr().count("shrink_axis_mask") > 0 &&
1524             node.attr().at("shrink_axis_mask").i() != 0) {
1525           valid = false;
1526         }
1527         int begin_mask = 0;
1528         if (node.attr().count("begin_mask") > 0) {
1529           begin_mask = node.attr().at("begin_mask").i();
1530         }
1531         int end_mask = 0;
1532         if (node.attr().count("end_mask") > 0) {
1533           end_mask = node.attr().at("end_mask").i();
1534         }
1535         if (begin_mask < 0 || begin_mask > 1 || end_mask < 0 || end_mask > 1) {
1536           valid = false;
1537         }
1538         if (valid) {
1539           int64 begin = 0;
1540           if (begin_mask == 0) {
1541             begin = slice_begin->dtype() == DT_INT32
1542                         ? slice_begin->flat<int32>()(0)
1543                         : slice_begin->flat<int64>()(0);
1544           }
1545           int64 end = std::numeric_limits<int64>::max();
1546           if (end_mask == 0) {
1547             end =
1548                 (slice_end->dtype() == DT_INT32 ? slice_end->flat<int32>()(0)
1549                                                 : slice_end->flat<int64>()(0));
1550           }
1551           int64 stride = slice_stride->dtype() == DT_INT32
1552                              ? slice_stride->flat<int32>()(0)
1553                              : slice_stride->flat<int64>()(0);
1554           ShapeHandle result;
1555           TF_RETURN_IF_ERROR(ic->Subshape(input, begin, end, stride, &result));
1556           c->output_tensors_as_shapes.resize(1);
1557           c->output_tensors_as_shapes[0] = result;
1558         }
1559       }
1560     }
1561 
1562     if (aggressive_shape_inference_) {
1563       // Update output shapes with annotated information. This is optional.
1564       UpdateOutputShapesUsingAnnotatedInformation(node, c).IgnoreError();
1565 
1566       // Update output tensor values using EvaluateNode() if we can.
1567       // Due to the cost of EvaluateNode(), we run it only for certain op types
1568       // (white listed) and small integer tensors.
1569 
1570       const int max_element_size = 17;  // Max up to 4x4 matrix or similar.
1571       if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) ||
1572           !ShouldUpdateOutputShapesAndValues(c, max_element_size)) {
1573         return Status::OK();
1574       }
1575       UpdateOutputShapesAndValues(node, c).IgnoreError();  // This is optional.
1576     }
1577     return Status::OK();
1578   }
1579 
InferShapes(const NodeDef & node,NodeContext * c)1580   Status InferShapes(const NodeDef& node, NodeContext* c) {
1581     // Infer the shapes of output tensors.
1582     if (!c->op_data || c->op_data->shape_inference_fn == nullptr) {
1583       // There is nothing more we can infer, annotate outputs with unknown
1584       // shapes
1585       return c->inference_context->Run(shape_inference::UnknownShape);
1586     }
1587 
1588     TF_RETURN_IF_ERROR(
1589         c->inference_context->Run(c->op_data->shape_inference_fn));
1590 
1591     Status status = Status::OK();
1592     auto it = fed_ports_.find(node.name());
1593     const bool is_fed = it != fed_ports_.end();
1594     if (is_fed) {
1595       // It is possible to feed node output ports with tensors of any shape: as
1596       // a result, the shape of a fed port is completely unknown.
1597       for (const int output_port : it->second) {
1598         status.Update(SetUnknownShape(&node, output_port));
1599       }
1600     }
1601 
1602     // Update NodeContext output fields after shape inference function runs.
1603     status.Update(MaybeUpdateNodeContextOutput(node, is_fed, c));
1604 
1605     return status;
1606   }
1607 
1608  private:
IsIntegerVector(const Tensor & tensor)1609   bool IsIntegerVector(const Tensor& tensor) {
1610     if (tensor.dims() == 1 &&
1611         (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64)) {
1612       return true;
1613     }
1614     return false;
1615   }
1616 
IsIntegerScalar(const Tensor & tensor)1617   bool IsIntegerScalar(const Tensor& tensor) {
1618     if (tensor.dims() == 0 &&
1619         (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64) &&
1620         tensor.NumElements() == 1) {
1621       return true;
1622     }
1623     return false;
1624   }
1625 
MakeIntegerScalarTensorProto(const DataType dtype,const int64 val)1626   TensorProto MakeIntegerScalarTensorProto(const DataType dtype,
1627                                            const int64 val) {
1628     TensorProto tensor_proto;
1629     tensor_proto.set_dtype(dtype);
1630     // Scalar TensorProto has an empty tensor_shape; no dim, no dim.size.
1631     tensor_proto.mutable_tensor_shape();
1632     if (dtype == DT_INT32) {
1633       tensor_proto.add_int_val(val);
1634     } else if (dtype == DT_INT64) {
1635       tensor_proto.add_int64_val(val);
1636     }
1637     return tensor_proto;
1638   }
1639 
MaybeTensorProtoToShape(InferenceContext * ic,const TensorProto & tensor_proto,ShapeHandle * tensors_as_shapes)1640   bool MaybeTensorProtoToShape(InferenceContext* ic,
1641                                const TensorProto& tensor_proto,
1642                                ShapeHandle* tensors_as_shapes) {
1643     // Skip if dtype is not integer.
1644     if (tensor_proto.dtype() != DT_INT32 && tensor_proto.dtype() != DT_INT64) {
1645       return false;
1646     }
1647     // Skip if shape is neither scalar nor vector.
1648     if (tensor_proto.tensor_shape().unknown_rank() ||
1649         tensor_proto.tensor_shape().dim_size() > 1) {
1650       return false;
1651     }
1652     Tensor tensor;
1653     if (!tensor.FromProto(tensor_proto)) {
1654       return false;
1655     }
1656     return MaybeTensorValueToShape(ic, tensor, tensors_as_shapes);
1657   }
1658 
MaybeTensorValueToShape(InferenceContext * ic,const Tensor & tensor,ShapeHandle * tensors_as_shapes)1659   bool MaybeTensorValueToShape(InferenceContext* ic, const Tensor& tensor,
1660                                ShapeHandle* tensors_as_shapes) {
1661     // Integer tensors of rank one can also be interpreted as a shape
1662     // provided all their values are >= -1.
1663     if (IsIntegerVector(tensor)) {
1664       bool has_values_smaller_than_minus_1 = false;
1665       std::vector<DimensionHandle> dims;
1666       for (int i = 0; i < tensor.NumElements(); i++) {
1667         int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(i)
1668                                                  : tensor.flat<int64>()(i);
1669         has_values_smaller_than_minus_1 |= (value < -1);
1670         dims.push_back(value < 0 ? ic->UnknownDim() : ic->MakeDim(value));
1671       }
1672       if (!has_values_smaller_than_minus_1) {
1673         *tensors_as_shapes = ic->MakeShape(dims);
1674       }
1675     } else if (IsIntegerScalar(tensor)) {
1676       // Scalar constant.
1677       int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0)
1678                                                : tensor.flat<int64>()(0);
1679       // Ideally, values can be < -1, but MakeDim() fails with a value < -1.
1680       // It's a limitation as we use ShapeHandle as a means to pass values.
1681       if (value >= -1) {
1682         *tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)});
1683         return true;
1684       }
1685     }
1686     return false;
1687   }
1688 
1689   const GraphView& graph_;
1690   int graph_def_version_;
1691   std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
1692   std::unordered_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
1693   std::unordered_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
1694   std::unordered_map<string, GrapplerFunctionItem>
1695       fun_to_grappler_function_item_;
1696   FunctionLibraryDefinition function_library_;
1697   const std::unordered_map<string, std::unordered_set<int>>& fed_ports_;
1698   // Store TensorProtos for tensor value propagation. Note that we use list, not
1699   // vector, as we use pointers to the TensorProtos in this container. Vector
1700   // may resize and copy the objects into a new buffer, then the existing
1701   // pointers become dangling pointers.
1702   std::list<TensorProto> const_tensors_to_propagate_;
1703 
1704   // For more aggressive shape and value inference.
1705   bool aggressive_shape_inference_;
1706   ResourceMgr resource_mgr_;
1707 };
1708 
1709 // Keep track of shapes and dimensions in a graph.
1710 // In particular, use disjoint sets to track equivalence between shapes and
1711 // dims, and consolidate the information globally.
1712 class SymbolicShapeManager {
1713  public:
SymbolicShapeManager()1714   SymbolicShapeManager() {}
1715 
Merge(ShapeHandle s1,ShapeHandle s2)1716   Status Merge(ShapeHandle s1, ShapeHandle s2) {
1717     if (!s1.IsSet() || !s2.IsSet()) {
1718       return Status::OK();
1719     }
1720     TF_RETURN_IF_ERROR(shapes_.Merge(s1, s2));
1721     if (InferenceContext::Rank(s1) > 0 && InferenceContext::Rank(s2) > 0) {
1722       CHECK_EQ(InferenceContext::Rank(s1), InferenceContext::Rank(s2));
1723       for (int i = 0; i < InferenceContext::Rank(s1); ++i) {
1724         TF_RETURN_IF_ERROR(dims_.Merge(InferenceContext::DimKnownRank(s1, i),
1725                                        InferenceContext::DimKnownRank(s2, i)));
1726       }
1727     }
1728     return Status::OK();
1729   }
Merge(DimensionHandle d1,DimensionHandle d2)1730   Status Merge(DimensionHandle d1, DimensionHandle d2) {
1731     if (!d1.IsSet() || !d2.IsSet()) {
1732       return Status::OK();
1733     }
1734     return dims_.Merge(d1, d2);
1735   }
1736 
AsTensorProperties(const ShapeHandle & shape,const DataType & type,OpInfo::TensorProperties * properties)1737   void AsTensorProperties(const ShapeHandle& shape, const DataType& type,
1738                           OpInfo::TensorProperties* properties) {
1739     properties->set_dtype(type);
1740     ShapeHandle actual_shape = shapes_.GetMergedValue(shape);
1741     if (!InferenceContext::RankKnown(actual_shape)) {
1742       properties->mutable_shape()->set_unknown_rank(true);
1743     } else {
1744       for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
1745         shape_inference::DimensionHandle dim =
1746             InferenceContext::DimKnownRank(actual_shape, j);
1747         int64 d = dims_.GetMergedValue(dim);
1748         properties->mutable_shape()->add_dim()->set_size(d);
1749       }
1750     }
1751   }
1752 
1753  private:
1754   DisjointSet<shape_inference::ShapeHandle> shapes_;
1755   DisjointSet<shape_inference::DimensionHandle> dims_;
1756 };
1757 
RelaxEnqueueShapesAndMergeTypes(SymbolicShapeRefiner * shape_refiner,const NodeDef * qnode,const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * queue_shapes_and_types)1758 Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
1759     SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
1760     const std::vector<ShapeAndType>& shapes_and_types,
1761     std::vector<ShapeAndType>* queue_shapes_and_types) {
1762   if (shapes_and_types.size() != queue_shapes_and_types->size()) {
1763     return errors::InvalidArgument(
1764         "Enqueue nodes mixed number of tensors: ", shapes_and_types.size(),
1765         "  vs ", queue_shapes_and_types->size());
1766   }
1767   for (size_t i = 0; i < shapes_and_types.size(); ++i) {
1768     const ShapeAndType& a = shapes_and_types[i];
1769     ShapeAndType& b = (*queue_shapes_and_types)[i];
1770     if (a.dtype != b.dtype) {
1771       return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ",
1772                                      i, ": ", DataTypeString(a.dtype), " vs ",
1773                                      DataTypeString(b.dtype));
1774     }
1775 
1776     b.shape = shape_refiner->OutputAsUnion(qnode, i, a.shape, b.shape);
1777   }
1778   return Status::OK();
1779 }
1780 
1781 // Compute the output shape of the merge node as the union of the available
1782 // input shapes.
UpdateMerge(SymbolicShapeRefiner * shape_refiner,const NodeDef * node,bool * new_shapes) const1783 Status GraphProperties::UpdateMerge(SymbolicShapeRefiner* shape_refiner,
1784                                     const NodeDef* node,
1785                                     bool* new_shapes) const {
1786   InferenceContext* ic = shape_refiner->GetContext(node);
1787   if (!ic) {
1788     // Now we can run shape inference
1789     TF_RETURN_IF_ERROR(shape_refiner->AddNode(node));
1790     ic = CHECK_NOTNULL(shape_refiner->GetContext(node));
1791     *new_shapes = true;
1792 
1793     // Infer the shape of the second output once and for all since it never
1794     // changes.
1795     ShapeHandle out1 = ic->Scalar();
1796     ic->set_output(1, out1);
1797   }
1798 
1799   ShapeHandle out;
1800   const std::vector<ShapeAndType>* out_handle = nullptr;
1801   bool out_initialized = false;
1802   for (const GraphView::Edge fanin : shape_refiner->graph().GetFaninEdges(
1803            *node, /*include_controlling_edges=*/false)) {
1804     InferenceContext* src_ic = shape_refiner->GetContext(fanin.src.node);
1805     if (!src_ic) {
1806       // Handling a loop for the first time, the back edge won't have any shape
1807       // info.
1808       continue;
1809     }
1810     ShapeHandle input = src_ic->output(fanin.src.port_id);
1811     ic->SetInput(fanin.dst.port_id, input);
1812     auto* input_handle =
1813         src_ic->output_handle_shapes_and_types(fanin.src.port_id);
1814     if (input_handle)
1815       ic->set_input_handle_shapes_and_types(fanin.dst.port_id, *input_handle);
1816     if (!out_initialized) {
1817       out_initialized = true;
1818       out = input;
1819       out_handle = input_handle;
1820     } else {
1821       // Note here only out, not out_handle, is modified.
1822       out = shape_refiner->OutputAsUnion(node, 0, input, out);
1823     }
1824   }
1825 
1826   if (*new_shapes || !shape_refiner->EquivalentShapes(out, ic->output(0))) {
1827     ic->set_output(0, out);
1828     if (out_handle) ic->set_output_handle_shapes_and_types(0, *out_handle);
1829     *new_shapes = true;
1830   }
1831 
1832   return Status::OK();
1833 }
1834 
1835 // Manually propagate the input shape for Enter nodes.
UpdateEnter(SymbolicShapeRefiner * shape_refiner,const NodeDef * node,bool * new_shapes)1836 Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
1837                                     const NodeDef* node, bool* new_shapes) {
1838   InferenceContext* ic = shape_refiner->GetContext(node);
1839   if (!ic) {
1840     TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(node, new_shapes));
1841     ic = shape_refiner->GetContext(node);
1842   }
1843 
1844   GraphView::InputPort port(node, 0);
1845   GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(port);
1846 
1847   InferenceContext* src_ic = shape_refiner->GetContext(fanin.node);
1848   ShapeHandle input = src_ic->output(fanin.port_id);
1849   if (!ic->output(0).SameHandle(input)) {
1850     ic->SetInput(0, input);
1851     ic->set_output(0, input);
1852     *new_shapes = true;
1853   }
1854   auto* outputs = src_ic->output_handle_shapes_and_types(fanin.port_id);
1855   if (outputs) {
1856     ic->set_input_handle_shapes_and_types(0, *outputs);
1857     ic->set_output_handle_shapes_and_types(0, *outputs);
1858     *new_shapes = true;
1859   }
1860   return Status::OK();
1861 }
1862 
UpdateShapes(SymbolicShapeRefiner * shape_refiner,const std::unordered_map<const NodeDef *,const NodeDef * > & resource_handles,const NodeDef * n,bool * new_shapes) const1863 Status GraphProperties::UpdateShapes(
1864     SymbolicShapeRefiner* shape_refiner,
1865     const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
1866     const NodeDef* n, bool* new_shapes) const {
1867   if (IsEnter(*n)) {
1868     // The Enter shape function always forwards an UnknownShape, so do the right
1869     // thing here.
1870     TF_RETURN_IF_ERROR(UpdateEnter(shape_refiner, n, new_shapes));
1871   } else if (IsMerge(*n)) {
1872     // Properly handle merge nodes.
1873     TF_RETURN_IF_ERROR(UpdateMerge(shape_refiner, n, new_shapes));
1874   } else if (IsEnqueue(*n)) {
1875     // Make sure the shapes of enqueued tensors are propagated to the queue
1876     // itself.
1877     TF_RETURN_IF_ERROR(
1878         UpdateEnqueue(n, resource_handles, shape_refiner, new_shapes));
1879   } else if (IsQueue(*n)) {
1880     // Set shapes and types of Queue ops, if needed.
1881     TF_RETURN_IF_ERROR(UpdateQueue(n, shape_refiner, new_shapes));
1882   } else {
1883     // Rely on regular TF shape refinement for all the other nodes.
1884     // UpdateNode calls UpdateFunction if a function node is detected.
1885     TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
1886   }
1887 
1888   return Status::OK();
1889 }
1890 
1891 // Propagates the shapes in the transitive fan-out of <new_shapes>.
PropagateShapes(SymbolicShapeRefiner * shape_refiner,TopoQueue * new_shapes,const std::unordered_map<const NodeDef *,const NodeDef * > & resource_handles,int num_loops) const1892 Status GraphProperties::PropagateShapes(
1893     SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
1894     const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
1895     int num_loops) const {
1896   // Limit the number of iterations to prevent infinite loops in the presence of
1897   // incorrect shape functions. The algorithm should converge in at most
1898   // num_nested_loops^2 * max_rank. We approximate max_rank with the constant 4.
1899   // The same applies to resources.
1900   VLOG(1) << "Propagating " << new_shapes->size() << " new shapes through "
1901           << num_loops << " loops and " << resource_handles.size()
1902           << " resources" << std::endl;
1903 
1904   const int64 max_loop_length = item_.graph.node_size();
1905   const int64 max_rank = 4;
1906   const int64 max_loop_iterations =
1907       max_rank * max_loop_length * std::max<int64>(1, num_loops * num_loops);
1908   const int64 num_queues = resource_handles.size();
1909   const int64 max_resource_iterations = num_queues * num_queues * max_rank;
1910 
1911   int64 num_resource_iterations = 0;
1912   do {
1913     int64 num_loop_iterations = 0;
1914     while (!new_shapes->empty() &&
1915            num_loop_iterations++ < max_loop_iterations) {
1916       const NodeDef* n = new_shapes->pop();
1917       bool updated = false;
1918       TF_RETURN_IF_ERROR(
1919           UpdateShapes(shape_refiner, resource_handles, n, &updated));
1920       if (updated) {
1921         for (const auto& fanout : shape_refiner->graph().GetFanouts(
1922                  *n, /*include_controlled_nodes=*/false)) {
1923           new_shapes->push(fanout.node);
1924         }
1925         // Make sure the corresponding queue nodes are (re)processed.
1926         if (IsEnqueue(*n)) {
1927           auto it = resource_handles.find(n);
1928           if (it != resource_handles.end()) {
1929             new_shapes->push(it->second);
1930           }
1931         }
1932       }
1933     }
1934   } while (!new_shapes->empty() &&
1935            num_resource_iterations++ < max_resource_iterations);
1936 
1937   if (!new_shapes->empty()) {
1938     return errors::Internal("Shape inference failed to converge");
1939   }
1940 
1941   return Status::OK();
1942 }
1943 
UpdateQueue(const NodeDef * queue_node,SymbolicShapeRefiner * shape_refiner,bool * new_shapes)1944 Status GraphProperties::UpdateQueue(const NodeDef* queue_node,
1945                                     SymbolicShapeRefiner* shape_refiner,
1946                                     bool* new_shapes) {
1947   auto* ctx = shape_refiner->GetNodeContext(queue_node);
1948   if (!ctx) {
1949     TF_RETURN_IF_ERROR(shape_refiner->AddNode(queue_node));
1950     ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(queue_node));
1951   }
1952   auto* ic = ctx->inference_context.get();
1953 
1954   auto* outputs = ic->output_handle_shapes_and_types(0);
1955   if (outputs) {
1956     // Shapes and types are already set, presumably by Enqueue ops.
1957     return shape_refiner->UpdateNode(queue_node, new_shapes);
1958   }
1959 
1960   if (queue_node->attr().count("shapes") <= 0 ||
1961       queue_node->attr().count("component_types") <= 0 ||
1962       queue_node->attr().at("shapes").list().shape_size() !=
1963           queue_node->attr().at("component_types").list().type_size()) {
1964     // Errors in shapes and component_types attr.
1965     return shape_refiner->UpdateNode(queue_node, new_shapes);
1966   }
1967 
1968   // Extract types and shapes from Queue attr.
1969   const auto& shapes = queue_node->attr().at("shapes").list().shape();
1970   const auto& types = queue_node->attr().at("component_types").list().type();
1971   std::vector<ShapeAndType> shapes_and_types;
1972   for (int i = 0; i < types.size(); i++) {
1973     const auto& shape = shapes[i];
1974     ShapeHandle shape_handle;
1975     TF_RETURN_IF_ERROR(
1976         ic->MakeShapeFromPartialTensorShape(shape, &shape_handle));
1977     DataType data_type =
1978         queue_node->attr().at("component_types").list().type(i);
1979     ShapeAndType shape_and_type(shape_handle, data_type);
1980     shapes_and_types.push_back(shape_and_type);
1981   }
1982   ic->set_output_handle_shapes_and_types(0, shapes_and_types);
1983 
1984   // Queue node is updated with output_handle_shapes_and_types, so set
1985   // new_shapes and ignore it from UpdateNoe().
1986   *new_shapes = true;
1987   bool dummy_new_shapes = false;
1988   return shape_refiner->UpdateNode(queue_node, &dummy_new_shapes);
1989 }
1990 
UpdateEnqueue(const NodeDef * enqueue_node,const std::unordered_map<const NodeDef *,const NodeDef * > & resource_handles,SymbolicShapeRefiner * shape_refiner,bool * new_shapes)1991 Status GraphProperties::UpdateEnqueue(
1992     const NodeDef* enqueue_node,
1993     const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
1994     SymbolicShapeRefiner* shape_refiner, bool* new_shapes) {
1995   auto ctx = shape_refiner->GetNodeContext(enqueue_node);
1996   if (!ctx) {
1997     TF_RETURN_IF_ERROR(shape_refiner->AddNode(enqueue_node));
1998     ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(enqueue_node));
1999   }
2000 
2001   auto it = resource_handles.find(enqueue_node);
2002   if (it == resource_handles.end()) {
2003     // The corresponding queue was not found, there isn't much we can do.
2004     return Status::OK();
2005   }
2006   const NodeDef* qnode = it->second;
2007   auto qctx = shape_refiner->GetContext(qnode);
2008   if (!qctx) {
2009     return Status::OK();
2010   }
2011   auto* queue_handle_data = qctx->output_handle_shapes_and_types(0);
2012 
2013   // TODO(bsteiner): handle EnqueueMany as well.
2014   std::vector<ShapeAndType> shapes_and_types;
2015   for (int i = 1; i < ctx->input_types.size(); ++i) {
2016     GraphView::InputPort inp(enqueue_node, i);
2017     GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(inp);
2018     InferenceContext* in = shape_refiner->GetContext(fanin.node);
2019     ShapeHandle input = in->output(fanin.port_id);
2020     ctx->inference_context->SetInput(i, input);
2021     shapes_and_types.push_back({input, ctx->input_types[i]});
2022   }
2023 
2024   if (queue_handle_data == nullptr) {
2025     qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
2026     *new_shapes = true;
2027   } else {
2028     TF_RETURN_IF_ERROR(RelaxEnqueueShapesAndMergeTypes(
2029         shape_refiner, qnode, *queue_handle_data, &shapes_and_types));
2030     *new_shapes |= !shape_refiner->EquivalentShapesAndTypes(*queue_handle_data,
2031                                                             shapes_and_types);
2032     qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
2033   }
2034 
2035   return Status::OK();
2036 }
2037 
InferStatically(bool assume_valid_feeds,bool aggressive_shape_inference)2038 Status GraphProperties::InferStatically(bool assume_valid_feeds,
2039                                         bool aggressive_shape_inference) {
2040   FunctionLibraryDefinition function_library(OpRegistry::Global(),
2041                                              item_.graph.library());
2042   std::unordered_map<string, std::unordered_set<int>> fed_ports;
2043   if (!assume_valid_feeds) {
2044     for (const auto& feed : item_.feed) {
2045       SafeTensorId tensor_id = ParseTensorName(feed.first);
2046       fed_ports[tensor_id.node()].insert(tensor_id.index());
2047     }
2048   }
2049 
2050   GraphView graph_view(&item_.graph);
2051 
2052   // List the resources and the nodes using them. Also collect the Merge nodes,
2053   // fed nodes, and primary inputs.
2054   std::unordered_map<const NodeDef*,
2055                      std::pair<std::unordered_set<const NodeDef*>,
2056                                std::unordered_set<const NodeDef*>>>
2057       resources;
2058   std::unordered_set<const NodeDef*> merge_nodes;
2059   std::unordered_set<const NodeDef*> fed_nodes;
2060   std::unordered_set<const NodeDef*> primary_inputs;
2061   int num_loops = 0;
2062   for (const NodeDef& node : item_.graph.node()) {
2063     if (IsQueue(node)) {
2064       for (const GraphView::InputPort& fanout :
2065            graph_view.GetFanouts(node, false)) {
2066         if (IsEnter(*fanout.node)) {
2067           const NodeDef& enter = *fanout.node;
2068           for (const GraphView::InputPort& fanout :
2069                graph_view.GetFanouts(enter, false)) {
2070             if (IsEnqueue(*fanout.node)) {
2071               resources[&node].first.insert(fanout.node);
2072             } else if (IsDequeue(*fanout.node)) {
2073               resources[&node].second.insert(fanout.node);
2074             }
2075           }
2076         } else {
2077           if (IsEnqueue(*fanout.node)) {
2078             resources[&node].first.insert(fanout.node);
2079           } else if (IsDequeue(*fanout.node)) {
2080             resources[&node].second.insert(fanout.node);
2081           }
2082         }
2083       }
2084     }
2085     if (NumNonControlInputs(node) == 0) {
2086       primary_inputs.insert(&node);
2087     } else if (IsMerge(node)) {
2088       merge_nodes.insert(&node);
2089     } else if (IsNextIteration(node)) {
2090       ++num_loops;
2091     }
2092     if (fed_ports.find(node.name()) != fed_ports.end()) {
2093       fed_nodes.insert(&node);
2094     }
2095   }
2096 
2097   std::unordered_map<const NodeDef*, const NodeDef*> resource_handles;
2098   std::vector<TopologicalDependency> extra_deps;
2099   for (const auto& resource : resources) {
2100     for (const NodeDef* src : resource.second.first) {
2101       resource_handles[src] = resource.first;
2102       for (const NodeDef* dst : resource.second.second) {
2103         // Add control edges from enqueue to dequeue nodes to ensure they are
2104         // processed in their logical order.
2105         extra_deps.emplace_back(src, dst);
2106       }
2107     }
2108   }
2109 
2110   std::vector<const NodeDef*> topo_order;
2111   Status s = ComputeTopologicalOrder(item_.graph, extra_deps, &topo_order);
2112   if (!s.ok()) {
2113     if (extra_deps.empty()) {
2114       return s;
2115     } else {
2116       // There is a loop between queues: we'll just use the graph topological
2117       // order. This will make the shape inference less precise but since this
2118       // isn't common it's not worth to figure out where to break the loop and
2119       // do a proper relaxation.
2120       TF_RETURN_IF_ERROR(ComputeTopologicalOrder(item_.graph, &topo_order));
2121     }
2122   }
2123 
2124   SymbolicShapeRefiner refiner(graph_view, fed_ports,
2125                                aggressive_shape_inference);
2126 
2127   TopoQueue new_shapes(topo_order);
2128   // Also seed the propagation of shapes in the fanout of primary inputs.
2129   for (const NodeDef* node : primary_inputs) {
2130     new_shapes.push(node);
2131   }
2132   // Also seed the propagation of shapes in the fanout of fed nodes.
2133   for (const NodeDef* node : fed_nodes) {
2134     new_shapes.push(node);
2135   }
2136   // Propagate shapes normally.
2137   TF_RETURN_IF_ERROR(
2138       PropagateShapes(&refiner, &new_shapes, resource_handles, num_loops));
2139 
2140   // Track shapes globally across the graph.
2141   std::unique_ptr<SymbolicShapeManager> shape_manager =
2142       absl::make_unique<SymbolicShapeManager>();
2143   bool found_error = false;
2144   for (const NodeDef& node : item_.graph.node()) {
2145     auto node_ctx = refiner.GetContext(&node);
2146     if (!node_ctx) {
2147       continue;
2148     }
2149     // Skip any information that comes from fed nodes.
2150     if (fed_ports.find(node.name()) != fed_ports.end()) {
2151       VLOG(2) << "Skipping feed node shape: " << node.name();
2152       continue;
2153     }
2154     for (const auto& merged_shapes : node_ctx->MergedShapes()) {
2155       if (!shape_manager->Merge(merged_shapes.first, merged_shapes.second)
2156                .ok()) {
2157         found_error = true;
2158         break;
2159       }
2160     }
2161     for (const auto& merged_dims : node_ctx->MergedDims()) {
2162       if (!shape_manager->Merge(merged_dims.first, merged_dims.second).ok()) {
2163         found_error = true;
2164         break;
2165       }
2166     }
2167     if (found_error) {
2168       // The shapes aren't consistent, we can't infer safely: discard all the
2169       // information discovered so far.
2170       shape_manager = absl::make_unique<SymbolicShapeManager>();
2171       break;
2172     }
2173   }
2174 
2175   for (const NodeDef& node : item_.graph.node()) {
2176     VLOG(3) << "Filling in graph properties for node: " << node.name();
2177     auto ctx = refiner.GetNodeContext(&node);
2178     if (!ctx) {
2179       continue;
2180     }
2181 
2182     auto* ic = ctx->inference_context.get();
2183 
2184     // Fill input properties.
2185     {
2186       auto& input_properties = input_properties_[node.name()];
2187 
2188       // Should always be empty, node names in graph are supposed to be unique.
2189       CHECK_EQ(input_properties.size(), 0);
2190 
2191       input_properties.resize(ic->num_inputs());
2192       GraphView::InputPort input(&node, -1);
2193       for (int i = 0; i < ic->num_inputs(); ++i) {
2194         shape_manager->AsTensorProperties(ic->input(i), ctx->input_types[i],
2195                                           &input_properties[i]);
2196         input.port_id = i;
2197         GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
2198         // Export tensor value to input_properties.value.
2199         if (IsConstant(*fanin.node)) {
2200           const TensorProto& raw_val = fanin.node->attr().at("value").tensor();
2201           *input_properties[i].mutable_value() = raw_val;
2202         } else if (ctx->input_tensor_protos.size() > i &&
2203                    ctx->input_tensor_protos[i] != nullptr) {
2204           *input_properties[i].mutable_value() = *ctx->input_tensor_protos[i];
2205         } else if (ic->input_tensors_as_shapes().size() > i &&
2206                    IsShapeFullyDefinedIntegerVectorOrScalar(
2207                        ic, ic->input(i), ic->input_tensors_as_shapes()[i],
2208                        ctx->input_types[i])) {
2209           *input_properties[i].mutable_value() = MakeTensorProtoFromShape(
2210               ic, ic->input(i), ic->input_tensors_as_shapes()[i],
2211               ctx->input_types[i]);
2212         }
2213       }
2214     }
2215 
2216     // Fill output properties.
2217     {
2218       auto& output_properties = output_properties_[node.name()];
2219 
2220       // Should always be empty, node names in graph are supposed to be unique.
2221       CHECK_EQ(output_properties.size(), 0);
2222 
2223       output_properties.resize(ic->num_outputs());
2224       for (int i = 0; i < ic->num_outputs(); ++i) {
2225         shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i],
2226                                           &output_properties[i]);
2227         // Export tensor value to output_properties.value.
2228         if (IsConstant(node)) {
2229           const TensorProto& raw_val = node.attr().at("value").tensor();
2230           *output_properties[i].mutable_value() = raw_val;
2231         } else if (ctx->output_tensor_protos.size() > i &&
2232                    ctx->output_tensor_protos[i] != nullptr) {
2233           *output_properties[i].mutable_value() = *ctx->output_tensor_protos[i];
2234         } else if (ctx->output_tensors_as_shapes.size() > i &&
2235                    IsShapeFullyDefinedIntegerVectorOrScalar(
2236                        ic, ic->output(i), ctx->output_tensors_as_shapes[i],
2237                        ctx->output_types[i])) {
2238           *output_properties[i].mutable_value() = MakeTensorProtoFromShape(
2239               ic, ic->output(i), ctx->output_tensors_as_shapes[i],
2240               ctx->output_types[i]);
2241         }
2242       }
2243     }
2244   }
2245 
2246   // Help trace the unknown dimensions to their origins.
2247   VerboseLogUnknownDimensionSources(item_.graph, input_properties_,
2248                                     output_properties_);
2249 
2250   return Status::OK();
2251 }
2252 
InferDynamically(Cluster * cluster)2253 Status GraphProperties::InferDynamically(Cluster* cluster) {
2254   TF_RETURN_IF_ERROR(cluster->Initialize(item_));
2255 
2256   // Runs the model once to collect the shapes in the cost model.
2257   RunMetadata metadata;
2258   TF_RETURN_IF_ERROR(
2259       cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata));
2260 
2261   return InferFromCostGraph(metadata.cost_graph());
2262 }
2263 
AnnotateOutputShapes(GraphDef * output_graph_def) const2264 Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def) const {
2265   *output_graph_def = item_.graph;
2266   for (int i = 0; i < output_graph_def->node_size(); i++) {
2267     auto node = output_graph_def->mutable_node(i);
2268     AttrValue attr_output_shape;
2269     auto tensor_properties = GetOutputProperties(node->name());
2270     for (const auto& tensor_property : tensor_properties) {
2271       *attr_output_shape.mutable_list()->add_shape() = tensor_property.shape();
2272     }
2273     (*node->mutable_attr())["_output_shapes"] = attr_output_shape;
2274   }
2275   return Status::OK();
2276 }
2277 
InferFromCostGraph(const CostGraphDef & cost_graph)2278 Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) {
2279   if (cost_graph.node_size() == 0) {
2280     LOG(WARNING) << "cost_graph is empty: nothing can be inferred!";
2281   }
2282   std::unordered_map<string, const CostGraphDef::Node*> name_to_cost;
2283   std::unordered_map<string, const NodeDef*> name_to_node;  // Empty
2284   for (auto& node : cost_graph.node()) {
2285     name_to_cost[node.name()] = &node;
2286 
2287     std::vector<OpInfo::TensorProperties> output_properties;
2288     for (const auto& out : node.output_info()) {
2289       OpInfo::TensorProperties properties;
2290       properties.set_dtype(out.dtype());
2291       *properties.mutable_shape() = out.shape();
2292       output_properties.push_back(properties);
2293     }
2294     output_properties_[node.name()] = output_properties;
2295   }
2296 
2297   for (const auto& node : item_.graph.node()) {
2298     // Skip the nodes that are not in the cost graph: these are nodes that
2299     // aren't run, because they aren't in the intersection of transitive fan-in
2300     // of a fetch node and the transitive fan-out of an input, or nodes that
2301     // were optimized away by the optimizer.
2302     auto it = name_to_cost.find(node.name());
2303     if (it == name_to_cost.end()) {
2304       continue;
2305     }
2306     std::vector<OpInfo::TensorProperties> inputs =
2307         FindInputFeatures(node, name_to_cost, name_to_node);
2308 
2309     input_properties_[node.name()] = inputs;
2310   }
2311   return Status::OK();
2312 }
2313 
HasInputProperties(const string & node_name) const2314 bool GraphProperties::HasInputProperties(const string& node_name) const {
2315   return input_properties_.find(node_name) != input_properties_.end();
2316 }
2317 
HasOutputProperties(const string & node_name) const2318 bool GraphProperties::HasOutputProperties(const string& node_name) const {
2319   return output_properties_.find(node_name) != output_properties_.end();
2320 }
2321 
2322 const std::vector<OpInfo::TensorProperties>&
GetInputProperties(const string & node_name) const2323 GraphProperties::GetInputProperties(const string& node_name) const {
2324   auto it = input_properties_.find(node_name);
2325   if (it != input_properties_.end()) {
2326     return it->second;
2327   }
2328   return missing_properties_;
2329 }
2330 
2331 const std::vector<OpInfo::TensorProperties>&
GetOutputProperties(const string & node_name) const2332 GraphProperties::GetOutputProperties(const string& node_name) const {
2333   auto it = output_properties_.find(node_name);
2334   if (it != output_properties_.end()) {
2335     return it->second;
2336   }
2337   return missing_properties_;
2338 }
2339 
ClearInputProperties(const string & node_name)2340 void GraphProperties::ClearInputProperties(const string& node_name) {
2341   input_properties_.erase(node_name);
2342 }
ClearOutputProperties(const string & node_name)2343 void GraphProperties::ClearOutputProperties(const string& node_name) {
2344   output_properties_.erase(node_name);
2345 }
2346 
2347 }  // end namespace grappler
2348 }  // end namespace tensorflow
2349