1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/shape_refiner.h"
16
17 #include <deque>
18 #include <memory>
19 #include <unordered_set>
20 #include <vector>
21
22 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
23 #include "tensorflow/core/framework/bounds_check.h"
24 #include "tensorflow/core/framework/common_shape_fns.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor.pb.h"
28 #include "tensorflow/core/framework/versions.pb.h"
29 #include "tensorflow/core/graph/algorithm.h"
30 #include "tensorflow/core/graph/graph_constructor.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/gtl/stl_util.h"
33 #include "tensorflow/core/public/session.h"
34
35 namespace tensorflow {
36
37 using shape_inference::DimensionHandle;
38 using shape_inference::InferenceContext;
39 using shape_inference::ShapeAndType;
40 using shape_inference::ShapeHandle;
41
ShapeRefiner(int graph_def_version,const OpRegistryInterface * ops)42 ShapeRefiner::ShapeRefiner(int graph_def_version,
43 const OpRegistryInterface* ops)
44 : graph_def_version_(graph_def_version),
45 ops_registry_(ops),
46 graph_runner_(Env::Default()) {}
47
ShapeRefiner(const VersionDef & versions,const OpRegistryInterface * ops)48 ShapeRefiner::ShapeRefiner(const VersionDef& versions,
49 const OpRegistryInterface* ops)
50 : ShapeRefiner(versions.producer(), ops) {}
51
~ShapeRefiner()52 ShapeRefiner::~ShapeRefiner() {
53 // The lifetime of the tensors are bound to the GraphRunner, so the tensors
54 // should be deleted before it.
55 const_tensor_map_.clear();
56 }
57
58 namespace {
59
60 constexpr char kArgOp[] = "_Arg";
61 constexpr char kRetvalOp[] = "_Retval";
62
63 // Runs shape inference for the given node using the given ShapeRefiner.
64 // The node must be a sub-node of a function node and the outer_context is
65 // the inference context of that function node in the outer graph.
InferShapesForFunctionSubNode(const Node * node,ShapeRefiner * refiner,InferenceContext * outer_context)66 Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner,
67 InferenceContext* outer_context) {
68 TF_RETURN_IF_ERROR(refiner->AddNode(node));
69 InferenceContext* node_context = CHECK_NOTNULL(refiner->GetContext(node));
70
71 if (StringPiece(node->type_string()) == kArgOp) {
72 // Handle special node: function input.
73 // Shapes for these nodes are provided in the outer inference
74 // context.
75
76 int index;
77 TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
78
79 if (index < 0 || outer_context->num_inputs() <= index) {
80 return errors::Internal(
81 "Function instantiation included invalid input index: ", index,
82 " not in [0, ", outer_context->num_inputs(), ").");
83 }
84
85 node_context->set_output(0, outer_context->input(index));
86
87 auto* resource = outer_context->input_handle_shapes_and_types(index);
88 if (resource) {
89 node_context->set_output_handle_shapes_and_types(0, *resource);
90 }
91 } else if (StringPiece(node->type_string()) == kRetvalOp) {
92 // Handle special node: function output.
93 // Shapes inferred for these nodes go into the outer inference
94 // context.
95
96 int index;
97 TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
98
99 if (index < 0 || outer_context->num_outputs() <= index) {
100 return errors::Internal(
101 "Function instantiation included invalid output index: ", index,
102 " not in [0, ", outer_context->num_outputs(), ").");
103 }
104
105 // outer_context outlives node_context, therefore we need to create
106 // a new shape handle owned by outer_context instead.
107 ShapeHandle handle;
108 TensorShapeProto proto;
109 node_context->ShapeHandleToProto(node_context->input(0), &proto);
110 TF_RETURN_IF_ERROR(outer_context->MakeShapeFromShapeProto(proto, &handle));
111 outer_context->set_output(index, handle);
112
113 auto* resource = node_context->input_handle_shapes_and_types(0);
114 if (resource) {
115 outer_context->set_output_handle_shapes_and_types(index, *resource);
116 }
117 }
118
119 return Status::OK();
120 }
121
122 } // namespace
123
124 // TODO(cwhipkey): When an inference context inside function has
125 // requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
126 // set when input(i) is an _Arg op, then this request should propagate to
127 // context, and vice versa.
128 //
129 // NOTE: Recursive user-defined functions are not supported.
130 // Maybe we won't support recursive functions at all in TF, because of
131 // other maintainability issues.
InferShapesForFunction(const tensorflow::FunctionDef * function_def,bool keep_nested_shapes,ExtendedInferenceContext * outer_context)132 Status ShapeRefiner::InferShapesForFunction(
133 const tensorflow::FunctionDef* function_def, bool keep_nested_shapes,
134 ExtendedInferenceContext* outer_context) {
135 const Graph* graph;
136 auto it = functions_.find(function_def);
137 if (it != functions_.end()) {
138 graph = it->second.get();
139 } else {
140 InstantiationResult result;
141 TF_RETURN_IF_ERROR(InstantiateFunction(
142 *function_def, outer_context->get_context()->attrs(),
143 [this](const string& op, const OpDef** sig) {
144 return this->function_library_->LookUpOpDef(op, sig);
145 },
146 &result));
147
148 Graph* new_graph = new Graph(function_library_);
149 GraphConstructorOptions options;
150 options.allow_internal_ops = true;
151 TF_RETURN_IF_ERROR(
152 ConvertNodeDefsToGraph(options, result.nodes, new_graph));
153 functions_[function_def].reset(new_graph);
154 graph = new_graph;
155 }
156
157 std::unordered_set<const Node*> function_nodes;
158 Status inference_status = Status::OK();
159 {
160 auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
161 &inference_status](const Node* node) {
162 if (!inference_status.ok()) return;
163 inference_status = InferShapesForFunctionSubNode(
164 node, this, outer_context->get_context());
165 function_nodes.insert(node);
166 };
167
168 // Calls inference lambda for each node after visiting all predecessors.
169 // Ensures that we are adding nodes to ShapeRefiner in the topological
170 // order.
171 ReverseDFS(*graph, {}, node_shape_inference_lambda);
172 }
173
174 if (keep_nested_shapes && inference_status.ok()) {
175 // Fill the nested inferences map.
176 //
177 // The materialized function graph has extra nodes for arguments and
178 // return values, which are not explicitly listed in the FunctionDef,
179 // we filter out these special nodes here to not expose the implementation
180 // details and keep only inferences for the nodes listed in the FunctionDef.
181 std::unordered_map<string, const NodeDef*> user_defined_nodes;
182 for (const auto& node_def : function_def->node_def()) {
183 user_defined_nodes[node_def.name()] = &node_def;
184 }
185
186 std::unordered_map<string, std::unique_ptr<ExtendedInferenceContext>>
187 nested_inferences;
188 for (const Node* node : function_nodes) {
189 const string& node_name = node->name();
190 if (user_defined_nodes.find(node_name) != user_defined_nodes.end()) {
191 nested_inferences[node_name] = std::move(node_to_context_[node]);
192 node_to_context_.erase(node);
193 // By default InferenceContext refers to a NodeDef from Graph.
194 // Change it to the publicly accessible NodeDef of the function
195 // definition.
196 nested_inferences[node_name]->get_context()->node_def_ =
197 user_defined_nodes[node_name];
198 }
199 }
200 outer_context->set_nested_inferences(std::move(nested_inferences));
201 } else {
202 // Delete the contexts created for the functions nodes to save memory.
203 for (const Node* node : function_nodes) {
204 node_to_context_.erase(node);
205 }
206 }
207
208 return inference_status;
209 }
210
AddNode(const Node * node)211 Status ShapeRefiner::AddNode(const Node* node) {
212 // For each 'input' of this node, fetch the corresponding shape
213 // from 'input's InferenceContext, and store into a vector
214 // indexed by 'node's input.
215 std::vector<const Node*> input_nodes(node->num_inputs());
216 std::vector<ShapeHandle> input_shapes(node->num_inputs());
217 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
218 input_handle_shapes_and_types(node->num_inputs());
219 for (const Edge* e : node->in_edges()) {
220 if (e->IsControlEdge()) continue;
221
222 const Node* input = e->src();
223 auto it = node_to_context_.find(input);
224 if (it == node_to_context_.end()) {
225 return errors::FailedPrecondition(
226 "Input ", e->dst_input(), " ('", input->name(), "') for '",
227 node->name(), "' was not previously added to ShapeRefiner.");
228 }
229
230 InferenceContext* c = it->second->get_context();
231 DCHECK_GE(e->dst_input(), 0);
232 input_nodes[e->dst_input()] = input;
233 input_shapes[e->dst_input()] = c->output(e->src_output());
234
235 const auto* in_v = c->output_handle_shapes_and_types(e->src_output());
236 if (in_v != nullptr) {
237 DataType input_type = e->src()->output_type(e->src_output());
238 DCHECK(input_type == DT_RESOURCE || input_type == DT_VARIANT);
239 input_handle_shapes_and_types[e->dst_input()].reset(
240 new std::vector<ShapeAndType>(*in_v));
241 }
242 }
243
244 // Get the shape function for this node
245 const OpRegistrationData* op_reg_data;
246 TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
247 if (op_reg_data->shape_inference_fn == nullptr &&
248 require_shape_inference_fns_) {
249 return errors::InvalidArgument(
250 "No shape inference function exists for op '", node->type_string(),
251 "', did you forget to define it?");
252 }
253
254 // This needs to be filled in with real data in a second pass.
255 std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
256 std::vector<ShapeHandle> input_tensors_as_shapes;
257
258 // Create the inference context for this node with the existing input shapes.
259 std::unique_ptr<InferenceContext> c(
260 new InferenceContext(graph_def_version_, &node->def(), node->op_def(),
261 input_shapes, input_tensors, input_tensors_as_shapes,
262 std::move(input_handle_shapes_and_types)));
263 if (!c->construction_status().ok()) {
264 return c->construction_status();
265 }
266
267 std::unique_ptr<ExtendedInferenceContext> ec(
268 new ExtendedInferenceContext(std::move(c), node));
269
270 // Run the shape inference function, and return if there was an error.
271 TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get()));
272
273 // Store the resulting context object in the map.
274 node_to_context_[node].swap(ec);
275
276 return Status::OK();
277 }
278
SetShape(const Node * node,int output_port,ShapeHandle shape)279 Status ShapeRefiner::SetShape(const Node* node, int output_port,
280 ShapeHandle shape) {
281 auto c = GetContext(node);
282 if (c == nullptr) {
283 return errors::Internal("Could not find context for ", node->name());
284 }
285
286 if (output_port < 0 || output_port >= node->num_outputs()) {
287 return errors::InvalidArgument(
288 "output_port '", output_port, "' is out of range, ", "node '",
289 node->name(), "' has ", node->num_outputs(), " outputs");
290 }
291 // Note: it's possible, if the node's been updated, that the shape inference
292 // context doesn't have the right number of outputs.
293 if (node->num_outputs() > c->num_outputs()) {
294 TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs()));
295 }
296
297 // Check compatibility, and merge the shapes.
298 ShapeHandle existing_shape = c->output(output_port);
299 TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &shape));
300 c->set_output(output_port, shape);
301
302 // TODO(vrv): Do we need to propagate the new shape through all
303 // consumers that change their outputs? At the moment, python
304 // does not do this, but this seems like a nice feature.
305
306 // TODO(vrv): We might need to keep track of the fact that the
307 // existing shape is invalidated, in case we need to propagate
308 // this information to remote workers.
309 return Status::OK();
310 }
311
UpdateNode(const Node * node,bool relax,bool * refined)312 Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
313 auto it = node_to_context_.find(node);
314 if (it == node_to_context_.end()) {
315 *refined = true;
316 return AddNode(node);
317 }
318 ExtendedInferenceContext* node_ext_context = it->second.get();
319 InferenceContext* node_context = node_ext_context->get_context();
320
321 // Give up if the context wasn't successfully built by the AddNode() method.
322 TF_RETURN_IF_ERROR(node_context->construction_status());
323
324 // Check if the shapes of the nodes in the fan-in of this node have changed,
325 // and if they have update the node input shapes.
326 for (const Edge* e : node->in_edges()) {
327 if (e->IsControlEdge()) continue;
328
329 int dst_input = e->dst_input();
330 int src_output = e->src_output();
331
332 Node* input = e->src();
333 auto iter = node_to_context_.find(input);
334 if (iter == node_to_context_.end()) {
335 return errors::FailedPrecondition(
336 "Input ", dst_input, " ('", input->name(), "') for '", node->name(),
337 "' was not previously added to ShapeRefiner.");
338 }
339
340 InferenceContext* c = iter->second->get_context();
341 DCHECK_GE(dst_input, 0);
342 ShapeHandle existing_input = node_context->input(dst_input);
343 if (!relax) {
344 if (node_context->MergeInput(dst_input, c->output(src_output))) {
345 if (!SameDefinedShape(node_context, node_context->input(dst_input),
346 existing_input)) {
347 *refined = true;
348 }
349 }
350 } else {
351 if (node_context->RelaxInput(dst_input, c->output(src_output))) {
352 if (!SameDefinedShape(node_context, node_context->input(dst_input),
353 existing_input)) {
354 *refined = true;
355 }
356 }
357 }
358 if (node_context->requested_input_tensor_as_partial_shape(dst_input)) {
359 // The input value may have changed. Since we have no way to know if
360 // that's indeed the case, err on the safe side.
361 *refined = true;
362 }
363
364 // Also propagate handle shape and dtype of edges which are carrying
365 // resource handles.
366 if (e->src()->output_type(src_output) == DT_RESOURCE) {
367 auto* outputs = c->output_handle_shapes_and_types(src_output);
368 if (!outputs) continue;
369
370 if (!relax &&
371 node_context->MergeInputHandleShapesAndTypes(dst_input, *outputs)) {
372 *refined = true;
373 } else if (relax) {
374 std::vector<ShapeAndType> existing_inputs;
375 const std::vector<ShapeAndType>* inputs =
376 node_context->input_handle_shapes_and_types(dst_input);
377 if (inputs) {
378 existing_inputs = *inputs;
379 }
380 if (node_context->RelaxInputHandleShapesAndMergeTypes(dst_input,
381 *outputs)) {
382 if (IsUpdatedShapesOrTypes(
383 node_context, existing_inputs,
384 *node_context->input_handle_shapes_and_types(dst_input))) {
385 *refined = true;
386 }
387 }
388 }
389 }
390 }
391
392 if (!*refined) {
393 // No input shape has changed, we're done
394 return Status::OK();
395 }
396
397 // Get and run the shape function for this node to update the shapes of the
398 // outputs.
399 const OpRegistrationData* op_reg_data;
400 TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
401 if (op_reg_data->shape_inference_fn == nullptr &&
402 require_shape_inference_fns_) {
403 return errors::InvalidArgument(
404 "No shape inference function exists for op '", node->type_string(),
405 "', did you forget to define it?");
406 }
407
408 if (!op_reg_data->shape_inference_fn) {
409 // There is nothing more we can infer
410 return Status::OK();
411 }
412
413 return RunShapeFn(node, op_reg_data, node_ext_context);
414 }
415
EvaluateConstantTensorForEdge(const Node * node,int dst_idx,bool * evaluated,Tensor * result)416 Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
417 int dst_idx, bool* evaluated,
418 Tensor* result) {
419 *evaluated = false;
420 const Edge* input_edge;
421 TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
422 OutputTensor tensor(input_edge->src(), input_edge->src_output());
423 return EvaluateConstantTensor(tensor, *this, *ops_registry_,
424 graph_def_version_, evaluated, result,
425 &graph_runner_, &const_tensor_map_,
426 kMaxTensorSize, disable_constant_propagation_);
427 }
428
EvaluateConstantIntScalarEdge(const Node * node,int dst_idx,bool * evaluated,int64 * result)429 Status ShapeRefiner::EvaluateConstantIntScalarEdge(const Node* node,
430 int dst_idx, bool* evaluated,
431 int64* result) {
432 Tensor scalar;
433 TF_RETURN_IF_ERROR(
434 EvaluateConstantTensorForEdge(node, dst_idx, evaluated, &scalar));
435 if (*evaluated) {
436 DCHECK_EQ(scalar.NumElements(), 1)
437 << "EvaluateConstantIntScalarEdge called on non-scalar edge: "
438 << scalar.NumElements();
439 if (scalar.dtype() == DT_INT32) {
440 *result = scalar.scalar<int32>()();
441 } else {
442 DCHECK_EQ(scalar.dtype(), DT_INT64)
443 << "EvaluateConstantIntScalarEdge called on non-integer edge: "
444 << scalar.dtype();
445 *result = scalar.scalar<int64>()();
446 }
447 }
448 return Status::OK();
449 }
450
ConstantPartialShape(InferenceContext * target_context,const Node * node,int dst_idx,ShapeHandle * result)451 Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
452 const Node* node, int dst_idx,
453 ShapeHandle* result) {
454 const Edge* input_edge;
455 TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
456
457 InferenceContext* src_context = GetContext(input_edge->src());
458 if (src_context == nullptr) return errors::Internal("Missing src context");
459 ShapeHandle src_shape = src_context->output(input_edge->src_output());
460
461 if (src_context->Value(src_context->Rank(src_shape)) == 0) {
462 Tensor t;
463 bool evaluated = false;
464 TF_RETURN_IF_ERROR(
465 EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
466 if (!evaluated) {
467 return errors::InvalidArgument(
468 "Received a shape scalar with unknown static value. A static value "
469 "of '-1' is required to represent an unknown shape.");
470 }
471 if (t.dims() == 0) {
472 if (t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) {
473 *result = target_context->UnknownShape();
474 return Status::OK();
475 } else if (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1) {
476 *result = target_context->UnknownShape();
477 return Status::OK();
478 }
479 }
480 return errors::InvalidArgument(
481 "Received an invalid shape scalar with a static value that is not "
482 "'-1': ",
483 t.DebugString());
484 }
485
486 TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
487
488 const string& src_op = input_edge->src()->type_string();
489 if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) {
490 // Source tensor is a vector of length 0, so the shape it
491 // represents is as scalar.
492 *result = target_context->Scalar();
493 } else if (src_op == "Shape") {
494 *result = src_context->input(0);
495 } else if (src_op == "ShapeN") {
496 *result = src_context->input(input_edge->src_output());
497 } else if (src_op == "Pack") {
498 std::vector<DimensionHandle> dims;
499 // Pack is concatenating its input scalars to form the shape tensor vector.
500 for (int i = 0; i < src_context->num_inputs(); ++i) {
501 int64 size;
502 bool evaluated;
503 TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(input_edge->src(), i,
504 &evaluated, &size));
505 if (evaluated) {
506 dims.push_back(size < 0 ? target_context->UnknownDim()
507 : target_context->MakeDim(size));
508 } else {
509 dims.push_back(target_context->UnknownDim());
510 }
511 }
512 *result = target_context->MakeShape(dims);
513 } else if (src_op == "Concat" || src_op == "ConcatV2") {
514 *result = target_context->Scalar();
515 // For Concat, input 0 is concat dim; for V2 it is the last input.
516 const int concat_dim =
517 src_op == "Concat" ? 0 : src_context->num_inputs() - 1;
518 // Concat is concatenating its input shape vectors.
519 for (int i = 0; i < src_context->num_inputs(); ++i) {
520 // Concat dim is ignored (and will always be a scalar).
521 if (i == concat_dim) continue;
522 ShapeHandle sub_result;
523 TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
524 i, &sub_result));
525 if (!target_context->RankKnown(sub_result)) {
526 // Failed to evaluate. Treat the output as completely unknown.
527 // TODO(cwhipkey): we could rely on all inputs being the same rank, so
528 // figure that rank out and append the right number of unknown dims.
529 *result = target_context->UnknownShape();
530 return Status::OK();
531 }
532 TF_RETURN_IF_ERROR(
533 target_context->Concatenate(*result, sub_result, result));
534 }
535 } else if (src_op == "StridedSlice") {
536 TF_RETURN_IF_ERROR(
537 PartialStridedSliceShape(input_edge->src(), src_context, result));
538 } else {
539 Tensor t;
540 bool evaluated = false;
541 TF_RETURN_IF_ERROR(
542 EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
543 TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
544 evaluated ? &t : nullptr, src_shape, result));
545 }
546 return Status::OK();
547 }
548
PartialStridedSliceShape(Node * slice_node,InferenceContext * ctx,ShapeHandle * result)549 Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
550 InferenceContext* ctx,
551 ShapeHandle* result) {
552 // Only attempt to evaluate if begin/end/strides all are scalars.
553 for (int i = 1; i <= 3; ++i) {
554 ShapeHandle input_shape = ctx->input(i);
555 if (ctx->Value(ctx->Dim(input_shape, 0)) != 1) {
556 *result = ctx->UnknownShape();
557 return Status::OK();
558 }
559 }
560
561 int begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
562 TF_RETURN_IF_ERROR(
563 GetNodeAttr(slice_node->attrs(), "begin_mask", &begin_mask));
564 TF_RETURN_IF_ERROR(GetNodeAttr(slice_node->attrs(), "end_mask", &end_mask));
565 TF_RETURN_IF_ERROR(
566 GetNodeAttr(slice_node->attrs(), "ellipsis_mask", &ellipsis_mask));
567 TF_RETURN_IF_ERROR(
568 GetNodeAttr(slice_node->attrs(), "new_axis_mask", &new_axis_mask));
569 TF_RETURN_IF_ERROR(
570 GetNodeAttr(slice_node->attrs(), "shrink_axis_mask", &shrink_axis_mask));
571
572 // Only attempt to evaluate if there are no special masks set (note that we
573 // can handle begin/end_mask == 1).
574 if (!(begin_mask == 0 || begin_mask == 1) ||
575 !(end_mask == 0 || end_mask == 1) || ellipsis_mask != 0 ||
576 new_axis_mask != 0 || shrink_axis_mask != 0) {
577 *result = ctx->UnknownShape();
578 return Status::OK();
579 }
580
581 bool evaluated;
582 int64 begin;
583 if (begin_mask == 1) {
584 begin = 0;
585 } else {
586 TF_RETURN_IF_ERROR(
587 EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated, &begin));
588 if (!evaluated) {
589 *result = ctx->UnknownShape();
590 return Status::OK();
591 }
592 }
593
594 int64 end;
595 if (end_mask == 1) {
596 end = std::numeric_limits<int64>::max();
597 } else {
598 TF_RETURN_IF_ERROR(
599 EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated, &end));
600 if (!evaluated) {
601 *result = ctx->UnknownShape();
602 return Status::OK();
603 }
604 }
605
606 int64 stride;
607 TF_RETURN_IF_ERROR(
608 EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated, &stride));
609 if (!evaluated) {
610 *result = ctx->UnknownShape();
611 return Status::OK();
612 }
613
614 // Apply stride to input interpreted as a partial shape.
615 ShapeHandle input;
616 TF_RETURN_IF_ERROR(ConstantPartialShape(ctx, slice_node, 0, &input));
617 TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result));
618 return Status::OK();
619 }
620
RunShapeFn(const Node * node,const OpRegistrationData * op_reg_data,ExtendedInferenceContext * ec)621 Status ShapeRefiner::RunShapeFn(const Node* node,
622 const OpRegistrationData* op_reg_data,
623 ExtendedInferenceContext* ec) {
624 // This will be filled in with real data in a second pass.
625 std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
626 std::vector<Tensor> real_tensors(node->num_inputs());
627 std::vector<bool> attempted_materialization(node->num_inputs());
628 std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
629 std::vector<ShapeHandle> input_tensors_as_shapes;
630
631 auto* c = ec->get_context();
632
633 c->set_input_tensors(input_tensors);
634 c->set_input_tensors_as_shapes(input_tensors_as_shapes);
635
636 // Run the shape inference function, and return if there was an error.
637 // Capture as lambda, because we might need to re-run inference later on.
638 auto run_inference_lambda = [&]() {
639 if (function_library_ && op_reg_data->is_function_op) {
640 // Special inference logic for user-defined functions.
641
642 auto* func_def = function_library_->Find(op_reg_data->op_def.name());
643 if (func_def) {
644 return InferShapesForFunction(func_def, keep_nested_shape_inferences_,
645 ec);
646 }
647 }
648
649 if (op_reg_data->shape_inference_fn) {
650 TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
651 } else {
652 TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
653 }
654 return Status::OK();
655 };
656 TF_RETURN_IF_ERROR(run_inference_lambda());
657
658 // We must run the shape function repeatedly, in case users write
659 // shape functions where they only conditionally call input_tensor()
660 // based on the values of another input tensor.
661 bool rerun_shape_fn;
662 do {
663 // If the result of running shape inference would have benefitted
664 // from knowing the values of input tensors, try to materialize
665 // the results of those tensors, and then run the shape inference
666 // function again using those known tensors.
667 rerun_shape_fn = false;
668
669 // NOTE: It is possible to batch the extraction and
670 // materialization of inputs, instead of materializing one input
671 // at a time like we do below. If input-at-a-time computation
672 // becomes a bottleneck, we could separate ExtractConstantSubgraph
673 // into two functions: one that returns true if an input is
674 // derivable from constants, and another function that extracts
675 // the subgraph for multiple target nodes and executes the whole
676 // subgraph once.
677
678 for (int i = 0; i < c->num_inputs(); ++i) {
679 if (!c->requested_input_tensor(i)) {
680 continue;
681 }
682 // Check if we have not already filled in the requested input,
683 // and if not, try to materialize the tensors.
684 if (!attempted_materialization[i]) {
685 attempted_materialization[i] = true;
686
687 Tensor result;
688 bool evaluated = false;
689 TF_RETURN_IF_ERROR(
690 EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
691 if (evaluated) {
692 real_tensors[i] = result;
693 input_tensors[i] = &real_tensors[i];
694 // We have more concrete information about a shape,
695 // so re-run shape inference.
696 rerun_shape_fn = true;
697 }
698 }
699 if (c->requested_input_tensor_as_partial_shape(i) &&
700 !attempted_tensor_as_shape_conversion[i]) {
701 attempted_tensor_as_shape_conversion[i] = true;
702 if (i >= input_tensors_as_shapes.size()) {
703 input_tensors_as_shapes.resize(i + 1);
704 }
705 ShapeHandle s;
706 TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s));
707 input_tensors_as_shapes[i] = s;
708 rerun_shape_fn = true;
709 }
710 }
711
712 if (rerun_shape_fn) {
713 // We have more information about the shapes on this pass,
714 // so re-run shape inference.
715 c->set_input_tensors(input_tensors);
716 c->set_input_tensors_as_shapes(input_tensors_as_shapes);
717 TF_RETURN_IF_ERROR(run_inference_lambda());
718 }
719 } while (rerun_shape_fn);
720
721 return Status::OK();
722 }
723
SameDefinedShape(InferenceContext * c,ShapeHandle s0,ShapeHandle s1)724 bool ShapeRefiner::SameDefinedShape(InferenceContext* c, ShapeHandle s0,
725 ShapeHandle s1) {
726 if (s0.SameHandle(s1)) {
727 return true;
728 }
729 if (c->Rank(s0) != c->Rank(s1)) {
730 return false;
731 }
732 if (!c->RankKnown(s0) && !c->RankKnown(s1)) {
733 return false;
734 }
735 for (int i = 0; i < c->Rank(s0); ++i) {
736 if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
737 int64 val0 = c->Value(c->Dim(s0, i));
738 int64 val1 = c->Value(c->Dim(s1, i));
739 if (val0 < 0 || val1 < 0 || val0 != val1) {
740 return false;
741 }
742 }
743 }
744
745 return true;
746 }
747
IsUpdatedShapesOrTypes(InferenceContext * c,const std::vector<ShapeAndType> & existing,const std::vector<ShapeAndType> & updated)748 bool ShapeRefiner::IsUpdatedShapesOrTypes(
749 InferenceContext* c, const std::vector<ShapeAndType>& existing,
750 const std::vector<ShapeAndType>& updated) {
751 if (existing.size() != updated.size()) {
752 return true;
753 }
754 for (int i = 0; i < existing.size(); i++) {
755 if (!SameDefinedShape(c, existing[i].shape, updated[i].shape) ||
756 existing[i].dtype != updated[i].dtype) {
757 return true;
758 }
759 }
760 return false;
761 }
762
763 } // namespace tensorflow
764