1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16
17 #include "tensorflow/core/framework/bounds_check.h"
18 #include "tensorflow/core/framework/node_def.pb_text.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/strings/numbers.h"
23 #include "tensorflow/core/lib/strings/scanner.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25
26 namespace tensorflow {
27 namespace shape_inference {
28
29 constexpr int32 InferenceContext::kUnknownRank;
30 constexpr int64 InferenceContext::kUnknownDim;
31
InferenceContext(int graph_def_version,const NodeDef * node_def,const OpDef & op_def,const std::vector<TensorShapeProto> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<TensorShapeProto> & input_tensors_as_shapes,const std::vector<std::unique_ptr<std::vector<std::pair<TensorShapeProto,DataType>>>> & input_handle_shapes_and_types)32 InferenceContext::InferenceContext(
33 int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
34 const std::vector<TensorShapeProto>& input_shapes,
35 const std::vector<const Tensor*>& input_tensors,
36 const std::vector<TensorShapeProto>& input_tensors_as_shapes,
37 const std::vector<
38 std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
39 input_handle_shapes_and_types)
40 : graph_def_version_(graph_def_version),
41 node_def_(CHECK_NOTNULL(node_def)) {
42 std::vector<ShapeHandle> input_tensors_as_shape_handles;
43 input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
44 for (const TensorShapeProto& p : input_tensors_as_shapes) {
45 ShapeHandle shape;
46 construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
47 if (!construction_status_.ok()) {
48 return;
49 }
50 input_tensors_as_shape_handles.push_back(shape);
51 }
52 PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
53 if (!construction_status_.ok()) return;
54 inputs_.reserve(input_shapes.size());
55 for (const TensorShapeProto& p : input_shapes) {
56 ShapeHandle shape;
57 construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
58 if (!construction_status_.ok()) {
59 return;
60 }
61 inputs_.push_back(shape);
62 }
63
64 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
65 input_shapes.size());
66 for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) {
67 const auto& v = input_handle_shapes_and_types[i];
68 if (v == nullptr) {
69 continue;
70 }
71 handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
72 auto& new_v = *handle_data[i];
73 for (int j = 0; j < v->size(); ++j) {
74 const auto& p = (*v)[j];
75 construction_status_.Update(
76 MakeShapeFromShapeProto(p.first, &new_v[j].shape));
77 if (!construction_status_.ok()) {
78 return;
79 }
80 new_v[j].dtype = p.second;
81 }
82 }
83 PostInputInit(std::move(handle_data));
84 }
85
86 // Same as above, but with PartialTensorShape instead of TensorShapeProto
InferenceContext(int graph_def_version,const NodeDef * node_def,const OpDef & op_def,const std::vector<PartialTensorShape> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<PartialTensorShape> & input_tensors_as_shapes,const std::vector<std::unique_ptr<std::vector<std::pair<PartialTensorShape,DataType>>>> & input_handle_shapes_and_types)87 InferenceContext::InferenceContext(
88 int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
89 const std::vector<PartialTensorShape>& input_shapes,
90 const std::vector<const Tensor*>& input_tensors,
91 const std::vector<PartialTensorShape>& input_tensors_as_shapes,
92 const std::vector<
93 std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
94 input_handle_shapes_and_types)
95 : graph_def_version_(graph_def_version),
96 node_def_(CHECK_NOTNULL(node_def)) {
97 std::vector<ShapeHandle> input_tensors_as_shape_handles;
98 input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
99 for (const PartialTensorShape& p : input_tensors_as_shapes) {
100 ShapeHandle shape;
101 construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
102 if (!construction_status_.ok()) {
103 return;
104 }
105 input_tensors_as_shape_handles.push_back(shape);
106 }
107 PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
108 if (!construction_status_.ok()) return;
109 inputs_.reserve(input_shapes.size());
110 for (const PartialTensorShape& p : input_shapes) {
111 ShapeHandle shape;
112 construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
113 if (!construction_status_.ok()) {
114 return;
115 }
116 inputs_.push_back(shape);
117 }
118 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
119 input_shapes.size());
120 for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) {
121 const auto& v = input_handle_shapes_and_types[i];
122 if (v == nullptr) {
123 continue;
124 }
125 handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
126 auto& new_v = *handle_data[i];
127 for (int j = 0; j < v->size(); ++j) {
128 const auto& p = (*v)[j];
129 construction_status_.Update(
130 MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape));
131 if (!construction_status_.ok()) {
132 return;
133 }
134 new_v[j].dtype = p.second;
135 }
136 }
137 PostInputInit(std::move(handle_data));
138 }
139
InferenceContext(int graph_def_version,const NodeDef * node_def,const OpDef & op_def,const std::vector<ShapeHandle> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes,std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types)140 InferenceContext::InferenceContext(
141 int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
142 const std::vector<ShapeHandle>& input_shapes,
143 const std::vector<const Tensor*>& input_tensors,
144 const std::vector<ShapeHandle>& input_tensors_as_shapes,
145 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
146 input_handle_shapes_and_types)
147 : graph_def_version_(graph_def_version),
148 node_def_(CHECK_NOTNULL(node_def)) {
149 PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
150 if (!construction_status_.ok()) return;
151 inputs_ = input_shapes;
152
153 PostInputInit(std::move(input_handle_shapes_and_types));
154 }
155
~InferenceContext()156 InferenceContext::~InferenceContext() {}
157
Run(const std::function<Status (shape_inference::InferenceContext * c)> & fn)158 Status InferenceContext::Run(
159 const std::function<Status(shape_inference::InferenceContext* c)>& fn) {
160 ForgetMerges();
161 Status s = fn(this);
162 if (!s.ok()) {
163 ForgetMerges();
164 return AttachContext(s);
165 }
166 #ifndef NDEBUG
167 for (int i = 0; i < num_outputs(); ++i) {
168 DCHECK(output(i).IsSet())
169 << i << " for " << node_def_->name() << " of type " << node_def_->op();
170 }
171 #endif // NDEBUG
172 return s;
173 }
174
set_output(StringPiece output_name,const std::vector<ShapeHandle> & shapes)175 Status InferenceContext::set_output(StringPiece output_name,
176 const std::vector<ShapeHandle>& shapes) {
177 auto result = output_name_map_.find(output_name);
178 if (result == output_name_map_.end()) {
179 return errors::InvalidArgument("Unknown output name: ", output_name);
180 } else {
181 const int start = result->second.first;
182 const int size = result->second.second - start;
183 if (size != shapes.size()) {
184 return errors::InvalidArgument("Must have exactly ", shapes.size(),
185 " shapes.");
186 }
187 for (int i = 0; i < size; ++i) {
188 outputs_[i + start] = shapes[i];
189 }
190 }
191 return Status::OK();
192 }
193
input(StringPiece input_name,std::vector<ShapeHandle> * output) const194 Status InferenceContext::input(StringPiece input_name,
195 std::vector<ShapeHandle>* output) const {
196 const auto result = input_name_map_.find(input_name);
197 if (result == input_name_map_.end()) {
198 return errors::InvalidArgument("Unknown input name: ", input_name);
199 } else {
200 output->clear();
201 for (int i = result->second.first; i < result->second.second; ++i) {
202 output->push_back(inputs_[i]);
203 }
204 }
205 return Status::OK();
206 }
207
output(StringPiece output_name,std::vector<ShapeHandle> * output) const208 Status InferenceContext::output(StringPiece output_name,
209 std::vector<ShapeHandle>* output) const {
210 const auto result = output_name_map_.find(output_name);
211 if (result == output_name_map_.end()) {
212 return errors::InvalidArgument("Unknown output name: ", output_name);
213 } else {
214 output->clear();
215 for (int i = result->second.first; i < result->second.second; ++i) {
216 output->push_back(outputs_[i]);
217 }
218 }
219 return Status::OK();
220 }
221
op() const222 string InferenceContext::op() const { return node_def_->op(); }
223
PreInputInit(const OpDef & op_def,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes)224 void InferenceContext::PreInputInit(
225 const OpDef& op_def, const std::vector<const Tensor*>& input_tensors,
226 const std::vector<ShapeHandle>& input_tensors_as_shapes) {
227 input_tensors_ = input_tensors;
228 input_tensors_as_shapes_ = input_tensors_as_shapes;
229
230 construction_status_ = NameRangesForNode(*node_def_, op_def, &input_name_map_,
231 &output_name_map_);
232 if (!construction_status_.ok()) return;
233
234 int num_outputs = 0;
235 for (const auto& e : output_name_map_) {
236 num_outputs = std::max(num_outputs, e.second.second);
237 }
238 outputs_.assign(num_outputs, nullptr);
239 output_handle_shapes_and_types_.resize(num_outputs);
240 }
241
ExpandOutputs(int new_output_size)242 Status InferenceContext::ExpandOutputs(int new_output_size) {
243 if (new_output_size < outputs_.size()) {
244 return errors::InvalidArgument("Trying to reduce number of outputs of op.");
245 }
246 outputs_.resize(new_output_size, nullptr);
247 output_handle_shapes_and_types_.resize(new_output_size);
248 return Status::OK();
249 }
250
PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data)251 void InferenceContext::PostInputInit(
252 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
253 int num_inputs_from_node_def = 0;
254 for (const auto& e : input_name_map_) {
255 num_inputs_from_node_def =
256 std::max(num_inputs_from_node_def, e.second.second);
257 }
258
259 // Allow passing empty shapes/dtypes to avoid changing every single test.
260 if (input_handle_data.empty()) {
261 input_handle_shapes_and_types_.resize(inputs_.size());
262 } else {
263 if (input_handle_data.size() != inputs_.size()) {
264 construction_status_ = errors::InvalidArgument(
265 "Wrong number of handle shapes passed; expected ", inputs_.size(),
266 " got ", input_handle_data.size());
267 return;
268 }
269 input_handle_shapes_and_types_ = std::move(input_handle_data);
270 }
271
272 if (inputs_.size() != num_inputs_from_node_def) {
273 construction_status_ = errors::InvalidArgument(
274 "Wrong number of inputs passed: ", inputs_.size(), " while ",
275 num_inputs_from_node_def, " expected based on NodeDef");
276 return;
277 }
278
279 CHECK_LE(input_tensors_.size(), inputs_.size());
280 input_tensors_.resize(inputs_.size());
281 requested_input_tensor_.resize(inputs_.size());
282 requested_input_tensor_as_partial_shape_.resize(inputs_.size());
283 }
284
ShapeHandleToProto(ShapeHandle handle,TensorShapeProto * proto)285 void InferenceContext::ShapeHandleToProto(ShapeHandle handle,
286 TensorShapeProto* proto) {
287 if (!RankKnown(handle)) {
288 proto->set_unknown_rank(true);
289 return;
290 }
291
292 for (int32 i = 0; i < Rank(handle); ++i) {
293 DimensionHandle dim = Dim(handle, i);
294 auto* dim_shape = proto->add_dim();
295 if (ValueKnown(dim)) {
296 dim_shape->set_size(Value(dim));
297 } else {
298 dim_shape->set_size(-1);
299 }
300 }
301 }
302
FullyDefined(ShapeHandle s)303 bool InferenceContext::FullyDefined(ShapeHandle s) {
304 if (!RankKnown(s)) return false;
305 for (int i = 0; i < Rank(s); ++i) {
306 if (!ValueKnown(Dim(s, i))) return false;
307 }
308 return true;
309 }
310
NumElements(ShapeHandle s)311 DimensionHandle InferenceContext::NumElements(ShapeHandle s) {
312 const auto rank = Rank(s);
313 if (rank == kUnknownRank) return UnknownDim();
314 bool found_unknown = false;
315 int64 size = 1;
316 for (int i = 0; i < rank; ++i) {
317 int64 dim_val = Value(Dim(s, i));
318 if (dim_val == kUnknownDim) {
319 found_unknown = true;
320 } else if (dim_val == 0) {
321 return MakeDim(0);
322 } else {
323 size *= dim_val;
324 }
325 }
326 if (found_unknown) {
327 return UnknownDim();
328 } else {
329 return MakeDim(size);
330 }
331 }
332
DebugString(ShapeHandle s)333 string InferenceContext::DebugString(ShapeHandle s) {
334 if (RankKnown(s)) {
335 std::vector<string> vals;
336 for (auto d : s->dims_) vals.push_back(DebugString(d));
337 return strings::StrCat("[", str_util::Join(vals, ","), "]");
338 } else {
339 return "?";
340 }
341 }
342
DebugString(DimensionHandle d)343 string InferenceContext::DebugString(DimensionHandle d) {
344 return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
345 }
346
DebugString() const347 string InferenceContext::DebugString() const {
348 return strings::StrCat("InferenceContext for node: ",
349 ProtoDebugString(*node_def_));
350 }
351
DebugString(const ShapeAndType & shape_and_type)352 string InferenceContext::DebugString(const ShapeAndType& shape_and_type) {
353 return strings::StrCat(DebugString(shape_and_type.shape), ":",
354 DataTypeString(shape_and_type.dtype));
355 }
356
DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types)357 string InferenceContext::DebugString(
358 gtl::ArraySlice<ShapeAndType> shape_and_types) {
359 std::vector<string> pieces;
360 for (const ShapeAndType& s : shape_and_types) {
361 pieces.push_back(DebugString(s));
362 }
363 return strings::StrCat("[", str_util::Join(pieces, ","), "]");
364 }
365
WithRank(ShapeHandle shape,int64 rank,ShapeHandle * out)366 Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
367 ShapeHandle* out) {
368 if (rank > kint32max) {
369 return errors::InvalidArgument("Rank cannot exceed kint32max");
370 }
371 const int32 existing = Rank(shape);
372 if (existing == rank) {
373 *out = shape;
374 return Status::OK();
375 }
376 if (existing == kUnknownRank) {
377 std::vector<DimensionHandle> dims;
378 dims.reserve(rank);
379 for (int i = 0; i < rank; ++i) {
380 dims.push_back(UnknownDim());
381 }
382 ShapeHandle shp = shape_manager_.MakeShape(dims);
383 return Merge(shape, shp, out);
384 }
385 *out = nullptr;
386
387 return errors::InvalidArgument("Shape must be rank ", rank, " but is rank ",
388 existing);
389 }
390
WithRankAtLeast(ShapeHandle shape,int64 rank,ShapeHandle * out)391 Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank,
392 ShapeHandle* out) {
393 if (rank > kint32max) {
394 return errors::InvalidArgument("Rank cannot exceed kint32max");
395 }
396 const int32 existing = Rank(shape);
397 if (existing >= rank || existing == kUnknownRank) {
398 *out = shape;
399 return Status::OK();
400 }
401 *out = nullptr;
402 return errors::InvalidArgument("Shape must be at least rank ", rank,
403 " but is rank ", existing);
404 }
405
WithRankAtMost(ShapeHandle shape,int64 rank,ShapeHandle * out)406 Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank,
407 ShapeHandle* out) {
408 if (rank > kint32max) {
409 return errors::InvalidArgument("Rank cannot exceed kint32max");
410 }
411 const int32 existing = Rank(shape);
412 if (existing <= rank || existing == kUnknownRank) {
413 *out = shape;
414 return Status::OK();
415 }
416 *out = nullptr;
417 return errors::InvalidArgument("Shape must be at most rank ", rank,
418 " but is rank ", existing);
419 }
420
WithValue(DimensionHandle dim,int64 value,DimensionHandle * out)421 Status InferenceContext::WithValue(DimensionHandle dim, int64 value,
422 DimensionHandle* out) {
423 const int64 existing = Value(dim);
424 if (existing == value) {
425 *out = dim;
426 return Status::OK();
427 }
428 if (existing == kUnknownDim) {
429 DimensionHandle d = MakeDim(value);
430 return Merge(dim, d, out);
431 }
432 *out = nullptr;
433 return errors::InvalidArgument("Dimension must be ", value, " but is ",
434 existing);
435 }
436
Relax(DimensionHandle d_old,DimensionHandle d_new,DimensionHandle * out)437 void InferenceContext::Relax(DimensionHandle d_old, DimensionHandle d_new,
438 DimensionHandle* out) {
439 if (d_old.SameHandle(d_new)) {
440 *out = d_old;
441 } else if (!ValueKnown(d_old) && !ValueKnown(d_new)) {
442 // The node will be fed by the dimension d_new instead of d_old: any
443 // equality assertion between d_old and other input dimension on this node
444 // may not be true anymore, so forget them all.
445 ForgetMerges();
446 // Return the new shape handle to force the relaxation to propagate to the
447 // fanout of the context.
448 *out = d_new;
449 } else if (!ValueKnown(d_new)) {
450 ForgetMerges();
451 *out = d_new;
452 } else if (Value(d_old) == Value(d_new)) {
453 // Return the old shape handle. This will stop the relaxation in the fanout
454 // of the context.
455 *out = d_old;
456 } else {
457 // Return a new handle that encodes a different unknown dim.
458 ForgetMerges();
459 *out = UnknownDim();
460 }
461 }
462
Merge(DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)463 Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1,
464 DimensionHandle* out) {
465 if (d0.SameHandle(d1)) {
466 *out = d0;
467 return Status::OK();
468 } else if (!ValueKnown(d1)) {
469 *out = d0;
470 merged_dims_.emplace_back(d0, d1);
471 return Status::OK();
472 } else if (!ValueKnown(d0)) {
473 *out = d1;
474 merged_dims_.emplace_back(d0, d1);
475 return Status::OK();
476 } else if (Value(d0) == Value(d1)) {
477 *out = d0;
478 return Status::OK();
479 } else {
480 *out = nullptr;
481 return errors::InvalidArgument("Dimensions must be equal, but are ",
482 Value(d0), " and ", Value(d1));
483 }
484 }
485
MergePrefix(ShapeHandle s,ShapeHandle prefix,ShapeHandle * s_out,ShapeHandle * prefix_out)486 Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix,
487 ShapeHandle* s_out,
488 ShapeHandle* prefix_out) {
489 *s_out = *prefix_out = nullptr;
490 if (!RankKnown(prefix) || !RankKnown(s)) {
491 *s_out = s;
492 *prefix_out = prefix;
493 return Status::OK();
494 }
495 const int32 rank = Rank(prefix);
496 TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s));
497
498 // Merge the prefix dims and create the new output shapes.
499 const int32 rank_s = Rank(s);
500 std::vector<DimensionHandle> dims;
501 dims.reserve(std::max(rank, rank_s));
502 dims.resize(rank);
503 for (int i = 0; i < rank; ++i) {
504 TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i]));
505 }
506 *prefix_out = MakeShape(dims);
507 for (int i = rank; i < rank_s; ++i) dims.push_back(Dim(s, i));
508 *s_out = MakeShape(dims);
509 return Status::OK();
510 }
511
Relax(ShapeHandle s_old,ShapeHandle s_new,ShapeHandle * out)512 void InferenceContext::Relax(ShapeHandle s_old, ShapeHandle s_new,
513 ShapeHandle* out) {
514 if (s_old.SameHandle(s_new)) {
515 *out = s_old;
516 return;
517 } else if (!RankKnown(s_new) || !s_old.IsSet()) {
518 ForgetMerges();
519 *out = s_new;
520 return;
521 }
522
523 const int32 rank = Rank(s_old);
524 if (rank != Rank(s_new)) {
525 ForgetMerges();
526 *out = UnknownShape();
527 return;
528 }
529
530 bool return_s_old = true;
531 for (int i = 0; i < rank; ++i) {
532 auto d0 = Dim(s_old, i);
533 auto d1 = Dim(s_new, i);
534 if (d0.SameHandle(d1)) continue;
535
536 auto v0 = Value(d0);
537 auto v1 = Value(d1);
538 if (v0 == kUnknownDim || v1 == kUnknownDim || v0 != v1) {
539 return_s_old = false;
540 break;
541 }
542 }
543 if (return_s_old) {
544 *out = s_old;
545 return;
546 }
547
548 // Relax dims.
549 std::vector<DimensionHandle> dims(rank);
550 for (int i = 0; i < rank; ++i) {
551 Relax(Dim(s_old, i), Dim(s_new, i), &dims[i]);
552 }
553 ForgetMerges();
554 *out = MakeShape(dims);
555 }
556
Merge(ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)557 Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
558 ShapeHandle* out) {
559 if (s0.SameHandle(s1)) {
560 *out = s0;
561 return Status::OK();
562 } else if (!RankKnown(s1)) {
563 *out = s0;
564 merged_shapes_.emplace_back(s0, s1);
565 return Status::OK();
566 } else if (!RankKnown(s0)) {
567 *out = s1;
568 merged_shapes_.emplace_back(s0, s1);
569 return Status::OK();
570 }
571
572 const int32 rank = Rank(s0);
573 if (rank != Rank(s1)) {
574 *out = nullptr;
575 return errors::InvalidArgument("Shapes must be equal rank, but are ", rank,
576 " and ", Rank(s1));
577 }
578
579 bool return_s0 = true;
580 bool return_s1 = true;
581 for (int i = 0; i < rank; ++i) {
582 auto d0 = Dim(s0, i);
583 auto d1 = Dim(s1, i);
584 if (d0.SameHandle(d1)) continue;
585
586 auto v0 = Value(d0);
587 auto v1 = Value(d1);
588 if (v0 == kUnknownDim) {
589 if (v1 != kUnknownDim) {
590 return_s0 = false;
591 }
592 } else if (v1 == kUnknownDim) {
593 return_s1 = false;
594 } else if (v0 != v1) {
595 *out = nullptr;
596 return errors::InvalidArgument(
597 "Dimension ", i, " in both shapes must be equal, but are ", Value(d0),
598 " and ", Value(d1), ". Shapes are ", DebugString(s0), " and ",
599 DebugString(s1), ".");
600 }
601 }
602
603 merged_shapes_.emplace_back(s0, s1);
604
605 if (return_s0 || return_s1) {
606 *out = return_s0 ? s0 : s1;
607 return Status::OK();
608 }
609
610 // Merge dims.
611 std::vector<DimensionHandle> dims(rank, nullptr);
612 for (int i = 0; i < rank; ++i) {
613 // Invariant for merge was checked earlier, so CHECK is ok.
614 TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i]));
615 }
616
617 Status s = ReturnCreatedShape(dims, out);
618 if (s.ok()) {
619 // Merge the new shape with s0. Since s0 and s1 are merged, this implies
620 // that s1 and out are also merged.
621 merged_shapes_.emplace_back(s0, *out);
622 }
623 return s;
624 }
625
Subshape(ShapeHandle s,int64 start,ShapeHandle * out)626 Status InferenceContext::Subshape(ShapeHandle s, int64 start,
627 ShapeHandle* out) {
628 return Subshape(s, start, std::numeric_limits<int64>::max() /* end */, out);
629 }
630
Subshape(ShapeHandle s,int64 start,int64 end,ShapeHandle * out)631 Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
632 ShapeHandle* out) {
633 return Subshape(s, start, end, 1 /* stride */, out);
634 }
635
Subshape(ShapeHandle s,int64 start,int64 end,int64 stride,ShapeHandle * out)636 Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
637 int64 stride, ShapeHandle* out) {
638 int64 start_in = start;
639 int64 end_in = end;
640
641 const int32 rank = Rank(s);
642 if (start == 0 && stride == 1 &&
643 ((RankKnown(s) && end >= rank) ||
644 end == std::numeric_limits<int64>::max())) {
645 *out = s;
646 return Status::OK();
647 }
648 if (!RankKnown(s)) {
649 return ReturnUnknownShape(out);
650 }
651
652 if (start > rank) start = rank;
653 if (end > rank) end = rank;
654
655 if (stride < 0 && start == rank) --start;
656
657 if (start < 0) {
658 start = rank + start;
659 if (start < 0) {
660 *out = nullptr;
661 return errors::InvalidArgument("Subshape start out of bounds: ", start_in,
662 ", for shape with rank ", rank);
663 }
664 }
665
666 if (end < 0) {
667 end = rank + end;
668 if (end < 0) {
669 *out = nullptr;
670 return errors::InvalidArgument("Subshape end out of bounds: ", end_in,
671 ", for shape with rank ", rank);
672 }
673 }
674 if (stride > 0 && start > end) {
675 *out = nullptr;
676 return errors::InvalidArgument(
677 "Subshape must have computed start <= end, but is ", start, " and ",
678 end, " (computed from start ", start_in, " and end ", end_in,
679 " over shape with rank ", rank, ")");
680 } else if (stride < 0 && start < end) {
681 *out = nullptr;
682 return errors::InvalidArgument(
683 "Subshape must have computed start >= end since stride is negative, "
684 "but is ",
685 start, " and ", end, " (computed from start ", start_in, " and end ",
686 end_in, " over shape with rank ", rank, " and stride", stride, ")");
687 }
688
689 std::vector<DimensionHandle> dims;
690 for (int i = start; stride > 0 ? i < end : i > end; i += stride) {
691 dims.push_back(Dim(s, i));
692 }
693 return ReturnCreatedShape(dims, out);
694 }
695
Concatenate(ShapeHandle s1,ShapeHandle s2,ShapeHandle * out)696 Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2,
697 ShapeHandle* out) {
698 if (!RankKnown(s1) || !RankKnown(s2)) {
699 return ReturnUnknownShape(out);
700 }
701 const int32 s1_rank = Rank(s1);
702 const int32 s2_rank = Rank(s2);
703 const int32 rank = s1_rank + s2_rank;
704 std::vector<DimensionHandle> dims;
705 dims.reserve(rank);
706 for (int i = 0; i < s1_rank; ++i) dims.push_back(Dim(s1, i));
707 for (int i = 0; i < s2_rank; ++i) dims.push_back(Dim(s2, i));
708 return ReturnCreatedShape(dims, out);
709 }
710
ReplaceDim(ShapeHandle s,int64 dim_index_in,DimensionHandle new_dim,ShapeHandle * out)711 Status InferenceContext::ReplaceDim(ShapeHandle s, int64 dim_index_in,
712 DimensionHandle new_dim, ShapeHandle* out) {
713 if (!RankKnown(s)) {
714 return ReturnUnknownShape(out);
715 }
716 int64 dim_index = dim_index_in;
717 if (dim_index < 0) {
718 dim_index = s->dims_.size() + dim_index;
719 }
720 if (!FastBoundsCheck(dim_index, s->dims_.size())) {
721 *out = nullptr;
722 return errors::InvalidArgument("Out of range dim_index ", dim_index_in,
723 " for shape with ", s->dims_.size(),
724 " dimensions");
725 }
726 std::vector<DimensionHandle> dims(s->dims_);
727 dims[dim_index] = new_dim;
728 return ReturnCreatedShape(dims, out);
729 }
730
MakeShape(const std::vector<DimensionHandle> & dims)731 ShapeHandle InferenceContext::MakeShape(
732 const std::vector<DimensionHandle>& dims) {
733 return shape_manager_.MakeShape(dims);
734 }
735
MakeShape(std::initializer_list<DimensionOrConstant> dims)736 ShapeHandle InferenceContext::MakeShape(
737 std::initializer_list<DimensionOrConstant> dims) {
738 std::vector<DimensionHandle> dims_actual;
739 dims_actual.reserve(dims.size());
740 for (const DimensionOrConstant& d : dims) {
741 dims_actual.push_back(MakeDim(d));
742 }
743
744 return shape_manager_.MakeShape(dims_actual);
745 }
746
UnknownShape()747 ShapeHandle InferenceContext::UnknownShape() {
748 return shape_manager_.UnknownShape();
749 }
750
UnknownShapeOfRank(int64 rank)751 ShapeHandle InferenceContext::UnknownShapeOfRank(int64 rank) {
752 CHECK_LE(rank, kint32max) << "rank must be less than kint32max";
753 if (rank == kUnknownRank) {
754 return UnknownShape();
755 }
756 CHECK_GE(rank, 0) << "rank must not be negative";
757 std::vector<DimensionHandle> dims(rank);
758 for (int32 i = 0; i < rank; ++i) {
759 dims[i] = UnknownDim();
760 }
761 return MakeShape(dims);
762 }
763
Scalar()764 ShapeHandle InferenceContext::Scalar() { return MakeShape({}); }
765
Vector(DimensionOrConstant dim)766 ShapeHandle InferenceContext::Vector(DimensionOrConstant dim) {
767 return MakeShape({dim});
768 }
769
Matrix(DimensionOrConstant dim1,DimensionOrConstant dim2)770 ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
771 DimensionOrConstant dim2) {
772 return MakeShape({dim1, dim2});
773 }
774
MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,ShapeHandle * out)775 Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
776 int input_idx, ShapeHandle* out) {
777 ShapeHandle input_shape;
778 TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
779
780 requested_input_tensor_as_partial_shape_[input_idx] = true;
781 if (input_idx < input_tensors_as_shapes_.size() &&
782 input_tensors_as_shapes_[input_idx].IsSet() &&
783 RankKnown(input_tensors_as_shapes_[input_idx])) {
784 *out = input_tensors_as_shapes_[input_idx];
785 return Status::OK();
786 }
787
788 return InternalMakeShapeFromTensor(
789 true /* treat_unknown_scalar_tensor_as_unknown_shape */,
790 input_tensor(input_idx), input_shape, out);
791 }
792
MakeShapeFromShapeTensor(int input_idx,ShapeHandle * out)793 Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
794 ShapeHandle* out) {
795 ShapeHandle input_shape;
796 TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
797
798 requested_input_tensor_as_partial_shape_[input_idx] = true;
799 if (input_idx < input_tensors_as_shapes_.size() &&
800 input_tensors_as_shapes_[input_idx].IsSet() &&
801 RankKnown(input_tensors_as_shapes_[input_idx])) {
802 *out = input_tensors_as_shapes_[input_idx];
803 return Status::OK();
804 }
805
806 return InternalMakeShapeFromTensor(
807 false /* treat_unknown_scalar_tensor_as_unknown_shape */,
808 input_tensor(input_idx), input_shape, out);
809 }
810
MakeShapeFromTensor(const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)811 Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
812 ShapeHandle tensor_shape,
813 ShapeHandle* out) {
814 return InternalMakeShapeFromTensor(
815 false /* treat_unknown_scalar_tensor_as_unknown_shape */, t, tensor_shape,
816 out);
817 }
818
InternalMakeShapeFromTensor(bool treat_unknown_scalar_tensor_as_unknown_shape,const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)819 Status InferenceContext::InternalMakeShapeFromTensor(
820 bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
821 ShapeHandle tensor_shape, ShapeHandle* out) {
822 // Only callers who have set
823 if (!treat_unknown_scalar_tensor_as_unknown_shape) {
824 TF_RETURN_IF_ERROR(WithRank(tensor_shape, 1, &tensor_shape));
825 }
826 if (t == nullptr) {
827 // This is guarded by the check above.
828 if (Rank(tensor_shape) == 0) {
829 return ReturnUnknownShape(out);
830 }
831 // Shape tensor is not known, but if the shape of the shape tensor is then
832 // the right number of unknown dims can be created.
833 DimensionHandle shape_dim = Dim(tensor_shape, 0);
834 if (!ValueKnown(shape_dim)) {
835 return ReturnUnknownShape(out);
836 }
837 const auto num_dims = Value(shape_dim);
838 std::vector<DimensionHandle> dims;
839 dims.reserve(num_dims);
840 for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim());
841 return ReturnCreatedShape(dims, out);
842 }
843
844 if (t->shape().dims() == 0) {
845 if (t->dtype() == DataType::DT_INT32) {
846 auto flat_t = t->scalar<int32>();
847 if (flat_t() != -1) {
848 *out = nullptr;
849 return errors::InvalidArgument(
850 "Input tensor must be rank 1, or if its rank 0 it must have value "
851 "-1 "
852 "(representing an unknown shape). Saw value: ",
853 flat_t());
854 }
855 return ReturnUnknownShape(out);
856 } else if (t->dtype() == DataType::DT_INT64) {
857 auto flat_t = t->scalar<int64>();
858 if (flat_t() != -1) {
859 *out = nullptr;
860 return errors::InvalidArgument(
861 "Input tensor must be rank 1, or if its rank 0 it must have value "
862 "-1 "
863 "(representing an unknown shape). Saw value: ",
864 flat_t());
865 }
866 return ReturnUnknownShape(out);
867 } else {
868 *out = nullptr;
869 return errors::InvalidArgument(
870 "Input tensor must be int32 or int64, but was ",
871 DataTypeString(t->dtype()));
872 }
873 }
874
875 if (t->shape().dims() != 1) {
876 *out = nullptr;
877 return errors::InvalidArgument(
878 "Input tensor must be rank 1, but was rank ", t->shape().dims(), ".",
879 ((t->shape().dims() == 0)
880 ? "If it is rank 0 rank 0 it must have statically known value -1 "
881 "(representing an unknown shape). "
882 : " "),
883 "Saw tensor shape ", t->shape().DebugString());
884 }
885 std::vector<DimensionHandle> dims;
886 if (t->dtype() == DataType::DT_INT32) {
887 auto flat_t = t->flat<int32>();
888 for (int i = 0; i < flat_t.size(); ++i) {
889 const int32 val = flat_t(i);
890 if (val < -1) {
891 return errors::InvalidArgument(
892 "Invalid value in tensor used for shape: ", val);
893 }
894 // -1 will become an unknown dim.
895 dims.push_back(MakeDim(val));
896 }
897 } else if (t->dtype() == DataType::DT_INT64) {
898 auto flat_t = t->flat<int64>();
899 for (int i = 0; i < flat_t.size(); ++i) {
900 const int64 val = flat_t(i);
901 if (val < -1) {
902 return errors::InvalidArgument(
903 "Invalid value in tensor used for shape: ", val);
904 }
905 // -1 will become an unknown dim.
906 dims.push_back(MakeDim(val));
907 }
908 } else {
909 *out = nullptr;
910 return errors::InvalidArgument(
911 "Input tensor must be int32 or int64, but was ",
912 DataTypeString(t->dtype()));
913 }
914
915 return ReturnCreatedShape(dims, out);
916 }
917
MakeShapeFromPartialTensorShape(const PartialTensorShape & partial_shape,ShapeHandle * out)918 Status InferenceContext::MakeShapeFromPartialTensorShape(
919 const PartialTensorShape& partial_shape, ShapeHandle* out) {
920 *out = nullptr;
921 if (partial_shape.dims() == -1) {
922 return ReturnUnknownShape(out);
923 }
924 const int num_dims = partial_shape.dims();
925 std::vector<DimensionHandle> dims(num_dims);
926 for (int i = 0; i < num_dims; ++i) {
927 // -1 is unknown in PartialTensorShape and in InferenceContext, so this size
928 // can be passed directly to MakeDim.
929 dims[i] = MakeDim(partial_shape.dim_size(i));
930 }
931 return ReturnCreatedShape(dims, out);
932 }
933
MakeShapeFromTensorShape(const TensorShape & shape,ShapeHandle * out)934 Status InferenceContext::MakeShapeFromTensorShape(const TensorShape& shape,
935 ShapeHandle* out) {
936 return MakeShapeFromPartialTensorShape(PartialTensorShape(shape.dim_sizes()),
937 out);
938 }
939
MakeShapeFromShapeProto(const TensorShapeProto & proto,ShapeHandle * out)940 Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
941 ShapeHandle* out) {
942 *out = nullptr;
943 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto));
944 PartialTensorShape partial_shape(proto);
945 return MakeShapeFromPartialTensorShape(partial_shape, out);
946 }
947
GetScalarFromTensor(const Tensor * t,int64 * val)948 Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
949 // Caller must ensure that <t> is not NULL.
950 const int rank = t->dims();
951 if (rank != 0) {
952 return errors::InvalidArgument("Input must be scalar but has rank ", rank);
953 }
954
955 if (t->dtype() == DT_INT32) {
956 *val = t->scalar<int32>()();
957 return Status::OK();
958 } else if (t->dtype() == DT_INT64) {
959 *val = t->scalar<int64>()();
960 return Status::OK();
961 } else {
962 return errors::InvalidArgument("Scalar input must be int32 or int64.");
963 }
964 }
965
966 // Returns a new dimension whose value is given by a scalar input tensor.
MakeDimForScalarInput(int idx,DimensionHandle * out)967 Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
968 int64 val;
969 const Tensor* t = input_tensor(idx);
970 if (t == nullptr) {
971 *out = UnknownDim();
972 return Status::OK();
973 }
974 TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
975 if (val < 0) {
976 return errors::InvalidArgument("Dimension size, given by scalar input ",
977 idx, ", must be non-negative but is ", val);
978 }
979 *out = MakeDim(val);
980 return Status::OK();
981 }
982
MakeDimForScalarInputWithNegativeIndexing(int idx,int input_rank,DimensionHandle * out)983 Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing(
984 int idx, int input_rank, DimensionHandle* out) {
985 int64 val;
986 const Tensor* t = input_tensor(idx);
987 if (t == nullptr) {
988 *out = UnknownDim();
989 return Status::OK();
990 }
991 TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
992 if (val < 0) {
993 if (input_rank < 0) {
994 *out = UnknownDim();
995 return Status::OK();
996 } else if (val + input_rank < 0) {
997 return errors::InvalidArgument("Dimension size, given by scalar input ",
998 val, " must be in range [-", input_rank,
999 ", ", input_rank, ")");
1000 } else {
1001 val += input_rank;
1002 }
1003 } else if (input_rank >= 0 && val >= input_rank) {
1004 return errors::InvalidArgument("Dimension size, given by scalar input ",
1005 val, " must be in range [-", input_rank,
1006 ", ", input_rank, ")");
1007 }
1008 *out = MakeDim(val);
1009 return Status::OK();
1010 }
1011
Divide(DimensionHandle dividend,DimensionOrConstant divisor,bool evenly_divisible,DimensionHandle * out)1012 Status InferenceContext::Divide(DimensionHandle dividend,
1013 DimensionOrConstant divisor,
1014 bool evenly_divisible, DimensionHandle* out) {
1015 const int64 divisor_value = Value(divisor);
1016 if (divisor_value == 1) {
1017 *out = dividend;
1018 } else if (!ValueKnown(dividend) ||
1019 (divisor.dim.IsSet() && !ValueKnown(divisor.dim))) {
1020 *out = UnknownDim();
1021 } else {
1022 const int64 v = Value(dividend);
1023 if (divisor_value <= 0) {
1024 return errors::InvalidArgument("Divisor must be positive but is ",
1025 divisor_value);
1026 }
1027 if (evenly_divisible && (v % divisor_value) != 0) {
1028 return errors::InvalidArgument(
1029 "Dimension size must be evenly divisible by ", divisor_value,
1030 " but is ", v);
1031 }
1032 *out = MakeDim(v / divisor_value);
1033 }
1034 return Status::OK();
1035 }
1036
Add(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1037 Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second,
1038 DimensionHandle* out) {
1039 const int64 first_value = Value(first);
1040 const int64 second_value = Value(second);
1041 // Special cases.
1042 if (first_value == 0) {
1043 *out = MakeDim(second);
1044 } else if (second_value == 0) {
1045 *out = first;
1046 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1047 *out = UnknownDim();
1048 } else {
1049 // Invariant: Both values are known and positive. Still in run-time we can
1050 // get pair of values which cannot be store in output. Check below will
1051 // report error. We still need to avoid undefined behavior of signed
1052 // overflow and use unsigned addition.
1053 const int64 sum = static_cast<uint64>(first_value) + second_value;
1054 if (sum < 0) {
1055 return errors::InvalidArgument("Dimension size overflow from adding ",
1056 first_value, " and ", second_value);
1057 }
1058 *out = MakeDim(sum);
1059 }
1060 return Status::OK();
1061 }
1062
Subtract(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1063 Status InferenceContext::Subtract(DimensionHandle first,
1064 DimensionOrConstant second,
1065 DimensionHandle* out) {
1066 const int64 first_value = Value(first);
1067 const int64 second_value = Value(second);
1068 // Special cases.
1069 if (second_value == 0) {
1070 *out = first;
1071 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1072 *out = UnknownDim();
1073 } else {
1074 // Invariant: Both values are known, first_value is non-negative, and
1075 // second_value is positive.
1076 if (first_value < second_value) {
1077 return errors::InvalidArgument(
1078 "Negative dimension size caused by subtracting ", second_value,
1079 " from ", first_value);
1080 }
1081 *out = MakeDim(first_value - second_value);
1082 }
1083 return Status::OK();
1084 }
1085
Multiply(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1086 Status InferenceContext::Multiply(DimensionHandle first,
1087 DimensionOrConstant second,
1088 DimensionHandle* out) {
1089 const int64 first_value = Value(first);
1090 const int64 second_value = Value(second);
1091 // Special cases.
1092 if (first_value == 0) {
1093 *out = first;
1094 } else if (second_value == 0) {
1095 *out = MakeDim(second);
1096 } else if (first_value == 1) {
1097 *out = MakeDim(second);
1098 } else if (second_value == 1) {
1099 *out = first;
1100 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1101 *out = UnknownDim();
1102 } else {
1103 // Invariant: Both values are known and greater than 1.
1104 const int64 product = first_value * second_value;
1105 if (product < 0) {
1106 return errors::InvalidArgument(
1107 "Negative dimension size caused by overflow when multiplying ",
1108 first_value, " and ", second_value);
1109 }
1110 *out = MakeDim(product);
1111 }
1112 return Status::OK();
1113 }
1114
Min(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1115 Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second,
1116 DimensionHandle* out) {
1117 const int64 first_value = Value(first);
1118 const int64 second_value = Value(second);
1119 if (first_value == 0) {
1120 *out = first;
1121 } else if (second_value == 0) {
1122 *out = MakeDim(second);
1123 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1124 *out = UnknownDim();
1125 } else {
1126 if (first_value <= second_value) {
1127 *out = first;
1128 } else {
1129 *out = MakeDim(second);
1130 }
1131 }
1132 return Status::OK();
1133 }
1134
Max(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1135 Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second,
1136 DimensionHandle* out) {
1137 const int64 first_value = Value(first);
1138 const int64 second_value = Value(second);
1139 if (first_value == kUnknownDim || second_value == kUnknownDim) {
1140 *out = UnknownDim();
1141 } else {
1142 if (first_value >= second_value) {
1143 *out = first;
1144 } else {
1145 *out = MakeDim(second);
1146 }
1147 }
1148 return Status::OK();
1149 }
1150
AttachContext(const Status & status)1151 Status InferenceContext::AttachContext(const Status& status) {
1152 std::vector<string> input_shapes;
1153 input_shapes.reserve(inputs_.size());
1154 for (const ShapeHandle& input_shape : inputs_) {
1155 input_shapes.emplace_back(DebugString(input_shape));
1156 }
1157
1158 // Add information about the input tensors and partial tensor shapes used.
1159 std::vector<string> input_from_tensors_str;
1160 std::vector<string> input_from_tensors_as_shape_str;
1161 input_from_tensors_as_shape_str.reserve(inputs_.size());
1162 for (int i = 0; i < inputs_.size(); ++i) {
1163 if (requested_input_tensor_as_partial_shape_[i] &&
1164 i < input_tensors_as_shapes_.size() &&
1165 input_tensors_as_shapes_[i].IsSet() &&
1166 RankKnown(input_tensors_as_shapes_[i])) {
1167 input_from_tensors_as_shape_str.push_back(strings::StrCat(
1168 "input[", i, "] = ", DebugString(input_tensors_as_shapes_[i])));
1169 } else if (requested_input_tensor_[i] && i < input_tensors_.size() &&
1170 input_tensors_[i] != nullptr) {
1171 input_from_tensors_str.push_back(strings::StrCat(
1172 "input[", i, "] = <",
1173 input_tensors_[i]->SummarizeValue(256 /* max_values */), ">"));
1174 }
1175 }
1176
1177 string error_context = strings::StrCat(
1178 " for '", node_def_->name(), "' (op: '", node_def_->op(),
1179 "') with input shapes: ", str_util::Join(input_shapes, ", "));
1180 if (!input_from_tensors_str.empty()) {
1181 strings::StrAppend(&error_context, " and with computed input tensors: ",
1182 str_util::Join(input_from_tensors_str, ", "));
1183 }
1184 if (!input_from_tensors_as_shape_str.empty()) {
1185 strings::StrAppend(&error_context,
1186 " and with input tensors computed as partial shapes: ",
1187 str_util::Join(input_from_tensors_as_shape_str, ","));
1188 }
1189
1190 strings::StrAppend(&error_context, ".");
1191 return Status(status.code(),
1192 strings::StrCat(status.error_message(), error_context));
1193 }
1194
MergeHandleShapesAndTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1195 bool InferenceContext::MergeHandleShapesAndTypes(
1196 const std::vector<ShapeAndType>& shapes_and_types,
1197 std::vector<ShapeAndType>* to_update) {
1198 if (shapes_and_types.size() != to_update->size()) {
1199 return false;
1200 }
1201 std::vector<ShapeAndType> new_values(shapes_and_types.size());
1202 bool refined = false;
1203 for (int i = 0; i < shapes_and_types.size(); ++i) {
1204 const ShapeAndType& existing = (*to_update)[i];
1205 if (shapes_and_types[i].dtype == existing.dtype) {
1206 new_values[i].dtype = existing.dtype;
1207 } else {
1208 if (existing.dtype != DT_INVALID) {
1209 return false;
1210 } else {
1211 new_values[i].dtype = shapes_and_types[i].dtype;
1212 refined = true;
1213 }
1214 }
1215 if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape)
1216 .ok()) {
1217 // merge failed, ignore the new value.
1218 new_values[i].shape = existing.shape;
1219 }
1220 if (!existing.shape.SameHandle(new_values[i].shape)) {
1221 refined = true;
1222 }
1223 }
1224 if (!refined) {
1225 return false;
1226 }
1227 for (int i = 0; i < new_values.size(); ++i) {
1228 (*to_update)[i] = new_values[i];
1229 }
1230 return true;
1231 }
1232
MergeOutputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1233 bool InferenceContext::MergeOutputHandleShapesAndTypes(
1234 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1235 if (output_handle_shapes_and_types_[idx] == nullptr) {
1236 output_handle_shapes_and_types_[idx].reset(
1237 new std::vector<ShapeAndType>(shapes_and_types));
1238 return true;
1239 }
1240 return MergeHandleShapesAndTypes(shapes_and_types,
1241 output_handle_shapes_and_types_[idx].get());
1242 }
1243
MergeInputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1244 bool InferenceContext::MergeInputHandleShapesAndTypes(
1245 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1246 if (input_handle_shapes_and_types_[idx] == nullptr) {
1247 input_handle_shapes_and_types_[idx].reset(
1248 new std::vector<ShapeAndType>(shapes_and_types));
1249 return true;
1250 }
1251 return MergeHandleShapesAndTypes(shapes_and_types,
1252 input_handle_shapes_and_types_[idx].get());
1253 }
1254
RelaxHandleShapesAndMergeTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1255 bool InferenceContext::RelaxHandleShapesAndMergeTypes(
1256 const std::vector<ShapeAndType>& shapes_and_types,
1257 std::vector<ShapeAndType>* to_update) {
1258 if (shapes_and_types.size() != to_update->size()) {
1259 return false;
1260 }
1261 std::vector<ShapeAndType> new_values(shapes_and_types.size());
1262 for (int i = 0; i < shapes_and_types.size(); ++i) {
1263 const ShapeAndType& existing = (*to_update)[i];
1264 if (shapes_and_types[i].dtype == existing.dtype) {
1265 new_values[i].dtype = existing.dtype;
1266 } else {
1267 if (existing.dtype != DT_INVALID) {
1268 return false;
1269 } else {
1270 new_values[i].dtype = shapes_and_types[i].dtype;
1271 }
1272 }
1273 Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape);
1274 }
1275 to_update->swap(new_values);
1276 return true;
1277 }
1278
RelaxOutputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1279 bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes(
1280 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1281 if (output_handle_shapes_and_types_[idx] == nullptr) {
1282 output_handle_shapes_and_types_[idx].reset(
1283 new std::vector<ShapeAndType>(shapes_and_types));
1284 return true;
1285 }
1286 return RelaxHandleShapesAndMergeTypes(
1287 shapes_and_types, output_handle_shapes_and_types_[idx].get());
1288 }
1289
RelaxInputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1290 bool InferenceContext::RelaxInputHandleShapesAndMergeTypes(
1291 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1292 if (input_handle_shapes_and_types_[idx] == nullptr) {
1293 input_handle_shapes_and_types_[idx].reset(
1294 new std::vector<ShapeAndType>(shapes_and_types));
1295 return true;
1296 }
1297 return RelaxHandleShapesAndMergeTypes(
1298 shapes_and_types, input_handle_shapes_and_types_[idx].get());
1299 }
1300
1301 // -----------------------------------------------------------------------------
1302 // ShapeManager
1303 // -----------------------------------------------------------------------------
ShapeManager()1304 InferenceContext::ShapeManager::ShapeManager() {}
~ShapeManager()1305 InferenceContext::ShapeManager::~ShapeManager() {
1306 for (auto* s : all_shapes_) delete s;
1307 for (auto* d : all_dims_) delete d;
1308 }
1309
MakeShape(const std::vector<DimensionHandle> & dims)1310 ShapeHandle InferenceContext::ShapeManager::MakeShape(
1311 const std::vector<DimensionHandle>& dims) {
1312 all_shapes_.push_back(new Shape(dims));
1313 return all_shapes_.back();
1314 }
1315
UnknownShape()1316 ShapeHandle InferenceContext::ShapeManager::UnknownShape() {
1317 all_shapes_.push_back(new Shape());
1318 return all_shapes_.back();
1319 }
1320
1321 } // namespace shape_inference
1322 } // namespace tensorflow
1323