1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16
17 #include <cstdint>
18
19 #include "tensorflow/core/framework/bounds_check.h"
20 #include "tensorflow/core/framework/full_type_util.h"
21 #include "tensorflow/core/framework/op_def.pb.h"
22 #include "tensorflow/core/framework/partial_tensor_shape.h"
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/util/overflow.h"
26
27 namespace tensorflow {
28 namespace shape_inference {
29
30 constexpr int32_t InferenceContext::kUnknownRank;
31 constexpr int64_t InferenceContext::kUnknownDim;
32
33 // Same as above, but with PartialTensorShape instead of TensorShapeProto
InferenceContext(int graph_def_version,const AttrSlice & attrs,const OpDef & op_def,const std::vector<PartialTensorShape> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<PartialTensorShape> & input_tensors_as_shapes,const std::vector<std::unique_ptr<std::vector<std::pair<PartialTensorShape,DataType>>>> & input_handle_shapes_and_types)34 InferenceContext::InferenceContext(
35 int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
36 const std::vector<PartialTensorShape>& input_shapes,
37 const std::vector<const Tensor*>& input_tensors,
38 const std::vector<PartialTensorShape>& input_tensors_as_shapes,
39 const std::vector<
40 std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
41 input_handle_shapes_and_types)
42 : graph_def_version_(graph_def_version), attrs_(attrs) {
43 std::vector<ShapeHandle> input_tensors_as_shape_handles;
44 input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
45 for (const PartialTensorShape& p : input_tensors_as_shapes) {
46 ShapeHandle shape;
47 construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
48 if (!construction_status_.ok()) {
49 return;
50 }
51 input_tensors_as_shape_handles.push_back(shape);
52 }
53 PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
54 if (!construction_status_.ok()) return;
55 inputs_.reserve(input_shapes.size());
56 for (const PartialTensorShape& p : input_shapes) {
57 ShapeHandle shape;
58 construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
59 if (!construction_status_.ok()) {
60 return;
61 }
62 inputs_.push_back(shape);
63 }
64 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
65 input_shapes.size());
66 for (int i = 0, end = input_handle_shapes_and_types.size(); i < end; ++i) {
67 const auto& v = input_handle_shapes_and_types[i];
68 if (v == nullptr) {
69 continue;
70 }
71 handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
72 auto& new_v = *handle_data[i];
73 for (int j = 0, end = v->size(); j < end; ++j) {
74 const auto& p = (*v)[j];
75 construction_status_.Update(
76 MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape));
77 if (!construction_status_.ok()) {
78 return;
79 }
80 new_v[j].dtype = p.second;
81 }
82 }
83 PostInputInit(std::move(handle_data));
84 }
85
InferenceContext(int graph_def_version,const AttrSlice & attrs,const OpDef & op_def,const std::vector<ShapeHandle> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes,std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types)86 InferenceContext::InferenceContext(
87 int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
88 const std::vector<ShapeHandle>& input_shapes,
89 const std::vector<const Tensor*>& input_tensors,
90 const std::vector<ShapeHandle>& input_tensors_as_shapes,
91 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
92 input_handle_shapes_and_types)
93 : graph_def_version_(graph_def_version), attrs_(attrs) {
94 PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
95 if (!construction_status_.ok()) return;
96 inputs_ = input_shapes;
97
98 PostInputInit(std::move(input_handle_shapes_and_types));
99 }
100
~InferenceContext()101 InferenceContext::~InferenceContext() {}
102
Run(const std::function<Status (shape_inference::InferenceContext * c)> & fn)103 Status InferenceContext::Run(
104 const std::function<Status(shape_inference::InferenceContext* c)>& fn) {
105 ForgetMerges();
106 Status s = fn(this);
107 if (!s.ok()) {
108 ForgetMerges();
109 return AttachContext(s);
110 }
111 #ifndef NDEBUG
112 for (int i = 0; i < num_outputs(); ++i) {
113 DCHECK(output(i).IsSet()) << i << " for " << attrs_.SummarizeNode();
114 }
115 #endif // NDEBUG
116 return s;
117 }
118
set_output(StringPiece output_name,const std::vector<ShapeHandle> & shapes)119 Status InferenceContext::set_output(StringPiece output_name,
120 const std::vector<ShapeHandle>& shapes) {
121 auto result = output_name_map_.find(output_name);
122 if (result == output_name_map_.end()) {
123 return errors::InvalidArgument("Unknown output name: ", output_name);
124 } else {
125 const int start = result->second.first;
126 const int size = result->second.second - start;
127 const int shapes_size = shapes.size();
128 if (size != shapes_size) {
129 return errors::InvalidArgument("Must provide exactly ", size, " shapes.");
130 }
131 for (int i = 0; i < shapes_size; ++i) {
132 outputs_[i + start] = shapes[i];
133 }
134 }
135 return OkStatus();
136 }
137
input(StringPiece input_name,std::vector<ShapeHandle> * output) const138 Status InferenceContext::input(StringPiece input_name,
139 std::vector<ShapeHandle>* output) const {
140 const auto result = input_name_map_.find(input_name);
141 if (result == input_name_map_.end()) {
142 return errors::InvalidArgument("Unknown input name: ", input_name);
143 } else {
144 output->clear();
145 for (int i = result->second.first; i < result->second.second; ++i) {
146 output->push_back(inputs_[i]);
147 }
148 }
149 return OkStatus();
150 }
151
output(StringPiece output_name,std::vector<ShapeHandle> * output) const152 Status InferenceContext::output(StringPiece output_name,
153 std::vector<ShapeHandle>* output) const {
154 const auto result = output_name_map_.find(output_name);
155 if (result == output_name_map_.end()) {
156 return errors::InvalidArgument("Unknown output name: ", output_name);
157 } else {
158 output->clear();
159 for (int i = result->second.first; i < result->second.second; ++i) {
160 output->push_back(outputs_[i]);
161 }
162 }
163 return OkStatus();
164 }
165
PreInputInit(const OpDef & op_def,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes)166 void InferenceContext::PreInputInit(
167 const OpDef& op_def, const std::vector<const Tensor*>& input_tensors,
168 const std::vector<ShapeHandle>& input_tensors_as_shapes) {
169 // TODO(mdan): This is also done at graph construction. Run only here instead?
170 Status s = full_type::SpecializeType(attrs_, op_def, ret_types_);
171 if (!s.ok()) {
172 construction_status_ = s;
173 return;
174 }
175
176 input_tensors_ = input_tensors;
177 input_tensors_as_shapes_ = input_tensors_as_shapes;
178
179 construction_status_ =
180 NameRangesForNode(attrs_, op_def, &input_name_map_, &output_name_map_);
181 if (!construction_status_.ok()) return;
182
183 int num_outputs = 0;
184 for (const auto& e : output_name_map_) {
185 num_outputs = std::max(num_outputs, e.second.second);
186 }
187 outputs_.assign(num_outputs, nullptr);
188 output_handle_shapes_and_types_.resize(num_outputs);
189 }
190
ExpandOutputs(int new_output_size)191 Status InferenceContext::ExpandOutputs(int new_output_size) {
192 const int outputs_size = outputs_.size();
193 if (new_output_size < outputs_size) {
194 return errors::InvalidArgument("Trying to reduce number of outputs of op.");
195 }
196 outputs_.resize(new_output_size, nullptr);
197 output_handle_shapes_and_types_.resize(new_output_size);
198 return OkStatus();
199 }
200
PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data)201 void InferenceContext::PostInputInit(
202 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
203 int num_inputs_from_node_def = 0;
204 for (const auto& e : input_name_map_) {
205 num_inputs_from_node_def =
206 std::max(num_inputs_from_node_def, e.second.second);
207 }
208
209 // Allow passing empty shapes/dtypes to avoid changing every single test.
210 if (input_handle_data.empty()) {
211 input_handle_shapes_and_types_.resize(inputs_.size());
212 } else {
213 if (input_handle_data.size() != inputs_.size()) {
214 construction_status_ = errors::InvalidArgument(
215 "Wrong number of handle shapes passed; expected ", inputs_.size(),
216 " got ", input_handle_data.size());
217 return;
218 }
219 input_handle_shapes_and_types_ = std::move(input_handle_data);
220 }
221 const int inputs_size = inputs_.size();
222 if (inputs_size != num_inputs_from_node_def) {
223 construction_status_ = errors::InvalidArgument(
224 "Wrong number of inputs passed: ", inputs_.size(), " while ",
225 num_inputs_from_node_def, " expected based on NodeDef");
226 return;
227 }
228
229 CHECK_LE(input_tensors_.size(), inputs_.size());
230 input_tensors_.resize(inputs_.size());
231 requested_input_tensor_.resize(inputs_.size());
232 requested_input_tensor_as_partial_shape_.resize(inputs_.size());
233 }
234
ShapeHandleToProto(ShapeHandle handle,TensorShapeProto * proto)235 void InferenceContext::ShapeHandleToProto(ShapeHandle handle,
236 TensorShapeProto* proto) {
237 if (!RankKnown(handle)) {
238 proto->set_unknown_rank(true);
239 return;
240 }
241
242 for (int32_t i = 0; i < Rank(handle); ++i) {
243 DimensionHandle dim = Dim(handle, i);
244 auto* dim_shape = proto->add_dim();
245 if (ValueKnown(dim)) {
246 dim_shape->set_size(Value(dim));
247 } else {
248 dim_shape->set_size(-1);
249 }
250 }
251 }
252
FullyDefined(ShapeHandle s)253 bool InferenceContext::FullyDefined(ShapeHandle s) {
254 if (!RankKnown(s)) return false;
255 for (int i = 0; i < Rank(s); ++i) {
256 if (!ValueKnown(Dim(s, i))) return false;
257 }
258 return true;
259 }
260
NumElements(ShapeHandle s)261 DimensionHandle InferenceContext::NumElements(ShapeHandle s) {
262 const auto rank = Rank(s);
263 if (rank == kUnknownRank) return UnknownDim();
264 bool found_unknown = false;
265 int64_t size = 1;
266 for (int i = 0; i < rank; ++i) {
267 int64_t dim_val = Value(Dim(s, i));
268 if (dim_val == kUnknownDim) {
269 found_unknown = true;
270 } else if (dim_val == 0) {
271 return MakeDim(0);
272 } else {
273 size *= dim_val;
274 }
275 }
276 if (found_unknown) {
277 return UnknownDim();
278 } else {
279 return MakeDim(size);
280 }
281 }
282
DebugString(ShapeHandle s)283 string InferenceContext::DebugString(ShapeHandle s) {
284 if (RankKnown(s)) {
285 std::vector<string> vals;
286 for (auto d : s->dims_) vals.push_back(DebugString(d));
287 return strings::StrCat("[", absl::StrJoin(vals, ","), "]");
288 } else {
289 return "?";
290 }
291 }
292
DebugString(DimensionHandle d)293 string InferenceContext::DebugString(DimensionHandle d) {
294 return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
295 }
296
DebugString() const297 string InferenceContext::DebugString() const {
298 return strings::StrCat("InferenceContext for node: ", attrs_.SummarizeNode());
299 }
300
DebugString(const ShapeAndType & shape_and_type)301 string InferenceContext::DebugString(const ShapeAndType& shape_and_type) {
302 return strings::StrCat(DebugString(shape_and_type.shape), ":",
303 DataTypeString(shape_and_type.dtype));
304 }
305
DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types)306 string InferenceContext::DebugString(
307 gtl::ArraySlice<ShapeAndType> shape_and_types) {
308 std::vector<string> pieces;
309 for (const ShapeAndType& s : shape_and_types) {
310 pieces.push_back(DebugString(s));
311 }
312 return strings::StrCat("[", absl::StrJoin(pieces, ","), "]");
313 }
314
WithRank(ShapeHandle shape,int64_t rank,ShapeHandle * out)315 Status InferenceContext::WithRank(ShapeHandle shape, int64_t rank,
316 ShapeHandle* out) {
317 if (rank > kint32max) {
318 return errors::InvalidArgument("Rank cannot exceed kint32max");
319 }
320 const int32_t existing = Rank(shape);
321 if (existing == rank) {
322 *out = shape;
323 return OkStatus();
324 }
325 if (existing == kUnknownRank) {
326 std::vector<DimensionHandle> dims;
327 dims.reserve(rank);
328 for (int i = 0; i < rank; ++i) {
329 dims.push_back(UnknownDim());
330 }
331 ShapeHandle shp = shape_manager_.MakeShape(dims);
332 return Merge(shape, shp, out);
333 }
334 *out = nullptr;
335
336 return errors::InvalidArgument("Shape must be rank ", rank, " but is rank ",
337 existing);
338 }
339
WithRankAtLeast(ShapeHandle shape,int64_t rank,ShapeHandle * out)340 Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64_t rank,
341 ShapeHandle* out) {
342 if (rank > kint32max) {
343 return errors::InvalidArgument("Rank cannot exceed kint32max");
344 }
345 const int32_t existing = Rank(shape);
346 if (existing >= rank || existing == kUnknownRank) {
347 *out = shape;
348 return OkStatus();
349 }
350 *out = nullptr;
351 return errors::InvalidArgument("Shape must be at least rank ", rank,
352 " but is rank ", existing);
353 }
354
WithRankAtMost(ShapeHandle shape,int64_t rank,ShapeHandle * out)355 Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64_t rank,
356 ShapeHandle* out) {
357 if (rank > kint32max) {
358 return errors::InvalidArgument("Rank cannot exceed kint32max");
359 }
360 const int32_t existing = Rank(shape);
361 if (existing <= rank || existing == kUnknownRank) {
362 *out = shape;
363 return OkStatus();
364 }
365 *out = nullptr;
366 return errors::InvalidArgument("Shape must be at most rank ", rank,
367 " but is rank ", existing);
368 }
369
WithValue(DimensionHandle dim,int64_t value,DimensionHandle * out)370 Status InferenceContext::WithValue(DimensionHandle dim, int64_t value,
371 DimensionHandle* out) {
372 const int64_t existing = Value(dim);
373 if (existing == value) {
374 *out = dim;
375 return OkStatus();
376 }
377 if (existing == kUnknownDim) {
378 DimensionHandle d = MakeDim(value);
379 return Merge(dim, d, out);
380 }
381 *out = nullptr;
382 return errors::InvalidArgument("Dimension must be ", value, " but is ",
383 existing);
384 }
385
Relax(DimensionHandle d_old,DimensionHandle d_new,DimensionHandle * out)386 void InferenceContext::Relax(DimensionHandle d_old, DimensionHandle d_new,
387 DimensionHandle* out) {
388 if (d_old.SameHandle(d_new)) {
389 *out = d_old;
390 } else if (!ValueKnown(d_old) && !ValueKnown(d_new)) {
391 // The node will be fed by the dimension d_new instead of d_old: any
392 // equality assertion between d_old and other input dimension on this node
393 // may not be true anymore, so forget them all.
394 ForgetMerges();
395 // Return the new shape handle to force the relaxation to propagate to the
396 // fanout of the context.
397 *out = d_new;
398 } else if (!ValueKnown(d_new)) {
399 ForgetMerges();
400 *out = d_new;
401 } else if (Value(d_old) == Value(d_new)) {
402 // Return the old shape handle. This will stop the relaxation in the fanout
403 // of the context.
404 *out = d_old;
405 } else {
406 // Return a new handle that encodes a different unknown dim.
407 ForgetMerges();
408 *out = UnknownDim();
409 }
410 }
411
Merge(DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)412 Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1,
413 DimensionHandle* out) {
414 if (d0.SameHandle(d1)) {
415 *out = d0;
416 return OkStatus();
417 } else if (!ValueKnown(d1)) {
418 *out = d0;
419 merged_dims_.emplace_back(d0, d1);
420 return OkStatus();
421 } else if (!ValueKnown(d0)) {
422 *out = d1;
423 merged_dims_.emplace_back(d0, d1);
424 return OkStatus();
425 } else if (Value(d0) == Value(d1)) {
426 *out = d0;
427 return OkStatus();
428 } else {
429 *out = nullptr;
430 return errors::InvalidArgument("Dimensions must be equal, but are ",
431 Value(d0), " and ", Value(d1));
432 }
433 }
434
MergePrefix(ShapeHandle s,ShapeHandle prefix,ShapeHandle * s_out,ShapeHandle * prefix_out)435 Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix,
436 ShapeHandle* s_out,
437 ShapeHandle* prefix_out) {
438 *s_out = *prefix_out = nullptr;
439 if (!RankKnown(prefix) || !RankKnown(s)) {
440 *s_out = s;
441 *prefix_out = prefix;
442 return OkStatus();
443 }
444 const int32_t rank = Rank(prefix);
445 TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s));
446
447 // Merge the prefix dims and create the new output shapes.
448 const int32_t rank_s = Rank(s);
449 std::vector<DimensionHandle> dims;
450 dims.reserve(std::max(rank, rank_s));
451 dims.resize(rank);
452 for (int i = 0; i < rank; ++i) {
453 TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i]));
454 }
455 *prefix_out = MakeShape(dims);
456 for (int i = rank; i < rank_s; ++i) dims.push_back(Dim(s, i));
457 *s_out = MakeShape(dims);
458 return OkStatus();
459 }
460
Relax(ShapeHandle s_old,ShapeHandle s_new,ShapeHandle * out)461 void InferenceContext::Relax(ShapeHandle s_old, ShapeHandle s_new,
462 ShapeHandle* out) {
463 if (s_old.SameHandle(s_new)) {
464 *out = s_old;
465 return;
466 } else if (!RankKnown(s_new) || !s_old.IsSet()) {
467 ForgetMerges();
468 *out = s_new;
469 return;
470 }
471
472 const int32_t rank = Rank(s_old);
473 if (rank != Rank(s_new)) {
474 ForgetMerges();
475 *out = UnknownShape();
476 return;
477 }
478
479 bool return_s_old = true;
480 for (int i = 0; i < rank; ++i) {
481 auto d0 = Dim(s_old, i);
482 auto d1 = Dim(s_new, i);
483 if (d0.SameHandle(d1)) continue;
484
485 auto v0 = Value(d0);
486 auto v1 = Value(d1);
487 if (v0 == kUnknownDim || v1 == kUnknownDim || v0 != v1) {
488 return_s_old = false;
489 break;
490 }
491 }
492 if (return_s_old) {
493 *out = s_old;
494 return;
495 }
496
497 // Relax dims.
498 std::vector<DimensionHandle> dims(rank);
499 for (int i = 0; i < rank; ++i) {
500 Relax(Dim(s_old, i), Dim(s_new, i), &dims[i]);
501 }
502 ForgetMerges();
503 *out = MakeShape(dims);
504 }
505
Merge(ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)506 Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
507 ShapeHandle* out) {
508 if (s0.SameHandle(s1)) {
509 *out = s0;
510 return OkStatus();
511 } else if (!RankKnown(s1)) {
512 *out = s0;
513 merged_shapes_.emplace_back(s0, s1);
514 return OkStatus();
515 } else if (!RankKnown(s0)) {
516 *out = s1;
517 merged_shapes_.emplace_back(s0, s1);
518 return OkStatus();
519 }
520
521 const int32_t rank = Rank(s0);
522 if (rank != Rank(s1)) {
523 *out = nullptr;
524 return errors::InvalidArgument("Shapes must be equal rank, but are ", rank,
525 " and ", Rank(s1));
526 }
527
528 bool return_s0 = true;
529 bool return_s1 = true;
530 for (int i = 0; i < rank; ++i) {
531 auto d0 = Dim(s0, i);
532 auto d1 = Dim(s1, i);
533 if (d0.SameHandle(d1)) continue;
534
535 auto v0 = Value(d0);
536 auto v1 = Value(d1);
537 if (v0 == kUnknownDim) {
538 if (v1 != kUnknownDim) {
539 return_s0 = false;
540 }
541 } else if (v1 == kUnknownDim) {
542 return_s1 = false;
543 } else if (v0 != v1) {
544 *out = nullptr;
545 return errors::InvalidArgument(
546 "Dimension ", i, " in both shapes must be equal, but are ", Value(d0),
547 " and ", Value(d1), ". Shapes are ", DebugString(s0), " and ",
548 DebugString(s1), ".");
549 }
550 }
551
552 merged_shapes_.emplace_back(s0, s1);
553
554 if (return_s0 || return_s1) {
555 *out = return_s0 ? s0 : s1;
556 return OkStatus();
557 }
558
559 // Merge dims.
560 std::vector<DimensionHandle> dims(rank, nullptr);
561 for (int i = 0; i < rank; ++i) {
562 // Invariant for merge was checked earlier, so CHECK is ok.
563 TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i]));
564 }
565
566 Status s = ReturnCreatedShape(dims, out);
567 if (s.ok()) {
568 // Merge the new shape with s0. Since s0 and s1 are merged, this implies
569 // that s1 and out are also merged.
570 merged_shapes_.emplace_back(s0, *out);
571 }
572 return s;
573 }
574
Subshape(ShapeHandle s,int64_t start,ShapeHandle * out)575 Status InferenceContext::Subshape(ShapeHandle s, int64_t start,
576 ShapeHandle* out) {
577 return Subshape(s, start, std::numeric_limits<int64_t>::max() /* end */, out);
578 }
579
Subshape(ShapeHandle s,int64_t start,int64_t end,ShapeHandle * out)580 Status InferenceContext::Subshape(ShapeHandle s, int64_t start, int64_t end,
581 ShapeHandle* out) {
582 return Subshape(s, start, end, 1 /* stride */, out);
583 }
584
Subshape(ShapeHandle s,int64_t start,int64_t end,int64_t stride,ShapeHandle * out)585 Status InferenceContext::Subshape(ShapeHandle s, int64_t start, int64_t end,
586 int64_t stride, ShapeHandle* out) {
587 int64_t start_in = start;
588 int64_t end_in = end;
589
590 const int32_t rank = Rank(s);
591 if (start == 0 && stride == 1 &&
592 ((RankKnown(s) && end >= rank) ||
593 end == std::numeric_limits<int64_t>::max())) {
594 *out = s;
595 return OkStatus();
596 }
597 if (!RankKnown(s)) {
598 return ReturnUnknownShape(out);
599 }
600
601 if (start > rank) start = rank;
602 if (end > rank) end = rank;
603
604 if (stride < 0 && start == rank) --start;
605
606 if (start < 0) {
607 start = rank + start;
608 if (start < 0) {
609 *out = nullptr;
610 return errors::InvalidArgument("Subshape start out of bounds: ", start_in,
611 ", for shape with rank ", rank);
612 }
613 }
614
615 if (end < 0) {
616 end = rank + end;
617 if (end < 0) {
618 *out = nullptr;
619 return errors::InvalidArgument("Subshape end out of bounds: ", end_in,
620 ", for shape with rank ", rank);
621 }
622 }
623 if (stride > 0 && start > end) {
624 *out = nullptr;
625 return errors::InvalidArgument(
626 "Subshape must have computed start <= end, but is ", start, " and ",
627 end, " (computed from start ", start_in, " and end ", end_in,
628 " over shape with rank ", rank, ")");
629 } else if (stride < 0 && start < end) {
630 *out = nullptr;
631 return errors::InvalidArgument(
632 "Subshape must have computed start >= end since stride is negative, "
633 "but is ",
634 start, " and ", end, " (computed from start ", start_in, " and end ",
635 end_in, " over shape with rank ", rank, " and stride", stride, ")");
636 }
637
638 std::vector<DimensionHandle> dims;
639 for (int i = start; stride > 0 ? i < end : i > end; i += stride) {
640 dims.push_back(Dim(s, i));
641 }
642 return ReturnCreatedShape(dims, out);
643 }
644
Concatenate(ShapeHandle s1,ShapeHandle s2,ShapeHandle * out)645 Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2,
646 ShapeHandle* out) {
647 if (!RankKnown(s1) || !RankKnown(s2)) {
648 return ReturnUnknownShape(out);
649 }
650 const int32_t s1_rank = Rank(s1);
651 const int32_t s2_rank = Rank(s2);
652 const int32_t rank = s1_rank + s2_rank;
653 std::vector<DimensionHandle> dims;
654 dims.reserve(rank);
655 for (int i = 0; i < s1_rank; ++i) dims.push_back(Dim(s1, i));
656 for (int i = 0; i < s2_rank; ++i) dims.push_back(Dim(s2, i));
657 return ReturnCreatedShape(dims, out);
658 }
659
ReplaceDim(ShapeHandle s,int64_t dim_index_in,DimensionHandle new_dim,ShapeHandle * out)660 Status InferenceContext::ReplaceDim(ShapeHandle s, int64_t dim_index_in,
661 DimensionHandle new_dim, ShapeHandle* out) {
662 if (!RankKnown(s)) {
663 return ReturnUnknownShape(out);
664 }
665 int64_t dim_index = dim_index_in;
666 if (dim_index < 0) {
667 dim_index = s->dims_.size() + dim_index;
668 }
669 if (!FastBoundsCheck(dim_index, s->dims_.size())) {
670 *out = nullptr;
671 return errors::InvalidArgument("Out of range dim_index ", dim_index_in,
672 " for shape with ", s->dims_.size(),
673 " dimensions");
674 }
675 std::vector<DimensionHandle> dims(s->dims_);
676 dims[dim_index] = new_dim;
677 return ReturnCreatedShape(dims, out);
678 }
679
MakeShape(const std::vector<DimensionHandle> & dims)680 ShapeHandle InferenceContext::MakeShape(
681 const std::vector<DimensionHandle>& dims) {
682 return shape_manager_.MakeShape(dims);
683 }
684
MakeShape(std::initializer_list<DimensionOrConstant> dims)685 ShapeHandle InferenceContext::MakeShape(
686 std::initializer_list<DimensionOrConstant> dims) {
687 std::vector<DimensionHandle> dims_actual;
688 dims_actual.reserve(dims.size());
689 for (const DimensionOrConstant& d : dims) {
690 dims_actual.push_back(MakeDim(d));
691 }
692
693 return shape_manager_.MakeShape(dims_actual);
694 }
695
UnknownShape()696 ShapeHandle InferenceContext::UnknownShape() {
697 return shape_manager_.UnknownShape();
698 }
699
UnknownShapeOfRank(int64_t rank)700 ShapeHandle InferenceContext::UnknownShapeOfRank(int64_t rank) {
701 CHECK_LE(rank, kint32max) << "rank must be less than kint32max";
702 if (rank == kUnknownRank) {
703 return UnknownShape();
704 }
705 CHECK_GE(rank, 0) << "rank must not be negative";
706 std::vector<DimensionHandle> dims(rank);
707 for (int32_t i = 0; i < rank; ++i) {
708 dims[i] = UnknownDim();
709 }
710 return MakeShape(dims);
711 }
712
Scalar()713 ShapeHandle InferenceContext::Scalar() { return MakeShape({}); }
714
Vector(DimensionOrConstant dim)715 ShapeHandle InferenceContext::Vector(DimensionOrConstant dim) {
716 return MakeShape({dim});
717 }
718
Matrix(DimensionOrConstant dim1,DimensionOrConstant dim2)719 ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
720 DimensionOrConstant dim2) {
721 return MakeShape({dim1, dim2});
722 }
723
MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,ShapeHandle * out)724 Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
725 int input_idx, ShapeHandle* out) {
726 ShapeHandle input_shape;
727 TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
728
729 request_input_tensor_as_partial_shape(input_idx);
730 const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
731 if (input_idx < input_tensors_as_shapes_size &&
732 input_tensors_as_shapes_[input_idx].IsSet() &&
733 RankKnown(input_tensors_as_shapes_[input_idx])) {
734 *out = input_tensors_as_shapes_[input_idx];
735 return OkStatus();
736 }
737
738 return InternalMakeShapeFromTensor(
739 true /* treat_unknown_scalar_tensor_as_unknown_shape */,
740 input_tensor(input_idx), input_shape, out);
741 }
742
MakeShapeFromShapeTensor(int input_idx,ShapeHandle * out)743 Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
744 ShapeHandle* out) {
745 ShapeHandle input_shape;
746 TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
747
748 request_input_tensor_as_partial_shape(input_idx);
749 const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
750 if (input_idx < input_tensors_as_shapes_size &&
751 input_tensors_as_shapes_[input_idx].IsSet() &&
752 RankKnown(input_tensors_as_shapes_[input_idx])) {
753 *out = input_tensors_as_shapes_[input_idx];
754 return OkStatus();
755 }
756
757 return InternalMakeShapeFromTensor(
758 false /* treat_unknown_scalar_tensor_as_unknown_shape */,
759 input_tensor(input_idx), input_shape, out);
760 }
761
MakeShapeFromTensor(const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)762 Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
763 ShapeHandle tensor_shape,
764 ShapeHandle* out) {
765 return InternalMakeShapeFromTensor(
766 false /* treat_unknown_scalar_tensor_as_unknown_shape */, t, tensor_shape,
767 out);
768 }
769
InternalMakeShapeFromTensor(bool treat_unknown_scalar_tensor_as_unknown_shape,const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)770 Status InferenceContext::InternalMakeShapeFromTensor(
771 bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
772 ShapeHandle tensor_shape, ShapeHandle* out) {
773 // Only callers who have set
774 if (!treat_unknown_scalar_tensor_as_unknown_shape) {
775 TF_RETURN_IF_ERROR(WithRank(tensor_shape, 1, &tensor_shape));
776 }
777 if (t == nullptr) {
778 // This is guarded by the check above.
779 if (Rank(tensor_shape) == 0) {
780 return ReturnUnknownShape(out);
781 }
782 // Shape tensor is not known, but if the shape of the shape tensor is then
783 // the right number of unknown dims can be created.
784 DimensionHandle shape_dim = Dim(tensor_shape, 0);
785 if (!ValueKnown(shape_dim)) {
786 return ReturnUnknownShape(out);
787 }
788 const auto num_dims = Value(shape_dim);
789 // Note: This should be `TensorShape::MaxDimensions()` as we are not able to
790 // materialize shapes with more than this number of dimensions but then
791 // shape inference would fail for operations such as `tf.range`/`tf.ones`,
792 // etc. where the shape is not really materialized, only used during the
793 // inference. Hence, just prevent doing a `reserve` with a very large
794 // argument.
795 const int64_t max_dimensions = 1 << 25;
796 if (num_dims >= max_dimensions) {
797 return errors::Internal(
798 "Cannot create a tensor with ", num_dims,
799 " dimensions, as these would be more than maximum of ",
800 max_dimensions);
801 }
802 std::vector<DimensionHandle> dims;
803 dims.reserve(num_dims);
804 for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim());
805 return ReturnCreatedShape(dims, out);
806 }
807
808 if (t->shape().dims() == 0) {
809 if (t->dtype() == DataType::DT_INT32) {
810 auto flat_t = t->scalar<int32>();
811 if (flat_t() != -1) {
812 *out = nullptr;
813 return errors::InvalidArgument(
814 "Input tensor must be rank 1, or if its rank 0 it must have value "
815 "-1 "
816 "(representing an unknown shape). Saw value: ",
817 flat_t());
818 }
819 return ReturnUnknownShape(out);
820 } else if (t->dtype() == DataType::DT_INT64) {
821 auto flat_t = t->scalar<int64_t>();
822 if (flat_t() != -1) {
823 *out = nullptr;
824 return errors::InvalidArgument(
825 "Input tensor must be rank 1, or if its rank 0 it must have value "
826 "-1 "
827 "(representing an unknown shape). Saw value: ",
828 flat_t());
829 }
830 return ReturnUnknownShape(out);
831 } else {
832 *out = nullptr;
833 return errors::InvalidArgument(
834 "Input tensor must be int32 or int64, but was ",
835 DataTypeString(t->dtype()));
836 }
837 }
838
839 if (t->shape().dims() != 1) {
840 *out = nullptr;
841 return errors::InvalidArgument(
842 "Input tensor must be rank 1, but was rank ", t->shape().dims(), ".",
843 ((t->shape().dims() == 0)
844 ? "If it is rank 0 it must have statically known value -1 "
845 "(representing an unknown shape). "
846 : " "),
847 "Saw tensor shape ", t->shape().DebugString());
848 }
849 std::vector<DimensionHandle> dims;
850 if (t->dtype() == DataType::DT_INT32) {
851 auto flat_t = t->flat<int32>();
852 for (int i = 0; i < flat_t.size(); ++i) {
853 const int32_t val = flat_t(i);
854 if (val < -1) {
855 return errors::InvalidArgument(
856 "Invalid value in tensor used for shape: ", val);
857 }
858 // -1 will become an unknown dim.
859 dims.push_back(MakeDim(val));
860 }
861 } else if (t->dtype() == DataType::DT_INT64) {
862 auto flat_t = t->flat<int64_t>();
863 for (int i = 0; i < flat_t.size(); ++i) {
864 const int64_t val = flat_t(i);
865 if (val < -1) {
866 return errors::InvalidArgument(
867 "Invalid value in tensor used for shape: ", val);
868 }
869 // -1 will become an unknown dim.
870 dims.push_back(MakeDim(val));
871 }
872 } else {
873 *out = nullptr;
874 return errors::InvalidArgument(
875 "Input tensor must be int32 or int64, but was ",
876 DataTypeString(t->dtype()));
877 }
878
879 return ReturnCreatedShape(dims, out);
880 }
881
MakeShapeFromPartialTensorShape(const PartialTensorShape & partial_shape,ShapeHandle * out)882 Status InferenceContext::MakeShapeFromPartialTensorShape(
883 const PartialTensorShape& partial_shape, ShapeHandle* out) {
884 *out = nullptr;
885 if (partial_shape.dims() == -1) {
886 return ReturnUnknownShape(out);
887 }
888 const int num_dims = partial_shape.dims();
889 std::vector<DimensionHandle> dims(num_dims);
890 for (int i = 0; i < num_dims; ++i) {
891 // -1 is unknown in PartialTensorShape and in InferenceContext, so this size
892 // can be passed directly to MakeDim.
893 dims[i] = MakeDim(partial_shape.dim_size(i));
894 }
895 return ReturnCreatedShape(dims, out);
896 }
897
MakeShapeFromTensorShape(const TensorShape & shape,ShapeHandle * out)898 Status InferenceContext::MakeShapeFromTensorShape(const TensorShape& shape,
899 ShapeHandle* out) {
900 return MakeShapeFromPartialTensorShape(PartialTensorShape(shape.dim_sizes()),
901 out);
902 }
903
MakeShapeFromShapeTensor(const TensorShape & shape)904 StatusOr<ShapeHandle> InferenceContext::MakeShapeFromShapeTensor(
905 const TensorShape& shape) {
906 ShapeHandle out;
907 TF_RETURN_IF_ERROR(MakeShapeFromTensorShape(shape, &out));
908 return out;
909 }
910
ShapeHandleToProto(ShapeHandle handle)911 TensorShapeProto InferenceContext::ShapeHandleToProto(ShapeHandle handle) {
912 TensorShapeProto out;
913 ShapeHandleToProto(handle, &out);
914 return out;
915 }
916
MakeShapeFromShapeProto(const TensorShapeProto & proto,ShapeHandle * out)917 Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
918 ShapeHandle* out) {
919 *out = nullptr;
920 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto));
921 PartialTensorShape partial_shape(proto);
922 return MakeShapeFromPartialTensorShape(partial_shape, out);
923 }
924
GetScalarFromTensor(const Tensor * t,int64_t * val)925 Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64_t* val) {
926 // Caller must ensure that <t> is not NULL.
927 const int rank = t->dims();
928 if (rank != 0) {
929 return errors::InvalidArgument("Input must be scalar but has rank ", rank);
930 }
931
932 if (t->dtype() == DataType::DT_INT32) {
933 *val = t->scalar<int32>()();
934 return OkStatus();
935 } else if (t->dtype() == DataType::DT_INT64) {
936 *val = t->scalar<int64_t>()();
937 return OkStatus();
938 } else {
939 return errors::InvalidArgument("Scalar input must be int32 or int64.");
940 }
941 }
942
GetScalarFromTensor(const Tensor * t,int64_t idx,int64_t * val)943 Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64_t idx,
944 int64_t* val) {
945 // Caller must ensure that <t> is not NULL.
946 const int rank = t->dims();
947 if (rank != 1) {
948 return errors::InvalidArgument("Input must be 1D but has rank ", rank);
949 }
950
951 if (t->dtype() == DataType::DT_INT32) {
952 auto flat_t = t->flat<int32>();
953 if (idx < 0 || idx >= flat_t.size()) {
954 return errors::InvalidArgument("Invalid index ", idx,
955 " for Tensor of size ", flat_t.size());
956 }
957 *val = flat_t(idx);
958 return OkStatus();
959 } else if (t->dtype() == DataType::DT_INT64) {
960 auto flat_t = t->flat<int64_t>();
961 if (idx < 0 || idx >= flat_t.size()) {
962 return errors::InvalidArgument("Invalid index ", idx,
963 " for Tensor of size ", flat_t.size());
964 }
965 *val = flat_t(idx);
966 return OkStatus();
967 } else {
968 return errors::InvalidArgument("Tensor input must be int32 or int64.");
969 }
970 }
971
972 // Returns a new dimension whose value is given by a scalar input tensor.
MakeDimForScalarInput(int idx,DimensionHandle * out)973 Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
974 int64_t val;
975 const Tensor* t = input_tensor(idx);
976 if (t == nullptr) {
977 *out = UnknownDim();
978 return OkStatus();
979 }
980 TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
981 if (val < 0) {
982 return errors::InvalidArgument("Dimension size, given by scalar input ",
983 idx, ", must be non-negative but is ", val);
984 }
985 *out = MakeDim(val);
986 return OkStatus();
987 }
988
MakeDimForScalarInputWithNegativeIndexing(int idx,int input_rank,DimensionHandle * out)989 Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing(
990 int idx, int input_rank, DimensionHandle* out) {
991 int64_t val;
992 const Tensor* t = input_tensor(idx);
993 if (t == nullptr) {
994 *out = UnknownDim();
995 return OkStatus();
996 }
997 TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
998 if (val < 0) {
999 if (input_rank < 0) {
1000 *out = UnknownDim();
1001 return OkStatus();
1002 } else if (val + input_rank < 0) {
1003 return errors::InvalidArgument("Dimension size, given by scalar input ",
1004 val, " must be in range [-", input_rank,
1005 ", ", input_rank, ")");
1006 } else {
1007 val += input_rank;
1008 }
1009 } else if (input_rank >= 0 && val >= input_rank) {
1010 return errors::InvalidArgument("Dimension size, given by scalar input ",
1011 val, " must be in range [-", input_rank,
1012 ", ", input_rank, ")");
1013 }
1014 *out = MakeDim(val);
1015 return OkStatus();
1016 }
1017
Divide(DimensionHandle dividend,DimensionOrConstant divisor,bool evenly_divisible,DimensionHandle * out)1018 Status InferenceContext::Divide(DimensionHandle dividend,
1019 DimensionOrConstant divisor,
1020 bool evenly_divisible, DimensionHandle* out) {
1021 const int64_t divisor_value = Value(divisor);
1022 if (divisor_value == 1) {
1023 *out = dividend;
1024 } else if (!ValueKnown(dividend) ||
1025 (divisor.dim.IsSet() && !ValueKnown(divisor.dim))) {
1026 *out = UnknownDim();
1027 } else {
1028 const int64_t v = Value(dividend);
1029 if (divisor_value <= 0) {
1030 return errors::InvalidArgument("Divisor must be positive but is ",
1031 divisor_value);
1032 }
1033 if (evenly_divisible && (v % divisor_value) != 0) {
1034 return errors::InvalidArgument(
1035 "Dimension size must be evenly divisible by ", divisor_value,
1036 " but is ", v);
1037 }
1038 *out = MakeDim(v / divisor_value);
1039 }
1040 return OkStatus();
1041 }
1042
Add(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1043 Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second,
1044 DimensionHandle* out) {
1045 const int64_t first_value = Value(first);
1046 const int64_t second_value = Value(second);
1047 // Special cases.
1048 if (first_value == 0) {
1049 *out = MakeDim(second);
1050 } else if (second_value == 0) {
1051 *out = first;
1052 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1053 *out = UnknownDim();
1054 } else {
1055 // Invariant: Both values are known and positive. Still in run-time we can
1056 // get pair of values which cannot be store in output. Check below will
1057 // report error. We still need to avoid undefined behavior of signed
1058 // overflow and use unsigned addition.
1059 const int64_t sum = static_cast<uint64>(first_value) + second_value;
1060 if (sum < 0) {
1061 return errors::InvalidArgument("Dimension size overflow from adding ",
1062 first_value, " and ", second_value);
1063 }
1064 *out = MakeDim(sum);
1065 }
1066 return OkStatus();
1067 }
1068
Subtract(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1069 Status InferenceContext::Subtract(DimensionHandle first,
1070 DimensionOrConstant second,
1071 DimensionHandle* out) {
1072 const int64_t first_value = Value(first);
1073 const int64_t second_value = Value(second);
1074 // Special cases.
1075 if (second_value == 0) {
1076 *out = first;
1077 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1078 *out = UnknownDim();
1079 } else {
1080 // Invariant: Both values are known, first_value is non-negative, and
1081 // second_value is positive.
1082 if (first_value < second_value) {
1083 return errors::InvalidArgument(
1084 "Negative dimension size caused by subtracting ", second_value,
1085 " from ", first_value);
1086 }
1087 *out = MakeDim(first_value - second_value);
1088 }
1089 return OkStatus();
1090 }
1091
Multiply(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1092 Status InferenceContext::Multiply(DimensionHandle first,
1093 DimensionOrConstant second,
1094 DimensionHandle* out) {
1095 const int64_t first_value = Value(first);
1096 const int64_t second_value = Value(second);
1097 // Special cases.
1098 if (first_value == 0) {
1099 *out = first;
1100 } else if (second_value == 0) {
1101 *out = MakeDim(second);
1102 } else if (first_value == 1) {
1103 *out = MakeDim(second);
1104 } else if (second_value == 1) {
1105 *out = first;
1106 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1107 *out = UnknownDim();
1108 } else {
1109 // Invariant: Both values are known and greater than 1.
1110 const int64_t product = MultiplyWithoutOverflow(first_value, second_value);
1111 if (product < 0) {
1112 return errors::InvalidArgument(
1113 "Negative dimension size caused by overflow when multiplying ",
1114 first_value, " and ", second_value);
1115 }
1116 *out = MakeDim(product);
1117 }
1118 return OkStatus();
1119 }
1120
Min(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1121 Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second,
1122 DimensionHandle* out) {
1123 const int64_t first_value = Value(first);
1124 const int64_t second_value = Value(second);
1125 if (first_value == 0) {
1126 *out = first;
1127 } else if (second_value == 0) {
1128 *out = MakeDim(second);
1129 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1130 *out = UnknownDim();
1131 } else {
1132 if (first_value <= second_value) {
1133 *out = first;
1134 } else {
1135 *out = MakeDim(second);
1136 }
1137 }
1138 return OkStatus();
1139 }
1140
Max(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1141 Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second,
1142 DimensionHandle* out) {
1143 const int64_t first_value = Value(first);
1144 const int64_t second_value = Value(second);
1145 if (first_value == kUnknownDim || second_value == kUnknownDim) {
1146 *out = UnknownDim();
1147 } else {
1148 if (first_value >= second_value) {
1149 *out = first;
1150 } else {
1151 *out = MakeDim(second);
1152 }
1153 }
1154 return OkStatus();
1155 }
1156
AttachContext(const Status & status)1157 Status InferenceContext::AttachContext(const Status& status) {
1158 std::vector<string> input_shapes;
1159 input_shapes.reserve(inputs_.size());
1160 for (const ShapeHandle& input_shape : inputs_) {
1161 input_shapes.emplace_back(DebugString(input_shape));
1162 }
1163
1164 // Add information about the input tensors and partial tensor shapes used.
1165 std::vector<string> input_from_tensors_str;
1166 std::vector<string> input_from_tensors_as_shape_str;
1167 input_from_tensors_as_shape_str.reserve(inputs_.size());
1168 for (int i = 0, end = inputs_.size(); i < end; ++i) {
1169 const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
1170 const int input_tensors_size = input_tensors_.size();
1171 if (requested_input_tensor_as_partial_shape_[i] &&
1172 i < input_tensors_as_shapes_size &&
1173 input_tensors_as_shapes_[i].IsSet() &&
1174 RankKnown(input_tensors_as_shapes_[i])) {
1175 input_from_tensors_as_shape_str.push_back(strings::StrCat(
1176 "input[", i, "] = ", DebugString(input_tensors_as_shapes_[i])));
1177 } else if (requested_input_tensor_[i] && i < input_tensors_size &&
1178 input_tensors_[i] != nullptr) {
1179 input_from_tensors_str.push_back(strings::StrCat(
1180 "input[", i, "] = <",
1181 input_tensors_[i]->SummarizeValue(256 /* max_values */), ">"));
1182 }
1183 }
1184
1185 string error_context = strings::StrCat(
1186 " for '", attrs_.SummarizeNode(),
1187 "' with input shapes: ", absl::StrJoin(input_shapes, ", "));
1188 if (!input_from_tensors_str.empty()) {
1189 strings::StrAppend(&error_context, " and with computed input tensors: ",
1190 absl::StrJoin(input_from_tensors_str, ", "));
1191 }
1192 if (!input_from_tensors_as_shape_str.empty()) {
1193 strings::StrAppend(&error_context,
1194 " and with input tensors computed as partial shapes: ",
1195 absl::StrJoin(input_from_tensors_as_shape_str, ","));
1196 }
1197
1198 strings::StrAppend(&error_context, ".");
1199 return errors::CreateWithUpdatedMessage(
1200 status, strings::StrCat(status.error_message(), error_context));
1201 }
1202
MergeHandleShapesAndTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1203 bool InferenceContext::MergeHandleShapesAndTypes(
1204 const std::vector<ShapeAndType>& shapes_and_types,
1205 std::vector<ShapeAndType>* to_update) {
1206 if (shapes_and_types.size() != to_update->size()) {
1207 return false;
1208 }
1209 std::vector<ShapeAndType> new_values(shapes_and_types.size());
1210 bool refined = false;
1211 for (int i = 0, end = shapes_and_types.size(); i < end; ++i) {
1212 const ShapeAndType& existing = (*to_update)[i];
1213 if (shapes_and_types[i].dtype == existing.dtype) {
1214 new_values[i].dtype = existing.dtype;
1215 } else {
1216 if (existing.dtype != DT_INVALID) {
1217 return false;
1218 } else {
1219 new_values[i].dtype = shapes_and_types[i].dtype;
1220 refined = true;
1221 }
1222 }
1223 if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape)
1224 .ok()) {
1225 // merge failed, ignore the new value.
1226 new_values[i].shape = existing.shape;
1227 }
1228 if (!existing.shape.SameHandle(new_values[i].shape)) {
1229 refined = true;
1230 }
1231 }
1232 if (!refined) {
1233 return false;
1234 }
1235 to_update->swap(new_values);
1236 return true;
1237 }
1238
MergeOutputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1239 bool InferenceContext::MergeOutputHandleShapesAndTypes(
1240 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1241 if (output_handle_shapes_and_types_[idx] == nullptr) {
1242 output_handle_shapes_and_types_[idx].reset(
1243 new std::vector<ShapeAndType>(shapes_and_types));
1244 return true;
1245 }
1246 return MergeHandleShapesAndTypes(shapes_and_types,
1247 output_handle_shapes_and_types_[idx].get());
1248 }
1249
MergeInputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1250 bool InferenceContext::MergeInputHandleShapesAndTypes(
1251 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1252 if (input_handle_shapes_and_types_[idx] == nullptr) {
1253 input_handle_shapes_and_types_[idx].reset(
1254 new std::vector<ShapeAndType>(shapes_and_types));
1255 return true;
1256 }
1257 return MergeHandleShapesAndTypes(shapes_and_types,
1258 input_handle_shapes_and_types_[idx].get());
1259 }
1260
RelaxHandleShapesAndMergeTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1261 bool InferenceContext::RelaxHandleShapesAndMergeTypes(
1262 const std::vector<ShapeAndType>& shapes_and_types,
1263 std::vector<ShapeAndType>* to_update) {
1264 if (shapes_and_types.size() != to_update->size()) {
1265 return false;
1266 }
1267 std::vector<ShapeAndType> new_values(shapes_and_types.size());
1268 for (int i = 0, end = shapes_and_types.size(); i < end; ++i) {
1269 const ShapeAndType& existing = (*to_update)[i];
1270 if (shapes_and_types[i].dtype == existing.dtype) {
1271 new_values[i].dtype = existing.dtype;
1272 } else {
1273 if (existing.dtype != DT_INVALID) {
1274 return false;
1275 } else {
1276 new_values[i].dtype = shapes_and_types[i].dtype;
1277 }
1278 }
1279 Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape);
1280 }
1281 to_update->swap(new_values);
1282 return true;
1283 }
1284
RelaxOutputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1285 bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes(
1286 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1287 if (output_handle_shapes_and_types_[idx] == nullptr) {
1288 output_handle_shapes_and_types_[idx].reset(
1289 new std::vector<ShapeAndType>(shapes_and_types));
1290 return true;
1291 }
1292 return RelaxHandleShapesAndMergeTypes(
1293 shapes_and_types, output_handle_shapes_and_types_[idx].get());
1294 }
1295
RelaxInputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1296 bool InferenceContext::RelaxInputHandleShapesAndMergeTypes(
1297 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1298 if (input_handle_shapes_and_types_[idx] == nullptr) {
1299 input_handle_shapes_and_types_[idx].reset(
1300 new std::vector<ShapeAndType>(shapes_and_types));
1301 return true;
1302 }
1303 return RelaxHandleShapesAndMergeTypes(
1304 shapes_and_types, input_handle_shapes_and_types_[idx].get());
1305 }
1306
1307 // -----------------------------------------------------------------------------
1308 // ShapeManager
1309 // -----------------------------------------------------------------------------
ShapeManager()1310 InferenceContext::ShapeManager::ShapeManager() {}
~ShapeManager()1311 InferenceContext::ShapeManager::~ShapeManager() {
1312 for (auto* s : all_shapes_) delete s;
1313 for (auto* d : all_dims_) delete d;
1314 }
1315
MakeShape(const std::vector<DimensionHandle> & dims)1316 ShapeHandle InferenceContext::ShapeManager::MakeShape(
1317 const std::vector<DimensionHandle>& dims) {
1318 all_shapes_.push_back(new Shape(dims));
1319 return all_shapes_.back();
1320 }
1321
UnknownShape()1322 ShapeHandle InferenceContext::ShapeManager::UnknownShape() {
1323 all_shapes_.push_back(new Shape());
1324 return all_shapes_.back();
1325 }
1326
1327 } // namespace shape_inference
1328 } // namespace tensorflow
1329