• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/shape_refiner.h"
16 
17 #include <deque>
18 #include <memory>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
23 #include "tensorflow/core/common_runtime/function_utils.h"
24 #include "tensorflow/core/common_runtime/graph_constructor.h"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/common_shape_fns.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/shape_inference.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 
35 namespace tensorflow {
36 
37 using shape_inference::DimensionHandle;
38 using shape_inference::InferenceContext;
39 using shape_inference::ShapeAndType;
40 using shape_inference::ShapeHandle;
41 
ShapeRefiner(int graph_def_version,const OpRegistryInterface * ops)42 ShapeRefiner::ShapeRefiner(int graph_def_version,
43                            const OpRegistryInterface* ops)
44     : graph_def_version_(graph_def_version),
45       ops_registry_(ops),
46       graph_runner_(Env::Default()) {}
47 
ShapeRefiner(const VersionDef & versions,const OpRegistryInterface * ops)48 ShapeRefiner::ShapeRefiner(const VersionDef& versions,
49                            const OpRegistryInterface* ops)
50     : ShapeRefiner(versions.producer(), ops) {}
51 
~ShapeRefiner()52 ShapeRefiner::~ShapeRefiner() {
53   // The lifetime of the tensors are bound to the GraphRunner, so the tensors
54   // should be deleted before it.
55   const_tensor_map_.clear();
56 }
57 
58 namespace {
59 
60 constexpr char kArgOp[] = "_Arg";
61 constexpr char kRetvalOp[] = "_Retval";
62 
63 }  // namespace
64 
65 // Runs shape inference for the given node using the given ShapeRefiner.
66 // The node must be a sub-node of a function node and the outer_context is
67 // the inference context of that function node in the outer graph.
InferShapesForFunctionSubNode(const Node * node,InferenceContext * outer_context)68 Status ShapeRefiner::InferShapesForFunctionSubNode(
69     const Node* node, InferenceContext* outer_context) {
70   TF_RETURN_IF_ERROR(AddNodeInternal(node, outer_context));
71   InferenceContext* node_context = CHECK_NOTNULL(GetContext(node));
72 
73   if (StringPiece(node->type_string()) == kArgOp) {
74     // Handle special node: function input.
75     // Shapes for these nodes are provided in the outer inference
76     // context.
77 
78     int index;
79     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
80 
81     if (index < 0 || outer_context->num_inputs() <= index) {
82       return errors::Internal(
83           "Function instantiation included invalid input index: ", index,
84           " not in [0, ", outer_context->num_inputs(), ").");
85     }
86 
87     // TODO(b/134547156): TEMPORARY WORKAROUND. If input shape handle is not set
88     // in outer context, set _Arg node output shape to unknown.
89     if (outer_context->input(index).SameHandle(ShapeHandle())) {
90       VLOG(1) << "Function instantiation has undefined input shape at "
91               << "index: " << index << " in the outer inference context.";
92       node_context->set_output(0, node_context->UnknownShape());
93     } else {
94       node_context->set_output(0, outer_context->input(index));
95     }
96 
97     auto* resource = outer_context->input_handle_shapes_and_types(index);
98     if (resource) {
99       node_context->set_output_handle_shapes_and_types(0, *resource);
100     }
101   } else if (StringPiece(node->type_string()) == kRetvalOp) {
102     // Handle special node: function output.
103     // Shapes inferred for these nodes go into the outer inference
104     // context.
105 
106     int index;
107     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
108 
109     if (index < 0 || outer_context->num_outputs() <= index) {
110       return errors::Internal(
111           "Function instantiation included invalid output index: ", index,
112           " not in [0, ", outer_context->num_outputs(), ").");
113     }
114 
115     // outer_context outlives node_context, therefore we need to create
116     // a new shape handle owned by outer_context instead.
117     ShapeHandle handle;
118     TensorShapeProto proto;
119     node_context->ShapeHandleToProto(node_context->input(0), &proto);
120     TF_RETURN_IF_ERROR(outer_context->MakeShapeFromShapeProto(proto, &handle));
121     outer_context->set_output(index, handle);
122 
123     const std::vector<ShapeAndType>* resource =
124         node_context->input_handle_shapes_and_types(0);
125     if (resource) {
126       // `ShapesAndType`s contain `ShapeHandle`s.  These `ShapeHandle`s point
127       // to `Shape`s that are owned by a different inference context too.  We
128       // need to copy them to the outer context to prevent them from being
129       // destroyed before they are used.
130       std::vector<ShapeAndType> copied_shapes_and_types;
131       for (auto& shape_and_type : *resource) {
132         ShapeHandle handle;
133         TensorShapeProto proto;
134         node_context->ShapeHandleToProto(shape_and_type.shape, &proto);
135         TF_RETURN_IF_ERROR(
136             outer_context->MakeShapeFromShapeProto(proto, &handle));
137         copied_shapes_and_types.push_back(
138             ShapeAndType(handle, shape_and_type.dtype, shape_and_type.type));
139       }
140 
141       outer_context->set_output_handle_shapes_and_types(
142           index, copied_shapes_and_types);
143     }
144   }
145 
146   return Status::OK();
147 }
148 
149 // TODO(cwhipkey): When an inference context inside function has
150 // requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
151 // set when input(i) is an _Arg op, then this request should propagate to
152 // context, and vice versa.
153 //
154 // NOTE: Recursive user-defined functions are not supported.
155 // Maybe we won't support recursive functions at all in TF, because of
156 // other maintainability issues.
InferShapesForFunction(const FunctionDef * function_def,AttrSlice attributes,ExtendedInferenceContext * outer_context)157 Status ShapeRefiner::InferShapesForFunction(
158     const FunctionDef* function_def, AttrSlice attributes,
159     ExtendedInferenceContext* outer_context) {
160   const Graph* graph;
161   auto it = functions_.find(function_def);
162   if (it != functions_.end()) {
163     graph = it->second.get();
164   } else {
165     InstantiationResult result;
166     TF_RETURN_IF_ERROR(InstantiateFunction(
167         *function_def, attributes,
168         [this](const string& op, const OpDef** sig) {
169           return this->function_library_->LookUpOpDef(op, sig);
170         },
171         &result));
172 
173     Graph* new_graph = new Graph(function_library_);
174     GraphConstructorOptions options;
175     options.allow_internal_ops = true;
176     TF_RETURN_IF_ERROR(
177         ConvertNodeDefsToGraph(options, result.nodes, new_graph));
178     functions_[function_def].reset(new_graph);
179     graph = new_graph;
180   }
181 
182   std::unordered_set<const Node*> function_nodes;
183   Status inference_status = Status::OK();
184   {
185     auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
186                                         &inference_status](const Node* node) {
187       if (!inference_status.ok()) return;
188       inference_status =
189           InferShapesForFunctionSubNode(node, outer_context->get_context());
190       function_nodes.insert(node);
191     };
192 
193     // Calls inference lambda for each node after visiting all predecessors.
194     // Ensures that we are adding nodes to ShapeRefiner in the topological
195     // order.
196     ReverseDFS(*graph, {}, node_shape_inference_lambda);
197   }
198 
199   // Delete the contexts created for the functions nodes to save memory.
200   for (const Node* node : function_nodes) {
201     node_to_context_.erase(node);
202   }
203 
204   return inference_status;
205 }
206 
AddNode(const Node * node)207 Status ShapeRefiner::AddNode(const Node* node) {
208   return AddNodeInternal(node, /*outer_context=*/nullptr);
209 }
210 
AddNodeInternal(const Node * node,shape_inference::InferenceContext * outer_context)211 Status ShapeRefiner::AddNodeInternal(
212     const Node* node, shape_inference::InferenceContext* outer_context) {
213   // Create the inference context for this node with the existing input shapes.
214   std::unique_ptr<InferenceContext> ic(new InferenceContext(
215       graph_def_version_, node->def(), node->op_def(),
216       std::vector<ShapeHandle>(node->num_inputs()), {}, {}, {}));
217   TF_RETURN_IF_ERROR(ic->construction_status());
218 
219   // For each 'input' of this node, fetch the corresponding shape
220   // from 'input's InferenceContext, and store into this node's
221   // InferenceContext.
222   for (const Edge* e : node->in_edges()) {
223     if (e->IsControlEdge()) continue;
224 
225     if (e->dst_input() < 0) {
226       return tensorflow::errors::Internal(
227           "Index ", e->dst_input(), " is negative but not a control edge.");
228     }
229 
230     const Node* input = e->src();
231     auto it = node_to_context_.find(input);
232     if (it == node_to_context_.end()) {
233       // v1 control flow adds loops to the graph; we have to break them
234       // somewhere, so we'll ignore this input and leave its shape undefined.
235       ic->SetInput(e->dst_input(), ic->UnknownShape());
236       continue;
237     }
238 
239     InferenceContext* input_ic = it->second->get_context();
240     ic->SetInput(e->dst_input(), input_ic->output(e->src_output()));
241 
242     const auto* in_v =
243         input_ic->output_handle_shapes_and_types(e->src_output());
244     if (in_v != nullptr) {
245       DataType input_type = e->src()->output_type(e->src_output());
246       DCHECK(input_type == DT_RESOURCE || input_type == DT_VARIANT);
247       ic->set_input_handle_shapes_and_types(e->dst_input(),
248                                             std::vector<ShapeAndType>(*in_v));
249     }
250   }
251 
252   // Get the shape function for this node
253   const OpRegistrationData* op_reg_data;
254   TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
255   if (op_reg_data->shape_inference_fn == nullptr &&
256       require_shape_inference_fns_) {
257     return errors::InvalidArgument(
258         "No shape inference function exists for op '", node->type_string(),
259         "', did you forget to define it?");
260   }
261 
262   std::unique_ptr<ExtendedInferenceContext> ec(
263       new ExtendedInferenceContext(std::move(ic), node));
264 
265   // Run the shape inference function, and return if there was an error.
266   TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get(), outer_context));
267 
268   // Store the resulting context object in the map.
269   node_to_context_[node].swap(ec);
270 
271   return Status::OK();
272 }
273 
SetShape(const Node * node,int output_port,ShapeHandle shape)274 Status ShapeRefiner::SetShape(const Node* node, int output_port,
275                               ShapeHandle shape) {
276   auto c = GetContext(node);
277   if (c == nullptr) {
278     return errors::Internal("Could not find context for ", node->name());
279   }
280 
281   if (output_port < 0 || output_port >= node->num_outputs()) {
282     return errors::InvalidArgument(
283         "output_port '", output_port, "' is out of range, ", "node '",
284         node->name(), "' has ", node->num_outputs(), " outputs");
285   }
286   // Note: it's possible, if the node's been updated, that the shape inference
287   // context doesn't have the right number of outputs.
288   if (node->num_outputs() > c->num_outputs()) {
289     TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs()));
290   }
291 
292   // Check compatibility, and merge the shapes.
293   ShapeHandle existing_shape = c->output(output_port);
294   TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &shape));
295   c->set_output(output_port, shape);
296 
297   // TODO(vrv): Do we need to propagate the new shape through all
298   // consumers that change their outputs?  At the moment, python
299   // does not do this, but this seems like a nice feature.
300 
301   // TODO(vrv): We might need to keep track of the fact that the
302   // existing shape is invalidated, in case we need to propagate
303   // this information to remote workers.
304   return Status::OK();
305 }
306 
UpdateNode(const Node * node,bool relax,bool * refined)307 Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
308   auto it = node_to_context_.find(node);
309   if (it == node_to_context_.end()) {
310     *refined = true;
311     return AddNode(node);
312   }
313   ExtendedInferenceContext* node_ext_context = it->second.get();
314   InferenceContext* node_context = node_ext_context->get_context();
315 
316   // Give up if the context wasn't successfully built by the AddNode() method.
317   TF_RETURN_IF_ERROR(node_context->construction_status());
318 
319   // Check if the shapes of the nodes in the fan-in of this node have changed,
320   // and if they have update the node input shapes.
321   for (const Edge* e : node->in_edges()) {
322     if (e->IsControlEdge()) continue;
323 
324     int dst_input = e->dst_input();
325     int src_output = e->src_output();
326 
327     Node* input = e->src();
328     auto iter = node_to_context_.find(input);
329     if (iter == node_to_context_.end()) {
330       return errors::FailedPrecondition(
331           "Input ", dst_input, " ('", input->name(), "') for '", node->name(),
332           "' was not previously added to ShapeRefiner.");
333     }
334 
335     InferenceContext* c = iter->second->get_context();
336     DCHECK_GE(dst_input, 0);
337     ShapeHandle existing_input = node_context->input(dst_input);
338     if (!relax) {
339       if (node_context->MergeInput(dst_input, c->output(src_output))) {
340         if (!SameDefinedShape(node_context, node_context->input(dst_input),
341                               existing_input)) {
342           *refined = true;
343         }
344       }
345     } else {
346       if (node_context->RelaxInput(dst_input, c->output(src_output))) {
347         if (!SameDefinedShape(node_context, node_context->input(dst_input),
348                               existing_input)) {
349           *refined = true;
350         }
351       }
352     }
353     if (node_context->requested_input_tensor_as_partial_shape(dst_input)) {
354       // The input value may have changed. Since we have no way to know if
355       // that's indeed the case, err on the safe side.
356       *refined = true;
357     }
358 
359     // Also propagate handle shape and dtype of edges which are carrying
360     // resource handles.
361     if (e->src()->output_type(src_output) == DT_RESOURCE) {
362       auto* outputs = c->output_handle_shapes_and_types(src_output);
363       if (!outputs) continue;
364 
365       if (!relax &&
366           node_context->MergeInputHandleShapesAndTypes(dst_input, *outputs)) {
367         *refined = true;
368       } else if (relax) {
369         std::vector<ShapeAndType> existing_inputs;
370         const std::vector<ShapeAndType>* inputs =
371             node_context->input_handle_shapes_and_types(dst_input);
372         if (inputs) {
373           existing_inputs = *inputs;
374         }
375         if (node_context->RelaxInputHandleShapesAndMergeTypes(dst_input,
376                                                               *outputs)) {
377           if (IsUpdatedShapesOrTypes(
378                   node_context, existing_inputs,
379                   *node_context->input_handle_shapes_and_types(dst_input))) {
380             *refined = true;
381           }
382         }
383       }
384     }
385   }
386 
387   if (!*refined) {
388     // No input shape has changed, we're done
389     return Status::OK();
390   }
391 
392   // Get and run the shape function for this node to update the shapes of the
393   // outputs.
394   const OpRegistrationData* op_reg_data;
395   TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
396   if (op_reg_data->shape_inference_fn == nullptr &&
397       require_shape_inference_fns_) {
398     return errors::InvalidArgument(
399         "No shape inference function exists for op '", node->type_string(),
400         "', did you forget to define it?");
401   }
402 
403   if (!op_reg_data->shape_inference_fn) {
404     // There is nothing more we can infer
405     return Status::OK();
406   }
407 
408   return RunShapeFn(node, op_reg_data, node_ext_context);
409 }
410 
EvaluateConstantTensorForEdge(const Node * node,int dst_idx,bool * evaluated,Tensor * result,InferenceContext * outer_context)411 Status ShapeRefiner::EvaluateConstantTensorForEdge(
412     const Node* node, int dst_idx, bool* evaluated, Tensor* result,
413     InferenceContext* outer_context) {
414   *evaluated = false;
415   const Edge* input_edge;
416   TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
417   OutputTensor tensor(input_edge->src(), input_edge->src_output());
418   return EvaluateConstantTensor(
419       tensor, *this, *ops_registry_, graph_def_version_, evaluated, result,
420       &graph_runner_, &const_tensor_map_, kMaxTensorSize,
421       disable_constant_propagation_, outer_context);
422 }
423 
EvaluateConstantIntScalarEdge(const Node * node,int dst_idx,bool * evaluated,int64 * result,shape_inference::InferenceContext * outer_context)424 Status ShapeRefiner::EvaluateConstantIntScalarEdge(
425     const Node* node, int dst_idx, bool* evaluated, int64* result,
426     shape_inference::InferenceContext* outer_context) {
427   Tensor scalar;
428   TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, evaluated,
429                                                    &scalar, outer_context));
430   if (*evaluated) {
431     if (scalar.NumElements() != 1) {
432       return errors::InvalidArgument(
433           "EvaluateConstantIntScalarEdge called on non-scalar edge: ",
434           scalar.NumElements());
435     }
436     if (scalar.dtype() == DT_INT32) {
437       *result = scalar.scalar<int32>()();
438     } else {
439       if (scalar.dtype() != DT_INT64) {
440         return errors::InvalidArgument(
441             "EvaluateConstantIntScalarEdge called on non-integer edge: ",
442             scalar.dtype());
443       }
444       *result = scalar.scalar<int64>()();
445     }
446   }
447   return Status::OK();
448 }
449 
ConstantPartialShape(InferenceContext * target_context,const Node * node,int dst_idx,ShapeHandle * result,shape_inference::InferenceContext * outer_context)450 Status ShapeRefiner::ConstantPartialShape(
451     InferenceContext* target_context, const Node* node, int dst_idx,
452     ShapeHandle* result, shape_inference::InferenceContext* outer_context) {
453   const Edge* input_edge;
454   TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
455 
456   InferenceContext* src_context = GetContext(input_edge->src());
457   if (src_context == nullptr) return errors::Internal("Missing src context");
458   ShapeHandle src_shape = src_context->output(input_edge->src_output());
459 
460   if (src_context->Value(src_context->Rank(src_shape)) == 0) {
461     Tensor t;
462     bool evaluated = false;
463     TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
464                                                      &t, outer_context));
465     if (!evaluated) {
466       return errors::InvalidArgument(
467           "Received a shape scalar with unknown static value.  A static value "
468           "of '-1' is required to represent an unknown shape.");
469     }
470     if (t.dims() == 0) {
471       if (t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) {
472         *result = target_context->UnknownShape();
473         return Status::OK();
474       } else if (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1) {
475         *result = target_context->UnknownShape();
476         return Status::OK();
477       }
478     }
479     return errors::InvalidArgument(
480         "Received an invalid shape scalar with a static value that is not "
481         "'-1': ",
482         t.DebugString());
483   }
484 
485   TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
486 
487   const string& src_op = input_edge->src()->type_string();
488   if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) {
489     // Source tensor is a vector of length 0, so the shape it
490     // represents is as scalar.
491     *result = target_context->Scalar();
492   } else if (src_op == "Cast") {
493     // First try to evaluate the current tensor, as it might be a valid cast of
494     // a float.
495     Tensor t;
496     bool evaluated = false;
497     if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t,
498                                       outer_context)
499             .ok()) {
500       if (evaluated &&
501           target_context->MakeShapeFromTensor(&t, src_shape, result).ok()) {
502         return Status::OK();
503       }
504     }
505 
506     // Then try to infer partial shape from the input to the cast tensor.
507     ShapeHandle pre_cast_shape;
508     if (!ConstantPartialShape(target_context, input_edge->src(), 0,
509                               &pre_cast_shape, outer_context)
510              .ok()) {
511       TF_RETURN_IF_ERROR(
512           target_context->MakeShapeFromTensor(nullptr, src_shape, result));
513     }
514     if (!target_context->RankKnown(pre_cast_shape)) {
515       // Failed to evaluate. Treat the output as completely unknown.
516       *result = target_context->UnknownShape();
517       return Status::OK();
518     }
519     auto* dest_type = input_edge->src()->attrs().Find("DstT");
520     if (dest_type == nullptr || dest_type->value_case() != AttrValue::kType ||
521         (dest_type->type() != DT_INT32 && dest_type->type() != DT_INT64)) {
522       // Casting to a weird type. Do not attempt to infer across it.
523       *result = target_context->MakeShape(std::vector<DimensionHandle>(
524           target_context->Rank(pre_cast_shape), target_context->UnknownDim()));
525       return Status::OK();
526     }
527     *result = pre_cast_shape;
528   } else if (src_op == "Shape") {
529     *result = src_context->input(0);
530   } else if (src_op == "ShapeN") {
531     *result = src_context->input(input_edge->src_output());
532   } else if (src_op == "Pack") {
533     std::vector<DimensionHandle> dims;
534     // Pack is concatenating its input scalars to form the shape tensor vector.
535     for (int i = 0; i < src_context->num_inputs(); ++i) {
536       int64_t size;
537       bool evaluated;
538       TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(
539           input_edge->src(), i, &evaluated, &size, outer_context));
540       if (evaluated) {
541         dims.push_back(size < 0 ? target_context->UnknownDim()
542                                 : target_context->MakeDim(size));
543       } else {
544         dims.push_back(target_context->UnknownDim());
545       }
546     }
547     *result = target_context->MakeShape(dims);
548   } else if (src_op == "Concat" || src_op == "ConcatV2") {
549     *result = target_context->Scalar();
550     // For Concat, input 0 is concat dim; for V2 it is the last input.
551     const int concat_dim =
552         src_op == "Concat" ? 0 : src_context->num_inputs() - 1;
553     // Concat is concatenating its input shape vectors.
554     for (int i = 0; i < src_context->num_inputs(); ++i) {
555       // Concat dim is ignored (and will always be a scalar).
556       if (i == concat_dim) continue;
557       ShapeHandle sub_result;
558       TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
559                                               i, &sub_result, outer_context));
560       if (!target_context->RankKnown(sub_result)) {
561         // Failed to evaluate. Treat the output as completely unknown.
562         // TODO(cwhipkey): we could rely on all inputs being the same rank, so
563         // figure that rank out and append the right number of unknown dims.
564         *result = target_context->UnknownShape();
565         return Status::OK();
566       }
567       TF_RETURN_IF_ERROR(
568           target_context->Concatenate(*result, sub_result, result));
569     }
570   } else if (src_op == "StridedSlice") {
571     TF_RETURN_IF_ERROR(PartialStridedSliceShape(input_edge->src(), src_context,
572                                                 result, outer_context));
573   } else if (src_op == "VariableShape") {
574     auto* handle_data = src_context->input_handle_shapes_and_types(0);
575     if (handle_data != nullptr && !handle_data->empty()) {
576       *result = handle_data->at(0).shape;
577     } else {
578       *result = target_context->UnknownShape();
579     }
580   } else {
581     Tensor t;
582     bool evaluated = false;
583     TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
584                                                      &t, outer_context));
585     TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
586         evaluated ? &t : nullptr, src_shape, result));
587   }
588   return Status::OK();
589 }
590 
PartialStridedSliceShape(Node * slice_node,InferenceContext * ctx,ShapeHandle * result,shape_inference::InferenceContext * outer_context)591 Status ShapeRefiner::PartialStridedSliceShape(
592     Node* slice_node, InferenceContext* ctx, ShapeHandle* result,
593     shape_inference::InferenceContext* outer_context) {
594   // Only attempt to evaluate if begin/end/strides all are scalars.
595   for (int i = 1; i <= 3; ++i) {
596     ShapeHandle input_shape = ctx->input(i);
597     if (ctx->Value(ctx->Dim(input_shape, 0)) != 1) {
598       *result = ctx->UnknownShape();
599       return Status::OK();
600     }
601   }
602 
603   int begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
604   TF_RETURN_IF_ERROR(
605       GetNodeAttr(slice_node->attrs(), "begin_mask", &begin_mask));
606   TF_RETURN_IF_ERROR(GetNodeAttr(slice_node->attrs(), "end_mask", &end_mask));
607   TF_RETURN_IF_ERROR(
608       GetNodeAttr(slice_node->attrs(), "ellipsis_mask", &ellipsis_mask));
609   TF_RETURN_IF_ERROR(
610       GetNodeAttr(slice_node->attrs(), "new_axis_mask", &new_axis_mask));
611   TF_RETURN_IF_ERROR(
612       GetNodeAttr(slice_node->attrs(), "shrink_axis_mask", &shrink_axis_mask));
613 
614   // Only attempt to evaluate if there are no special masks set (note that we
615   // can handle begin/end_mask == 1).
616   if (!(begin_mask == 0 || begin_mask == 1) ||
617       !(end_mask == 0 || end_mask == 1) || ellipsis_mask != 0 ||
618       new_axis_mask != 0 || shrink_axis_mask != 0) {
619     *result = ctx->UnknownShape();
620     return Status::OK();
621   }
622 
623   bool evaluated;
624   int64_t begin;
625   if (begin_mask == 1) {
626     begin = 0;
627   } else {
628     TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated,
629                                                      &begin, outer_context));
630     if (!evaluated) {
631       *result = ctx->UnknownShape();
632       return Status::OK();
633     }
634   }
635 
636   int64_t end;
637   if (end_mask == 1) {
638     end = std::numeric_limits<int64>::max();
639   } else {
640     TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated,
641                                                      &end, outer_context));
642     if (!evaluated) {
643       *result = ctx->UnknownShape();
644       return Status::OK();
645     }
646   }
647 
648   int64_t stride;
649   TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated,
650                                                    &stride, outer_context));
651   if (!evaluated) {
652     *result = ctx->UnknownShape();
653     return Status::OK();
654   }
655 
656   // Apply stride to input interpreted as a partial shape.
657   ShapeHandle input;
658   TF_RETURN_IF_ERROR(
659       ConstantPartialShape(ctx, slice_node, 0, &input, outer_context));
660   TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result));
661   return Status::OK();
662 }
663 
RunShapeFn(const Node * node,const OpRegistrationData * op_reg_data,ExtendedInferenceContext * ec,InferenceContext * outer_context)664 Status ShapeRefiner::RunShapeFn(const Node* node,
665                                 const OpRegistrationData* op_reg_data,
666                                 ExtendedInferenceContext* ec,
667                                 InferenceContext* outer_context) {
668   // This will be filled in with real data in a second pass.
669   std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
670   std::vector<Tensor> real_tensors(node->num_inputs());
671   std::vector<bool> attempted_materialization(node->num_inputs());
672   std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
673   std::vector<ShapeHandle> input_tensors_as_shapes;
674 
675   auto* c = ec->get_context();
676 
677   c->set_input_tensors(input_tensors);
678   c->set_input_tensors_as_shapes(input_tensors_as_shapes);
679 
680   // Run the shape inference function, and return if there was an error.
681   // Capture as lambda, because we might need to re-run inference later on.
682   auto run_inference_lambda = [&]() {
683     if (function_library_ && IsFunctionCall(*function_library_, *node)) {
684       bool disable_shape_inference;
685       if (!GetNodeAttr(AttrSlice(node->def()), "_disable_call_shape_inference",
686                        &disable_shape_inference)
687                .ok() ||
688           !disable_shape_inference) {
689         // Special inference logic for user-defined functions.
690         NameAttrList function;
691         TF_RETURN_IF_ERROR(
692             NameAndAttrsFromFunctionCall(node->def(), &function));
693         const FunctionDef* function_def =
694             function_library_->Find(function.name());
695         if (function_def != nullptr) {
696           // The constant Tensor map we have for the outside context is not
697           // valid inside the function. We need to push a new clean map while
698           // performing inference on the function body.
699           auto const_tensor_map_copy = const_tensor_map_;
700           const_tensor_map_.clear();
701           Status function_inference_status = InferShapesForFunction(
702               function_def, AttrSlice(&function.attr()), ec);
703           const_tensor_map_ = const_tensor_map_copy;
704           return function_inference_status;
705         }
706       }
707     }
708 
709     if (op_reg_data->shape_inference_fn) {
710       TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
711     } else {
712       TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
713     }
714     return Status::OK();
715   };
716   TF_RETURN_IF_ERROR(run_inference_lambda());
717 
718   // We must run the shape function repeatedly, in case users write
719   // shape functions where they only conditionally call input_tensor()
720   // based on the values of another input tensor.
721   bool rerun_shape_fn;
722   do {
723     // If the result of running shape inference would have benefitted
724     // from knowing the values of input tensors, try to materialize
725     // the results of those tensors, and then run the shape inference
726     // function again using those known tensors.
727     rerun_shape_fn = false;
728 
729     // NOTE: It is possible to batch the extraction and
730     // materialization of inputs, instead of materializing one input
731     // at a time like we do below.  If input-at-a-time computation
732     // becomes a bottleneck, we could separate ExtractConstantSubgraph
733     // into two functions: one that returns true if an input is
734     // derivable from constants, and another function that extracts
735     // the subgraph for multiple target nodes and executes the whole
736     // subgraph once.
737 
738     for (int i = 0; i < c->num_inputs(); ++i) {
739       if (!c->requested_input_tensor(i)) {
740         continue;
741       }
742       // Check if we have not already filled in the requested input,
743       // and if not, try to materialize the tensors.
744       if (!attempted_materialization[i]) {
745         attempted_materialization[i] = true;
746 
747         Tensor result;
748         bool evaluated = false;
749         TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(
750             node, i, &evaluated, &result, outer_context));
751         if (evaluated) {
752           real_tensors[i] = result;
753           input_tensors[i] = &real_tensors[i];
754           // We have more concrete information about a shape,
755           // so re-run shape inference.
756           rerun_shape_fn = true;
757         }
758       }
759       if (c->requested_input_tensor_as_partial_shape(i) &&
760           !attempted_tensor_as_shape_conversion[i]) {
761         attempted_tensor_as_shape_conversion[i] = true;
762         if (i >= input_tensors_as_shapes.size()) {
763           input_tensors_as_shapes.resize(i + 1);
764         }
765         ShapeHandle s;
766         TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s, outer_context));
767         input_tensors_as_shapes[i] = s;
768         rerun_shape_fn = true;
769       }
770     }
771 
772     if (rerun_shape_fn) {
773       // We have more information about the shapes on this pass,
774       // so re-run shape inference.
775       c->set_input_tensors(input_tensors);
776       c->set_input_tensors_as_shapes(input_tensors_as_shapes);
777       TF_RETURN_IF_ERROR(run_inference_lambda());
778     }
779   } while (rerun_shape_fn);
780 
781   return Status::OK();
782 }
783 
SameDefinedShape(InferenceContext * c,ShapeHandle s0,ShapeHandle s1)784 bool ShapeRefiner::SameDefinedShape(InferenceContext* c, ShapeHandle s0,
785                                     ShapeHandle s1) {
786   if (s0.SameHandle(s1)) {
787     return true;
788   }
789   if (c->Rank(s0) != c->Rank(s1)) {
790     return false;
791   }
792   if (!c->RankKnown(s0) && !c->RankKnown(s1)) {
793     return false;
794   }
795   for (int i = 0; i < c->Rank(s0); ++i) {
796     if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
797       int64_t val0 = c->Value(c->Dim(s0, i));
798       int64_t val1 = c->Value(c->Dim(s1, i));
799       if (val0 < 0 || val1 < 0 || val0 != val1) {
800         return false;
801       }
802     }
803   }
804 
805   return true;
806 }
807 
IsUpdatedShapesOrTypes(InferenceContext * c,const std::vector<ShapeAndType> & existing,const std::vector<ShapeAndType> & updated)808 bool ShapeRefiner::IsUpdatedShapesOrTypes(
809     InferenceContext* c, const std::vector<ShapeAndType>& existing,
810     const std::vector<ShapeAndType>& updated) {
811   if (existing.size() != updated.size()) {
812     return true;
813   }
814   for (int i = 0; i < existing.size(); i++) {
815     if (!SameDefinedShape(c, existing[i].shape, updated[i].shape) ||
816         existing[i].dtype != updated[i].dtype) {
817       return true;
818     }
819   }
820   return false;
821 }
822 
823 }  // namespace tensorflow
824