1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16
17 #include "tensorflow/core/framework/bounds_check.h"
18 #include "tensorflow/core/framework/full_type_util.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/op_def.pb.h"
21 #include "tensorflow/core/framework/partial_tensor_shape.h"
22 #include "tensorflow/core/framework/tensor_shape.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/strings/numbers.h"
25 #include "tensorflow/core/lib/strings/scanner.h"
26 #include "tensorflow/core/lib/strings/str_util.h"
27
28 namespace tensorflow {
29 namespace shape_inference {
30
31 constexpr int32_t InferenceContext::kUnknownRank;
32 constexpr int64_t InferenceContext::kUnknownDim;
33
34 // Same as above, but with PartialTensorShape instead of TensorShapeProto
InferenceContext(int graph_def_version,const AttrSlice & attrs,const OpDef & op_def,const std::vector<PartialTensorShape> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<PartialTensorShape> & input_tensors_as_shapes,const std::vector<std::unique_ptr<std::vector<std::pair<PartialTensorShape,DataType>>>> & input_handle_shapes_and_types)35 InferenceContext::InferenceContext(
36 int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
37 const std::vector<PartialTensorShape>& input_shapes,
38 const std::vector<const Tensor*>& input_tensors,
39 const std::vector<PartialTensorShape>& input_tensors_as_shapes,
40 const std::vector<
41 std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
42 input_handle_shapes_and_types)
43 : graph_def_version_(graph_def_version), attrs_(attrs) {
44 std::vector<ShapeHandle> input_tensors_as_shape_handles;
45 input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
46 for (const PartialTensorShape& p : input_tensors_as_shapes) {
47 ShapeHandle shape;
48 construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
49 if (!construction_status_.ok()) {
50 return;
51 }
52 input_tensors_as_shape_handles.push_back(shape);
53 }
54 PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
55 if (!construction_status_.ok()) return;
56 inputs_.reserve(input_shapes.size());
57 for (const PartialTensorShape& p : input_shapes) {
58 ShapeHandle shape;
59 construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
60 if (!construction_status_.ok()) {
61 return;
62 }
63 inputs_.push_back(shape);
64 }
65 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
66 input_shapes.size());
67 for (int i = 0, end = input_handle_shapes_and_types.size(); i < end; ++i) {
68 const auto& v = input_handle_shapes_and_types[i];
69 if (v == nullptr) {
70 continue;
71 }
72 handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
73 auto& new_v = *handle_data[i];
74 for (int j = 0, end = v->size(); j < end; ++j) {
75 const auto& p = (*v)[j];
76 construction_status_.Update(
77 MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape));
78 if (!construction_status_.ok()) {
79 return;
80 }
81 new_v[j].dtype = p.second;
82 }
83 }
84 PostInputInit(std::move(handle_data));
85 }
86
InferenceContext(int graph_def_version,const AttrSlice & attrs,const OpDef & op_def,const std::vector<ShapeHandle> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes,std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types)87 InferenceContext::InferenceContext(
88 int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
89 const std::vector<ShapeHandle>& input_shapes,
90 const std::vector<const Tensor*>& input_tensors,
91 const std::vector<ShapeHandle>& input_tensors_as_shapes,
92 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
93 input_handle_shapes_and_types)
94 : graph_def_version_(graph_def_version), attrs_(attrs) {
95 PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
96 if (!construction_status_.ok()) return;
97 inputs_ = input_shapes;
98
99 PostInputInit(std::move(input_handle_shapes_and_types));
100 }
101
~InferenceContext()102 InferenceContext::~InferenceContext() {}
103
Run(const std::function<Status (shape_inference::InferenceContext * c)> & fn)104 Status InferenceContext::Run(
105 const std::function<Status(shape_inference::InferenceContext* c)>& fn) {
106 ForgetMerges();
107 Status s = fn(this);
108 if (!s.ok()) {
109 ForgetMerges();
110 return AttachContext(s);
111 }
112 #ifndef NDEBUG
113 for (int i = 0; i < num_outputs(); ++i) {
114 DCHECK(output(i).IsSet()) << i << " for " << attrs_.SummarizeNode();
115 }
116 #endif // NDEBUG
117 return s;
118 }
119
set_output(StringPiece output_name,const std::vector<ShapeHandle> & shapes)120 Status InferenceContext::set_output(StringPiece output_name,
121 const std::vector<ShapeHandle>& shapes) {
122 auto result = output_name_map_.find(output_name);
123 if (result == output_name_map_.end()) {
124 return errors::InvalidArgument("Unknown output name: ", output_name);
125 } else {
126 const int start = result->second.first;
127 const int size = result->second.second - start;
128 const int shapes_size = shapes.size();
129 if (size != shapes_size) {
130 return errors::InvalidArgument("Must have exactly ", shapes.size(),
131 " shapes.");
132 }
133 for (int i = 0; i < shapes_size; ++i) {
134 outputs_[i + start] = shapes[i];
135 }
136 }
137 return Status::OK();
138 }
139
input(StringPiece input_name,std::vector<ShapeHandle> * output) const140 Status InferenceContext::input(StringPiece input_name,
141 std::vector<ShapeHandle>* output) const {
142 const auto result = input_name_map_.find(input_name);
143 if (result == input_name_map_.end()) {
144 return errors::InvalidArgument("Unknown input name: ", input_name);
145 } else {
146 output->clear();
147 for (int i = result->second.first; i < result->second.second; ++i) {
148 output->push_back(inputs_[i]);
149 }
150 }
151 return Status::OK();
152 }
153
output(StringPiece output_name,std::vector<ShapeHandle> * output) const154 Status InferenceContext::output(StringPiece output_name,
155 std::vector<ShapeHandle>* output) const {
156 const auto result = output_name_map_.find(output_name);
157 if (result == output_name_map_.end()) {
158 return errors::InvalidArgument("Unknown output name: ", output_name);
159 } else {
160 output->clear();
161 for (int i = result->second.first; i < result->second.second; ++i) {
162 output->push_back(outputs_[i]);
163 }
164 }
165 return Status::OK();
166 }
167
PreInputInit(const OpDef & op_def,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes)168 void InferenceContext::PreInputInit(
169 const OpDef& op_def, const std::vector<const Tensor*>& input_tensors,
170 const std::vector<ShapeHandle>& input_tensors_as_shapes) {
171 // TODO(mdan): This is also done at graph construction. Run only here instead?
172 const auto ret = full_type::SpecializeType(attrs_, op_def);
173 DCHECK(ret.status().ok()) << "while instantiating types: " << ret.status();
174 ret_types_ = ret.ValueOrDie();
175
176 input_tensors_ = input_tensors;
177 input_tensors_as_shapes_ = input_tensors_as_shapes;
178
179 construction_status_ =
180 NameRangesForNode(attrs_, op_def, &input_name_map_, &output_name_map_);
181 if (!construction_status_.ok()) return;
182
183 int num_outputs = 0;
184 for (const auto& e : output_name_map_) {
185 num_outputs = std::max(num_outputs, e.second.second);
186 }
187 outputs_.assign(num_outputs, nullptr);
188 output_handle_shapes_and_types_.resize(num_outputs);
189 }
190
ExpandOutputs(int new_output_size)191 Status InferenceContext::ExpandOutputs(int new_output_size) {
192 const int outputs_size = outputs_.size();
193 if (new_output_size < outputs_size) {
194 return errors::InvalidArgument("Trying to reduce number of outputs of op.");
195 }
196 outputs_.resize(new_output_size, nullptr);
197 output_handle_shapes_and_types_.resize(new_output_size);
198 return Status::OK();
199 }
200
PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data)201 void InferenceContext::PostInputInit(
202 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
203 int num_inputs_from_node_def = 0;
204 for (const auto& e : input_name_map_) {
205 num_inputs_from_node_def =
206 std::max(num_inputs_from_node_def, e.second.second);
207 }
208
209 // Allow passing empty shapes/dtypes to avoid changing every single test.
210 if (input_handle_data.empty()) {
211 input_handle_shapes_and_types_.resize(inputs_.size());
212 } else {
213 if (input_handle_data.size() != inputs_.size()) {
214 construction_status_ = errors::InvalidArgument(
215 "Wrong number of handle shapes passed; expected ", inputs_.size(),
216 " got ", input_handle_data.size());
217 return;
218 }
219 input_handle_shapes_and_types_ = std::move(input_handle_data);
220 }
221 const int inputs_size = inputs_.size();
222 if (inputs_size != num_inputs_from_node_def) {
223 construction_status_ = errors::InvalidArgument(
224 "Wrong number of inputs passed: ", inputs_.size(), " while ",
225 num_inputs_from_node_def, " expected based on NodeDef");
226 return;
227 }
228
229 CHECK_LE(input_tensors_.size(), inputs_.size());
230 input_tensors_.resize(inputs_.size());
231 requested_input_tensor_.resize(inputs_.size());
232 requested_input_tensor_as_partial_shape_.resize(inputs_.size());
233 }
234
ShapeHandleToProto(ShapeHandle handle,TensorShapeProto * proto)235 void InferenceContext::ShapeHandleToProto(ShapeHandle handle,
236 TensorShapeProto* proto) {
237 if (!RankKnown(handle)) {
238 proto->set_unknown_rank(true);
239 return;
240 }
241
242 for (int32_t i = 0; i < Rank(handle); ++i) {
243 DimensionHandle dim = Dim(handle, i);
244 auto* dim_shape = proto->add_dim();
245 if (ValueKnown(dim)) {
246 dim_shape->set_size(Value(dim));
247 } else {
248 dim_shape->set_size(-1);
249 }
250 }
251 }
252
FullyDefined(ShapeHandle s)253 bool InferenceContext::FullyDefined(ShapeHandle s) {
254 if (!RankKnown(s)) return false;
255 for (int i = 0; i < Rank(s); ++i) {
256 if (!ValueKnown(Dim(s, i))) return false;
257 }
258 return true;
259 }
260
NumElements(ShapeHandle s)261 DimensionHandle InferenceContext::NumElements(ShapeHandle s) {
262 const auto rank = Rank(s);
263 if (rank == kUnknownRank) return UnknownDim();
264 bool found_unknown = false;
265 int64_t size = 1;
266 for (int i = 0; i < rank; ++i) {
267 int64_t dim_val = Value(Dim(s, i));
268 if (dim_val == kUnknownDim) {
269 found_unknown = true;
270 } else if (dim_val == 0) {
271 return MakeDim(0);
272 } else {
273 size *= dim_val;
274 }
275 }
276 if (found_unknown) {
277 return UnknownDim();
278 } else {
279 return MakeDim(size);
280 }
281 }
282
DebugString(ShapeHandle s)283 string InferenceContext::DebugString(ShapeHandle s) {
284 if (RankKnown(s)) {
285 std::vector<string> vals;
286 for (auto d : s->dims_) vals.push_back(DebugString(d));
287 return strings::StrCat("[", absl::StrJoin(vals, ","), "]");
288 } else {
289 return "?";
290 }
291 }
292
DebugString(DimensionHandle d)293 string InferenceContext::DebugString(DimensionHandle d) {
294 return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
295 }
296
DebugString() const297 string InferenceContext::DebugString() const {
298 return strings::StrCat("InferenceContext for node: ", attrs_.SummarizeNode());
299 }
300
DebugString(const ShapeAndType & shape_and_type)301 string InferenceContext::DebugString(const ShapeAndType& shape_and_type) {
302 return strings::StrCat(DebugString(shape_and_type.shape), ":",
303 DataTypeString(shape_and_type.dtype));
304 }
305
DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types)306 string InferenceContext::DebugString(
307 gtl::ArraySlice<ShapeAndType> shape_and_types) {
308 std::vector<string> pieces;
309 for (const ShapeAndType& s : shape_and_types) {
310 pieces.push_back(DebugString(s));
311 }
312 return strings::StrCat("[", absl::StrJoin(pieces, ","), "]");
313 }
314
WithRank(ShapeHandle shape,int64_t rank,ShapeHandle * out)315 Status InferenceContext::WithRank(ShapeHandle shape, int64_t rank,
316 ShapeHandle* out) {
317 if (rank > kint32max) {
318 return errors::InvalidArgument("Rank cannot exceed kint32max");
319 }
320 const int32_t existing = Rank(shape);
321 if (existing == rank) {
322 *out = shape;
323 return Status::OK();
324 }
325 if (existing == kUnknownRank) {
326 std::vector<DimensionHandle> dims;
327 dims.reserve(rank);
328 for (int i = 0; i < rank; ++i) {
329 dims.push_back(UnknownDim());
330 }
331 ShapeHandle shp = shape_manager_.MakeShape(dims);
332 return Merge(shape, shp, out);
333 }
334 *out = nullptr;
335
336 return errors::InvalidArgument("Shape must be rank ", rank, " but is rank ",
337 existing);
338 }
339
WithRankAtLeast(ShapeHandle shape,int64_t rank,ShapeHandle * out)340 Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64_t rank,
341 ShapeHandle* out) {
342 if (rank > kint32max) {
343 return errors::InvalidArgument("Rank cannot exceed kint32max");
344 }
345 const int32_t existing = Rank(shape);
346 if (existing >= rank || existing == kUnknownRank) {
347 *out = shape;
348 return Status::OK();
349 }
350 *out = nullptr;
351 return errors::InvalidArgument("Shape must be at least rank ", rank,
352 " but is rank ", existing);
353 }
354
WithRankAtMost(ShapeHandle shape,int64_t rank,ShapeHandle * out)355 Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64_t rank,
356 ShapeHandle* out) {
357 if (rank > kint32max) {
358 return errors::InvalidArgument("Rank cannot exceed kint32max");
359 }
360 const int32_t existing = Rank(shape);
361 if (existing <= rank || existing == kUnknownRank) {
362 *out = shape;
363 return Status::OK();
364 }
365 *out = nullptr;
366 return errors::InvalidArgument("Shape must be at most rank ", rank,
367 " but is rank ", existing);
368 }
369
WithValue(DimensionHandle dim,int64_t value,DimensionHandle * out)370 Status InferenceContext::WithValue(DimensionHandle dim, int64_t value,
371 DimensionHandle* out) {
372 const int64_t existing = Value(dim);
373 if (existing == value) {
374 *out = dim;
375 return Status::OK();
376 }
377 if (existing == kUnknownDim) {
378 DimensionHandle d = MakeDim(value);
379 return Merge(dim, d, out);
380 }
381 *out = nullptr;
382 return errors::InvalidArgument("Dimension must be ", value, " but is ",
383 existing);
384 }
385
Relax(DimensionHandle d_old,DimensionHandle d_new,DimensionHandle * out)386 void InferenceContext::Relax(DimensionHandle d_old, DimensionHandle d_new,
387 DimensionHandle* out) {
388 if (d_old.SameHandle(d_new)) {
389 *out = d_old;
390 } else if (!ValueKnown(d_old) && !ValueKnown(d_new)) {
391 // The node will be fed by the dimension d_new instead of d_old: any
392 // equality assertion between d_old and other input dimension on this node
393 // may not be true anymore, so forget them all.
394 ForgetMerges();
395 // Return the new shape handle to force the relaxation to propagate to the
396 // fanout of the context.
397 *out = d_new;
398 } else if (!ValueKnown(d_new)) {
399 ForgetMerges();
400 *out = d_new;
401 } else if (Value(d_old) == Value(d_new)) {
402 // Return the old shape handle. This will stop the relaxation in the fanout
403 // of the context.
404 *out = d_old;
405 } else {
406 // Return a new handle that encodes a different unknown dim.
407 ForgetMerges();
408 *out = UnknownDim();
409 }
410 }
411
Merge(DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)412 Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1,
413 DimensionHandle* out) {
414 if (d0.SameHandle(d1)) {
415 *out = d0;
416 return Status::OK();
417 } else if (!ValueKnown(d1)) {
418 *out = d0;
419 merged_dims_.emplace_back(d0, d1);
420 return Status::OK();
421 } else if (!ValueKnown(d0)) {
422 *out = d1;
423 merged_dims_.emplace_back(d0, d1);
424 return Status::OK();
425 } else if (Value(d0) == Value(d1)) {
426 *out = d0;
427 return Status::OK();
428 } else {
429 *out = nullptr;
430 return errors::InvalidArgument("Dimensions must be equal, but are ",
431 Value(d0), " and ", Value(d1));
432 }
433 }
434
MergePrefix(ShapeHandle s,ShapeHandle prefix,ShapeHandle * s_out,ShapeHandle * prefix_out)435 Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix,
436 ShapeHandle* s_out,
437 ShapeHandle* prefix_out) {
438 *s_out = *prefix_out = nullptr;
439 if (!RankKnown(prefix) || !RankKnown(s)) {
440 *s_out = s;
441 *prefix_out = prefix;
442 return Status::OK();
443 }
444 const int32_t rank = Rank(prefix);
445 TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s));
446
447 // Merge the prefix dims and create the new output shapes.
448 const int32_t rank_s = Rank(s);
449 std::vector<DimensionHandle> dims;
450 dims.reserve(std::max(rank, rank_s));
451 dims.resize(rank);
452 for (int i = 0; i < rank; ++i) {
453 TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i]));
454 }
455 *prefix_out = MakeShape(dims);
456 for (int i = rank; i < rank_s; ++i) dims.push_back(Dim(s, i));
457 *s_out = MakeShape(dims);
458 return Status::OK();
459 }
460
Relax(ShapeHandle s_old,ShapeHandle s_new,ShapeHandle * out)461 void InferenceContext::Relax(ShapeHandle s_old, ShapeHandle s_new,
462 ShapeHandle* out) {
463 if (s_old.SameHandle(s_new)) {
464 *out = s_old;
465 return;
466 } else if (!RankKnown(s_new) || !s_old.IsSet()) {
467 ForgetMerges();
468 *out = s_new;
469 return;
470 }
471
472 const int32_t rank = Rank(s_old);
473 if (rank != Rank(s_new)) {
474 ForgetMerges();
475 *out = UnknownShape();
476 return;
477 }
478
479 bool return_s_old = true;
480 for (int i = 0; i < rank; ++i) {
481 auto d0 = Dim(s_old, i);
482 auto d1 = Dim(s_new, i);
483 if (d0.SameHandle(d1)) continue;
484
485 auto v0 = Value(d0);
486 auto v1 = Value(d1);
487 if (v0 == kUnknownDim || v1 == kUnknownDim || v0 != v1) {
488 return_s_old = false;
489 break;
490 }
491 }
492 if (return_s_old) {
493 *out = s_old;
494 return;
495 }
496
497 // Relax dims.
498 std::vector<DimensionHandle> dims(rank);
499 for (int i = 0; i < rank; ++i) {
500 Relax(Dim(s_old, i), Dim(s_new, i), &dims[i]);
501 }
502 ForgetMerges();
503 *out = MakeShape(dims);
504 }
505
Merge(ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)506 Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
507 ShapeHandle* out) {
508 if (s0.SameHandle(s1)) {
509 *out = s0;
510 return Status::OK();
511 } else if (!RankKnown(s1)) {
512 *out = s0;
513 merged_shapes_.emplace_back(s0, s1);
514 return Status::OK();
515 } else if (!RankKnown(s0)) {
516 *out = s1;
517 merged_shapes_.emplace_back(s0, s1);
518 return Status::OK();
519 }
520
521 const int32_t rank = Rank(s0);
522 if (rank != Rank(s1)) {
523 *out = nullptr;
524 return errors::InvalidArgument("Shapes must be equal rank, but are ", rank,
525 " and ", Rank(s1));
526 }
527
528 bool return_s0 = true;
529 bool return_s1 = true;
530 for (int i = 0; i < rank; ++i) {
531 auto d0 = Dim(s0, i);
532 auto d1 = Dim(s1, i);
533 if (d0.SameHandle(d1)) continue;
534
535 auto v0 = Value(d0);
536 auto v1 = Value(d1);
537 if (v0 == kUnknownDim) {
538 if (v1 != kUnknownDim) {
539 return_s0 = false;
540 }
541 } else if (v1 == kUnknownDim) {
542 return_s1 = false;
543 } else if (v0 != v1) {
544 *out = nullptr;
545 return errors::InvalidArgument(
546 "Dimension ", i, " in both shapes must be equal, but are ", Value(d0),
547 " and ", Value(d1), ". Shapes are ", DebugString(s0), " and ",
548 DebugString(s1), ".");
549 }
550 }
551
552 merged_shapes_.emplace_back(s0, s1);
553
554 if (return_s0 || return_s1) {
555 *out = return_s0 ? s0 : s1;
556 return Status::OK();
557 }
558
559 // Merge dims.
560 std::vector<DimensionHandle> dims(rank, nullptr);
561 for (int i = 0; i < rank; ++i) {
562 // Invariant for merge was checked earlier, so CHECK is ok.
563 TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i]));
564 }
565
566 Status s = ReturnCreatedShape(dims, out);
567 if (s.ok()) {
568 // Merge the new shape with s0. Since s0 and s1 are merged, this implies
569 // that s1 and out are also merged.
570 merged_shapes_.emplace_back(s0, *out);
571 }
572 return s;
573 }
574
Subshape(ShapeHandle s,int64_t start,ShapeHandle * out)575 Status InferenceContext::Subshape(ShapeHandle s, int64_t start,
576 ShapeHandle* out) {
577 return Subshape(s, start, std::numeric_limits<int64>::max() /* end */, out);
578 }
579
Subshape(ShapeHandle s,int64_t start,int64_t end,ShapeHandle * out)580 Status InferenceContext::Subshape(ShapeHandle s, int64_t start, int64_t end,
581 ShapeHandle* out) {
582 return Subshape(s, start, end, 1 /* stride */, out);
583 }
584
Subshape(ShapeHandle s,int64_t start,int64_t end,int64_t stride,ShapeHandle * out)585 Status InferenceContext::Subshape(ShapeHandle s, int64_t start, int64_t end,
586 int64_t stride, ShapeHandle* out) {
587 int64_t start_in = start;
588 int64_t end_in = end;
589
590 const int32_t rank = Rank(s);
591 if (start == 0 && stride == 1 &&
592 ((RankKnown(s) && end >= rank) ||
593 end == std::numeric_limits<int64>::max())) {
594 *out = s;
595 return Status::OK();
596 }
597 if (!RankKnown(s)) {
598 return ReturnUnknownShape(out);
599 }
600
601 if (start > rank) start = rank;
602 if (end > rank) end = rank;
603
604 if (stride < 0 && start == rank) --start;
605
606 if (start < 0) {
607 start = rank + start;
608 if (start < 0) {
609 *out = nullptr;
610 return errors::InvalidArgument("Subshape start out of bounds: ", start_in,
611 ", for shape with rank ", rank);
612 }
613 }
614
615 if (end < 0) {
616 end = rank + end;
617 if (end < 0) {
618 *out = nullptr;
619 return errors::InvalidArgument("Subshape end out of bounds: ", end_in,
620 ", for shape with rank ", rank);
621 }
622 }
623 if (stride > 0 && start > end) {
624 *out = nullptr;
625 return errors::InvalidArgument(
626 "Subshape must have computed start <= end, but is ", start, " and ",
627 end, " (computed from start ", start_in, " and end ", end_in,
628 " over shape with rank ", rank, ")");
629 } else if (stride < 0 && start < end) {
630 *out = nullptr;
631 return errors::InvalidArgument(
632 "Subshape must have computed start >= end since stride is negative, "
633 "but is ",
634 start, " and ", end, " (computed from start ", start_in, " and end ",
635 end_in, " over shape with rank ", rank, " and stride", stride, ")");
636 }
637
638 std::vector<DimensionHandle> dims;
639 for (int i = start; stride > 0 ? i < end : i > end; i += stride) {
640 dims.push_back(Dim(s, i));
641 }
642 return ReturnCreatedShape(dims, out);
643 }
644
Concatenate(ShapeHandle s1,ShapeHandle s2,ShapeHandle * out)645 Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2,
646 ShapeHandle* out) {
647 if (!RankKnown(s1) || !RankKnown(s2)) {
648 return ReturnUnknownShape(out);
649 }
650 const int32_t s1_rank = Rank(s1);
651 const int32_t s2_rank = Rank(s2);
652 const int32_t rank = s1_rank + s2_rank;
653 std::vector<DimensionHandle> dims;
654 dims.reserve(rank);
655 for (int i = 0; i < s1_rank; ++i) dims.push_back(Dim(s1, i));
656 for (int i = 0; i < s2_rank; ++i) dims.push_back(Dim(s2, i));
657 return ReturnCreatedShape(dims, out);
658 }
659
ReplaceDim(ShapeHandle s,int64_t dim_index_in,DimensionHandle new_dim,ShapeHandle * out)660 Status InferenceContext::ReplaceDim(ShapeHandle s, int64_t dim_index_in,
661 DimensionHandle new_dim, ShapeHandle* out) {
662 if (!RankKnown(s)) {
663 return ReturnUnknownShape(out);
664 }
665 int64_t dim_index = dim_index_in;
666 if (dim_index < 0) {
667 dim_index = s->dims_.size() + dim_index;
668 }
669 if (!FastBoundsCheck(dim_index, s->dims_.size())) {
670 *out = nullptr;
671 return errors::InvalidArgument("Out of range dim_index ", dim_index_in,
672 " for shape with ", s->dims_.size(),
673 " dimensions");
674 }
675 std::vector<DimensionHandle> dims(s->dims_);
676 dims[dim_index] = new_dim;
677 return ReturnCreatedShape(dims, out);
678 }
679
MakeShape(const std::vector<DimensionHandle> & dims)680 ShapeHandle InferenceContext::MakeShape(
681 const std::vector<DimensionHandle>& dims) {
682 return shape_manager_.MakeShape(dims);
683 }
684
MakeShape(std::initializer_list<DimensionOrConstant> dims)685 ShapeHandle InferenceContext::MakeShape(
686 std::initializer_list<DimensionOrConstant> dims) {
687 std::vector<DimensionHandle> dims_actual;
688 dims_actual.reserve(dims.size());
689 for (const DimensionOrConstant& d : dims) {
690 dims_actual.push_back(MakeDim(d));
691 }
692
693 return shape_manager_.MakeShape(dims_actual);
694 }
695
UnknownShape()696 ShapeHandle InferenceContext::UnknownShape() {
697 return shape_manager_.UnknownShape();
698 }
699
UnknownShapeOfRank(int64_t rank)700 ShapeHandle InferenceContext::UnknownShapeOfRank(int64_t rank) {
701 CHECK_LE(rank, kint32max) << "rank must be less than kint32max";
702 if (rank == kUnknownRank) {
703 return UnknownShape();
704 }
705 CHECK_GE(rank, 0) << "rank must not be negative";
706 std::vector<DimensionHandle> dims(rank);
707 for (int32_t i = 0; i < rank; ++i) {
708 dims[i] = UnknownDim();
709 }
710 return MakeShape(dims);
711 }
712
Scalar()713 ShapeHandle InferenceContext::Scalar() { return MakeShape({}); }
714
Vector(DimensionOrConstant dim)715 ShapeHandle InferenceContext::Vector(DimensionOrConstant dim) {
716 return MakeShape({dim});
717 }
718
Matrix(DimensionOrConstant dim1,DimensionOrConstant dim2)719 ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
720 DimensionOrConstant dim2) {
721 return MakeShape({dim1, dim2});
722 }
723
MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,ShapeHandle * out)724 Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
725 int input_idx, ShapeHandle* out) {
726 ShapeHandle input_shape;
727 TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
728
729 request_input_tensor_as_partial_shape(input_idx);
730 const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
731 if (input_idx < input_tensors_as_shapes_size &&
732 input_tensors_as_shapes_[input_idx].IsSet() &&
733 RankKnown(input_tensors_as_shapes_[input_idx])) {
734 *out = input_tensors_as_shapes_[input_idx];
735 return Status::OK();
736 }
737
738 return InternalMakeShapeFromTensor(
739 true /* treat_unknown_scalar_tensor_as_unknown_shape */,
740 input_tensor(input_idx), input_shape, out);
741 }
742
MakeShapeFromShapeTensor(int input_idx,ShapeHandle * out)743 Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
744 ShapeHandle* out) {
745 ShapeHandle input_shape;
746 TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
747
748 request_input_tensor_as_partial_shape(input_idx);
749 const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
750 if (input_idx < input_tensors_as_shapes_size &&
751 input_tensors_as_shapes_[input_idx].IsSet() &&
752 RankKnown(input_tensors_as_shapes_[input_idx])) {
753 *out = input_tensors_as_shapes_[input_idx];
754 return Status::OK();
755 }
756
757 return InternalMakeShapeFromTensor(
758 false /* treat_unknown_scalar_tensor_as_unknown_shape */,
759 input_tensor(input_idx), input_shape, out);
760 }
761
MakeShapeFromTensor(const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)762 Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
763 ShapeHandle tensor_shape,
764 ShapeHandle* out) {
765 return InternalMakeShapeFromTensor(
766 false /* treat_unknown_scalar_tensor_as_unknown_shape */, t, tensor_shape,
767 out);
768 }
769
InternalMakeShapeFromTensor(bool treat_unknown_scalar_tensor_as_unknown_shape,const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)770 Status InferenceContext::InternalMakeShapeFromTensor(
771 bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
772 ShapeHandle tensor_shape, ShapeHandle* out) {
773 // Only callers who have set
774 if (!treat_unknown_scalar_tensor_as_unknown_shape) {
775 TF_RETURN_IF_ERROR(WithRank(tensor_shape, 1, &tensor_shape));
776 }
777 if (t == nullptr) {
778 // This is guarded by the check above.
779 if (Rank(tensor_shape) == 0) {
780 return ReturnUnknownShape(out);
781 }
782 // Shape tensor is not known, but if the shape of the shape tensor is then
783 // the right number of unknown dims can be created.
784 DimensionHandle shape_dim = Dim(tensor_shape, 0);
785 if (!ValueKnown(shape_dim)) {
786 return ReturnUnknownShape(out);
787 }
788 const auto num_dims = Value(shape_dim);
789 std::vector<DimensionHandle> dims;
790 dims.reserve(num_dims);
791 for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim());
792 return ReturnCreatedShape(dims, out);
793 }
794
795 if (t->shape().dims() == 0) {
796 if (t->dtype() == DataType::DT_INT32) {
797 auto flat_t = t->scalar<int32>();
798 if (flat_t() != -1) {
799 *out = nullptr;
800 return errors::InvalidArgument(
801 "Input tensor must be rank 1, or if its rank 0 it must have value "
802 "-1 "
803 "(representing an unknown shape). Saw value: ",
804 flat_t());
805 }
806 return ReturnUnknownShape(out);
807 } else if (t->dtype() == DataType::DT_INT64) {
808 auto flat_t = t->scalar<int64>();
809 if (flat_t() != -1) {
810 *out = nullptr;
811 return errors::InvalidArgument(
812 "Input tensor must be rank 1, or if its rank 0 it must have value "
813 "-1 "
814 "(representing an unknown shape). Saw value: ",
815 flat_t());
816 }
817 return ReturnUnknownShape(out);
818 } else {
819 *out = nullptr;
820 return errors::InvalidArgument(
821 "Input tensor must be int32 or int64, but was ",
822 DataTypeString(t->dtype()));
823 }
824 }
825
826 if (t->shape().dims() != 1) {
827 *out = nullptr;
828 return errors::InvalidArgument(
829 "Input tensor must be rank 1, but was rank ", t->shape().dims(), ".",
830 ((t->shape().dims() == 0)
831 ? "If it is rank 0 rank 0 it must have statically known value -1 "
832 "(representing an unknown shape). "
833 : " "),
834 "Saw tensor shape ", t->shape().DebugString());
835 }
836 std::vector<DimensionHandle> dims;
837 if (t->dtype() == DataType::DT_INT32) {
838 auto flat_t = t->flat<int32>();
839 for (int i = 0; i < flat_t.size(); ++i) {
840 const int32_t val = flat_t(i);
841 if (val < -1) {
842 return errors::InvalidArgument(
843 "Invalid value in tensor used for shape: ", val);
844 }
845 // -1 will become an unknown dim.
846 dims.push_back(MakeDim(val));
847 }
848 } else if (t->dtype() == DataType::DT_INT64) {
849 auto flat_t = t->flat<int64>();
850 for (int i = 0; i < flat_t.size(); ++i) {
851 const int64_t val = flat_t(i);
852 if (val < -1) {
853 return errors::InvalidArgument(
854 "Invalid value in tensor used for shape: ", val);
855 }
856 // -1 will become an unknown dim.
857 dims.push_back(MakeDim(val));
858 }
859 } else {
860 *out = nullptr;
861 return errors::InvalidArgument(
862 "Input tensor must be int32 or int64, but was ",
863 DataTypeString(t->dtype()));
864 }
865
866 return ReturnCreatedShape(dims, out);
867 }
868
MakeShapeFromPartialTensorShape(const PartialTensorShape & partial_shape,ShapeHandle * out)869 Status InferenceContext::MakeShapeFromPartialTensorShape(
870 const PartialTensorShape& partial_shape, ShapeHandle* out) {
871 *out = nullptr;
872 if (partial_shape.dims() == -1) {
873 return ReturnUnknownShape(out);
874 }
875 const int num_dims = partial_shape.dims();
876 std::vector<DimensionHandle> dims(num_dims);
877 for (int i = 0; i < num_dims; ++i) {
878 // -1 is unknown in PartialTensorShape and in InferenceContext, so this size
879 // can be passed directly to MakeDim.
880 dims[i] = MakeDim(partial_shape.dim_size(i));
881 }
882 return ReturnCreatedShape(dims, out);
883 }
884
MakeShapeFromTensorShape(const TensorShape & shape,ShapeHandle * out)885 Status InferenceContext::MakeShapeFromTensorShape(const TensorShape& shape,
886 ShapeHandle* out) {
887 return MakeShapeFromPartialTensorShape(PartialTensorShape(shape.dim_sizes()),
888 out);
889 }
890
MakeShapeFromShapeProto(const TensorShapeProto & proto,ShapeHandle * out)891 Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
892 ShapeHandle* out) {
893 *out = nullptr;
894 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto));
895 PartialTensorShape partial_shape(proto);
896 return MakeShapeFromPartialTensorShape(partial_shape, out);
897 }
898
GetScalarFromTensor(const Tensor * t,int64 * val)899 Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
900 // Caller must ensure that <t> is not NULL.
901 const int rank = t->dims();
902 if (rank != 0) {
903 return errors::InvalidArgument("Input must be scalar but has rank ", rank);
904 }
905
906 if (t->dtype() == DataType::DT_INT32) {
907 *val = t->scalar<int32>()();
908 return Status::OK();
909 } else if (t->dtype() == DataType::DT_INT64) {
910 *val = t->scalar<int64>()();
911 return Status::OK();
912 } else {
913 return errors::InvalidArgument("Scalar input must be int32 or int64.");
914 }
915 }
916
GetScalarFromTensor(const Tensor * t,int64_t idx,int64 * val)917 Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64_t idx,
918 int64* val) {
919 // Caller must ensure that <t> is not NULL.
920 const int rank = t->dims();
921 if (rank != 1) {
922 return errors::InvalidArgument("Input must be 1D but has rank ", rank);
923 }
924
925 if (t->dtype() == DataType::DT_INT32) {
926 auto flat_t = t->flat<int32>();
927 if (idx < 0 || idx >= flat_t.size()) {
928 return errors::InvalidArgument("Invalid index ", idx,
929 " for Tensor of size ", flat_t.size());
930 }
931 *val = flat_t(idx);
932 return Status::OK();
933 } else if (t->dtype() == DataType::DT_INT64) {
934 auto flat_t = t->flat<int64>();
935 if (idx < 0 || idx >= flat_t.size()) {
936 return errors::InvalidArgument("Invalid index ", idx,
937 " for Tensor of size ", flat_t.size());
938 }
939 *val = flat_t(idx);
940 return Status::OK();
941 } else {
942 return errors::InvalidArgument("Tensor input must be int32 or int64.");
943 }
944 }
945
946 // Returns a new dimension whose value is given by a scalar input tensor.
MakeDimForScalarInput(int idx,DimensionHandle * out)947 Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
948 int64_t val;
949 const Tensor* t = input_tensor(idx);
950 if (t == nullptr) {
951 *out = UnknownDim();
952 return Status::OK();
953 }
954 TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
955 if (val < 0) {
956 return errors::InvalidArgument("Dimension size, given by scalar input ",
957 idx, ", must be non-negative but is ", val);
958 }
959 *out = MakeDim(val);
960 return Status::OK();
961 }
962
MakeDimForScalarInputWithNegativeIndexing(int idx,int input_rank,DimensionHandle * out)963 Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing(
964 int idx, int input_rank, DimensionHandle* out) {
965 int64_t val;
966 const Tensor* t = input_tensor(idx);
967 if (t == nullptr) {
968 *out = UnknownDim();
969 return Status::OK();
970 }
971 TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
972 if (val < 0) {
973 if (input_rank < 0) {
974 *out = UnknownDim();
975 return Status::OK();
976 } else if (val + input_rank < 0) {
977 return errors::InvalidArgument("Dimension size, given by scalar input ",
978 val, " must be in range [-", input_rank,
979 ", ", input_rank, ")");
980 } else {
981 val += input_rank;
982 }
983 } else if (input_rank >= 0 && val >= input_rank) {
984 return errors::InvalidArgument("Dimension size, given by scalar input ",
985 val, " must be in range [-", input_rank,
986 ", ", input_rank, ")");
987 }
988 *out = MakeDim(val);
989 return Status::OK();
990 }
991
Divide(DimensionHandle dividend,DimensionOrConstant divisor,bool evenly_divisible,DimensionHandle * out)992 Status InferenceContext::Divide(DimensionHandle dividend,
993 DimensionOrConstant divisor,
994 bool evenly_divisible, DimensionHandle* out) {
995 const int64_t divisor_value = Value(divisor);
996 if (divisor_value == 1) {
997 *out = dividend;
998 } else if (!ValueKnown(dividend) ||
999 (divisor.dim.IsSet() && !ValueKnown(divisor.dim))) {
1000 *out = UnknownDim();
1001 } else {
1002 const int64_t v = Value(dividend);
1003 if (divisor_value <= 0) {
1004 return errors::InvalidArgument("Divisor must be positive but is ",
1005 divisor_value);
1006 }
1007 if (evenly_divisible && (v % divisor_value) != 0) {
1008 return errors::InvalidArgument(
1009 "Dimension size must be evenly divisible by ", divisor_value,
1010 " but is ", v);
1011 }
1012 *out = MakeDim(v / divisor_value);
1013 }
1014 return Status::OK();
1015 }
1016
Add(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1017 Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second,
1018 DimensionHandle* out) {
1019 const int64_t first_value = Value(first);
1020 const int64_t second_value = Value(second);
1021 // Special cases.
1022 if (first_value == 0) {
1023 *out = MakeDim(second);
1024 } else if (second_value == 0) {
1025 *out = first;
1026 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1027 *out = UnknownDim();
1028 } else {
1029 // Invariant: Both values are known and positive. Still in run-time we can
1030 // get pair of values which cannot be store in output. Check below will
1031 // report error. We still need to avoid undefined behavior of signed
1032 // overflow and use unsigned addition.
1033 const int64_t sum = static_cast<uint64>(first_value) + second_value;
1034 if (sum < 0) {
1035 return errors::InvalidArgument("Dimension size overflow from adding ",
1036 first_value, " and ", second_value);
1037 }
1038 *out = MakeDim(sum);
1039 }
1040 return Status::OK();
1041 }
1042
Subtract(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1043 Status InferenceContext::Subtract(DimensionHandle first,
1044 DimensionOrConstant second,
1045 DimensionHandle* out) {
1046 const int64_t first_value = Value(first);
1047 const int64_t second_value = Value(second);
1048 // Special cases.
1049 if (second_value == 0) {
1050 *out = first;
1051 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1052 *out = UnknownDim();
1053 } else {
1054 // Invariant: Both values are known, first_value is non-negative, and
1055 // second_value is positive.
1056 if (first_value < second_value) {
1057 return errors::InvalidArgument(
1058 "Negative dimension size caused by subtracting ", second_value,
1059 " from ", first_value);
1060 }
1061 *out = MakeDim(first_value - second_value);
1062 }
1063 return Status::OK();
1064 }
1065
Multiply(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1066 Status InferenceContext::Multiply(DimensionHandle first,
1067 DimensionOrConstant second,
1068 DimensionHandle* out) {
1069 const int64_t first_value = Value(first);
1070 const int64_t second_value = Value(second);
1071 // Special cases.
1072 if (first_value == 0) {
1073 *out = first;
1074 } else if (second_value == 0) {
1075 *out = MakeDim(second);
1076 } else if (first_value == 1) {
1077 *out = MakeDim(second);
1078 } else if (second_value == 1) {
1079 *out = first;
1080 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1081 *out = UnknownDim();
1082 } else {
1083 // Invariant: Both values are known and greater than 1.
1084 const int64_t product = first_value * second_value;
1085 if (product < 0) {
1086 return errors::InvalidArgument(
1087 "Negative dimension size caused by overflow when multiplying ",
1088 first_value, " and ", second_value);
1089 }
1090 *out = MakeDim(product);
1091 }
1092 return Status::OK();
1093 }
1094
Min(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1095 Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second,
1096 DimensionHandle* out) {
1097 const int64_t first_value = Value(first);
1098 const int64_t second_value = Value(second);
1099 if (first_value == 0) {
1100 *out = first;
1101 } else if (second_value == 0) {
1102 *out = MakeDim(second);
1103 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1104 *out = UnknownDim();
1105 } else {
1106 if (first_value <= second_value) {
1107 *out = first;
1108 } else {
1109 *out = MakeDim(second);
1110 }
1111 }
1112 return Status::OK();
1113 }
1114
Max(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1115 Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second,
1116 DimensionHandle* out) {
1117 const int64_t first_value = Value(first);
1118 const int64_t second_value = Value(second);
1119 if (first_value == kUnknownDim || second_value == kUnknownDim) {
1120 *out = UnknownDim();
1121 } else {
1122 if (first_value >= second_value) {
1123 *out = first;
1124 } else {
1125 *out = MakeDim(second);
1126 }
1127 }
1128 return Status::OK();
1129 }
1130
AttachContext(const Status & status)1131 Status InferenceContext::AttachContext(const Status& status) {
1132 std::vector<string> input_shapes;
1133 input_shapes.reserve(inputs_.size());
1134 for (const ShapeHandle& input_shape : inputs_) {
1135 input_shapes.emplace_back(DebugString(input_shape));
1136 }
1137
1138 // Add information about the input tensors and partial tensor shapes used.
1139 std::vector<string> input_from_tensors_str;
1140 std::vector<string> input_from_tensors_as_shape_str;
1141 input_from_tensors_as_shape_str.reserve(inputs_.size());
1142 for (int i = 0, end = inputs_.size(); i < end; ++i) {
1143 const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
1144 const int input_tensors_size = input_tensors_.size();
1145 if (requested_input_tensor_as_partial_shape_[i] &&
1146 i < input_tensors_as_shapes_size &&
1147 input_tensors_as_shapes_[i].IsSet() &&
1148 RankKnown(input_tensors_as_shapes_[i])) {
1149 input_from_tensors_as_shape_str.push_back(strings::StrCat(
1150 "input[", i, "] = ", DebugString(input_tensors_as_shapes_[i])));
1151 } else if (requested_input_tensor_[i] && i < input_tensors_size &&
1152 input_tensors_[i] != nullptr) {
1153 input_from_tensors_str.push_back(strings::StrCat(
1154 "input[", i, "] = <",
1155 input_tensors_[i]->SummarizeValue(256 /* max_values */), ">"));
1156 }
1157 }
1158
1159 string error_context = strings::StrCat(
1160 " for '", attrs_.SummarizeNode(),
1161 "' with input shapes: ", absl::StrJoin(input_shapes, ", "));
1162 if (!input_from_tensors_str.empty()) {
1163 strings::StrAppend(&error_context, " and with computed input tensors: ",
1164 absl::StrJoin(input_from_tensors_str, ", "));
1165 }
1166 if (!input_from_tensors_as_shape_str.empty()) {
1167 strings::StrAppend(&error_context,
1168 " and with input tensors computed as partial shapes: ",
1169 absl::StrJoin(input_from_tensors_as_shape_str, ","));
1170 }
1171
1172 strings::StrAppend(&error_context, ".");
1173 return Status(status.code(),
1174 strings::StrCat(status.error_message(), error_context));
1175 }
1176
MergeHandleShapesAndTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1177 bool InferenceContext::MergeHandleShapesAndTypes(
1178 const std::vector<ShapeAndType>& shapes_and_types,
1179 std::vector<ShapeAndType>* to_update) {
1180 if (shapes_and_types.size() != to_update->size()) {
1181 return false;
1182 }
1183 std::vector<ShapeAndType> new_values(shapes_and_types.size());
1184 bool refined = false;
1185 for (int i = 0, end = shapes_and_types.size(); i < end; ++i) {
1186 const ShapeAndType& existing = (*to_update)[i];
1187 if (shapes_and_types[i].dtype == existing.dtype) {
1188 new_values[i].dtype = existing.dtype;
1189 } else {
1190 if (existing.dtype != DT_INVALID) {
1191 return false;
1192 } else {
1193 new_values[i].dtype = shapes_and_types[i].dtype;
1194 refined = true;
1195 }
1196 }
1197 if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape)
1198 .ok()) {
1199 // merge failed, ignore the new value.
1200 new_values[i].shape = existing.shape;
1201 }
1202 if (!existing.shape.SameHandle(new_values[i].shape)) {
1203 refined = true;
1204 }
1205 }
1206 if (!refined) {
1207 return false;
1208 }
1209 for (int i = 0, end = new_values.size(); i < end; ++i) {
1210 (*to_update)[i] = new_values[i];
1211 }
1212 return true;
1213 }
1214
MergeOutputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1215 bool InferenceContext::MergeOutputHandleShapesAndTypes(
1216 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1217 if (output_handle_shapes_and_types_[idx] == nullptr) {
1218 output_handle_shapes_and_types_[idx].reset(
1219 new std::vector<ShapeAndType>(shapes_and_types));
1220 return true;
1221 }
1222 return MergeHandleShapesAndTypes(shapes_and_types,
1223 output_handle_shapes_and_types_[idx].get());
1224 }
1225
MergeInputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1226 bool InferenceContext::MergeInputHandleShapesAndTypes(
1227 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1228 if (input_handle_shapes_and_types_[idx] == nullptr) {
1229 input_handle_shapes_and_types_[idx].reset(
1230 new std::vector<ShapeAndType>(shapes_and_types));
1231 return true;
1232 }
1233 return MergeHandleShapesAndTypes(shapes_and_types,
1234 input_handle_shapes_and_types_[idx].get());
1235 }
1236
RelaxHandleShapesAndMergeTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1237 bool InferenceContext::RelaxHandleShapesAndMergeTypes(
1238 const std::vector<ShapeAndType>& shapes_and_types,
1239 std::vector<ShapeAndType>* to_update) {
1240 if (shapes_and_types.size() != to_update->size()) {
1241 return false;
1242 }
1243 std::vector<ShapeAndType> new_values(shapes_and_types.size());
1244 for (int i = 0, end = shapes_and_types.size(); i < end; ++i) {
1245 const ShapeAndType& existing = (*to_update)[i];
1246 if (shapes_and_types[i].dtype == existing.dtype) {
1247 new_values[i].dtype = existing.dtype;
1248 } else {
1249 if (existing.dtype != DT_INVALID) {
1250 return false;
1251 } else {
1252 new_values[i].dtype = shapes_and_types[i].dtype;
1253 }
1254 }
1255 Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape);
1256 }
1257 to_update->swap(new_values);
1258 return true;
1259 }
1260
RelaxOutputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1261 bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes(
1262 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1263 if (output_handle_shapes_and_types_[idx] == nullptr) {
1264 output_handle_shapes_and_types_[idx].reset(
1265 new std::vector<ShapeAndType>(shapes_and_types));
1266 return true;
1267 }
1268 return RelaxHandleShapesAndMergeTypes(
1269 shapes_and_types, output_handle_shapes_and_types_[idx].get());
1270 }
1271
RelaxInputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1272 bool InferenceContext::RelaxInputHandleShapesAndMergeTypes(
1273 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1274 if (input_handle_shapes_and_types_[idx] == nullptr) {
1275 input_handle_shapes_and_types_[idx].reset(
1276 new std::vector<ShapeAndType>(shapes_and_types));
1277 return true;
1278 }
1279 return RelaxHandleShapesAndMergeTypes(
1280 shapes_and_types, input_handle_shapes_and_types_[idx].get());
1281 }
1282
1283 // -----------------------------------------------------------------------------
1284 // ShapeManager
1285 // -----------------------------------------------------------------------------
ShapeManager()1286 InferenceContext::ShapeManager::ShapeManager() {}
~ShapeManager()1287 InferenceContext::ShapeManager::~ShapeManager() {
1288 for (auto* s : all_shapes_) delete s;
1289 for (auto* d : all_dims_) delete d;
1290 }
1291
MakeShape(const std::vector<DimensionHandle> & dims)1292 ShapeHandle InferenceContext::ShapeManager::MakeShape(
1293 const std::vector<DimensionHandle>& dims) {
1294 all_shapes_.push_back(new Shape(dims));
1295 return all_shapes_.back();
1296 }
1297
UnknownShape()1298 ShapeHandle InferenceContext::ShapeManager::UnknownShape() {
1299 all_shapes_.push_back(new Shape());
1300 return all_shapes_.back();
1301 }
1302
1303 } // namespace shape_inference
1304 } // namespace tensorflow
1305