1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/shape_refiner.h"
16
17 #include <deque>
18 #include <memory>
19 #include <unordered_set>
20 #include <vector>
21
22 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
23 #include "tensorflow/core/common_runtime/function_utils.h"
24 #include "tensorflow/core/common_runtime/graph_constructor.h"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/common_shape_fns.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/shape_inference.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/lib/core/errors.h"
34
35 namespace tensorflow {
36
37 using shape_inference::DimensionHandle;
38 using shape_inference::InferenceContext;
39 using shape_inference::ShapeAndType;
40 using shape_inference::ShapeHandle;
41
ShapeRefiner(int graph_def_version,const OpRegistryInterface * ops)42 ShapeRefiner::ShapeRefiner(int graph_def_version,
43 const OpRegistryInterface* ops)
44 : graph_def_version_(graph_def_version),
45 ops_registry_(ops),
46 graph_runner_(Env::Default()) {}
47
ShapeRefiner(const VersionDef & versions,const OpRegistryInterface * ops)48 ShapeRefiner::ShapeRefiner(const VersionDef& versions,
49 const OpRegistryInterface* ops)
50 : ShapeRefiner(versions.producer(), ops) {}
51
~ShapeRefiner()52 ShapeRefiner::~ShapeRefiner() {
53 // The lifetime of the tensors are bound to the GraphRunner, so the tensors
54 // should be deleted before it.
55 const_tensor_map_.clear();
56 }
57
58 namespace {
59
60 constexpr char kArgOp[] = "_Arg";
61 constexpr char kRetvalOp[] = "_Retval";
62
63 } // namespace
64
65 // Runs shape inference for the given node using the given ShapeRefiner.
66 // The node must be a sub-node of a function node and the outer_context is
67 // the inference context of that function node in the outer graph.
InferShapesForFunctionSubNode(const Node * node,InferenceContext * outer_context)68 Status ShapeRefiner::InferShapesForFunctionSubNode(
69 const Node* node, InferenceContext* outer_context) {
70 TF_RETURN_IF_ERROR(AddNodeInternal(node, outer_context));
71 InferenceContext* node_context = CHECK_NOTNULL(GetContext(node));
72
73 if (StringPiece(node->type_string()) == kArgOp) {
74 // Handle special node: function input.
75 // Shapes for these nodes are provided in the outer inference
76 // context.
77
78 int index;
79 TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
80
81 if (index < 0 || outer_context->num_inputs() <= index) {
82 return errors::Internal(
83 "Function instantiation included invalid input index: ", index,
84 " not in [0, ", outer_context->num_inputs(), ").");
85 }
86
87 // TODO(b/134547156): TEMPORARY WORKAROUND. If input shape handle is not set
88 // in outer context, set _Arg node output shape to unknown.
89 if (outer_context->input(index).SameHandle(ShapeHandle())) {
90 VLOG(1) << "Function instantiation has undefined input shape at "
91 << "index: " << index << " in the outer inference context.";
92 node_context->set_output(0, node_context->UnknownShape());
93 } else {
94 node_context->set_output(0, outer_context->input(index));
95 }
96
97 auto* resource = outer_context->input_handle_shapes_and_types(index);
98 if (resource) {
99 node_context->set_output_handle_shapes_and_types(0, *resource);
100 }
101 } else if (StringPiece(node->type_string()) == kRetvalOp) {
102 // Handle special node: function output.
103 // Shapes inferred for these nodes go into the outer inference
104 // context.
105
106 int index;
107 TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
108
109 if (index < 0 || outer_context->num_outputs() <= index) {
110 return errors::Internal(
111 "Function instantiation included invalid output index: ", index,
112 " not in [0, ", outer_context->num_outputs(), ").");
113 }
114
115 // outer_context outlives node_context, therefore we need to create
116 // a new shape handle owned by outer_context instead.
117 ShapeHandle handle;
118 TensorShapeProto proto;
119 node_context->ShapeHandleToProto(node_context->input(0), &proto);
120 TF_RETURN_IF_ERROR(outer_context->MakeShapeFromShapeProto(proto, &handle));
121 outer_context->set_output(index, handle);
122
123 const std::vector<ShapeAndType>* resource =
124 node_context->input_handle_shapes_and_types(0);
125 if (resource) {
126 // `ShapesAndType`s contain `ShapeHandle`s. These `ShapeHandle`s point
127 // to `Shape`s that are owned by a different inference context too. We
128 // need to copy them to the outer context to prevent them from being
129 // destroyed before they are used.
130 std::vector<ShapeAndType> copied_shapes_and_types;
131 for (auto& shape_and_type : *resource) {
132 ShapeHandle handle;
133 TensorShapeProto proto;
134 node_context->ShapeHandleToProto(shape_and_type.shape, &proto);
135 TF_RETURN_IF_ERROR(
136 outer_context->MakeShapeFromShapeProto(proto, &handle));
137 copied_shapes_and_types.push_back(
138 ShapeAndType(handle, shape_and_type.dtype, shape_and_type.type));
139 }
140
141 outer_context->set_output_handle_shapes_and_types(
142 index, copied_shapes_and_types);
143 }
144 }
145
146 return Status::OK();
147 }
148
149 // TODO(cwhipkey): When an inference context inside function has
150 // requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
151 // set when input(i) is an _Arg op, then this request should propagate to
152 // context, and vice versa.
153 //
154 // NOTE: Recursive user-defined functions are not supported.
155 // Maybe we won't support recursive functions at all in TF, because of
156 // other maintainability issues.
InferShapesForFunction(const FunctionDef * function_def,AttrSlice attributes,ExtendedInferenceContext * outer_context)157 Status ShapeRefiner::InferShapesForFunction(
158 const FunctionDef* function_def, AttrSlice attributes,
159 ExtendedInferenceContext* outer_context) {
160 const Graph* graph;
161 auto it = functions_.find(function_def);
162 if (it != functions_.end()) {
163 graph = it->second.get();
164 } else {
165 InstantiationResult result;
166 TF_RETURN_IF_ERROR(InstantiateFunction(
167 *function_def, attributes,
168 [this](const string& op, const OpDef** sig) {
169 return this->function_library_->LookUpOpDef(op, sig);
170 },
171 &result));
172
173 Graph* new_graph = new Graph(function_library_);
174 GraphConstructorOptions options;
175 options.allow_internal_ops = true;
176 TF_RETURN_IF_ERROR(
177 ConvertNodeDefsToGraph(options, result.nodes, new_graph));
178 functions_[function_def].reset(new_graph);
179 graph = new_graph;
180 }
181
182 std::unordered_set<const Node*> function_nodes;
183 Status inference_status = Status::OK();
184 {
185 auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
186 &inference_status](const Node* node) {
187 if (!inference_status.ok()) return;
188 inference_status =
189 InferShapesForFunctionSubNode(node, outer_context->get_context());
190 function_nodes.insert(node);
191 };
192
193 // Calls inference lambda for each node after visiting all predecessors.
194 // Ensures that we are adding nodes to ShapeRefiner in the topological
195 // order.
196 ReverseDFS(*graph, {}, node_shape_inference_lambda);
197 }
198
199 // Delete the contexts created for the functions nodes to save memory.
200 for (const Node* node : function_nodes) {
201 node_to_context_.erase(node);
202 }
203
204 return inference_status;
205 }
206
AddNode(const Node * node)207 Status ShapeRefiner::AddNode(const Node* node) {
208 return AddNodeInternal(node, /*outer_context=*/nullptr);
209 }
210
AddNodeInternal(const Node * node,shape_inference::InferenceContext * outer_context)211 Status ShapeRefiner::AddNodeInternal(
212 const Node* node, shape_inference::InferenceContext* outer_context) {
213 // Create the inference context for this node with the existing input shapes.
214 std::unique_ptr<InferenceContext> ic(new InferenceContext(
215 graph_def_version_, node->def(), node->op_def(),
216 std::vector<ShapeHandle>(node->num_inputs()), {}, {}, {}));
217 TF_RETURN_IF_ERROR(ic->construction_status());
218
219 // For each 'input' of this node, fetch the corresponding shape
220 // from 'input's InferenceContext, and store into this node's
221 // InferenceContext.
222 for (const Edge* e : node->in_edges()) {
223 if (e->IsControlEdge()) continue;
224
225 if (e->dst_input() < 0) {
226 return tensorflow::errors::Internal(
227 "Index ", e->dst_input(), " is negative but not a control edge.");
228 }
229
230 const Node* input = e->src();
231 auto it = node_to_context_.find(input);
232 if (it == node_to_context_.end()) {
233 // v1 control flow adds loops to the graph; we have to break them
234 // somewhere, so we'll ignore this input and leave its shape undefined.
235 ic->SetInput(e->dst_input(), ic->UnknownShape());
236 continue;
237 }
238
239 InferenceContext* input_ic = it->second->get_context();
240 ic->SetInput(e->dst_input(), input_ic->output(e->src_output()));
241
242 const auto* in_v =
243 input_ic->output_handle_shapes_and_types(e->src_output());
244 if (in_v != nullptr) {
245 DataType input_type = e->src()->output_type(e->src_output());
246 DCHECK(input_type == DT_RESOURCE || input_type == DT_VARIANT);
247 ic->set_input_handle_shapes_and_types(e->dst_input(),
248 std::vector<ShapeAndType>(*in_v));
249 }
250 }
251
252 // Get the shape function for this node
253 const OpRegistrationData* op_reg_data;
254 TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
255 if (op_reg_data->shape_inference_fn == nullptr &&
256 require_shape_inference_fns_) {
257 return errors::InvalidArgument(
258 "No shape inference function exists for op '", node->type_string(),
259 "', did you forget to define it?");
260 }
261
262 std::unique_ptr<ExtendedInferenceContext> ec(
263 new ExtendedInferenceContext(std::move(ic), node));
264
265 // Run the shape inference function, and return if there was an error.
266 TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get(), outer_context));
267
268 // Store the resulting context object in the map.
269 node_to_context_[node].swap(ec);
270
271 return Status::OK();
272 }
273
SetShape(const Node * node,int output_port,ShapeHandle shape)274 Status ShapeRefiner::SetShape(const Node* node, int output_port,
275 ShapeHandle shape) {
276 auto c = GetContext(node);
277 if (c == nullptr) {
278 return errors::Internal("Could not find context for ", node->name());
279 }
280
281 if (output_port < 0 || output_port >= node->num_outputs()) {
282 return errors::InvalidArgument(
283 "output_port '", output_port, "' is out of range, ", "node '",
284 node->name(), "' has ", node->num_outputs(), " outputs");
285 }
286 // Note: it's possible, if the node's been updated, that the shape inference
287 // context doesn't have the right number of outputs.
288 if (node->num_outputs() > c->num_outputs()) {
289 TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs()));
290 }
291
292 // Check compatibility, and merge the shapes.
293 ShapeHandle existing_shape = c->output(output_port);
294 TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &shape));
295 c->set_output(output_port, shape);
296
297 // TODO(vrv): Do we need to propagate the new shape through all
298 // consumers that change their outputs? At the moment, python
299 // does not do this, but this seems like a nice feature.
300
301 // TODO(vrv): We might need to keep track of the fact that the
302 // existing shape is invalidated, in case we need to propagate
303 // this information to remote workers.
304 return Status::OK();
305 }
306
UpdateNode(const Node * node,bool relax,bool * refined)307 Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
308 auto it = node_to_context_.find(node);
309 if (it == node_to_context_.end()) {
310 *refined = true;
311 return AddNode(node);
312 }
313 ExtendedInferenceContext* node_ext_context = it->second.get();
314 InferenceContext* node_context = node_ext_context->get_context();
315
316 // Give up if the context wasn't successfully built by the AddNode() method.
317 TF_RETURN_IF_ERROR(node_context->construction_status());
318
319 // Check if the shapes of the nodes in the fan-in of this node have changed,
320 // and if they have update the node input shapes.
321 for (const Edge* e : node->in_edges()) {
322 if (e->IsControlEdge()) continue;
323
324 int dst_input = e->dst_input();
325 int src_output = e->src_output();
326
327 Node* input = e->src();
328 auto iter = node_to_context_.find(input);
329 if (iter == node_to_context_.end()) {
330 return errors::FailedPrecondition(
331 "Input ", dst_input, " ('", input->name(), "') for '", node->name(),
332 "' was not previously added to ShapeRefiner.");
333 }
334
335 InferenceContext* c = iter->second->get_context();
336 DCHECK_GE(dst_input, 0);
337 ShapeHandle existing_input = node_context->input(dst_input);
338 if (!relax) {
339 if (node_context->MergeInput(dst_input, c->output(src_output))) {
340 if (!SameDefinedShape(node_context, node_context->input(dst_input),
341 existing_input)) {
342 *refined = true;
343 }
344 }
345 } else {
346 if (node_context->RelaxInput(dst_input, c->output(src_output))) {
347 if (!SameDefinedShape(node_context, node_context->input(dst_input),
348 existing_input)) {
349 *refined = true;
350 }
351 }
352 }
353 if (node_context->requested_input_tensor_as_partial_shape(dst_input)) {
354 // The input value may have changed. Since we have no way to know if
355 // that's indeed the case, err on the safe side.
356 *refined = true;
357 }
358
359 // Also propagate handle shape and dtype of edges which are carrying
360 // resource handles.
361 if (e->src()->output_type(src_output) == DT_RESOURCE) {
362 auto* outputs = c->output_handle_shapes_and_types(src_output);
363 if (!outputs) continue;
364
365 if (!relax &&
366 node_context->MergeInputHandleShapesAndTypes(dst_input, *outputs)) {
367 *refined = true;
368 } else if (relax) {
369 std::vector<ShapeAndType> existing_inputs;
370 const std::vector<ShapeAndType>* inputs =
371 node_context->input_handle_shapes_and_types(dst_input);
372 if (inputs) {
373 existing_inputs = *inputs;
374 }
375 if (node_context->RelaxInputHandleShapesAndMergeTypes(dst_input,
376 *outputs)) {
377 if (IsUpdatedShapesOrTypes(
378 node_context, existing_inputs,
379 *node_context->input_handle_shapes_and_types(dst_input))) {
380 *refined = true;
381 }
382 }
383 }
384 }
385 }
386
387 if (!*refined) {
388 // No input shape has changed, we're done
389 return Status::OK();
390 }
391
392 // Get and run the shape function for this node to update the shapes of the
393 // outputs.
394 const OpRegistrationData* op_reg_data;
395 TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
396 if (op_reg_data->shape_inference_fn == nullptr &&
397 require_shape_inference_fns_) {
398 return errors::InvalidArgument(
399 "No shape inference function exists for op '", node->type_string(),
400 "', did you forget to define it?");
401 }
402
403 if (!op_reg_data->shape_inference_fn) {
404 // There is nothing more we can infer
405 return Status::OK();
406 }
407
408 return RunShapeFn(node, op_reg_data, node_ext_context);
409 }
410
EvaluateConstantTensorForEdge(const Node * node,int dst_idx,bool * evaluated,Tensor * result,InferenceContext * outer_context)411 Status ShapeRefiner::EvaluateConstantTensorForEdge(
412 const Node* node, int dst_idx, bool* evaluated, Tensor* result,
413 InferenceContext* outer_context) {
414 *evaluated = false;
415 const Edge* input_edge;
416 TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
417 OutputTensor tensor(input_edge->src(), input_edge->src_output());
418 return EvaluateConstantTensor(
419 tensor, *this, *ops_registry_, graph_def_version_, evaluated, result,
420 &graph_runner_, &const_tensor_map_, kMaxTensorSize,
421 disable_constant_propagation_, outer_context);
422 }
423
EvaluateConstantIntScalarEdge(const Node * node,int dst_idx,bool * evaluated,int64 * result,shape_inference::InferenceContext * outer_context)424 Status ShapeRefiner::EvaluateConstantIntScalarEdge(
425 const Node* node, int dst_idx, bool* evaluated, int64* result,
426 shape_inference::InferenceContext* outer_context) {
427 Tensor scalar;
428 TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, evaluated,
429 &scalar, outer_context));
430 if (*evaluated) {
431 if (scalar.NumElements() != 1) {
432 return errors::InvalidArgument(
433 "EvaluateConstantIntScalarEdge called on non-scalar edge: ",
434 scalar.NumElements());
435 }
436 if (scalar.dtype() == DT_INT32) {
437 *result = scalar.scalar<int32>()();
438 } else {
439 if (scalar.dtype() != DT_INT64) {
440 return errors::InvalidArgument(
441 "EvaluateConstantIntScalarEdge called on non-integer edge: ",
442 scalar.dtype());
443 }
444 *result = scalar.scalar<int64>()();
445 }
446 }
447 return Status::OK();
448 }
449
ConstantPartialShape(InferenceContext * target_context,const Node * node,int dst_idx,ShapeHandle * result,shape_inference::InferenceContext * outer_context)450 Status ShapeRefiner::ConstantPartialShape(
451 InferenceContext* target_context, const Node* node, int dst_idx,
452 ShapeHandle* result, shape_inference::InferenceContext* outer_context) {
453 const Edge* input_edge;
454 TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
455
456 InferenceContext* src_context = GetContext(input_edge->src());
457 if (src_context == nullptr) return errors::Internal("Missing src context");
458 ShapeHandle src_shape = src_context->output(input_edge->src_output());
459
460 if (src_context->Value(src_context->Rank(src_shape)) == 0) {
461 Tensor t;
462 bool evaluated = false;
463 TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
464 &t, outer_context));
465 if (!evaluated) {
466 return errors::InvalidArgument(
467 "Received a shape scalar with unknown static value. A static value "
468 "of '-1' is required to represent an unknown shape.");
469 }
470 if (t.dims() == 0) {
471 if (t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) {
472 *result = target_context->UnknownShape();
473 return Status::OK();
474 } else if (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1) {
475 *result = target_context->UnknownShape();
476 return Status::OK();
477 }
478 }
479 return errors::InvalidArgument(
480 "Received an invalid shape scalar with a static value that is not "
481 "'-1': ",
482 t.DebugString());
483 }
484
485 TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
486
487 const string& src_op = input_edge->src()->type_string();
488 if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) {
489 // Source tensor is a vector of length 0, so the shape it
490 // represents is as scalar.
491 *result = target_context->Scalar();
492 } else if (src_op == "Cast") {
493 // First try to evaluate the current tensor, as it might be a valid cast of
494 // a float.
495 Tensor t;
496 bool evaluated = false;
497 if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t,
498 outer_context)
499 .ok()) {
500 if (evaluated &&
501 target_context->MakeShapeFromTensor(&t, src_shape, result).ok()) {
502 return Status::OK();
503 }
504 }
505
506 // Then try to infer partial shape from the input to the cast tensor.
507 ShapeHandle pre_cast_shape;
508 if (!ConstantPartialShape(target_context, input_edge->src(), 0,
509 &pre_cast_shape, outer_context)
510 .ok()) {
511 TF_RETURN_IF_ERROR(
512 target_context->MakeShapeFromTensor(nullptr, src_shape, result));
513 }
514 if (!target_context->RankKnown(pre_cast_shape)) {
515 // Failed to evaluate. Treat the output as completely unknown.
516 *result = target_context->UnknownShape();
517 return Status::OK();
518 }
519 auto* dest_type = input_edge->src()->attrs().Find("DstT");
520 if (dest_type == nullptr || dest_type->value_case() != AttrValue::kType ||
521 (dest_type->type() != DT_INT32 && dest_type->type() != DT_INT64)) {
522 // Casting to a weird type. Do not attempt to infer across it.
523 *result = target_context->MakeShape(std::vector<DimensionHandle>(
524 target_context->Rank(pre_cast_shape), target_context->UnknownDim()));
525 return Status::OK();
526 }
527 *result = pre_cast_shape;
528 } else if (src_op == "Shape") {
529 *result = src_context->input(0);
530 } else if (src_op == "ShapeN") {
531 *result = src_context->input(input_edge->src_output());
532 } else if (src_op == "Pack") {
533 std::vector<DimensionHandle> dims;
534 // Pack is concatenating its input scalars to form the shape tensor vector.
535 for (int i = 0; i < src_context->num_inputs(); ++i) {
536 int64_t size;
537 bool evaluated;
538 TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(
539 input_edge->src(), i, &evaluated, &size, outer_context));
540 if (evaluated) {
541 dims.push_back(size < 0 ? target_context->UnknownDim()
542 : target_context->MakeDim(size));
543 } else {
544 dims.push_back(target_context->UnknownDim());
545 }
546 }
547 *result = target_context->MakeShape(dims);
548 } else if (src_op == "Concat" || src_op == "ConcatV2") {
549 *result = target_context->Scalar();
550 // For Concat, input 0 is concat dim; for V2 it is the last input.
551 const int concat_dim =
552 src_op == "Concat" ? 0 : src_context->num_inputs() - 1;
553 // Concat is concatenating its input shape vectors.
554 for (int i = 0; i < src_context->num_inputs(); ++i) {
555 // Concat dim is ignored (and will always be a scalar).
556 if (i == concat_dim) continue;
557 ShapeHandle sub_result;
558 TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
559 i, &sub_result, outer_context));
560 if (!target_context->RankKnown(sub_result)) {
561 // Failed to evaluate. Treat the output as completely unknown.
562 // TODO(cwhipkey): we could rely on all inputs being the same rank, so
563 // figure that rank out and append the right number of unknown dims.
564 *result = target_context->UnknownShape();
565 return Status::OK();
566 }
567 TF_RETURN_IF_ERROR(
568 target_context->Concatenate(*result, sub_result, result));
569 }
570 } else if (src_op == "StridedSlice") {
571 TF_RETURN_IF_ERROR(PartialStridedSliceShape(input_edge->src(), src_context,
572 result, outer_context));
573 } else if (src_op == "VariableShape") {
574 auto* handle_data = src_context->input_handle_shapes_and_types(0);
575 if (handle_data != nullptr && !handle_data->empty()) {
576 *result = handle_data->at(0).shape;
577 } else {
578 *result = target_context->UnknownShape();
579 }
580 } else {
581 Tensor t;
582 bool evaluated = false;
583 TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
584 &t, outer_context));
585 TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
586 evaluated ? &t : nullptr, src_shape, result));
587 }
588 return Status::OK();
589 }
590
PartialStridedSliceShape(Node * slice_node,InferenceContext * ctx,ShapeHandle * result,shape_inference::InferenceContext * outer_context)591 Status ShapeRefiner::PartialStridedSliceShape(
592 Node* slice_node, InferenceContext* ctx, ShapeHandle* result,
593 shape_inference::InferenceContext* outer_context) {
594 // Only attempt to evaluate if begin/end/strides all are scalars.
595 for (int i = 1; i <= 3; ++i) {
596 ShapeHandle input_shape = ctx->input(i);
597 if (ctx->Value(ctx->Dim(input_shape, 0)) != 1) {
598 *result = ctx->UnknownShape();
599 return Status::OK();
600 }
601 }
602
603 int begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
604 TF_RETURN_IF_ERROR(
605 GetNodeAttr(slice_node->attrs(), "begin_mask", &begin_mask));
606 TF_RETURN_IF_ERROR(GetNodeAttr(slice_node->attrs(), "end_mask", &end_mask));
607 TF_RETURN_IF_ERROR(
608 GetNodeAttr(slice_node->attrs(), "ellipsis_mask", &ellipsis_mask));
609 TF_RETURN_IF_ERROR(
610 GetNodeAttr(slice_node->attrs(), "new_axis_mask", &new_axis_mask));
611 TF_RETURN_IF_ERROR(
612 GetNodeAttr(slice_node->attrs(), "shrink_axis_mask", &shrink_axis_mask));
613
614 // Only attempt to evaluate if there are no special masks set (note that we
615 // can handle begin/end_mask == 1).
616 if (!(begin_mask == 0 || begin_mask == 1) ||
617 !(end_mask == 0 || end_mask == 1) || ellipsis_mask != 0 ||
618 new_axis_mask != 0 || shrink_axis_mask != 0) {
619 *result = ctx->UnknownShape();
620 return Status::OK();
621 }
622
623 bool evaluated;
624 int64_t begin;
625 if (begin_mask == 1) {
626 begin = 0;
627 } else {
628 TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated,
629 &begin, outer_context));
630 if (!evaluated) {
631 *result = ctx->UnknownShape();
632 return Status::OK();
633 }
634 }
635
636 int64_t end;
637 if (end_mask == 1) {
638 end = std::numeric_limits<int64>::max();
639 } else {
640 TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated,
641 &end, outer_context));
642 if (!evaluated) {
643 *result = ctx->UnknownShape();
644 return Status::OK();
645 }
646 }
647
648 int64_t stride;
649 TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated,
650 &stride, outer_context));
651 if (!evaluated) {
652 *result = ctx->UnknownShape();
653 return Status::OK();
654 }
655
656 // Apply stride to input interpreted as a partial shape.
657 ShapeHandle input;
658 TF_RETURN_IF_ERROR(
659 ConstantPartialShape(ctx, slice_node, 0, &input, outer_context));
660 TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result));
661 return Status::OK();
662 }
663
RunShapeFn(const Node * node,const OpRegistrationData * op_reg_data,ExtendedInferenceContext * ec,InferenceContext * outer_context)664 Status ShapeRefiner::RunShapeFn(const Node* node,
665 const OpRegistrationData* op_reg_data,
666 ExtendedInferenceContext* ec,
667 InferenceContext* outer_context) {
668 // This will be filled in with real data in a second pass.
669 std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
670 std::vector<Tensor> real_tensors(node->num_inputs());
671 std::vector<bool> attempted_materialization(node->num_inputs());
672 std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
673 std::vector<ShapeHandle> input_tensors_as_shapes;
674
675 auto* c = ec->get_context();
676
677 c->set_input_tensors(input_tensors);
678 c->set_input_tensors_as_shapes(input_tensors_as_shapes);
679
680 // Run the shape inference function, and return if there was an error.
681 // Capture as lambda, because we might need to re-run inference later on.
682 auto run_inference_lambda = [&]() {
683 if (function_library_ && IsFunctionCall(*function_library_, *node)) {
684 bool disable_shape_inference;
685 if (!GetNodeAttr(AttrSlice(node->def()), "_disable_call_shape_inference",
686 &disable_shape_inference)
687 .ok() ||
688 !disable_shape_inference) {
689 // Special inference logic for user-defined functions.
690 NameAttrList function;
691 TF_RETURN_IF_ERROR(
692 NameAndAttrsFromFunctionCall(node->def(), &function));
693 const FunctionDef* function_def =
694 function_library_->Find(function.name());
695 if (function_def != nullptr) {
696 // The constant Tensor map we have for the outside context is not
697 // valid inside the function. We need to push a new clean map while
698 // performing inference on the function body.
699 auto const_tensor_map_copy = const_tensor_map_;
700 const_tensor_map_.clear();
701 Status function_inference_status = InferShapesForFunction(
702 function_def, AttrSlice(&function.attr()), ec);
703 const_tensor_map_ = const_tensor_map_copy;
704 return function_inference_status;
705 }
706 }
707 }
708
709 if (op_reg_data->shape_inference_fn) {
710 TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
711 } else {
712 TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
713 }
714 return Status::OK();
715 };
716 TF_RETURN_IF_ERROR(run_inference_lambda());
717
718 // We must run the shape function repeatedly, in case users write
719 // shape functions where they only conditionally call input_tensor()
720 // based on the values of another input tensor.
721 bool rerun_shape_fn;
722 do {
723 // If the result of running shape inference would have benefitted
724 // from knowing the values of input tensors, try to materialize
725 // the results of those tensors, and then run the shape inference
726 // function again using those known tensors.
727 rerun_shape_fn = false;
728
729 // NOTE: It is possible to batch the extraction and
730 // materialization of inputs, instead of materializing one input
731 // at a time like we do below. If input-at-a-time computation
732 // becomes a bottleneck, we could separate ExtractConstantSubgraph
733 // into two functions: one that returns true if an input is
734 // derivable from constants, and another function that extracts
735 // the subgraph for multiple target nodes and executes the whole
736 // subgraph once.
737
738 for (int i = 0; i < c->num_inputs(); ++i) {
739 if (!c->requested_input_tensor(i)) {
740 continue;
741 }
742 // Check if we have not already filled in the requested input,
743 // and if not, try to materialize the tensors.
744 if (!attempted_materialization[i]) {
745 attempted_materialization[i] = true;
746
747 Tensor result;
748 bool evaluated = false;
749 TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(
750 node, i, &evaluated, &result, outer_context));
751 if (evaluated) {
752 real_tensors[i] = result;
753 input_tensors[i] = &real_tensors[i];
754 // We have more concrete information about a shape,
755 // so re-run shape inference.
756 rerun_shape_fn = true;
757 }
758 }
759 if (c->requested_input_tensor_as_partial_shape(i) &&
760 !attempted_tensor_as_shape_conversion[i]) {
761 attempted_tensor_as_shape_conversion[i] = true;
762 if (i >= input_tensors_as_shapes.size()) {
763 input_tensors_as_shapes.resize(i + 1);
764 }
765 ShapeHandle s;
766 TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s, outer_context));
767 input_tensors_as_shapes[i] = s;
768 rerun_shape_fn = true;
769 }
770 }
771
772 if (rerun_shape_fn) {
773 // We have more information about the shapes on this pass,
774 // so re-run shape inference.
775 c->set_input_tensors(input_tensors);
776 c->set_input_tensors_as_shapes(input_tensors_as_shapes);
777 TF_RETURN_IF_ERROR(run_inference_lambda());
778 }
779 } while (rerun_shape_fn);
780
781 return Status::OK();
782 }
783
SameDefinedShape(InferenceContext * c,ShapeHandle s0,ShapeHandle s1)784 bool ShapeRefiner::SameDefinedShape(InferenceContext* c, ShapeHandle s0,
785 ShapeHandle s1) {
786 if (s0.SameHandle(s1)) {
787 return true;
788 }
789 if (c->Rank(s0) != c->Rank(s1)) {
790 return false;
791 }
792 if (!c->RankKnown(s0) && !c->RankKnown(s1)) {
793 return false;
794 }
795 for (int i = 0; i < c->Rank(s0); ++i) {
796 if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
797 int64_t val0 = c->Value(c->Dim(s0, i));
798 int64_t val1 = c->Value(c->Dim(s1, i));
799 if (val0 < 0 || val1 < 0 || val0 != val1) {
800 return false;
801 }
802 }
803 }
804
805 return true;
806 }
807
IsUpdatedShapesOrTypes(InferenceContext * c,const std::vector<ShapeAndType> & existing,const std::vector<ShapeAndType> & updated)808 bool ShapeRefiner::IsUpdatedShapesOrTypes(
809 InferenceContext* c, const std::vector<ShapeAndType>& existing,
810 const std::vector<ShapeAndType>& updated) {
811 if (existing.size() != updated.size()) {
812 return true;
813 }
814 for (int i = 0; i < existing.size(); i++) {
815 if (!SameDefinedShape(c, existing[i].shape, updated[i].shape) ||
816 existing[i].dtype != updated[i].dtype) {
817 return true;
818 }
819 }
820 return false;
821 }
822
823 } // namespace tensorflow
824