• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/shape_refiner.h"
16 
17 #include <deque>
18 #include <memory>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/framework/bounds_check.h"
25 #include "tensorflow/core/framework/common_shape_fns.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor.pb.h"
29 #include "tensorflow/core/framework/versions.pb.h"
30 #include "tensorflow/core/graph/algorithm.h"
31 #include "tensorflow/core/graph/graph_constructor.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/public/session.h"
34 
35 namespace tensorflow {
36 
37 using shape_inference::DimensionHandle;
38 using shape_inference::InferenceContext;
39 using shape_inference::ShapeAndType;
40 using shape_inference::ShapeHandle;
41 
ShapeRefiner(int graph_def_version,const OpRegistryInterface * ops)42 ShapeRefiner::ShapeRefiner(int graph_def_version,
43                            const OpRegistryInterface* ops)
44     : graph_def_version_(graph_def_version),
45       ops_registry_(ops),
46       graph_runner_(Env::Default()) {}
47 
ShapeRefiner(const VersionDef & versions,const OpRegistryInterface * ops)48 ShapeRefiner::ShapeRefiner(const VersionDef& versions,
49                            const OpRegistryInterface* ops)
50     : ShapeRefiner(versions.producer(), ops) {}
51 
~ShapeRefiner()52 ShapeRefiner::~ShapeRefiner() {
53   // The lifetime of the tensors are bound to the GraphRunner, so the tensors
54   // should be deleted before it.
55   const_tensor_map_.clear();
56 }
57 
58 namespace {
59 
60 constexpr char kArgOp[] = "_Arg";
61 constexpr char kRetvalOp[] = "_Retval";
62 
63 // Runs shape inference for the given node using the given ShapeRefiner.
64 // The node must be a sub-node of a function node and the outer_context is
65 // the inference context of that function node in the outer graph.
InferShapesForFunctionSubNode(const Node * node,ShapeRefiner * refiner,InferenceContext * outer_context)66 Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner,
67                                      InferenceContext* outer_context) {
68   TF_RETURN_IF_ERROR(refiner->AddNode(node));
69   InferenceContext* node_context = CHECK_NOTNULL(refiner->GetContext(node));
70 
71   if (StringPiece(node->type_string()) == kArgOp) {
72     // Handle special node: function input.
73     // Shapes for these nodes are provided in the outer inference
74     // context.
75 
76     int index;
77     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
78 
79     if (index < 0 || outer_context->num_inputs() <= index) {
80       return errors::Internal(
81           "Function instantiation included invalid input index: ", index,
82           " not in [0, ", outer_context->num_inputs(), ").");
83     }
84 
85     // TODO(b/134547156): TEMPORARY WORKAROUND. If input shape handle is not set
86     // in outer context, set _Arg node output shape to unknown.
87     if (outer_context->input(index).SameHandle(ShapeHandle())) {
88       LOG(WARNING) << "Function instantiation has undefined input shape at "
89                    << "index: " << index << " in the outer inference context.";
90       node_context->set_output(0, node_context->UnknownShape());
91     } else {
92       node_context->set_output(0, outer_context->input(index));
93     }
94 
95     auto* resource = outer_context->input_handle_shapes_and_types(index);
96     if (resource) {
97       node_context->set_output_handle_shapes_and_types(0, *resource);
98     }
99   } else if (StringPiece(node->type_string()) == kRetvalOp) {
100     // Handle special node: function output.
101     // Shapes inferred for these nodes go into the outer inference
102     // context.
103 
104     int index;
105     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
106 
107     if (index < 0 || outer_context->num_outputs() <= index) {
108       return errors::Internal(
109           "Function instantiation included invalid output index: ", index,
110           " not in [0, ", outer_context->num_outputs(), ").");
111     }
112 
113     // outer_context outlives node_context, therefore we need to create
114     // a new shape handle owned by outer_context instead.
115     ShapeHandle handle;
116     TensorShapeProto proto;
117     node_context->ShapeHandleToProto(node_context->input(0), &proto);
118     TF_RETURN_IF_ERROR(outer_context->MakeShapeFromShapeProto(proto, &handle));
119     outer_context->set_output(index, handle);
120 
121     auto* resource = node_context->input_handle_shapes_and_types(0);
122     if (resource) {
123       outer_context->set_output_handle_shapes_and_types(index, *resource);
124     }
125   }
126 
127   return Status::OK();
128 }
129 
130 }  // namespace
131 
132 // TODO(cwhipkey): When an inference context inside function has
133 // requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
134 // set when input(i) is an _Arg op, then this request should propagate to
135 // context, and vice versa.
136 //
137 // NOTE: Recursive user-defined functions are not supported.
138 // Maybe we won't support recursive functions at all in TF, because of
139 // other maintainability issues.
InferShapesForFunction(const FunctionDef * function_def,AttrSlice attributes,ExtendedInferenceContext * outer_context)140 Status ShapeRefiner::InferShapesForFunction(
141     const FunctionDef* function_def, AttrSlice attributes,
142     ExtendedInferenceContext* outer_context) {
143   const Graph* graph;
144   auto it = functions_.find(function_def);
145   if (it != functions_.end()) {
146     graph = it->second.get();
147   } else {
148     InstantiationResult result;
149     TF_RETURN_IF_ERROR(InstantiateFunction(
150         *function_def, attributes,
151         [this](const string& op, const OpDef** sig) {
152           return this->function_library_->LookUpOpDef(op, sig);
153         },
154         &result));
155 
156     Graph* new_graph = new Graph(function_library_);
157     GraphConstructorOptions options;
158     options.allow_internal_ops = true;
159     TF_RETURN_IF_ERROR(
160         ConvertNodeDefsToGraph(options, result.nodes, new_graph));
161     functions_[function_def].reset(new_graph);
162     graph = new_graph;
163   }
164 
165   std::unordered_set<const Node*> function_nodes;
166   Status inference_status = Status::OK();
167   {
168     auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
169                                         &inference_status](const Node* node) {
170       if (!inference_status.ok()) return;
171       inference_status = InferShapesForFunctionSubNode(
172           node, this, outer_context->get_context());
173       function_nodes.insert(node);
174     };
175 
176     // Calls inference lambda for each node after visiting all predecessors.
177     // Ensures that we are adding nodes to ShapeRefiner in the topological
178     // order.
179     ReverseDFS(*graph, {}, node_shape_inference_lambda);
180   }
181 
182   // Delete the contexts created for the functions nodes to save memory.
183   for (const Node* node : function_nodes) {
184     node_to_context_.erase(node);
185   }
186 
187   return inference_status;
188 }
189 
AddNode(const Node * node)190 Status ShapeRefiner::AddNode(const Node* node) {
191   // Create the inference context for this node with the existing input shapes.
192   std::unique_ptr<InferenceContext> ic(new InferenceContext(
193       graph_def_version_, node->def(), node->op_def(),
194       std::vector<ShapeHandle>(node->num_inputs()), {}, {}, {}));
195   TF_RETURN_IF_ERROR(ic->construction_status());
196 
197   // For each 'input' of this node, fetch the corresponding shape
198   // from 'input's InferenceContext, and store into this node's
199   // InferenceContext.
200   for (const Edge* e : node->in_edges()) {
201     if (e->IsControlEdge()) continue;
202 
203     if (e->dst_input() < 0) {
204       return tensorflow::errors::Internal(
205           "Index ", e->dst_input(), " is negative but not a control edge.");
206     }
207 
208     const Node* input = e->src();
209     auto it = node_to_context_.find(input);
210     if (it == node_to_context_.end()) {
211       // v1 control flow adds loops to the graph; we have to break them
212       // somewhere, so we'll ignore this input and leave its shape undefined.
213       ic->SetInput(e->dst_input(), ic->UnknownShape());
214       continue;
215     }
216 
217     InferenceContext* input_ic = it->second->get_context();
218     ic->SetInput(e->dst_input(), input_ic->output(e->src_output()));
219 
220     const auto* in_v =
221         input_ic->output_handle_shapes_and_types(e->src_output());
222     if (in_v != nullptr) {
223       DataType input_type = e->src()->output_type(e->src_output());
224       DCHECK(input_type == DT_RESOURCE || input_type == DT_VARIANT);
225       ic->set_input_handle_shapes_and_types(e->dst_input(),
226                                             std::vector<ShapeAndType>(*in_v));
227     }
228   }
229 
230   // Get the shape function for this node
231   const OpRegistrationData* op_reg_data;
232   TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
233   if (op_reg_data->shape_inference_fn == nullptr &&
234       require_shape_inference_fns_) {
235     return errors::InvalidArgument(
236         "No shape inference function exists for op '", node->type_string(),
237         "', did you forget to define it?");
238   }
239 
240   std::unique_ptr<ExtendedInferenceContext> ec(
241       new ExtendedInferenceContext(std::move(ic), node));
242 
243   // Run the shape inference function, and return if there was an error.
244   TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get()));
245 
246   // Store the resulting context object in the map.
247   node_to_context_[node].swap(ec);
248 
249   return Status::OK();
250 }
251 
SetShape(const Node * node,int output_port,ShapeHandle shape)252 Status ShapeRefiner::SetShape(const Node* node, int output_port,
253                               ShapeHandle shape) {
254   auto c = GetContext(node);
255   if (c == nullptr) {
256     return errors::Internal("Could not find context for ", node->name());
257   }
258 
259   if (output_port < 0 || output_port >= node->num_outputs()) {
260     return errors::InvalidArgument(
261         "output_port '", output_port, "' is out of range, ", "node '",
262         node->name(), "' has ", node->num_outputs(), " outputs");
263   }
264   // Note: it's possible, if the node's been updated, that the shape inference
265   // context doesn't have the right number of outputs.
266   if (node->num_outputs() > c->num_outputs()) {
267     TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs()));
268   }
269 
270   // Check compatibility, and merge the shapes.
271   ShapeHandle existing_shape = c->output(output_port);
272   TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &shape));
273   c->set_output(output_port, shape);
274 
275   // TODO(vrv): Do we need to propagate the new shape through all
276   // consumers that change their outputs?  At the moment, python
277   // does not do this, but this seems like a nice feature.
278 
279   // TODO(vrv): We might need to keep track of the fact that the
280   // existing shape is invalidated, in case we need to propagate
281   // this information to remote workers.
282   return Status::OK();
283 }
284 
UpdateNode(const Node * node,bool relax,bool * refined)285 Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
286   auto it = node_to_context_.find(node);
287   if (it == node_to_context_.end()) {
288     *refined = true;
289     return AddNode(node);
290   }
291   ExtendedInferenceContext* node_ext_context = it->second.get();
292   InferenceContext* node_context = node_ext_context->get_context();
293 
294   // Give up if the context wasn't successfully built by the AddNode() method.
295   TF_RETURN_IF_ERROR(node_context->construction_status());
296 
297   // Check if the shapes of the nodes in the fan-in of this node have changed,
298   // and if they have update the node input shapes.
299   for (const Edge* e : node->in_edges()) {
300     if (e->IsControlEdge()) continue;
301 
302     int dst_input = e->dst_input();
303     int src_output = e->src_output();
304 
305     Node* input = e->src();
306     auto iter = node_to_context_.find(input);
307     if (iter == node_to_context_.end()) {
308       return errors::FailedPrecondition(
309           "Input ", dst_input, " ('", input->name(), "') for '", node->name(),
310           "' was not previously added to ShapeRefiner.");
311     }
312 
313     InferenceContext* c = iter->second->get_context();
314     DCHECK_GE(dst_input, 0);
315     ShapeHandle existing_input = node_context->input(dst_input);
316     if (!relax) {
317       if (node_context->MergeInput(dst_input, c->output(src_output))) {
318         if (!SameDefinedShape(node_context, node_context->input(dst_input),
319                               existing_input)) {
320           *refined = true;
321         }
322       }
323     } else {
324       if (node_context->RelaxInput(dst_input, c->output(src_output))) {
325         if (!SameDefinedShape(node_context, node_context->input(dst_input),
326                               existing_input)) {
327           *refined = true;
328         }
329       }
330     }
331     if (node_context->requested_input_tensor_as_partial_shape(dst_input)) {
332       // The input value may have changed. Since we have no way to know if
333       // that's indeed the case, err on the safe side.
334       *refined = true;
335     }
336 
337     // Also propagate handle shape and dtype of edges which are carrying
338     // resource handles.
339     if (e->src()->output_type(src_output) == DT_RESOURCE) {
340       auto* outputs = c->output_handle_shapes_and_types(src_output);
341       if (!outputs) continue;
342 
343       if (!relax &&
344           node_context->MergeInputHandleShapesAndTypes(dst_input, *outputs)) {
345         *refined = true;
346       } else if (relax) {
347         std::vector<ShapeAndType> existing_inputs;
348         const std::vector<ShapeAndType>* inputs =
349             node_context->input_handle_shapes_and_types(dst_input);
350         if (inputs) {
351           existing_inputs = *inputs;
352         }
353         if (node_context->RelaxInputHandleShapesAndMergeTypes(dst_input,
354                                                               *outputs)) {
355           if (IsUpdatedShapesOrTypes(
356                   node_context, existing_inputs,
357                   *node_context->input_handle_shapes_and_types(dst_input))) {
358             *refined = true;
359           }
360         }
361       }
362     }
363   }
364 
365   if (!*refined) {
366     // No input shape has changed, we're done
367     return Status::OK();
368   }
369 
370   // Get and run the shape function for this node to update the shapes of the
371   // outputs.
372   const OpRegistrationData* op_reg_data;
373   TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
374   if (op_reg_data->shape_inference_fn == nullptr &&
375       require_shape_inference_fns_) {
376     return errors::InvalidArgument(
377         "No shape inference function exists for op '", node->type_string(),
378         "', did you forget to define it?");
379   }
380 
381   if (!op_reg_data->shape_inference_fn) {
382     // There is nothing more we can infer
383     return Status::OK();
384   }
385 
386   return RunShapeFn(node, op_reg_data, node_ext_context);
387 }
388 
EvaluateConstantTensorForEdge(const Node * node,int dst_idx,bool * evaluated,Tensor * result)389 Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
390                                                    int dst_idx, bool* evaluated,
391                                                    Tensor* result) {
392   *evaluated = false;
393   const Edge* input_edge;
394   TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
395   OutputTensor tensor(input_edge->src(), input_edge->src_output());
396   return EvaluateConstantTensor(tensor, *this, *ops_registry_,
397                                 graph_def_version_, evaluated, result,
398                                 &graph_runner_, &const_tensor_map_,
399                                 kMaxTensorSize, disable_constant_propagation_);
400 }
401 
EvaluateConstantIntScalarEdge(const Node * node,int dst_idx,bool * evaluated,int64 * result)402 Status ShapeRefiner::EvaluateConstantIntScalarEdge(const Node* node,
403                                                    int dst_idx, bool* evaluated,
404                                                    int64* result) {
405   Tensor scalar;
406   TF_RETURN_IF_ERROR(
407       EvaluateConstantTensorForEdge(node, dst_idx, evaluated, &scalar));
408   if (*evaluated) {
409     if (scalar.NumElements() != 1) {
410       return errors::InvalidArgument(
411           "EvaluateConstantIntScalarEdge called on non-scalar edge: ",
412           scalar.NumElements());
413     }
414     if (scalar.dtype() == DT_INT32) {
415       *result = scalar.scalar<int32>()();
416     } else {
417       if (scalar.dtype() != DT_INT64) {
418         return errors::InvalidArgument(
419             "EvaluateConstantIntScalarEdge called on non-integer edge: ",
420             scalar.dtype());
421       }
422       *result = scalar.scalar<int64>()();
423     }
424   }
425   return Status::OK();
426 }
427 
ConstantPartialShape(InferenceContext * target_context,const Node * node,int dst_idx,ShapeHandle * result)428 Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
429                                           const Node* node, int dst_idx,
430                                           ShapeHandle* result) {
431   const Edge* input_edge;
432   TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
433 
434   InferenceContext* src_context = GetContext(input_edge->src());
435   if (src_context == nullptr) return errors::Internal("Missing src context");
436   ShapeHandle src_shape = src_context->output(input_edge->src_output());
437 
438   if (src_context->Value(src_context->Rank(src_shape)) == 0) {
439     Tensor t;
440     bool evaluated = false;
441     TF_RETURN_IF_ERROR(
442         EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
443     if (!evaluated) {
444       return errors::InvalidArgument(
445           "Received a shape scalar with unknown static value.  A static value "
446           "of '-1' is required to represent an unknown shape.");
447     }
448     if (t.dims() == 0) {
449       if (t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) {
450         *result = target_context->UnknownShape();
451         return Status::OK();
452       } else if (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1) {
453         *result = target_context->UnknownShape();
454         return Status::OK();
455       }
456     }
457     return errors::InvalidArgument(
458         "Received an invalid shape scalar with a static value that is not "
459         "'-1': ",
460         t.DebugString());
461   }
462 
463   TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
464 
465   const string& src_op = input_edge->src()->type_string();
466   if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) {
467     // Source tensor is a vector of length 0, so the shape it
468     // represents is as scalar.
469     *result = target_context->Scalar();
470   } else if (src_op == "Cast") {
471     // First try to evaluate the current tensor, as it might be a valid cast of
472     // a float.
473     Tensor t;
474     bool evaluated = false;
475     if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t).ok()) {
476       if (evaluated &&
477           target_context->MakeShapeFromTensor(&t, src_shape, result).ok()) {
478         return Status::OK();
479       }
480     }
481 
482     // Then try to infer partial shape from the input to the cast tensor.
483     ShapeHandle pre_cast_shape;
484     if (!ConstantPartialShape(target_context, input_edge->src(), 0,
485                               &pre_cast_shape)
486              .ok()) {
487       TF_RETURN_IF_ERROR(
488           target_context->MakeShapeFromTensor(nullptr, src_shape, result));
489     }
490     if (!target_context->RankKnown(pre_cast_shape)) {
491       // Failed to evaluate. Treat the output as completely unknown.
492       *result = target_context->UnknownShape();
493       return Status::OK();
494     }
495     auto* dest_type = input_edge->src()->attrs().Find("DstT");
496     if (dest_type == nullptr || dest_type->value_case() != AttrValue::kType ||
497         (dest_type->type() != DT_INT32 && dest_type->type() != DT_INT64)) {
498       // Casting to a weird type. Do not attempt to infer across it.
499       *result = target_context->MakeShape(std::vector<DimensionHandle>(
500           target_context->Rank(pre_cast_shape), target_context->UnknownDim()));
501       return Status::OK();
502     }
503     *result = pre_cast_shape;
504   } else if (src_op == "Shape") {
505     *result = src_context->input(0);
506   } else if (src_op == "ShapeN") {
507     *result = src_context->input(input_edge->src_output());
508   } else if (src_op == "Pack") {
509     std::vector<DimensionHandle> dims;
510     // Pack is concatenating its input scalars to form the shape tensor vector.
511     for (int i = 0; i < src_context->num_inputs(); ++i) {
512       int64 size;
513       bool evaluated;
514       TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(input_edge->src(), i,
515                                                        &evaluated, &size));
516       if (evaluated) {
517         dims.push_back(size < 0 ? target_context->UnknownDim()
518                                 : target_context->MakeDim(size));
519       } else {
520         dims.push_back(target_context->UnknownDim());
521       }
522     }
523     *result = target_context->MakeShape(dims);
524   } else if (src_op == "Concat" || src_op == "ConcatV2") {
525     *result = target_context->Scalar();
526     // For Concat, input 0 is concat dim; for V2 it is the last input.
527     const int concat_dim =
528         src_op == "Concat" ? 0 : src_context->num_inputs() - 1;
529     // Concat is concatenating its input shape vectors.
530     for (int i = 0; i < src_context->num_inputs(); ++i) {
531       // Concat dim is ignored (and will always be a scalar).
532       if (i == concat_dim) continue;
533       ShapeHandle sub_result;
534       TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
535                                               i, &sub_result));
536       if (!target_context->RankKnown(sub_result)) {
537         // Failed to evaluate. Treat the output as completely unknown.
538         // TODO(cwhipkey): we could rely on all inputs being the same rank, so
539         // figure that rank out and append the right number of unknown dims.
540         *result = target_context->UnknownShape();
541         return Status::OK();
542       }
543       TF_RETURN_IF_ERROR(
544           target_context->Concatenate(*result, sub_result, result));
545     }
546   } else if (src_op == "StridedSlice") {
547     TF_RETURN_IF_ERROR(
548         PartialStridedSliceShape(input_edge->src(), src_context, result));
549   } else if (src_op == "VariableShape") {
550     auto* handle_data = src_context->input_handle_shapes_and_types(0);
551     if (handle_data != nullptr && !handle_data->empty()) {
552       *result = handle_data->at(0).shape;
553     } else {
554       *result = target_context->UnknownShape();
555     }
556   } else {
557     Tensor t;
558     bool evaluated = false;
559     TF_RETURN_IF_ERROR(
560         EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
561     TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
562         evaluated ? &t : nullptr, src_shape, result));
563   }
564   return Status::OK();
565 }
566 
PartialStridedSliceShape(Node * slice_node,InferenceContext * ctx,ShapeHandle * result)567 Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
568                                               InferenceContext* ctx,
569                                               ShapeHandle* result) {
570   // Only attempt to evaluate if begin/end/strides all are scalars.
571   for (int i = 1; i <= 3; ++i) {
572     ShapeHandle input_shape = ctx->input(i);
573     if (ctx->Value(ctx->Dim(input_shape, 0)) != 1) {
574       *result = ctx->UnknownShape();
575       return Status::OK();
576     }
577   }
578 
579   int begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
580   TF_RETURN_IF_ERROR(
581       GetNodeAttr(slice_node->attrs(), "begin_mask", &begin_mask));
582   TF_RETURN_IF_ERROR(GetNodeAttr(slice_node->attrs(), "end_mask", &end_mask));
583   TF_RETURN_IF_ERROR(
584       GetNodeAttr(slice_node->attrs(), "ellipsis_mask", &ellipsis_mask));
585   TF_RETURN_IF_ERROR(
586       GetNodeAttr(slice_node->attrs(), "new_axis_mask", &new_axis_mask));
587   TF_RETURN_IF_ERROR(
588       GetNodeAttr(slice_node->attrs(), "shrink_axis_mask", &shrink_axis_mask));
589 
590   // Only attempt to evaluate if there are no special masks set (note that we
591   // can handle begin/end_mask == 1).
592   if (!(begin_mask == 0 || begin_mask == 1) ||
593       !(end_mask == 0 || end_mask == 1) || ellipsis_mask != 0 ||
594       new_axis_mask != 0 || shrink_axis_mask != 0) {
595     *result = ctx->UnknownShape();
596     return Status::OK();
597   }
598 
599   bool evaluated;
600   int64 begin;
601   if (begin_mask == 1) {
602     begin = 0;
603   } else {
604     TF_RETURN_IF_ERROR(
605         EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated, &begin));
606     if (!evaluated) {
607       *result = ctx->UnknownShape();
608       return Status::OK();
609     }
610   }
611 
612   int64 end;
613   if (end_mask == 1) {
614     end = std::numeric_limits<int64>::max();
615   } else {
616     TF_RETURN_IF_ERROR(
617         EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated, &end));
618     if (!evaluated) {
619       *result = ctx->UnknownShape();
620       return Status::OK();
621     }
622   }
623 
624   int64 stride;
625   TF_RETURN_IF_ERROR(
626       EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated, &stride));
627   if (!evaluated) {
628     *result = ctx->UnknownShape();
629     return Status::OK();
630   }
631 
632   // Apply stride to input interpreted as a partial shape.
633   ShapeHandle input;
634   TF_RETURN_IF_ERROR(ConstantPartialShape(ctx, slice_node, 0, &input));
635   TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result));
636   return Status::OK();
637 }
638 
RunShapeFn(const Node * node,const OpRegistrationData * op_reg_data,ExtendedInferenceContext * ec)639 Status ShapeRefiner::RunShapeFn(const Node* node,
640                                 const OpRegistrationData* op_reg_data,
641                                 ExtendedInferenceContext* ec) {
642   // This will be filled in with real data in a second pass.
643   std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
644   std::vector<Tensor> real_tensors(node->num_inputs());
645   std::vector<bool> attempted_materialization(node->num_inputs());
646   std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
647   std::vector<ShapeHandle> input_tensors_as_shapes;
648 
649   auto* c = ec->get_context();
650 
651   c->set_input_tensors(input_tensors);
652   c->set_input_tensors_as_shapes(input_tensors_as_shapes);
653 
654   // Run the shape inference function, and return if there was an error.
655   // Capture as lambda, because we might need to re-run inference later on.
656   auto run_inference_lambda = [&]() {
657     if (function_library_ && IsFunctionCall(*function_library_, *node)) {
658       bool disable_shape_inference;
659       if (!GetNodeAttr(AttrSlice(node->def()), "_disable_call_shape_inference",
660                        &disable_shape_inference)
661                .ok() ||
662           !disable_shape_inference) {
663         // Special inference logic for user-defined functions.
664         NameAttrList function;
665         TF_RETURN_IF_ERROR(
666             NameAndAttrsFromFunctionCall(node->def(), &function));
667         const FunctionDef* function_def =
668             function_library_->Find(function.name());
669         if (function_def != nullptr) {
670           // The constant Tensor map we have for the outside context is not
671           // valid inside the function. We need to push a new clean map while
672           // performing inference on the function body.
673           auto const_tensor_map_copy = const_tensor_map_;
674           const_tensor_map_.clear();
675           Status function_inference_status = InferShapesForFunction(
676               function_def, AttrSlice(&function.attr()), ec);
677           const_tensor_map_ = const_tensor_map_copy;
678           return function_inference_status;
679         }
680       }
681     }
682 
683     if (op_reg_data->shape_inference_fn) {
684       TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
685     } else {
686       TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
687     }
688     return Status::OK();
689   };
690   TF_RETURN_IF_ERROR(run_inference_lambda());
691 
692   // We must run the shape function repeatedly, in case users write
693   // shape functions where they only conditionally call input_tensor()
694   // based on the values of another input tensor.
695   bool rerun_shape_fn;
696   do {
697     // If the result of running shape inference would have benefitted
698     // from knowing the values of input tensors, try to materialize
699     // the results of those tensors, and then run the shape inference
700     // function again using those known tensors.
701     rerun_shape_fn = false;
702 
703     // NOTE: It is possible to batch the extraction and
704     // materialization of inputs, instead of materializing one input
705     // at a time like we do below.  If input-at-a-time computation
706     // becomes a bottleneck, we could separate ExtractConstantSubgraph
707     // into two functions: one that returns true if an input is
708     // derivable from constants, and another function that extracts
709     // the subgraph for multiple target nodes and executes the whole
710     // subgraph once.
711 
712     for (int i = 0; i < c->num_inputs(); ++i) {
713       if (!c->requested_input_tensor(i)) {
714         continue;
715       }
716       // Check if we have not already filled in the requested input,
717       // and if not, try to materialize the tensors.
718       if (!attempted_materialization[i]) {
719         attempted_materialization[i] = true;
720 
721         Tensor result;
722         bool evaluated = false;
723         TF_RETURN_IF_ERROR(
724             EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
725         if (evaluated) {
726           real_tensors[i] = result;
727           input_tensors[i] = &real_tensors[i];
728           // We have more concrete information about a shape,
729           // so re-run shape inference.
730           rerun_shape_fn = true;
731         }
732       }
733       if (c->requested_input_tensor_as_partial_shape(i) &&
734           !attempted_tensor_as_shape_conversion[i]) {
735         attempted_tensor_as_shape_conversion[i] = true;
736         if (i >= input_tensors_as_shapes.size()) {
737           input_tensors_as_shapes.resize(i + 1);
738         }
739         ShapeHandle s;
740         TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s));
741         input_tensors_as_shapes[i] = s;
742         rerun_shape_fn = true;
743       }
744     }
745 
746     if (rerun_shape_fn) {
747       // We have more information about the shapes on this pass,
748       // so re-run shape inference.
749       c->set_input_tensors(input_tensors);
750       c->set_input_tensors_as_shapes(input_tensors_as_shapes);
751       TF_RETURN_IF_ERROR(run_inference_lambda());
752     }
753   } while (rerun_shape_fn);
754 
755   return Status::OK();
756 }
757 
SameDefinedShape(InferenceContext * c,ShapeHandle s0,ShapeHandle s1)758 bool ShapeRefiner::SameDefinedShape(InferenceContext* c, ShapeHandle s0,
759                                     ShapeHandle s1) {
760   if (s0.SameHandle(s1)) {
761     return true;
762   }
763   if (c->Rank(s0) != c->Rank(s1)) {
764     return false;
765   }
766   if (!c->RankKnown(s0) && !c->RankKnown(s1)) {
767     return false;
768   }
769   for (int i = 0; i < c->Rank(s0); ++i) {
770     if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
771       int64 val0 = c->Value(c->Dim(s0, i));
772       int64 val1 = c->Value(c->Dim(s1, i));
773       if (val0 < 0 || val1 < 0 || val0 != val1) {
774         return false;
775       }
776     }
777   }
778 
779   return true;
780 }
781 
IsUpdatedShapesOrTypes(InferenceContext * c,const std::vector<ShapeAndType> & existing,const std::vector<ShapeAndType> & updated)782 bool ShapeRefiner::IsUpdatedShapesOrTypes(
783     InferenceContext* c, const std::vector<ShapeAndType>& existing,
784     const std::vector<ShapeAndType>& updated) {
785   if (existing.size() != updated.size()) {
786     return true;
787   }
788   for (int i = 0; i < existing.size(); i++) {
789     if (!SameDefinedShape(c, existing[i].shape, updated[i].shape) ||
790         existing[i].dtype != updated[i].dtype) {
791       return true;
792     }
793   }
794   return false;
795 }
796 
797 }  // namespace tensorflow
798