1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/shape_refiner.h"
16
17 #include <deque>
18 #include <memory>
19 #include <unordered_set>
20 #include <vector>
21
22 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
23 #include "tensorflow/core/common_runtime/function_utils.h"
24 #include "tensorflow/core/common_runtime/graph_constructor.h"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/common_shape_fns.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/shape_inference.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/lib/core/errors.h"
34
35 namespace tensorflow {
36
37 using shape_inference::DimensionHandle;
38 using shape_inference::InferenceContext;
39 using shape_inference::ShapeAndType;
40 using shape_inference::ShapeHandle;
41
ShapeRefiner(int graph_def_version,const OpRegistryInterface * ops)42 ShapeRefiner::ShapeRefiner(int graph_def_version,
43 const OpRegistryInterface* ops)
44 : graph_def_version_(graph_def_version),
45 ops_registry_(ops),
46 graph_runner_(Env::Default()) {}
47
ShapeRefiner(const VersionDef & versions,const OpRegistryInterface * ops)48 ShapeRefiner::ShapeRefiner(const VersionDef& versions,
49 const OpRegistryInterface* ops)
50 : ShapeRefiner(versions.producer(), ops) {}
51
~ShapeRefiner()52 ShapeRefiner::~ShapeRefiner() {
53 // The lifetime of the tensors are bound to the GraphRunner, so the tensors
54 // should be deleted before it.
55 const_tensor_map_.clear();
56 }
57
58 namespace {
59
60 constexpr char kArgOp[] = "_Arg";
61 constexpr char kRetvalOp[] = "_Retval";
62
63 } // namespace
64
65 // Runs shape inference for the given node using the given ShapeRefiner.
66 // The node must be a sub-node of a function node and the outer_context is
67 // the inference context of that function node in the outer graph.
InferShapesForFunctionSubNode(const Node * node,InferenceContext * outer_context)68 Status ShapeRefiner::InferShapesForFunctionSubNode(
69 const Node* node, InferenceContext* outer_context) {
70 TF_RETURN_IF_ERROR(AddNodeInternal(node, outer_context));
71 InferenceContext* node_context = CHECK_NOTNULL(GetContext(node));
72
73 if (StringPiece(node->type_string()) == kArgOp) {
74 // Handle special node: function input.
75 // Shapes for these nodes are provided in the outer inference
76 // context.
77
78 int index;
79 TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
80
81 if (index < 0 || outer_context->num_inputs() <= index) {
82 return errors::Internal(
83 "Function instantiation included invalid input index: ", index,
84 " not in [0, ", outer_context->num_inputs(), ").");
85 }
86
87 // TODO(b/134547156): TEMPORARY WORKAROUND. If input shape handle is not set
88 // in outer context, set _Arg node output shape to unknown.
89 if (outer_context->input(index).SameHandle(ShapeHandle())) {
90 VLOG(1) << "Function instantiation has undefined input shape at "
91 << "index: " << index << " in the outer inference context.";
92 node_context->set_output(0, node_context->UnknownShape());
93 } else {
94 node_context->set_output(0, outer_context->input(index));
95 }
96
97 auto* resource = outer_context->input_handle_shapes_and_types(index);
98 if (resource) {
99 node_context->set_output_handle_shapes_and_types(0, *resource);
100 }
101 } else if (StringPiece(node->type_string()) == kRetvalOp) {
102 // Handle special node: function output.
103 // Shapes inferred for these nodes go into the outer inference
104 // context.
105
106 int index;
107 TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
108
109 if (index < 0 || outer_context->num_outputs() <= index) {
110 return errors::Internal(
111 "Function instantiation included invalid output index: ", index,
112 " not in [0, ", outer_context->num_outputs(), ").");
113 }
114
115 // outer_context outlives node_context, therefore we need to create
116 // a new shape handle owned by outer_context instead.
117 ShapeHandle handle;
118 TensorShapeProto proto;
119 node_context->ShapeHandleToProto(node_context->input(0), &proto);
120 TF_RETURN_IF_ERROR(outer_context->MakeShapeFromShapeProto(proto, &handle));
121 outer_context->set_output(index, handle);
122
123 auto* resource = node_context->input_handle_shapes_and_types(0);
124 if (resource) {
125 outer_context->set_output_handle_shapes_and_types(index, *resource);
126 }
127 }
128
129 return Status::OK();
130 }
131
132 // TODO(cwhipkey): When an inference context inside function has
133 // requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
134 // set when input(i) is an _Arg op, then this request should propagate to
135 // context, and vice versa.
136 //
137 // NOTE: Recursive user-defined functions are not supported.
138 // Maybe we won't support recursive functions at all in TF, because of
139 // other maintainability issues.
InferShapesForFunction(const FunctionDef * function_def,AttrSlice attributes,ExtendedInferenceContext * outer_context)140 Status ShapeRefiner::InferShapesForFunction(
141 const FunctionDef* function_def, AttrSlice attributes,
142 ExtendedInferenceContext* outer_context) {
143 const Graph* graph;
144 auto it = functions_.find(function_def);
145 if (it != functions_.end()) {
146 graph = it->second.get();
147 } else {
148 InstantiationResult result;
149 TF_RETURN_IF_ERROR(InstantiateFunction(
150 *function_def, attributes,
151 [this](const string& op, const OpDef** sig) {
152 return this->function_library_->LookUpOpDef(op, sig);
153 },
154 &result));
155
156 Graph* new_graph = new Graph(function_library_);
157 GraphConstructorOptions options;
158 options.allow_internal_ops = true;
159 TF_RETURN_IF_ERROR(
160 ConvertNodeDefsToGraph(options, result.nodes, new_graph));
161 functions_[function_def].reset(new_graph);
162 graph = new_graph;
163 }
164
165 std::unordered_set<const Node*> function_nodes;
166 Status inference_status = Status::OK();
167 {
168 auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
169 &inference_status](const Node* node) {
170 if (!inference_status.ok()) return;
171 inference_status =
172 InferShapesForFunctionSubNode(node, outer_context->get_context());
173 function_nodes.insert(node);
174 };
175
176 // Calls inference lambda for each node after visiting all predecessors.
177 // Ensures that we are adding nodes to ShapeRefiner in the topological
178 // order.
179 ReverseDFS(*graph, {}, node_shape_inference_lambda);
180 }
181
182 // Delete the contexts created for the functions nodes to save memory.
183 for (const Node* node : function_nodes) {
184 node_to_context_.erase(node);
185 }
186
187 return inference_status;
188 }
189
AddNode(const Node * node)190 Status ShapeRefiner::AddNode(const Node* node) {
191 return AddNodeInternal(node, /*outer_context=*/nullptr);
192 }
193
AddNodeInternal(const Node * node,shape_inference::InferenceContext * outer_context)194 Status ShapeRefiner::AddNodeInternal(
195 const Node* node, shape_inference::InferenceContext* outer_context) {
196 // Create the inference context for this node with the existing input shapes.
197 std::unique_ptr<InferenceContext> ic(new InferenceContext(
198 graph_def_version_, node->def(), node->op_def(),
199 std::vector<ShapeHandle>(node->num_inputs()), {}, {}, {}));
200 TF_RETURN_IF_ERROR(ic->construction_status());
201
202 // For each 'input' of this node, fetch the corresponding shape
203 // from 'input's InferenceContext, and store into this node's
204 // InferenceContext.
205 for (const Edge* e : node->in_edges()) {
206 if (e->IsControlEdge()) continue;
207
208 if (e->dst_input() < 0) {
209 return tensorflow::errors::Internal(
210 "Index ", e->dst_input(), " is negative but not a control edge.");
211 }
212
213 const Node* input = e->src();
214 auto it = node_to_context_.find(input);
215 if (it == node_to_context_.end()) {
216 // v1 control flow adds loops to the graph; we have to break them
217 // somewhere, so we'll ignore this input and leave its shape undefined.
218 ic->SetInput(e->dst_input(), ic->UnknownShape());
219 continue;
220 }
221
222 InferenceContext* input_ic = it->second->get_context();
223 ic->SetInput(e->dst_input(), input_ic->output(e->src_output()));
224
225 const auto* in_v =
226 input_ic->output_handle_shapes_and_types(e->src_output());
227 if (in_v != nullptr) {
228 DataType input_type = e->src()->output_type(e->src_output());
229 DCHECK(input_type == DT_RESOURCE || input_type == DT_VARIANT);
230 ic->set_input_handle_shapes_and_types(e->dst_input(),
231 std::vector<ShapeAndType>(*in_v));
232 }
233 }
234
235 // Get the shape function for this node
236 const OpRegistrationData* op_reg_data;
237 TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
238 if (op_reg_data->shape_inference_fn == nullptr &&
239 require_shape_inference_fns_) {
240 return errors::InvalidArgument(
241 "No shape inference function exists for op '", node->type_string(),
242 "', did you forget to define it?");
243 }
244
245 std::unique_ptr<ExtendedInferenceContext> ec(
246 new ExtendedInferenceContext(std::move(ic), node));
247
248 // Run the shape inference function, and return if there was an error.
249 TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get(), outer_context));
250
251 // Store the resulting context object in the map.
252 node_to_context_[node].swap(ec);
253
254 return Status::OK();
255 }
256
SetShape(const Node * node,int output_port,ShapeHandle shape)257 Status ShapeRefiner::SetShape(const Node* node, int output_port,
258 ShapeHandle shape) {
259 auto c = GetContext(node);
260 if (c == nullptr) {
261 return errors::Internal("Could not find context for ", node->name());
262 }
263
264 if (output_port < 0 || output_port >= node->num_outputs()) {
265 return errors::InvalidArgument(
266 "output_port '", output_port, "' is out of range, ", "node '",
267 node->name(), "' has ", node->num_outputs(), " outputs");
268 }
269 // Note: it's possible, if the node's been updated, that the shape inference
270 // context doesn't have the right number of outputs.
271 if (node->num_outputs() > c->num_outputs()) {
272 TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs()));
273 }
274
275 // Check compatibility, and merge the shapes.
276 ShapeHandle existing_shape = c->output(output_port);
277 TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &shape));
278 c->set_output(output_port, shape);
279
280 // TODO(vrv): Do we need to propagate the new shape through all
281 // consumers that change their outputs? At the moment, python
282 // does not do this, but this seems like a nice feature.
283
284 // TODO(vrv): We might need to keep track of the fact that the
285 // existing shape is invalidated, in case we need to propagate
286 // this information to remote workers.
287 return Status::OK();
288 }
289
UpdateNode(const Node * node,bool relax,bool * refined)290 Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
291 auto it = node_to_context_.find(node);
292 if (it == node_to_context_.end()) {
293 *refined = true;
294 return AddNode(node);
295 }
296 ExtendedInferenceContext* node_ext_context = it->second.get();
297 InferenceContext* node_context = node_ext_context->get_context();
298
299 // Give up if the context wasn't successfully built by the AddNode() method.
300 TF_RETURN_IF_ERROR(node_context->construction_status());
301
302 // Check if the shapes of the nodes in the fan-in of this node have changed,
303 // and if they have update the node input shapes.
304 for (const Edge* e : node->in_edges()) {
305 if (e->IsControlEdge()) continue;
306
307 int dst_input = e->dst_input();
308 int src_output = e->src_output();
309
310 Node* input = e->src();
311 auto iter = node_to_context_.find(input);
312 if (iter == node_to_context_.end()) {
313 return errors::FailedPrecondition(
314 "Input ", dst_input, " ('", input->name(), "') for '", node->name(),
315 "' was not previously added to ShapeRefiner.");
316 }
317
318 InferenceContext* c = iter->second->get_context();
319 DCHECK_GE(dst_input, 0);
320 ShapeHandle existing_input = node_context->input(dst_input);
321 if (!relax) {
322 if (node_context->MergeInput(dst_input, c->output(src_output))) {
323 if (!SameDefinedShape(node_context, node_context->input(dst_input),
324 existing_input)) {
325 *refined = true;
326 }
327 }
328 } else {
329 if (node_context->RelaxInput(dst_input, c->output(src_output))) {
330 if (!SameDefinedShape(node_context, node_context->input(dst_input),
331 existing_input)) {
332 *refined = true;
333 }
334 }
335 }
336 if (node_context->requested_input_tensor_as_partial_shape(dst_input)) {
337 // The input value may have changed. Since we have no way to know if
338 // that's indeed the case, err on the safe side.
339 *refined = true;
340 }
341
342 // Also propagate handle shape and dtype of edges which are carrying
343 // resource handles.
344 if (e->src()->output_type(src_output) == DT_RESOURCE) {
345 auto* outputs = c->output_handle_shapes_and_types(src_output);
346 if (!outputs) continue;
347
348 if (!relax &&
349 node_context->MergeInputHandleShapesAndTypes(dst_input, *outputs)) {
350 *refined = true;
351 } else if (relax) {
352 std::vector<ShapeAndType> existing_inputs;
353 const std::vector<ShapeAndType>* inputs =
354 node_context->input_handle_shapes_and_types(dst_input);
355 if (inputs) {
356 existing_inputs = *inputs;
357 }
358 if (node_context->RelaxInputHandleShapesAndMergeTypes(dst_input,
359 *outputs)) {
360 if (IsUpdatedShapesOrTypes(
361 node_context, existing_inputs,
362 *node_context->input_handle_shapes_and_types(dst_input))) {
363 *refined = true;
364 }
365 }
366 }
367 }
368 }
369
370 if (!*refined) {
371 // No input shape has changed, we're done
372 return Status::OK();
373 }
374
375 // Get and run the shape function for this node to update the shapes of the
376 // outputs.
377 const OpRegistrationData* op_reg_data;
378 TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
379 if (op_reg_data->shape_inference_fn == nullptr &&
380 require_shape_inference_fns_) {
381 return errors::InvalidArgument(
382 "No shape inference function exists for op '", node->type_string(),
383 "', did you forget to define it?");
384 }
385
386 if (!op_reg_data->shape_inference_fn) {
387 // There is nothing more we can infer
388 return Status::OK();
389 }
390
391 return RunShapeFn(node, op_reg_data, node_ext_context);
392 }
393
EvaluateConstantTensorForEdge(const Node * node,int dst_idx,bool * evaluated,Tensor * result,InferenceContext * outer_context)394 Status ShapeRefiner::EvaluateConstantTensorForEdge(
395 const Node* node, int dst_idx, bool* evaluated, Tensor* result,
396 InferenceContext* outer_context) {
397 *evaluated = false;
398 const Edge* input_edge;
399 TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
400 OutputTensor tensor(input_edge->src(), input_edge->src_output());
401 return EvaluateConstantTensor(
402 tensor, *this, *ops_registry_, graph_def_version_, evaluated, result,
403 &graph_runner_, &const_tensor_map_, kMaxTensorSize,
404 disable_constant_propagation_, outer_context);
405 }
406
EvaluateConstantIntScalarEdge(const Node * node,int dst_idx,bool * evaluated,int64 * result,shape_inference::InferenceContext * outer_context)407 Status ShapeRefiner::EvaluateConstantIntScalarEdge(
408 const Node* node, int dst_idx, bool* evaluated, int64* result,
409 shape_inference::InferenceContext* outer_context) {
410 Tensor scalar;
411 TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, evaluated,
412 &scalar, outer_context));
413 if (*evaluated) {
414 if (scalar.NumElements() != 1) {
415 return errors::InvalidArgument(
416 "EvaluateConstantIntScalarEdge called on non-scalar edge: ",
417 scalar.NumElements());
418 }
419 if (scalar.dtype() == DT_INT32) {
420 *result = scalar.scalar<int32>()();
421 } else {
422 if (scalar.dtype() != DT_INT64) {
423 return errors::InvalidArgument(
424 "EvaluateConstantIntScalarEdge called on non-integer edge: ",
425 scalar.dtype());
426 }
427 *result = scalar.scalar<int64>()();
428 }
429 }
430 return Status::OK();
431 }
432
ConstantPartialShape(InferenceContext * target_context,const Node * node,int dst_idx,ShapeHandle * result,shape_inference::InferenceContext * outer_context)433 Status ShapeRefiner::ConstantPartialShape(
434 InferenceContext* target_context, const Node* node, int dst_idx,
435 ShapeHandle* result, shape_inference::InferenceContext* outer_context) {
436 const Edge* input_edge;
437 TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
438
439 InferenceContext* src_context = GetContext(input_edge->src());
440 if (src_context == nullptr) return errors::Internal("Missing src context");
441 ShapeHandle src_shape = src_context->output(input_edge->src_output());
442
443 if (src_context->Value(src_context->Rank(src_shape)) == 0) {
444 Tensor t;
445 bool evaluated = false;
446 TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
447 &t, outer_context));
448 if (!evaluated) {
449 return errors::InvalidArgument(
450 "Received a shape scalar with unknown static value. A static value "
451 "of '-1' is required to represent an unknown shape.");
452 }
453 if (t.dims() == 0) {
454 if (t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) {
455 *result = target_context->UnknownShape();
456 return Status::OK();
457 } else if (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1) {
458 *result = target_context->UnknownShape();
459 return Status::OK();
460 }
461 }
462 return errors::InvalidArgument(
463 "Received an invalid shape scalar with a static value that is not "
464 "'-1': ",
465 t.DebugString());
466 }
467
468 TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
469
470 const string& src_op = input_edge->src()->type_string();
471 if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) {
472 // Source tensor is a vector of length 0, so the shape it
473 // represents is as scalar.
474 *result = target_context->Scalar();
475 } else if (src_op == "Cast") {
476 // First try to evaluate the current tensor, as it might be a valid cast of
477 // a float.
478 Tensor t;
479 bool evaluated = false;
480 if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t,
481 outer_context)
482 .ok()) {
483 if (evaluated &&
484 target_context->MakeShapeFromTensor(&t, src_shape, result).ok()) {
485 return Status::OK();
486 }
487 }
488
489 // Then try to infer partial shape from the input to the cast tensor.
490 ShapeHandle pre_cast_shape;
491 if (!ConstantPartialShape(target_context, input_edge->src(), 0,
492 &pre_cast_shape, outer_context)
493 .ok()) {
494 TF_RETURN_IF_ERROR(
495 target_context->MakeShapeFromTensor(nullptr, src_shape, result));
496 }
497 if (!target_context->RankKnown(pre_cast_shape)) {
498 // Failed to evaluate. Treat the output as completely unknown.
499 *result = target_context->UnknownShape();
500 return Status::OK();
501 }
502 auto* dest_type = input_edge->src()->attrs().Find("DstT");
503 if (dest_type == nullptr || dest_type->value_case() != AttrValue::kType ||
504 (dest_type->type() != DT_INT32 && dest_type->type() != DT_INT64)) {
505 // Casting to a weird type. Do not attempt to infer across it.
506 *result = target_context->MakeShape(std::vector<DimensionHandle>(
507 target_context->Rank(pre_cast_shape), target_context->UnknownDim()));
508 return Status::OK();
509 }
510 *result = pre_cast_shape;
511 } else if (src_op == "Shape") {
512 *result = src_context->input(0);
513 } else if (src_op == "ShapeN") {
514 *result = src_context->input(input_edge->src_output());
515 } else if (src_op == "Pack") {
516 std::vector<DimensionHandle> dims;
517 // Pack is concatenating its input scalars to form the shape tensor vector.
518 for (int i = 0; i < src_context->num_inputs(); ++i) {
519 int64 size;
520 bool evaluated;
521 TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(
522 input_edge->src(), i, &evaluated, &size, outer_context));
523 if (evaluated) {
524 dims.push_back(size < 0 ? target_context->UnknownDim()
525 : target_context->MakeDim(size));
526 } else {
527 dims.push_back(target_context->UnknownDim());
528 }
529 }
530 *result = target_context->MakeShape(dims);
531 } else if (src_op == "Concat" || src_op == "ConcatV2") {
532 *result = target_context->Scalar();
533 // For Concat, input 0 is concat dim; for V2 it is the last input.
534 const int concat_dim =
535 src_op == "Concat" ? 0 : src_context->num_inputs() - 1;
536 // Concat is concatenating its input shape vectors.
537 for (int i = 0; i < src_context->num_inputs(); ++i) {
538 // Concat dim is ignored (and will always be a scalar).
539 if (i == concat_dim) continue;
540 ShapeHandle sub_result;
541 TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
542 i, &sub_result, outer_context));
543 if (!target_context->RankKnown(sub_result)) {
544 // Failed to evaluate. Treat the output as completely unknown.
545 // TODO(cwhipkey): we could rely on all inputs being the same rank, so
546 // figure that rank out and append the right number of unknown dims.
547 *result = target_context->UnknownShape();
548 return Status::OK();
549 }
550 TF_RETURN_IF_ERROR(
551 target_context->Concatenate(*result, sub_result, result));
552 }
553 } else if (src_op == "StridedSlice") {
554 TF_RETURN_IF_ERROR(PartialStridedSliceShape(input_edge->src(), src_context,
555 result, outer_context));
556 } else if (src_op == "VariableShape") {
557 auto* handle_data = src_context->input_handle_shapes_and_types(0);
558 if (handle_data != nullptr && !handle_data->empty()) {
559 *result = handle_data->at(0).shape;
560 } else {
561 *result = target_context->UnknownShape();
562 }
563 } else {
564 Tensor t;
565 bool evaluated = false;
566 TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
567 &t, outer_context));
568 TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
569 evaluated ? &t : nullptr, src_shape, result));
570 }
571 return Status::OK();
572 }
573
PartialStridedSliceShape(Node * slice_node,InferenceContext * ctx,ShapeHandle * result,shape_inference::InferenceContext * outer_context)574 Status ShapeRefiner::PartialStridedSliceShape(
575 Node* slice_node, InferenceContext* ctx, ShapeHandle* result,
576 shape_inference::InferenceContext* outer_context) {
577 // Only attempt to evaluate if begin/end/strides all are scalars.
578 for (int i = 1; i <= 3; ++i) {
579 ShapeHandle input_shape = ctx->input(i);
580 if (ctx->Value(ctx->Dim(input_shape, 0)) != 1) {
581 *result = ctx->UnknownShape();
582 return Status::OK();
583 }
584 }
585
586 int begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
587 TF_RETURN_IF_ERROR(
588 GetNodeAttr(slice_node->attrs(), "begin_mask", &begin_mask));
589 TF_RETURN_IF_ERROR(GetNodeAttr(slice_node->attrs(), "end_mask", &end_mask));
590 TF_RETURN_IF_ERROR(
591 GetNodeAttr(slice_node->attrs(), "ellipsis_mask", &ellipsis_mask));
592 TF_RETURN_IF_ERROR(
593 GetNodeAttr(slice_node->attrs(), "new_axis_mask", &new_axis_mask));
594 TF_RETURN_IF_ERROR(
595 GetNodeAttr(slice_node->attrs(), "shrink_axis_mask", &shrink_axis_mask));
596
597 // Only attempt to evaluate if there are no special masks set (note that we
598 // can handle begin/end_mask == 1).
599 if (!(begin_mask == 0 || begin_mask == 1) ||
600 !(end_mask == 0 || end_mask == 1) || ellipsis_mask != 0 ||
601 new_axis_mask != 0 || shrink_axis_mask != 0) {
602 *result = ctx->UnknownShape();
603 return Status::OK();
604 }
605
606 bool evaluated;
607 int64 begin;
608 if (begin_mask == 1) {
609 begin = 0;
610 } else {
611 TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated,
612 &begin, outer_context));
613 if (!evaluated) {
614 *result = ctx->UnknownShape();
615 return Status::OK();
616 }
617 }
618
619 int64 end;
620 if (end_mask == 1) {
621 end = std::numeric_limits<int64>::max();
622 } else {
623 TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated,
624 &end, outer_context));
625 if (!evaluated) {
626 *result = ctx->UnknownShape();
627 return Status::OK();
628 }
629 }
630
631 int64 stride;
632 TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated,
633 &stride, outer_context));
634 if (!evaluated) {
635 *result = ctx->UnknownShape();
636 return Status::OK();
637 }
638
639 // Apply stride to input interpreted as a partial shape.
640 ShapeHandle input;
641 TF_RETURN_IF_ERROR(
642 ConstantPartialShape(ctx, slice_node, 0, &input, outer_context));
643 TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result));
644 return Status::OK();
645 }
646
RunShapeFn(const Node * node,const OpRegistrationData * op_reg_data,ExtendedInferenceContext * ec,InferenceContext * outer_context)647 Status ShapeRefiner::RunShapeFn(const Node* node,
648 const OpRegistrationData* op_reg_data,
649 ExtendedInferenceContext* ec,
650 InferenceContext* outer_context) {
651 // This will be filled in with real data in a second pass.
652 std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
653 std::vector<Tensor> real_tensors(node->num_inputs());
654 std::vector<bool> attempted_materialization(node->num_inputs());
655 std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
656 std::vector<ShapeHandle> input_tensors_as_shapes;
657
658 auto* c = ec->get_context();
659
660 c->set_input_tensors(input_tensors);
661 c->set_input_tensors_as_shapes(input_tensors_as_shapes);
662
663 // Run the shape inference function, and return if there was an error.
664 // Capture as lambda, because we might need to re-run inference later on.
665 auto run_inference_lambda = [&]() {
666 if (function_library_ && IsFunctionCall(*function_library_, *node)) {
667 bool disable_shape_inference;
668 if (!GetNodeAttr(AttrSlice(node->def()), "_disable_call_shape_inference",
669 &disable_shape_inference)
670 .ok() ||
671 !disable_shape_inference) {
672 // Special inference logic for user-defined functions.
673 NameAttrList function;
674 TF_RETURN_IF_ERROR(
675 NameAndAttrsFromFunctionCall(node->def(), &function));
676 const FunctionDef* function_def =
677 function_library_->Find(function.name());
678 if (function_def != nullptr) {
679 // The constant Tensor map we have for the outside context is not
680 // valid inside the function. We need to push a new clean map while
681 // performing inference on the function body.
682 auto const_tensor_map_copy = const_tensor_map_;
683 const_tensor_map_.clear();
684 Status function_inference_status = InferShapesForFunction(
685 function_def, AttrSlice(&function.attr()), ec);
686 const_tensor_map_ = const_tensor_map_copy;
687 return function_inference_status;
688 }
689 }
690 }
691
692 if (op_reg_data->shape_inference_fn) {
693 TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
694 } else {
695 TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
696 }
697 return Status::OK();
698 };
699 TF_RETURN_IF_ERROR(run_inference_lambda());
700
701 // We must run the shape function repeatedly, in case users write
702 // shape functions where they only conditionally call input_tensor()
703 // based on the values of another input tensor.
704 bool rerun_shape_fn;
705 do {
706 // If the result of running shape inference would have benefitted
707 // from knowing the values of input tensors, try to materialize
708 // the results of those tensors, and then run the shape inference
709 // function again using those known tensors.
710 rerun_shape_fn = false;
711
712 // NOTE: It is possible to batch the extraction and
713 // materialization of inputs, instead of materializing one input
714 // at a time like we do below. If input-at-a-time computation
715 // becomes a bottleneck, we could separate ExtractConstantSubgraph
716 // into two functions: one that returns true if an input is
717 // derivable from constants, and another function that extracts
718 // the subgraph for multiple target nodes and executes the whole
719 // subgraph once.
720
721 for (int i = 0; i < c->num_inputs(); ++i) {
722 if (!c->requested_input_tensor(i)) {
723 continue;
724 }
725 // Check if we have not already filled in the requested input,
726 // and if not, try to materialize the tensors.
727 if (!attempted_materialization[i]) {
728 attempted_materialization[i] = true;
729
730 Tensor result;
731 bool evaluated = false;
732 TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(
733 node, i, &evaluated, &result, outer_context));
734 if (evaluated) {
735 real_tensors[i] = result;
736 input_tensors[i] = &real_tensors[i];
737 // We have more concrete information about a shape,
738 // so re-run shape inference.
739 rerun_shape_fn = true;
740 }
741 }
742 if (c->requested_input_tensor_as_partial_shape(i) &&
743 !attempted_tensor_as_shape_conversion[i]) {
744 attempted_tensor_as_shape_conversion[i] = true;
745 if (i >= input_tensors_as_shapes.size()) {
746 input_tensors_as_shapes.resize(i + 1);
747 }
748 ShapeHandle s;
749 TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s, outer_context));
750 input_tensors_as_shapes[i] = s;
751 rerun_shape_fn = true;
752 }
753 }
754
755 if (rerun_shape_fn) {
756 // We have more information about the shapes on this pass,
757 // so re-run shape inference.
758 c->set_input_tensors(input_tensors);
759 c->set_input_tensors_as_shapes(input_tensors_as_shapes);
760 TF_RETURN_IF_ERROR(run_inference_lambda());
761 }
762 } while (rerun_shape_fn);
763
764 return Status::OK();
765 }
766
SameDefinedShape(InferenceContext * c,ShapeHandle s0,ShapeHandle s1)767 bool ShapeRefiner::SameDefinedShape(InferenceContext* c, ShapeHandle s0,
768 ShapeHandle s1) {
769 if (s0.SameHandle(s1)) {
770 return true;
771 }
772 if (c->Rank(s0) != c->Rank(s1)) {
773 return false;
774 }
775 if (!c->RankKnown(s0) && !c->RankKnown(s1)) {
776 return false;
777 }
778 for (int i = 0; i < c->Rank(s0); ++i) {
779 if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
780 int64 val0 = c->Value(c->Dim(s0, i));
781 int64 val1 = c->Value(c->Dim(s1, i));
782 if (val0 < 0 || val1 < 0 || val0 != val1) {
783 return false;
784 }
785 }
786 }
787
788 return true;
789 }
790
IsUpdatedShapesOrTypes(InferenceContext * c,const std::vector<ShapeAndType> & existing,const std::vector<ShapeAndType> & updated)791 bool ShapeRefiner::IsUpdatedShapesOrTypes(
792 InferenceContext* c, const std::vector<ShapeAndType>& existing,
793 const std::vector<ShapeAndType>& updated) {
794 if (existing.size() != updated.size()) {
795 return true;
796 }
797 for (int i = 0; i < existing.size(); i++) {
798 if (!SameDefinedShape(c, existing[i].shape, updated[i].shape) ||
799 existing[i].dtype != updated[i].dtype) {
800 return true;
801 }
802 }
803 return false;
804 }
805
806 } // namespace tensorflow
807