• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/shape_refiner.h"
16 
17 #include <deque>
18 #include <memory>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
23 #include "tensorflow/core/common_runtime/function_utils.h"
24 #include "tensorflow/core/common_runtime/graph_constructor.h"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/common_shape_fns.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/shape_inference.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 
35 namespace tensorflow {
36 
37 using shape_inference::DimensionHandle;
38 using shape_inference::InferenceContext;
39 using shape_inference::ShapeAndType;
40 using shape_inference::ShapeHandle;
41 
ShapeRefiner(int graph_def_version,const OpRegistryInterface * ops)42 ShapeRefiner::ShapeRefiner(int graph_def_version,
43                            const OpRegistryInterface* ops)
44     : graph_def_version_(graph_def_version),
45       ops_registry_(ops),
46       graph_runner_(Env::Default()) {}
47 
ShapeRefiner(const VersionDef & versions,const OpRegistryInterface * ops)48 ShapeRefiner::ShapeRefiner(const VersionDef& versions,
49                            const OpRegistryInterface* ops)
50     : ShapeRefiner(versions.producer(), ops) {}
51 
~ShapeRefiner()52 ShapeRefiner::~ShapeRefiner() {
53   // The lifetime of the tensors are bound to the GraphRunner, so the tensors
54   // should be deleted before it.
55   const_tensor_map_.clear();
56 }
57 
58 namespace {
59 
60 constexpr char kArgOp[] = "_Arg";
61 constexpr char kRetvalOp[] = "_Retval";
62 
63 }  // namespace
64 
65 // Runs shape inference for the given node using the given ShapeRefiner.
66 // The node must be a sub-node of a function node and the outer_context is
67 // the inference context of that function node in the outer graph.
InferShapesForFunctionSubNode(const Node * node,InferenceContext * outer_context)68 Status ShapeRefiner::InferShapesForFunctionSubNode(
69     const Node* node, InferenceContext* outer_context) {
70   TF_RETURN_IF_ERROR(AddNodeInternal(node, outer_context));
71   InferenceContext* node_context = CHECK_NOTNULL(GetContext(node));
72 
73   if (StringPiece(node->type_string()) == kArgOp) {
74     // Handle special node: function input.
75     // Shapes for these nodes are provided in the outer inference
76     // context.
77 
78     int index;
79     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
80 
81     if (index < 0 || outer_context->num_inputs() <= index) {
82       return errors::Internal(
83           "Function instantiation included invalid input index: ", index,
84           " not in [0, ", outer_context->num_inputs(), ").");
85     }
86 
87     // TODO(b/134547156): TEMPORARY WORKAROUND. If input shape handle is not set
88     // in outer context, set _Arg node output shape to unknown.
89     if (outer_context->input(index).SameHandle(ShapeHandle())) {
90       VLOG(1) << "Function instantiation has undefined input shape at "
91               << "index: " << index << " in the outer inference context.";
92       node_context->set_output(0, node_context->UnknownShape());
93     } else {
94       node_context->set_output(0, outer_context->input(index));
95     }
96 
97     auto* resource = outer_context->input_handle_shapes_and_types(index);
98     if (resource) {
99       node_context->set_output_handle_shapes_and_types(0, *resource);
100     }
101   } else if (StringPiece(node->type_string()) == kRetvalOp) {
102     // Handle special node: function output.
103     // Shapes inferred for these nodes go into the outer inference
104     // context.
105 
106     int index;
107     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
108 
109     if (index < 0 || outer_context->num_outputs() <= index) {
110       return errors::Internal(
111           "Function instantiation included invalid output index: ", index,
112           " not in [0, ", outer_context->num_outputs(), ").");
113     }
114 
115     // outer_context outlives node_context, therefore we need to create
116     // a new shape handle owned by outer_context instead.
117     ShapeHandle handle;
118     TensorShapeProto proto;
119     node_context->ShapeHandleToProto(node_context->input(0), &proto);
120     TF_RETURN_IF_ERROR(outer_context->MakeShapeFromShapeProto(proto, &handle));
121     outer_context->set_output(index, handle);
122 
123     const std::vector<ShapeAndType>* resource =
124         node_context->input_handle_shapes_and_types(0);
125     if (resource) {
126       // `ShapesAndType`s contain `ShapeHandle`s.  These `ShapeHandle`s point
127       // to `Shape`s that are owned by a different inference context too.  We
128       // need to copy them to the outer context to prevent them from being
129       // destroyed before they are used.
130       std::vector<ShapeAndType> copied_shapes_and_types;
131       for (auto& shape_and_type : *resource) {
132         ShapeHandle handle;
133         TensorShapeProto proto;
134         node_context->ShapeHandleToProto(shape_and_type.shape, &proto);
135         TF_RETURN_IF_ERROR(
136             outer_context->MakeShapeFromShapeProto(proto, &handle));
137         copied_shapes_and_types.push_back(
138             ShapeAndType(handle, shape_and_type.dtype, shape_and_type.type));
139       }
140 
141       outer_context->set_output_handle_shapes_and_types(
142           index, copied_shapes_and_types);
143     }
144   }
145 
146   return OkStatus();
147 }
148 
149 // TODO(cwhipkey): When an inference context inside function has
150 // requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
151 // set when input(i) is an _Arg op, then this request should propagate to
152 // context, and vice versa.
153 //
154 // NOTE: Recursive user-defined functions are not supported.
155 // Maybe we won't support recursive functions at all in TF, because of
156 // other maintainability issues.
InferShapesForFunction(const FunctionDef * function_def,AttrSlice attributes,InferenceContext * outer_context)157 Status ShapeRefiner::InferShapesForFunction(const FunctionDef* function_def,
158                                             AttrSlice attributes,
159                                             InferenceContext* outer_context) {
160   const Graph* graph;
161   auto it = functions_.find(function_def);
162   if (it != functions_.end()) {
163     graph = it->second.get();
164   } else {
165     InstantiationResult result;
166     TF_RETURN_IF_ERROR(InstantiateFunction(
167         *function_def, attributes,
168         [this](const string& op, const OpDef** sig) {
169           return this->function_library_->LookUpOpDef(op, sig);
170         },
171         &result));
172 
173     Graph* new_graph = new Graph(function_library_);
174     GraphConstructorOptions options;
175     options.allow_internal_ops = true;
176     TF_RETURN_IF_ERROR(
177         ConvertNodeDefsToGraph(options, result.nodes, new_graph));
178     functions_[function_def].reset(new_graph);
179     graph = new_graph;
180   }
181 
182   std::unordered_set<const Node*> function_nodes;
183   Status inference_status = OkStatus();
184   {
185     auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
186                                         &inference_status](const Node* node) {
187       if (!inference_status.ok()) return;
188       inference_status = InferShapesForFunctionSubNode(node, outer_context);
189       function_nodes.insert(node);
190     };
191 
192     // Calls inference lambda for each node after visiting all predecessors.
193     // Ensures that we are adding nodes to ShapeRefiner in the topological
194     // order.
195     ReverseDFS(*graph, {}, node_shape_inference_lambda);
196   }
197 
198   // Delete the contexts created for the functions nodes to save memory.
199   for (const Node* node : function_nodes) {
200     node_to_context_.erase(node);
201   }
202 
203   return inference_status;
204 }
205 
AddNode(const Node * node)206 Status ShapeRefiner::AddNode(const Node* node) {
207   return AddNodeInternal(node, /*outer_context=*/nullptr);
208 }
209 
AddNodeInternal(const Node * node,shape_inference::InferenceContext * outer_context)210 Status ShapeRefiner::AddNodeInternal(
211     const Node* node, shape_inference::InferenceContext* outer_context) {
212   // Create the inference context for this node with the existing input shapes.
213   std::unique_ptr<InferenceContext> ic(new InferenceContext(
214       graph_def_version_, node->def(), node->op_def(),
215       std::vector<ShapeHandle>(node->num_inputs()), {}, {}, {}));
216   TF_RETURN_IF_ERROR(ic->construction_status());
217 
218   // For each 'input' of this node, fetch the corresponding shape
219   // from 'input's InferenceContext, and store into this node's
220   // InferenceContext.
221   for (const Edge* e : node->in_edges()) {
222     if (e->IsControlEdge()) continue;
223 
224     if (e->dst_input() < 0) {
225       return tensorflow::errors::Internal(
226           "Index ", e->dst_input(), " is negative but not a control edge.");
227     }
228 
229     const Node* input = e->src();
230     auto it = node_to_context_.find(input);
231     if (it == node_to_context_.end()) {
232       // v1 control flow adds loops to the graph; we have to break them
233       // somewhere, so we'll ignore this input and leave its shape undefined.
234       ic->SetInput(e->dst_input(), ic->UnknownShape());
235       continue;
236     }
237 
238     InferenceContext* input_ic = it->second->get_context();
239     ic->SetInput(e->dst_input(), input_ic->output(e->src_output()));
240 
241     const auto* in_v =
242         input_ic->output_handle_shapes_and_types(e->src_output());
243     if (in_v != nullptr) {
244       DataType input_type = e->src()->output_type(e->src_output());
245       DCHECK(input_type == DT_RESOURCE || input_type == DT_VARIANT);
246       ic->set_input_handle_shapes_and_types(e->dst_input(),
247                                             std::vector<ShapeAndType>(*in_v));
248     }
249   }
250 
251   // Get the shape function for this node
252   const OpRegistrationData* op_reg_data;
253   TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
254   if (op_reg_data->shape_inference_fn == nullptr &&
255       require_shape_inference_fns_) {
256     return errors::InvalidArgument(
257         "No shape inference function exists for op '", node->type_string(),
258         "', did you forget to define it?");
259   }
260 
261   std::unique_ptr<ExtendedInferenceContext> ec(
262       new ExtendedInferenceContext(std::move(ic), node));
263 
264   // Run the shape inference function, and return if there was an error.
265   TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get(), outer_context));
266 
267   // Store the resulting context object in the map.
268   node_to_context_[node].swap(ec);
269 
270   return OkStatus();
271 }
272 
SetShape(const Node * node,int output_port,ShapeHandle shape)273 Status ShapeRefiner::SetShape(const Node* node, int output_port,
274                               ShapeHandle shape) {
275   auto c = GetContext(node);
276   if (c == nullptr) {
277     return errors::Internal("Could not find context for ", node->name());
278   }
279 
280   if (output_port < 0 || output_port >= node->num_outputs()) {
281     return errors::InvalidArgument(
282         "output_port '", output_port, "' is out of range, ", "node '",
283         node->name(), "' has ", node->num_outputs(), " outputs");
284   }
285   // Note: it's possible, if the node's been updated, that the shape inference
286   // context doesn't have the right number of outputs.
287   if (node->num_outputs() > c->num_outputs()) {
288     TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs()));
289   }
290 
291   // Check compatibility, and merge the shapes.
292   ShapeHandle existing_shape = c->output(output_port);
293   TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &shape));
294   c->set_output(output_port, shape);
295 
296   // TODO(vrv): Do we need to propagate the new shape through all
297   // consumers that change their outputs?  At the moment, python
298   // does not do this, but this seems like a nice feature.
299 
300   // TODO(vrv): We might need to keep track of the fact that the
301   // existing shape is invalidated, in case we need to propagate
302   // this information to remote workers.
303   return OkStatus();
304 }
305 
UpdateNode(const Node * node,bool relax,bool * refined)306 Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
307   auto it = node_to_context_.find(node);
308   if (it == node_to_context_.end()) {
309     *refined = true;
310     return AddNode(node);
311   }
312   ExtendedInferenceContext* node_ext_context = it->second.get();
313   InferenceContext* node_context = node_ext_context->get_context();
314 
315   // Give up if the context wasn't successfully built by the AddNode() method.
316   TF_RETURN_IF_ERROR(node_context->construction_status());
317 
318   // Check if the shapes of the nodes in the fan-in of this node have changed,
319   // and if they have update the node input shapes.
320   for (const Edge* e : node->in_edges()) {
321     if (e->IsControlEdge()) continue;
322 
323     int dst_input = e->dst_input();
324     int src_output = e->src_output();
325 
326     Node* input = e->src();
327     auto iter = node_to_context_.find(input);
328     if (iter == node_to_context_.end()) {
329       return errors::FailedPrecondition(
330           "Input ", dst_input, " ('", input->name(), "') for '", node->name(),
331           "' was not previously added to ShapeRefiner.");
332     }
333 
334     InferenceContext* c = iter->second->get_context();
335     DCHECK_GE(dst_input, 0);
336     ShapeHandle existing_input = node_context->input(dst_input);
337     if (!relax) {
338       if (node_context->MergeInput(dst_input, c->output(src_output))) {
339         if (!SameDefinedShape(node_context, node_context->input(dst_input),
340                               existing_input)) {
341           *refined = true;
342         }
343       }
344     } else {
345       if (node_context->RelaxInput(dst_input, c->output(src_output))) {
346         if (!SameDefinedShape(node_context, node_context->input(dst_input),
347                               existing_input)) {
348           *refined = true;
349         }
350       }
351     }
352     if (node_context->requested_input_tensor_as_partial_shape(dst_input)) {
353       // The input value may have changed. Since we have no way to know if
354       // that's indeed the case, err on the safe side.
355       *refined = true;
356     }
357 
358     // Also propagate handle shape and dtype of edges which are carrying
359     // resource handles.
360     if (e->src()->output_type(src_output) == DT_RESOURCE) {
361       auto* outputs = c->output_handle_shapes_and_types(src_output);
362       if (!outputs) continue;
363 
364       if (!relax &&
365           node_context->MergeInputHandleShapesAndTypes(dst_input, *outputs)) {
366         *refined = true;
367       } else if (relax) {
368         std::vector<ShapeAndType> existing_inputs;
369         const std::vector<ShapeAndType>* inputs =
370             node_context->input_handle_shapes_and_types(dst_input);
371         if (inputs) {
372           existing_inputs = *inputs;
373         }
374         if (node_context->RelaxInputHandleShapesAndMergeTypes(dst_input,
375                                                               *outputs)) {
376           if (IsUpdatedShapesOrTypes(
377                   node_context, existing_inputs,
378                   *node_context->input_handle_shapes_and_types(dst_input))) {
379             *refined = true;
380           }
381         }
382       }
383     }
384   }
385 
386   if (!*refined) {
387     // No input shape has changed, we're done
388     return OkStatus();
389   }
390 
391   // Get and run the shape function for this node to update the shapes of the
392   // outputs.
393   const OpRegistrationData* op_reg_data;
394   TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
395   if (op_reg_data->shape_inference_fn == nullptr &&
396       require_shape_inference_fns_) {
397     return errors::InvalidArgument(
398         "No shape inference function exists for op '", node->type_string(),
399         "', did you forget to define it?");
400   }
401 
402   if (!op_reg_data->shape_inference_fn) {
403     // There is nothing more we can infer
404     return OkStatus();
405   }
406 
407   return RunShapeFn(node, op_reg_data, node_ext_context);
408 }
409 
EvaluateConstantTensorForEdge(const Node * node,int dst_idx,bool * evaluated,Tensor * result,InferenceContext * outer_context)410 Status ShapeRefiner::EvaluateConstantTensorForEdge(
411     const Node* node, int dst_idx, bool* evaluated, Tensor* result,
412     InferenceContext* outer_context) {
413   *evaluated = false;
414   const Edge* input_edge;
415   TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
416   OutputTensor tensor(input_edge->src(), input_edge->src_output());
417   return EvaluateConstantTensor(
418       tensor, *this, *ops_registry_, graph_def_version_, evaluated, result,
419       &graph_runner_, &const_tensor_map_, kMaxTensorSize,
420       disable_constant_propagation_, outer_context);
421 }
422 
EvaluateConstantIntScalarEdge(const Node * node,int dst_idx,bool * evaluated,int64_t * result,shape_inference::InferenceContext * outer_context)423 Status ShapeRefiner::EvaluateConstantIntScalarEdge(
424     const Node* node, int dst_idx, bool* evaluated, int64_t* result,
425     shape_inference::InferenceContext* outer_context) {
426   Tensor scalar;
427   TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, evaluated,
428                                                    &scalar, outer_context));
429   if (*evaluated) {
430     if (scalar.NumElements() != 1) {
431       return errors::InvalidArgument(
432           "EvaluateConstantIntScalarEdge called on non-scalar edge: ",
433           scalar.NumElements());
434     }
435     if (scalar.dtype() == DT_INT32) {
436       *result = scalar.scalar<int32>()();
437     } else {
438       if (scalar.dtype() != DT_INT64) {
439         return errors::InvalidArgument(
440             "EvaluateConstantIntScalarEdge called on non-integer edge: ",
441             scalar.dtype());
442       }
443       *result = scalar.scalar<int64_t>()();
444     }
445   }
446   return OkStatus();
447 }
448 
ConstantPartialShape(InferenceContext * target_context,const Node * node,int dst_idx,ShapeHandle * result,shape_inference::InferenceContext * outer_context)449 Status ShapeRefiner::ConstantPartialShape(
450     InferenceContext* target_context, const Node* node, int dst_idx,
451     ShapeHandle* result, shape_inference::InferenceContext* outer_context) {
452   const Edge* input_edge;
453   TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
454 
455   InferenceContext* src_context = GetContext(input_edge->src());
456   if (src_context == nullptr) return errors::Internal("Missing src context");
457   ShapeHandle src_shape = src_context->output(input_edge->src_output());
458 
459   // All shapes are expected to be 1D integer tensors with the exception of the
460   // sentinel that represents an unknown shape (scalar/rank 0 tensor with -1 as
461   // value). Handle the special case first before considering the more general
462   // rank 1 case.
463 
464   if (src_context->Value(src_context->Rank(src_shape)) == 0) {
465     Tensor t;
466     bool evaluated = false;
467     TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
468                                                      &t, outer_context));
469     if (!evaluated) {
470       return errors::InvalidArgument(
471           "Received a shape scalar with unknown static value.  A static value "
472           "of '-1' is required to represent an unknown shape.");
473     }
474     if (t.dims() == 0) {
475       if (t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) {
476         *result = target_context->UnknownShape();
477         return OkStatus();
478       } else if (t.dtype() == DT_INT64 && t.scalar<int64_t>()() == -1) {
479         *result = target_context->UnknownShape();
480         return OkStatus();
481       }
482     }
483     return errors::InvalidArgument(
484         "Received an invalid shape scalar with a static value that is not "
485         "'-1': ",
486         t.DebugString());
487   }
488 
489   TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
490 
491   const string& src_op = input_edge->src()->type_string();
492   if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) {
493     // Source tensor is a vector of length 0, so the shape it
494     // represents is as scalar.
495     *result = target_context->Scalar();
496   } else if (src_op == "Cast") {
497     // First try to evaluate the current tensor, as it might be a valid cast of
498     // a float.
499     Tensor t;
500     bool evaluated = false;
501     if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t,
502                                       outer_context)
503             .ok()) {
504       if (evaluated &&
505           target_context->MakeShapeFromTensor(&t, src_shape, result).ok()) {
506         return OkStatus();
507       }
508     }
509 
510     // Then try to infer partial shape from the input to the cast tensor.
511     ShapeHandle pre_cast_shape;
512     if (!ConstantPartialShape(target_context, input_edge->src(), 0,
513                               &pre_cast_shape, outer_context)
514              .ok()) {
515       TF_RETURN_IF_ERROR(
516           target_context->MakeShapeFromTensor(nullptr, src_shape, result));
517     }
518     if (!target_context->RankKnown(pre_cast_shape)) {
519       // Failed to evaluate. Treat the output as completely unknown.
520       *result = target_context->UnknownShape();
521       return OkStatus();
522     }
523     auto* dest_type = input_edge->src()->attrs().Find("DstT");
524     if (dest_type == nullptr || dest_type->value_case() != AttrValue::kType ||
525         (dest_type->type() != DT_INT32 && dest_type->type() != DT_INT64)) {
526       // Casting to a weird type. Do not attempt to infer across it.
527       *result = target_context->MakeShape(std::vector<DimensionHandle>(
528           target_context->Rank(pre_cast_shape), target_context->UnknownDim()));
529       return OkStatus();
530     }
531     *result = pre_cast_shape;
532   } else if (src_op == "Shape") {
533     *result = src_context->input(0);
534   } else if (src_op == "ShapeN") {
535     *result = src_context->input(input_edge->src_output());
536   } else if (src_op == "Pack") {
537     std::vector<DimensionHandle> dims;
538     // Pack is concatenating its input scalars to form the shape tensor vector.
539     for (int i = 0; i < src_context->num_inputs(); ++i) {
540       int64_t size;
541       bool evaluated;
542       TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(
543           input_edge->src(), i, &evaluated, &size, outer_context));
544       if (evaluated) {
545         dims.push_back(size < 0 ? target_context->UnknownDim()
546                                 : target_context->MakeDim(size));
547       } else {
548         dims.push_back(target_context->UnknownDim());
549       }
550     }
551     *result = target_context->MakeShape(dims);
552   } else if (src_op == "Concat" || src_op == "ConcatV2") {
553     *result = target_context->Scalar();
554     // For Concat, input 0 is concat dim; for V2 it is the last input.
555     const int concat_dim =
556         src_op == "Concat" ? 0 : src_context->num_inputs() - 1;
557     // Concat is concatenating its input shape vectors.
558     for (int i = 0; i < src_context->num_inputs(); ++i) {
559       // Concat dim is ignored (and will always be a scalar).
560       if (i == concat_dim) continue;
561       ShapeHandle sub_result;
562       TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
563                                               i, &sub_result, outer_context));
564       if (!target_context->RankKnown(sub_result)) {
565         // Failed to evaluate. Treat the output as completely unknown.
566         // TODO(cwhipkey): we could rely on all inputs being the same rank, so
567         // figure that rank out and append the right number of unknown dims.
568         *result = target_context->UnknownShape();
569         return OkStatus();
570       }
571       TF_RETURN_IF_ERROR(
572           target_context->Concatenate(*result, sub_result, result));
573     }
574   } else if (src_op == "StridedSlice") {
575     TF_RETURN_IF_ERROR(PartialStridedSliceShape(input_edge->src(), src_context,
576                                                 result, outer_context));
577   } else if (src_op == "VariableShape") {
578     auto* handle_data = src_context->input_handle_shapes_and_types(0);
579     if (handle_data != nullptr && !handle_data->empty()) {
580       *result = handle_data->at(0).shape;
581     } else {
582       *result = target_context->UnknownShape();
583     }
584   } else {
585     Tensor t;
586     bool evaluated = false;
587     TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
588                                                      &t, outer_context));
589     TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
590         evaluated ? &t : nullptr, src_shape, result));
591   }
592   return OkStatus();
593 }
594 
PartialStridedSliceShape(Node * slice_node,InferenceContext * ctx,ShapeHandle * result,shape_inference::InferenceContext * outer_context)595 Status ShapeRefiner::PartialStridedSliceShape(
596     Node* slice_node, InferenceContext* ctx, ShapeHandle* result,
597     shape_inference::InferenceContext* outer_context) {
598   // Only attempt to evaluate if begin/end/strides all are scalars.
599   for (int i = 1; i <= 3; ++i) {
600     ShapeHandle input_shape = ctx->input(i);
601     if (ctx->Value(ctx->Dim(input_shape, 0)) != 1) {
602       *result = ctx->UnknownShape();
603       return OkStatus();
604     }
605   }
606 
607   int begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
608   TF_RETURN_IF_ERROR(
609       GetNodeAttr(slice_node->attrs(), "begin_mask", &begin_mask));
610   TF_RETURN_IF_ERROR(GetNodeAttr(slice_node->attrs(), "end_mask", &end_mask));
611   TF_RETURN_IF_ERROR(
612       GetNodeAttr(slice_node->attrs(), "ellipsis_mask", &ellipsis_mask));
613   TF_RETURN_IF_ERROR(
614       GetNodeAttr(slice_node->attrs(), "new_axis_mask", &new_axis_mask));
615   TF_RETURN_IF_ERROR(
616       GetNodeAttr(slice_node->attrs(), "shrink_axis_mask", &shrink_axis_mask));
617 
618   // Only attempt to evaluate if there are no special masks set (note that we
619   // can handle begin/end_mask == 1).
620   if (!(begin_mask == 0 || begin_mask == 1) ||
621       !(end_mask == 0 || end_mask == 1) || ellipsis_mask != 0 ||
622       new_axis_mask != 0 || shrink_axis_mask != 0) {
623     *result = ctx->UnknownShape();
624     return OkStatus();
625   }
626 
627   bool evaluated;
628   int64_t begin;
629   if (begin_mask == 1) {
630     begin = 0;
631   } else {
632     TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated,
633                                                      &begin, outer_context));
634     if (!evaluated) {
635       *result = ctx->UnknownShape();
636       return OkStatus();
637     }
638   }
639 
640   int64_t end;
641   if (end_mask == 1) {
642     end = std::numeric_limits<int64_t>::max();
643   } else {
644     TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated,
645                                                      &end, outer_context));
646     if (!evaluated) {
647       *result = ctx->UnknownShape();
648       return OkStatus();
649     }
650   }
651 
652   int64_t stride;
653   TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated,
654                                                    &stride, outer_context));
655   if (!evaluated) {
656     *result = ctx->UnknownShape();
657     return OkStatus();
658   }
659 
660   // Apply stride to input interpreted as a partial shape.
661   ShapeHandle input;
662   TF_RETURN_IF_ERROR(
663       ConstantPartialShape(ctx, slice_node, 0, &input, outer_context));
664   TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result));
665   return OkStatus();
666 }
667 
RunShapeFn(const Node * node,const OpRegistrationData * op_reg_data,ExtendedInferenceContext * ec,InferenceContext * outer_context)668 Status ShapeRefiner::RunShapeFn(const Node* node,
669                                 const OpRegistrationData* op_reg_data,
670                                 ExtendedInferenceContext* ec,
671                                 InferenceContext* outer_context) {
672   // This will be filled in with real data in a second pass.
673   std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
674   std::vector<Tensor> real_tensors(node->num_inputs());
675   std::vector<bool> attempted_materialization(node->num_inputs());
676   std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
677   std::vector<ShapeHandle> input_tensors_as_shapes;
678 
679   auto* c = ec->get_context();
680 
681   c->set_input_tensors(input_tensors);
682   c->set_input_tensors_as_shapes(input_tensors_as_shapes);
683 
684   // Run the shape inference function, and return if there was an error.
685   // Capture as lambda, because we might need to re-run inference later on.
686   auto run_inference_lambda = [&]() {
687     if (function_library_ && IsFunctionCall(*function_library_, *node)) {
688       bool disable_shape_inference;
689       if (!GetNodeAttr(AttrSlice(node->def()), "_disable_call_shape_inference",
690                        &disable_shape_inference)
691                .ok() ||
692           !disable_shape_inference) {
693         // Special inference logic for user-defined functions.
694         NameAttrList function;
695         TF_RETURN_IF_ERROR(
696             NameAndAttrsFromFunctionCall(node->def(), &function));
697         const FunctionDef* function_def =
698             function_library_->Find(function.name());
699         if (function_def != nullptr) {
700           // The constant Tensor map we have for the outside context is not
701           // valid inside the function. We need to push a new clean map while
702           // performing inference on the function body.
703           auto const_tensor_map_copy = const_tensor_map_;
704           const_tensor_map_.clear();
705           Status function_inference_status = InferShapesForFunction(
706               function_def, AttrSlice(&function.attr()), c);
707           const_tensor_map_ = const_tensor_map_copy;
708           return function_inference_status;
709         }
710       }
711     }
712 
713     if (op_reg_data->shape_inference_fn) {
714       TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
715     } else {
716       TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
717     }
718     return OkStatus();
719   };
720   TF_RETURN_IF_ERROR(run_inference_lambda());
721 
722   // We must run the shape function repeatedly, in case users write
723   // shape functions where they only conditionally call input_tensor()
724   // based on the values of another input tensor.
725   bool rerun_shape_fn;
726   do {
727     // If the result of running shape inference would have benefitted
728     // from knowing the values of input tensors, try to materialize
729     // the results of those tensors, and then run the shape inference
730     // function again using those known tensors.
731     rerun_shape_fn = false;
732 
733     // NOTE: It is possible to batch the extraction and
734     // materialization of inputs, instead of materializing one input
735     // at a time like we do below.  If input-at-a-time computation
736     // becomes a bottleneck, we could separate ExtractConstantSubgraph
737     // into two functions: one that returns true if an input is
738     // derivable from constants, and another function that extracts
739     // the subgraph for multiple target nodes and executes the whole
740     // subgraph once.
741 
742     for (int i = 0; i < c->num_inputs(); ++i) {
743       if (!c->requested_input_tensor(i)) {
744         continue;
745       }
746       // Check if we have not already filled in the requested input,
747       // and if not, try to materialize the tensors.
748       if (!attempted_materialization[i]) {
749         attempted_materialization[i] = true;
750 
751         Tensor result;
752         bool evaluated = false;
753         TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(
754             node, i, &evaluated, &result, outer_context));
755         if (evaluated) {
756           real_tensors[i] = result;
757           input_tensors[i] = &real_tensors[i];
758           // We have more concrete information about a shape,
759           // so re-run shape inference.
760           rerun_shape_fn = true;
761         }
762       }
763       if (c->requested_input_tensor_as_partial_shape(i) &&
764           !attempted_tensor_as_shape_conversion[i]) {
765         attempted_tensor_as_shape_conversion[i] = true;
766         if (i >= input_tensors_as_shapes.size()) {
767           input_tensors_as_shapes.resize(i + 1);
768         }
769         ShapeHandle s;
770         TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s, outer_context));
771         input_tensors_as_shapes[i] = s;
772         rerun_shape_fn = true;
773       }
774     }
775 
776     if (rerun_shape_fn) {
777       // We have more information about the shapes on this pass,
778       // so re-run shape inference.
779       c->set_input_tensors(input_tensors);
780       c->set_input_tensors_as_shapes(input_tensors_as_shapes);
781       TF_RETURN_IF_ERROR(run_inference_lambda());
782     }
783   } while (rerun_shape_fn);
784 
785   return OkStatus();
786 }
787 
SameDefinedShape(InferenceContext * c,ShapeHandle s0,ShapeHandle s1)788 bool ShapeRefiner::SameDefinedShape(InferenceContext* c, ShapeHandle s0,
789                                     ShapeHandle s1) {
790   if (s0.SameHandle(s1)) {
791     return true;
792   }
793   if (c->Rank(s0) != c->Rank(s1)) {
794     return false;
795   }
796   if (!c->RankKnown(s0) && !c->RankKnown(s1)) {
797     return false;
798   }
799   for (int i = 0; i < c->Rank(s0); ++i) {
800     if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
801       int64_t val0 = c->Value(c->Dim(s0, i));
802       int64_t val1 = c->Value(c->Dim(s1, i));
803       if (val0 < 0 || val1 < 0 || val0 != val1) {
804         return false;
805       }
806     }
807   }
808 
809   return true;
810 }
811 
IsUpdatedShapesOrTypes(InferenceContext * c,const std::vector<ShapeAndType> & existing,const std::vector<ShapeAndType> & updated)812 bool ShapeRefiner::IsUpdatedShapesOrTypes(
813     InferenceContext* c, const std::vector<ShapeAndType>& existing,
814     const std::vector<ShapeAndType>& updated) {
815   if (existing.size() != updated.size()) {
816     return true;
817   }
818   for (int i = 0; i < existing.size(); i++) {
819     if (!SameDefinedShape(c, existing[i].shape, updated[i].shape) ||
820         existing[i].dtype != updated[i].dtype) {
821       return true;
822     }
823   }
824   return false;
825 }
826 
827 }  // namespace tensorflow
828