1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16
17 #include "tensorflow/core/framework/bounds_check.h"
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/strings/numbers.h"
23 #include "tensorflow/core/lib/strings/scanner.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25
26 namespace tensorflow {
27 namespace shape_inference {
28
29 constexpr int32 InferenceContext::kUnknownRank;
30 constexpr int64 InferenceContext::kUnknownDim;
31
32 // Same as above, but with PartialTensorShape instead of TensorShapeProto
InferenceContext(int graph_def_version,const AttrSlice & attrs,const OpDef & op_def,const std::vector<PartialTensorShape> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<PartialTensorShape> & input_tensors_as_shapes,const std::vector<std::unique_ptr<std::vector<std::pair<PartialTensorShape,DataType>>>> & input_handle_shapes_and_types)33 InferenceContext::InferenceContext(
34 int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
35 const std::vector<PartialTensorShape>& input_shapes,
36 const std::vector<const Tensor*>& input_tensors,
37 const std::vector<PartialTensorShape>& input_tensors_as_shapes,
38 const std::vector<
39 std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
40 input_handle_shapes_and_types)
41 : graph_def_version_(graph_def_version), attrs_(attrs) {
42 std::vector<ShapeHandle> input_tensors_as_shape_handles;
43 input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
44 for (const PartialTensorShape& p : input_tensors_as_shapes) {
45 ShapeHandle shape;
46 construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
47 if (!construction_status_.ok()) {
48 return;
49 }
50 input_tensors_as_shape_handles.push_back(shape);
51 }
52 PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
53 if (!construction_status_.ok()) return;
54 inputs_.reserve(input_shapes.size());
55 for (const PartialTensorShape& p : input_shapes) {
56 ShapeHandle shape;
57 construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
58 if (!construction_status_.ok()) {
59 return;
60 }
61 inputs_.push_back(shape);
62 }
63 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
64 input_shapes.size());
65 for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) {
66 const auto& v = input_handle_shapes_and_types[i];
67 if (v == nullptr) {
68 continue;
69 }
70 handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
71 auto& new_v = *handle_data[i];
72 for (int j = 0; j < v->size(); ++j) {
73 const auto& p = (*v)[j];
74 construction_status_.Update(
75 MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape));
76 if (!construction_status_.ok()) {
77 return;
78 }
79 new_v[j].dtype = p.second;
80 }
81 }
82 PostInputInit(std::move(handle_data));
83 }
84
InferenceContext(int graph_def_version,const AttrSlice & attrs,const OpDef & op_def,const std::vector<ShapeHandle> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes,std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types)85 InferenceContext::InferenceContext(
86 int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
87 const std::vector<ShapeHandle>& input_shapes,
88 const std::vector<const Tensor*>& input_tensors,
89 const std::vector<ShapeHandle>& input_tensors_as_shapes,
90 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
91 input_handle_shapes_and_types)
92 : graph_def_version_(graph_def_version), attrs_(attrs) {
93 PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
94 if (!construction_status_.ok()) return;
95 inputs_ = input_shapes;
96
97 PostInputInit(std::move(input_handle_shapes_and_types));
98 }
99
~InferenceContext()100 InferenceContext::~InferenceContext() {}
101
Run(const std::function<Status (shape_inference::InferenceContext * c)> & fn)102 Status InferenceContext::Run(
103 const std::function<Status(shape_inference::InferenceContext* c)>& fn) {
104 ForgetMerges();
105 Status s = fn(this);
106 if (!s.ok()) {
107 ForgetMerges();
108 return AttachContext(s);
109 }
110 #ifndef NDEBUG
111 for (int i = 0; i < num_outputs(); ++i) {
112 DCHECK(output(i).IsSet()) << i << " for " << attrs_.SummarizeNode();
113 }
114 #endif // NDEBUG
115 return s;
116 }
117
set_output(StringPiece output_name,const std::vector<ShapeHandle> & shapes)118 Status InferenceContext::set_output(StringPiece output_name,
119 const std::vector<ShapeHandle>& shapes) {
120 auto result = output_name_map_.find(output_name);
121 if (result == output_name_map_.end()) {
122 return errors::InvalidArgument("Unknown output name: ", output_name);
123 } else {
124 const int start = result->second.first;
125 const int size = result->second.second - start;
126 if (size != shapes.size()) {
127 return errors::InvalidArgument("Must have exactly ", shapes.size(),
128 " shapes.");
129 }
130 for (int i = 0; i < size; ++i) {
131 outputs_[i + start] = shapes[i];
132 }
133 }
134 return Status::OK();
135 }
136
input(StringPiece input_name,std::vector<ShapeHandle> * output) const137 Status InferenceContext::input(StringPiece input_name,
138 std::vector<ShapeHandle>* output) const {
139 const auto result = input_name_map_.find(input_name);
140 if (result == input_name_map_.end()) {
141 return errors::InvalidArgument("Unknown input name: ", input_name);
142 } else {
143 output->clear();
144 for (int i = result->second.first; i < result->second.second; ++i) {
145 output->push_back(inputs_[i]);
146 }
147 }
148 return Status::OK();
149 }
150
output(StringPiece output_name,std::vector<ShapeHandle> * output) const151 Status InferenceContext::output(StringPiece output_name,
152 std::vector<ShapeHandle>* output) const {
153 const auto result = output_name_map_.find(output_name);
154 if (result == output_name_map_.end()) {
155 return errors::InvalidArgument("Unknown output name: ", output_name);
156 } else {
157 output->clear();
158 for (int i = result->second.first; i < result->second.second; ++i) {
159 output->push_back(outputs_[i]);
160 }
161 }
162 return Status::OK();
163 }
164
PreInputInit(const OpDef & op_def,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes)165 void InferenceContext::PreInputInit(
166 const OpDef& op_def, const std::vector<const Tensor*>& input_tensors,
167 const std::vector<ShapeHandle>& input_tensors_as_shapes) {
168 input_tensors_ = input_tensors;
169 input_tensors_as_shapes_ = input_tensors_as_shapes;
170
171 construction_status_ =
172 NameRangesForNode(attrs_, op_def, &input_name_map_, &output_name_map_);
173 if (!construction_status_.ok()) return;
174
175 int num_outputs = 0;
176 for (const auto& e : output_name_map_) {
177 num_outputs = std::max(num_outputs, e.second.second);
178 }
179 outputs_.assign(num_outputs, nullptr);
180 output_handle_shapes_and_types_.resize(num_outputs);
181 }
182
ExpandOutputs(int new_output_size)183 Status InferenceContext::ExpandOutputs(int new_output_size) {
184 if (new_output_size < outputs_.size()) {
185 return errors::InvalidArgument("Trying to reduce number of outputs of op.");
186 }
187 outputs_.resize(new_output_size, nullptr);
188 output_handle_shapes_and_types_.resize(new_output_size);
189 return Status::OK();
190 }
191
PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data)192 void InferenceContext::PostInputInit(
193 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
194 int num_inputs_from_node_def = 0;
195 for (const auto& e : input_name_map_) {
196 num_inputs_from_node_def =
197 std::max(num_inputs_from_node_def, e.second.second);
198 }
199
200 // Allow passing empty shapes/dtypes to avoid changing every single test.
201 if (input_handle_data.empty()) {
202 input_handle_shapes_and_types_.resize(inputs_.size());
203 } else {
204 if (input_handle_data.size() != inputs_.size()) {
205 construction_status_ = errors::InvalidArgument(
206 "Wrong number of handle shapes passed; expected ", inputs_.size(),
207 " got ", input_handle_data.size());
208 return;
209 }
210 input_handle_shapes_and_types_ = std::move(input_handle_data);
211 }
212
213 if (inputs_.size() != num_inputs_from_node_def) {
214 construction_status_ = errors::InvalidArgument(
215 "Wrong number of inputs passed: ", inputs_.size(), " while ",
216 num_inputs_from_node_def, " expected based on NodeDef");
217 return;
218 }
219
220 CHECK_LE(input_tensors_.size(), inputs_.size());
221 input_tensors_.resize(inputs_.size());
222 requested_input_tensor_.resize(inputs_.size());
223 requested_input_tensor_as_partial_shape_.resize(inputs_.size());
224 }
225
ShapeHandleToProto(ShapeHandle handle,TensorShapeProto * proto)226 void InferenceContext::ShapeHandleToProto(ShapeHandle handle,
227 TensorShapeProto* proto) {
228 if (!RankKnown(handle)) {
229 proto->set_unknown_rank(true);
230 return;
231 }
232
233 for (int32 i = 0; i < Rank(handle); ++i) {
234 DimensionHandle dim = Dim(handle, i);
235 auto* dim_shape = proto->add_dim();
236 if (ValueKnown(dim)) {
237 dim_shape->set_size(Value(dim));
238 } else {
239 dim_shape->set_size(-1);
240 }
241 }
242 }
243
FullyDefined(ShapeHandle s)244 bool InferenceContext::FullyDefined(ShapeHandle s) {
245 if (!RankKnown(s)) return false;
246 for (int i = 0; i < Rank(s); ++i) {
247 if (!ValueKnown(Dim(s, i))) return false;
248 }
249 return true;
250 }
251
NumElements(ShapeHandle s)252 DimensionHandle InferenceContext::NumElements(ShapeHandle s) {
253 const auto rank = Rank(s);
254 if (rank == kUnknownRank) return UnknownDim();
255 bool found_unknown = false;
256 int64 size = 1;
257 for (int i = 0; i < rank; ++i) {
258 int64 dim_val = Value(Dim(s, i));
259 if (dim_val == kUnknownDim) {
260 found_unknown = true;
261 } else if (dim_val == 0) {
262 return MakeDim(0);
263 } else {
264 size *= dim_val;
265 }
266 }
267 if (found_unknown) {
268 return UnknownDim();
269 } else {
270 return MakeDim(size);
271 }
272 }
273
DebugString(ShapeHandle s)274 string InferenceContext::DebugString(ShapeHandle s) {
275 if (RankKnown(s)) {
276 std::vector<string> vals;
277 for (auto d : s->dims_) vals.push_back(DebugString(d));
278 return strings::StrCat("[", absl::StrJoin(vals, ","), "]");
279 } else {
280 return "?";
281 }
282 }
283
DebugString(DimensionHandle d)284 string InferenceContext::DebugString(DimensionHandle d) {
285 return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
286 }
287
DebugString() const288 string InferenceContext::DebugString() const {
289 return strings::StrCat("InferenceContext for node: ", attrs_.SummarizeNode());
290 }
291
DebugString(const ShapeAndType & shape_and_type)292 string InferenceContext::DebugString(const ShapeAndType& shape_and_type) {
293 return strings::StrCat(DebugString(shape_and_type.shape), ":",
294 DataTypeString(shape_and_type.dtype));
295 }
296
DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types)297 string InferenceContext::DebugString(
298 gtl::ArraySlice<ShapeAndType> shape_and_types) {
299 std::vector<string> pieces;
300 for (const ShapeAndType& s : shape_and_types) {
301 pieces.push_back(DebugString(s));
302 }
303 return strings::StrCat("[", absl::StrJoin(pieces, ","), "]");
304 }
305
WithRank(ShapeHandle shape,int64 rank,ShapeHandle * out)306 Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
307 ShapeHandle* out) {
308 if (rank > kint32max) {
309 return errors::InvalidArgument("Rank cannot exceed kint32max");
310 }
311 const int32 existing = Rank(shape);
312 if (existing == rank) {
313 *out = shape;
314 return Status::OK();
315 }
316 if (existing == kUnknownRank) {
317 std::vector<DimensionHandle> dims;
318 dims.reserve(rank);
319 for (int i = 0; i < rank; ++i) {
320 dims.push_back(UnknownDim());
321 }
322 ShapeHandle shp = shape_manager_.MakeShape(dims);
323 return Merge(shape, shp, out);
324 }
325 *out = nullptr;
326
327 return errors::InvalidArgument("Shape must be rank ", rank, " but is rank ",
328 existing);
329 }
330
WithRankAtLeast(ShapeHandle shape,int64 rank,ShapeHandle * out)331 Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank,
332 ShapeHandle* out) {
333 if (rank > kint32max) {
334 return errors::InvalidArgument("Rank cannot exceed kint32max");
335 }
336 const int32 existing = Rank(shape);
337 if (existing >= rank || existing == kUnknownRank) {
338 *out = shape;
339 return Status::OK();
340 }
341 *out = nullptr;
342 return errors::InvalidArgument("Shape must be at least rank ", rank,
343 " but is rank ", existing);
344 }
345
WithRankAtMost(ShapeHandle shape,int64 rank,ShapeHandle * out)346 Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank,
347 ShapeHandle* out) {
348 if (rank > kint32max) {
349 return errors::InvalidArgument("Rank cannot exceed kint32max");
350 }
351 const int32 existing = Rank(shape);
352 if (existing <= rank || existing == kUnknownRank) {
353 *out = shape;
354 return Status::OK();
355 }
356 *out = nullptr;
357 return errors::InvalidArgument("Shape must be at most rank ", rank,
358 " but is rank ", existing);
359 }
360
WithValue(DimensionHandle dim,int64 value,DimensionHandle * out)361 Status InferenceContext::WithValue(DimensionHandle dim, int64 value,
362 DimensionHandle* out) {
363 const int64 existing = Value(dim);
364 if (existing == value) {
365 *out = dim;
366 return Status::OK();
367 }
368 if (existing == kUnknownDim) {
369 DimensionHandle d = MakeDim(value);
370 return Merge(dim, d, out);
371 }
372 *out = nullptr;
373 return errors::InvalidArgument("Dimension must be ", value, " but is ",
374 existing);
375 }
376
Relax(DimensionHandle d_old,DimensionHandle d_new,DimensionHandle * out)377 void InferenceContext::Relax(DimensionHandle d_old, DimensionHandle d_new,
378 DimensionHandle* out) {
379 if (d_old.SameHandle(d_new)) {
380 *out = d_old;
381 } else if (!ValueKnown(d_old) && !ValueKnown(d_new)) {
382 // The node will be fed by the dimension d_new instead of d_old: any
383 // equality assertion between d_old and other input dimension on this node
384 // may not be true anymore, so forget them all.
385 ForgetMerges();
386 // Return the new shape handle to force the relaxation to propagate to the
387 // fanout of the context.
388 *out = d_new;
389 } else if (!ValueKnown(d_new)) {
390 ForgetMerges();
391 *out = d_new;
392 } else if (Value(d_old) == Value(d_new)) {
393 // Return the old shape handle. This will stop the relaxation in the fanout
394 // of the context.
395 *out = d_old;
396 } else {
397 // Return a new handle that encodes a different unknown dim.
398 ForgetMerges();
399 *out = UnknownDim();
400 }
401 }
402
Merge(DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)403 Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1,
404 DimensionHandle* out) {
405 if (d0.SameHandle(d1)) {
406 *out = d0;
407 return Status::OK();
408 } else if (!ValueKnown(d1)) {
409 *out = d0;
410 merged_dims_.emplace_back(d0, d1);
411 return Status::OK();
412 } else if (!ValueKnown(d0)) {
413 *out = d1;
414 merged_dims_.emplace_back(d0, d1);
415 return Status::OK();
416 } else if (Value(d0) == Value(d1)) {
417 *out = d0;
418 return Status::OK();
419 } else {
420 *out = nullptr;
421 return errors::InvalidArgument("Dimensions must be equal, but are ",
422 Value(d0), " and ", Value(d1));
423 }
424 }
425
MergePrefix(ShapeHandle s,ShapeHandle prefix,ShapeHandle * s_out,ShapeHandle * prefix_out)426 Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix,
427 ShapeHandle* s_out,
428 ShapeHandle* prefix_out) {
429 *s_out = *prefix_out = nullptr;
430 if (!RankKnown(prefix) || !RankKnown(s)) {
431 *s_out = s;
432 *prefix_out = prefix;
433 return Status::OK();
434 }
435 const int32 rank = Rank(prefix);
436 TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s));
437
438 // Merge the prefix dims and create the new output shapes.
439 const int32 rank_s = Rank(s);
440 std::vector<DimensionHandle> dims;
441 dims.reserve(std::max(rank, rank_s));
442 dims.resize(rank);
443 for (int i = 0; i < rank; ++i) {
444 TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i]));
445 }
446 *prefix_out = MakeShape(dims);
447 for (int i = rank; i < rank_s; ++i) dims.push_back(Dim(s, i));
448 *s_out = MakeShape(dims);
449 return Status::OK();
450 }
451
Relax(ShapeHandle s_old,ShapeHandle s_new,ShapeHandle * out)452 void InferenceContext::Relax(ShapeHandle s_old, ShapeHandle s_new,
453 ShapeHandle* out) {
454 if (s_old.SameHandle(s_new)) {
455 *out = s_old;
456 return;
457 } else if (!RankKnown(s_new) || !s_old.IsSet()) {
458 ForgetMerges();
459 *out = s_new;
460 return;
461 }
462
463 const int32 rank = Rank(s_old);
464 if (rank != Rank(s_new)) {
465 ForgetMerges();
466 *out = UnknownShape();
467 return;
468 }
469
470 bool return_s_old = true;
471 for (int i = 0; i < rank; ++i) {
472 auto d0 = Dim(s_old, i);
473 auto d1 = Dim(s_new, i);
474 if (d0.SameHandle(d1)) continue;
475
476 auto v0 = Value(d0);
477 auto v1 = Value(d1);
478 if (v0 == kUnknownDim || v1 == kUnknownDim || v0 != v1) {
479 return_s_old = false;
480 break;
481 }
482 }
483 if (return_s_old) {
484 *out = s_old;
485 return;
486 }
487
488 // Relax dims.
489 std::vector<DimensionHandle> dims(rank);
490 for (int i = 0; i < rank; ++i) {
491 Relax(Dim(s_old, i), Dim(s_new, i), &dims[i]);
492 }
493 ForgetMerges();
494 *out = MakeShape(dims);
495 }
496
Merge(ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)497 Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
498 ShapeHandle* out) {
499 if (s0.SameHandle(s1)) {
500 *out = s0;
501 return Status::OK();
502 } else if (!RankKnown(s1)) {
503 *out = s0;
504 merged_shapes_.emplace_back(s0, s1);
505 return Status::OK();
506 } else if (!RankKnown(s0)) {
507 *out = s1;
508 merged_shapes_.emplace_back(s0, s1);
509 return Status::OK();
510 }
511
512 const int32 rank = Rank(s0);
513 if (rank != Rank(s1)) {
514 *out = nullptr;
515 return errors::InvalidArgument("Shapes must be equal rank, but are ", rank,
516 " and ", Rank(s1));
517 }
518
519 bool return_s0 = true;
520 bool return_s1 = true;
521 for (int i = 0; i < rank; ++i) {
522 auto d0 = Dim(s0, i);
523 auto d1 = Dim(s1, i);
524 if (d0.SameHandle(d1)) continue;
525
526 auto v0 = Value(d0);
527 auto v1 = Value(d1);
528 if (v0 == kUnknownDim) {
529 if (v1 != kUnknownDim) {
530 return_s0 = false;
531 }
532 } else if (v1 == kUnknownDim) {
533 return_s1 = false;
534 } else if (v0 != v1) {
535 *out = nullptr;
536 return errors::InvalidArgument(
537 "Dimension ", i, " in both shapes must be equal, but are ", Value(d0),
538 " and ", Value(d1), ". Shapes are ", DebugString(s0), " and ",
539 DebugString(s1), ".");
540 }
541 }
542
543 merged_shapes_.emplace_back(s0, s1);
544
545 if (return_s0 || return_s1) {
546 *out = return_s0 ? s0 : s1;
547 return Status::OK();
548 }
549
550 // Merge dims.
551 std::vector<DimensionHandle> dims(rank, nullptr);
552 for (int i = 0; i < rank; ++i) {
553 // Invariant for merge was checked earlier, so CHECK is ok.
554 TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i]));
555 }
556
557 Status s = ReturnCreatedShape(dims, out);
558 if (s.ok()) {
559 // Merge the new shape with s0. Since s0 and s1 are merged, this implies
560 // that s1 and out are also merged.
561 merged_shapes_.emplace_back(s0, *out);
562 }
563 return s;
564 }
565
Subshape(ShapeHandle s,int64 start,ShapeHandle * out)566 Status InferenceContext::Subshape(ShapeHandle s, int64 start,
567 ShapeHandle* out) {
568 return Subshape(s, start, std::numeric_limits<int64>::max() /* end */, out);
569 }
570
Subshape(ShapeHandle s,int64 start,int64 end,ShapeHandle * out)571 Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
572 ShapeHandle* out) {
573 return Subshape(s, start, end, 1 /* stride */, out);
574 }
575
Subshape(ShapeHandle s,int64 start,int64 end,int64 stride,ShapeHandle * out)576 Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
577 int64 stride, ShapeHandle* out) {
578 int64 start_in = start;
579 int64 end_in = end;
580
581 const int32 rank = Rank(s);
582 if (start == 0 && stride == 1 &&
583 ((RankKnown(s) && end >= rank) ||
584 end == std::numeric_limits<int64>::max())) {
585 *out = s;
586 return Status::OK();
587 }
588 if (!RankKnown(s)) {
589 return ReturnUnknownShape(out);
590 }
591
592 if (start > rank) start = rank;
593 if (end > rank) end = rank;
594
595 if (stride < 0 && start == rank) --start;
596
597 if (start < 0) {
598 start = rank + start;
599 if (start < 0) {
600 *out = nullptr;
601 return errors::InvalidArgument("Subshape start out of bounds: ", start_in,
602 ", for shape with rank ", rank);
603 }
604 }
605
606 if (end < 0) {
607 end = rank + end;
608 if (end < 0) {
609 *out = nullptr;
610 return errors::InvalidArgument("Subshape end out of bounds: ", end_in,
611 ", for shape with rank ", rank);
612 }
613 }
614 if (stride > 0 && start > end) {
615 *out = nullptr;
616 return errors::InvalidArgument(
617 "Subshape must have computed start <= end, but is ", start, " and ",
618 end, " (computed from start ", start_in, " and end ", end_in,
619 " over shape with rank ", rank, ")");
620 } else if (stride < 0 && start < end) {
621 *out = nullptr;
622 return errors::InvalidArgument(
623 "Subshape must have computed start >= end since stride is negative, "
624 "but is ",
625 start, " and ", end, " (computed from start ", start_in, " and end ",
626 end_in, " over shape with rank ", rank, " and stride", stride, ")");
627 }
628
629 std::vector<DimensionHandle> dims;
630 for (int i = start; stride > 0 ? i < end : i > end; i += stride) {
631 dims.push_back(Dim(s, i));
632 }
633 return ReturnCreatedShape(dims, out);
634 }
635
Concatenate(ShapeHandle s1,ShapeHandle s2,ShapeHandle * out)636 Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2,
637 ShapeHandle* out) {
638 if (!RankKnown(s1) || !RankKnown(s2)) {
639 return ReturnUnknownShape(out);
640 }
641 const int32 s1_rank = Rank(s1);
642 const int32 s2_rank = Rank(s2);
643 const int32 rank = s1_rank + s2_rank;
644 std::vector<DimensionHandle> dims;
645 dims.reserve(rank);
646 for (int i = 0; i < s1_rank; ++i) dims.push_back(Dim(s1, i));
647 for (int i = 0; i < s2_rank; ++i) dims.push_back(Dim(s2, i));
648 return ReturnCreatedShape(dims, out);
649 }
650
ReplaceDim(ShapeHandle s,int64 dim_index_in,DimensionHandle new_dim,ShapeHandle * out)651 Status InferenceContext::ReplaceDim(ShapeHandle s, int64 dim_index_in,
652 DimensionHandle new_dim, ShapeHandle* out) {
653 if (!RankKnown(s)) {
654 return ReturnUnknownShape(out);
655 }
656 int64 dim_index = dim_index_in;
657 if (dim_index < 0) {
658 dim_index = s->dims_.size() + dim_index;
659 }
660 if (!FastBoundsCheck(dim_index, s->dims_.size())) {
661 *out = nullptr;
662 return errors::InvalidArgument("Out of range dim_index ", dim_index_in,
663 " for shape with ", s->dims_.size(),
664 " dimensions");
665 }
666 std::vector<DimensionHandle> dims(s->dims_);
667 dims[dim_index] = new_dim;
668 return ReturnCreatedShape(dims, out);
669 }
670
MakeShape(const std::vector<DimensionHandle> & dims)671 ShapeHandle InferenceContext::MakeShape(
672 const std::vector<DimensionHandle>& dims) {
673 return shape_manager_.MakeShape(dims);
674 }
675
MakeShape(std::initializer_list<DimensionOrConstant> dims)676 ShapeHandle InferenceContext::MakeShape(
677 std::initializer_list<DimensionOrConstant> dims) {
678 std::vector<DimensionHandle> dims_actual;
679 dims_actual.reserve(dims.size());
680 for (const DimensionOrConstant& d : dims) {
681 dims_actual.push_back(MakeDim(d));
682 }
683
684 return shape_manager_.MakeShape(dims_actual);
685 }
686
UnknownShape()687 ShapeHandle InferenceContext::UnknownShape() {
688 return shape_manager_.UnknownShape();
689 }
690
UnknownShapeOfRank(int64 rank)691 ShapeHandle InferenceContext::UnknownShapeOfRank(int64 rank) {
692 CHECK_LE(rank, kint32max) << "rank must be less than kint32max";
693 if (rank == kUnknownRank) {
694 return UnknownShape();
695 }
696 CHECK_GE(rank, 0) << "rank must not be negative";
697 std::vector<DimensionHandle> dims(rank);
698 for (int32 i = 0; i < rank; ++i) {
699 dims[i] = UnknownDim();
700 }
701 return MakeShape(dims);
702 }
703
Scalar()704 ShapeHandle InferenceContext::Scalar() { return MakeShape({}); }
705
Vector(DimensionOrConstant dim)706 ShapeHandle InferenceContext::Vector(DimensionOrConstant dim) {
707 return MakeShape({dim});
708 }
709
Matrix(DimensionOrConstant dim1,DimensionOrConstant dim2)710 ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
711 DimensionOrConstant dim2) {
712 return MakeShape({dim1, dim2});
713 }
714
MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,ShapeHandle * out)715 Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
716 int input_idx, ShapeHandle* out) {
717 ShapeHandle input_shape;
718 TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
719
720 requested_input_tensor_as_partial_shape_[input_idx] = true;
721 if (input_idx < input_tensors_as_shapes_.size() &&
722 input_tensors_as_shapes_[input_idx].IsSet() &&
723 RankKnown(input_tensors_as_shapes_[input_idx])) {
724 *out = input_tensors_as_shapes_[input_idx];
725 return Status::OK();
726 }
727
728 return InternalMakeShapeFromTensor(
729 true /* treat_unknown_scalar_tensor_as_unknown_shape */,
730 input_tensor(input_idx), input_shape, out);
731 }
732
MakeShapeFromShapeTensor(int input_idx,ShapeHandle * out)733 Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
734 ShapeHandle* out) {
735 ShapeHandle input_shape;
736 TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
737
738 requested_input_tensor_as_partial_shape_[input_idx] = true;
739 if (input_idx < input_tensors_as_shapes_.size() &&
740 input_tensors_as_shapes_[input_idx].IsSet() &&
741 RankKnown(input_tensors_as_shapes_[input_idx])) {
742 *out = input_tensors_as_shapes_[input_idx];
743 return Status::OK();
744 }
745
746 return InternalMakeShapeFromTensor(
747 false /* treat_unknown_scalar_tensor_as_unknown_shape */,
748 input_tensor(input_idx), input_shape, out);
749 }
750
MakeShapeFromTensor(const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)751 Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
752 ShapeHandle tensor_shape,
753 ShapeHandle* out) {
754 return InternalMakeShapeFromTensor(
755 false /* treat_unknown_scalar_tensor_as_unknown_shape */, t, tensor_shape,
756 out);
757 }
758
InternalMakeShapeFromTensor(bool treat_unknown_scalar_tensor_as_unknown_shape,const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)759 Status InferenceContext::InternalMakeShapeFromTensor(
760 bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
761 ShapeHandle tensor_shape, ShapeHandle* out) {
762 // Only callers who have set
763 if (!treat_unknown_scalar_tensor_as_unknown_shape) {
764 TF_RETURN_IF_ERROR(WithRank(tensor_shape, 1, &tensor_shape));
765 }
766 if (t == nullptr) {
767 // This is guarded by the check above.
768 if (Rank(tensor_shape) == 0) {
769 return ReturnUnknownShape(out);
770 }
771 // Shape tensor is not known, but if the shape of the shape tensor is then
772 // the right number of unknown dims can be created.
773 DimensionHandle shape_dim = Dim(tensor_shape, 0);
774 if (!ValueKnown(shape_dim)) {
775 return ReturnUnknownShape(out);
776 }
777 const auto num_dims = Value(shape_dim);
778 std::vector<DimensionHandle> dims;
779 dims.reserve(num_dims);
780 for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim());
781 return ReturnCreatedShape(dims, out);
782 }
783
784 if (t->shape().dims() == 0) {
785 if (t->dtype() == DataType::DT_INT32) {
786 auto flat_t = t->scalar<int32>();
787 if (flat_t() != -1) {
788 *out = nullptr;
789 return errors::InvalidArgument(
790 "Input tensor must be rank 1, or if its rank 0 it must have value "
791 "-1 "
792 "(representing an unknown shape). Saw value: ",
793 flat_t());
794 }
795 return ReturnUnknownShape(out);
796 } else if (t->dtype() == DataType::DT_INT64) {
797 auto flat_t = t->scalar<int64>();
798 if (flat_t() != -1) {
799 *out = nullptr;
800 return errors::InvalidArgument(
801 "Input tensor must be rank 1, or if its rank 0 it must have value "
802 "-1 "
803 "(representing an unknown shape). Saw value: ",
804 flat_t());
805 }
806 return ReturnUnknownShape(out);
807 } else {
808 *out = nullptr;
809 return errors::InvalidArgument(
810 "Input tensor must be int32 or int64, but was ",
811 DataTypeString(t->dtype()));
812 }
813 }
814
815 if (t->shape().dims() != 1) {
816 *out = nullptr;
817 return errors::InvalidArgument(
818 "Input tensor must be rank 1, but was rank ", t->shape().dims(), ".",
819 ((t->shape().dims() == 0)
820 ? "If it is rank 0 rank 0 it must have statically known value -1 "
821 "(representing an unknown shape). "
822 : " "),
823 "Saw tensor shape ", t->shape().DebugString());
824 }
825 std::vector<DimensionHandle> dims;
826 if (t->dtype() == DataType::DT_INT32) {
827 auto flat_t = t->flat<int32>();
828 for (int i = 0; i < flat_t.size(); ++i) {
829 const int32 val = flat_t(i);
830 if (val < -1) {
831 return errors::InvalidArgument(
832 "Invalid value in tensor used for shape: ", val);
833 }
834 // -1 will become an unknown dim.
835 dims.push_back(MakeDim(val));
836 }
837 } else if (t->dtype() == DataType::DT_INT64) {
838 auto flat_t = t->flat<int64>();
839 for (int i = 0; i < flat_t.size(); ++i) {
840 const int64 val = flat_t(i);
841 if (val < -1) {
842 return errors::InvalidArgument(
843 "Invalid value in tensor used for shape: ", val);
844 }
845 // -1 will become an unknown dim.
846 dims.push_back(MakeDim(val));
847 }
848 } else {
849 *out = nullptr;
850 return errors::InvalidArgument(
851 "Input tensor must be int32 or int64, but was ",
852 DataTypeString(t->dtype()));
853 }
854
855 return ReturnCreatedShape(dims, out);
856 }
857
MakeShapeFromPartialTensorShape(const PartialTensorShape & partial_shape,ShapeHandle * out)858 Status InferenceContext::MakeShapeFromPartialTensorShape(
859 const PartialTensorShape& partial_shape, ShapeHandle* out) {
860 *out = nullptr;
861 if (partial_shape.dims() == -1) {
862 return ReturnUnknownShape(out);
863 }
864 const int num_dims = partial_shape.dims();
865 std::vector<DimensionHandle> dims(num_dims);
866 for (int i = 0; i < num_dims; ++i) {
867 // -1 is unknown in PartialTensorShape and in InferenceContext, so this size
868 // can be passed directly to MakeDim.
869 dims[i] = MakeDim(partial_shape.dim_size(i));
870 }
871 return ReturnCreatedShape(dims, out);
872 }
873
MakeShapeFromTensorShape(const TensorShape & shape,ShapeHandle * out)874 Status InferenceContext::MakeShapeFromTensorShape(const TensorShape& shape,
875 ShapeHandle* out) {
876 return MakeShapeFromPartialTensorShape(PartialTensorShape(shape.dim_sizes()),
877 out);
878 }
879
MakeShapeFromShapeProto(const TensorShapeProto & proto,ShapeHandle * out)880 Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
881 ShapeHandle* out) {
882 *out = nullptr;
883 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto));
884 PartialTensorShape partial_shape(proto);
885 return MakeShapeFromPartialTensorShape(partial_shape, out);
886 }
887
GetScalarFromTensor(const Tensor * t,int64 * val)888 Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
889 // Caller must ensure that <t> is not NULL.
890 const int rank = t->dims();
891 if (rank != 0) {
892 return errors::InvalidArgument("Input must be scalar but has rank ", rank);
893 }
894
895 if (t->dtype() == DT_INT32) {
896 *val = t->scalar<int32>()();
897 return Status::OK();
898 } else if (t->dtype() == DT_INT64) {
899 *val = t->scalar<int64>()();
900 return Status::OK();
901 } else {
902 return errors::InvalidArgument("Scalar input must be int32 or int64.");
903 }
904 }
905
906 // Returns a new dimension whose value is given by a scalar input tensor.
MakeDimForScalarInput(int idx,DimensionHandle * out)907 Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
908 int64 val;
909 const Tensor* t = input_tensor(idx);
910 if (t == nullptr) {
911 *out = UnknownDim();
912 return Status::OK();
913 }
914 TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
915 if (val < 0) {
916 return errors::InvalidArgument("Dimension size, given by scalar input ",
917 idx, ", must be non-negative but is ", val);
918 }
919 *out = MakeDim(val);
920 return Status::OK();
921 }
922
MakeDimForScalarInputWithNegativeIndexing(int idx,int input_rank,DimensionHandle * out)923 Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing(
924 int idx, int input_rank, DimensionHandle* out) {
925 int64 val;
926 const Tensor* t = input_tensor(idx);
927 if (t == nullptr) {
928 *out = UnknownDim();
929 return Status::OK();
930 }
931 TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
932 if (val < 0) {
933 if (input_rank < 0) {
934 *out = UnknownDim();
935 return Status::OK();
936 } else if (val + input_rank < 0) {
937 return errors::InvalidArgument("Dimension size, given by scalar input ",
938 val, " must be in range [-", input_rank,
939 ", ", input_rank, ")");
940 } else {
941 val += input_rank;
942 }
943 } else if (input_rank >= 0 && val >= input_rank) {
944 return errors::InvalidArgument("Dimension size, given by scalar input ",
945 val, " must be in range [-", input_rank,
946 ", ", input_rank, ")");
947 }
948 *out = MakeDim(val);
949 return Status::OK();
950 }
951
Divide(DimensionHandle dividend,DimensionOrConstant divisor,bool evenly_divisible,DimensionHandle * out)952 Status InferenceContext::Divide(DimensionHandle dividend,
953 DimensionOrConstant divisor,
954 bool evenly_divisible, DimensionHandle* out) {
955 const int64 divisor_value = Value(divisor);
956 if (divisor_value == 1) {
957 *out = dividend;
958 } else if (!ValueKnown(dividend) ||
959 (divisor.dim.IsSet() && !ValueKnown(divisor.dim))) {
960 *out = UnknownDim();
961 } else {
962 const int64 v = Value(dividend);
963 if (divisor_value <= 0) {
964 return errors::InvalidArgument("Divisor must be positive but is ",
965 divisor_value);
966 }
967 if (evenly_divisible && (v % divisor_value) != 0) {
968 return errors::InvalidArgument(
969 "Dimension size must be evenly divisible by ", divisor_value,
970 " but is ", v);
971 }
972 *out = MakeDim(v / divisor_value);
973 }
974 return Status::OK();
975 }
976
Add(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)977 Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second,
978 DimensionHandle* out) {
979 const int64 first_value = Value(first);
980 const int64 second_value = Value(second);
981 // Special cases.
982 if (first_value == 0) {
983 *out = MakeDim(second);
984 } else if (second_value == 0) {
985 *out = first;
986 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
987 *out = UnknownDim();
988 } else {
989 // Invariant: Both values are known and positive. Still in run-time we can
990 // get pair of values which cannot be store in output. Check below will
991 // report error. We still need to avoid undefined behavior of signed
992 // overflow and use unsigned addition.
993 const int64 sum = static_cast<uint64>(first_value) + second_value;
994 if (sum < 0) {
995 return errors::InvalidArgument("Dimension size overflow from adding ",
996 first_value, " and ", second_value);
997 }
998 *out = MakeDim(sum);
999 }
1000 return Status::OK();
1001 }
1002
Subtract(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1003 Status InferenceContext::Subtract(DimensionHandle first,
1004 DimensionOrConstant second,
1005 DimensionHandle* out) {
1006 const int64 first_value = Value(first);
1007 const int64 second_value = Value(second);
1008 // Special cases.
1009 if (second_value == 0) {
1010 *out = first;
1011 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1012 *out = UnknownDim();
1013 } else {
1014 // Invariant: Both values are known, first_value is non-negative, and
1015 // second_value is positive.
1016 if (first_value < second_value) {
1017 return errors::InvalidArgument(
1018 "Negative dimension size caused by subtracting ", second_value,
1019 " from ", first_value);
1020 }
1021 *out = MakeDim(first_value - second_value);
1022 }
1023 return Status::OK();
1024 }
1025
Multiply(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1026 Status InferenceContext::Multiply(DimensionHandle first,
1027 DimensionOrConstant second,
1028 DimensionHandle* out) {
1029 const int64 first_value = Value(first);
1030 const int64 second_value = Value(second);
1031 // Special cases.
1032 if (first_value == 0) {
1033 *out = first;
1034 } else if (second_value == 0) {
1035 *out = MakeDim(second);
1036 } else if (first_value == 1) {
1037 *out = MakeDim(second);
1038 } else if (second_value == 1) {
1039 *out = first;
1040 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1041 *out = UnknownDim();
1042 } else {
1043 // Invariant: Both values are known and greater than 1.
1044 const int64 product = first_value * second_value;
1045 if (product < 0) {
1046 return errors::InvalidArgument(
1047 "Negative dimension size caused by overflow when multiplying ",
1048 first_value, " and ", second_value);
1049 }
1050 *out = MakeDim(product);
1051 }
1052 return Status::OK();
1053 }
1054
Min(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1055 Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second,
1056 DimensionHandle* out) {
1057 const int64 first_value = Value(first);
1058 const int64 second_value = Value(second);
1059 if (first_value == 0) {
1060 *out = first;
1061 } else if (second_value == 0) {
1062 *out = MakeDim(second);
1063 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1064 *out = UnknownDim();
1065 } else {
1066 if (first_value <= second_value) {
1067 *out = first;
1068 } else {
1069 *out = MakeDim(second);
1070 }
1071 }
1072 return Status::OK();
1073 }
1074
Max(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1075 Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second,
1076 DimensionHandle* out) {
1077 const int64 first_value = Value(first);
1078 const int64 second_value = Value(second);
1079 if (first_value == kUnknownDim || second_value == kUnknownDim) {
1080 *out = UnknownDim();
1081 } else {
1082 if (first_value >= second_value) {
1083 *out = first;
1084 } else {
1085 *out = MakeDim(second);
1086 }
1087 }
1088 return Status::OK();
1089 }
1090
AttachContext(const Status & status)1091 Status InferenceContext::AttachContext(const Status& status) {
1092 std::vector<string> input_shapes;
1093 input_shapes.reserve(inputs_.size());
1094 for (const ShapeHandle& input_shape : inputs_) {
1095 input_shapes.emplace_back(DebugString(input_shape));
1096 }
1097
1098 // Add information about the input tensors and partial tensor shapes used.
1099 std::vector<string> input_from_tensors_str;
1100 std::vector<string> input_from_tensors_as_shape_str;
1101 input_from_tensors_as_shape_str.reserve(inputs_.size());
1102 for (int i = 0; i < inputs_.size(); ++i) {
1103 if (requested_input_tensor_as_partial_shape_[i] &&
1104 i < input_tensors_as_shapes_.size() &&
1105 input_tensors_as_shapes_[i].IsSet() &&
1106 RankKnown(input_tensors_as_shapes_[i])) {
1107 input_from_tensors_as_shape_str.push_back(strings::StrCat(
1108 "input[", i, "] = ", DebugString(input_tensors_as_shapes_[i])));
1109 } else if (requested_input_tensor_[i] && i < input_tensors_.size() &&
1110 input_tensors_[i] != nullptr) {
1111 input_from_tensors_str.push_back(strings::StrCat(
1112 "input[", i, "] = <",
1113 input_tensors_[i]->SummarizeValue(256 /* max_values */), ">"));
1114 }
1115 }
1116
1117 string error_context = strings::StrCat(
1118 " for '", attrs_.SummarizeNode(),
1119 "' with input shapes: ", absl::StrJoin(input_shapes, ", "));
1120 if (!input_from_tensors_str.empty()) {
1121 strings::StrAppend(&error_context, " and with computed input tensors: ",
1122 absl::StrJoin(input_from_tensors_str, ", "));
1123 }
1124 if (!input_from_tensors_as_shape_str.empty()) {
1125 strings::StrAppend(&error_context,
1126 " and with input tensors computed as partial shapes: ",
1127 absl::StrJoin(input_from_tensors_as_shape_str, ","));
1128 }
1129
1130 strings::StrAppend(&error_context, ".");
1131 return Status(status.code(),
1132 strings::StrCat(status.error_message(), error_context));
1133 }
1134
MergeHandleShapesAndTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1135 bool InferenceContext::MergeHandleShapesAndTypes(
1136 const std::vector<ShapeAndType>& shapes_and_types,
1137 std::vector<ShapeAndType>* to_update) {
1138 if (shapes_and_types.size() != to_update->size()) {
1139 return false;
1140 }
1141 std::vector<ShapeAndType> new_values(shapes_and_types.size());
1142 bool refined = false;
1143 for (int i = 0; i < shapes_and_types.size(); ++i) {
1144 const ShapeAndType& existing = (*to_update)[i];
1145 if (shapes_and_types[i].dtype == existing.dtype) {
1146 new_values[i].dtype = existing.dtype;
1147 } else {
1148 if (existing.dtype != DT_INVALID) {
1149 return false;
1150 } else {
1151 new_values[i].dtype = shapes_and_types[i].dtype;
1152 refined = true;
1153 }
1154 }
1155 if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape)
1156 .ok()) {
1157 // merge failed, ignore the new value.
1158 new_values[i].shape = existing.shape;
1159 }
1160 if (!existing.shape.SameHandle(new_values[i].shape)) {
1161 refined = true;
1162 }
1163 }
1164 if (!refined) {
1165 return false;
1166 }
1167 for (int i = 0; i < new_values.size(); ++i) {
1168 (*to_update)[i] = new_values[i];
1169 }
1170 return true;
1171 }
1172
MergeOutputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1173 bool InferenceContext::MergeOutputHandleShapesAndTypes(
1174 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1175 if (output_handle_shapes_and_types_[idx] == nullptr) {
1176 output_handle_shapes_and_types_[idx].reset(
1177 new std::vector<ShapeAndType>(shapes_and_types));
1178 return true;
1179 }
1180 return MergeHandleShapesAndTypes(shapes_and_types,
1181 output_handle_shapes_and_types_[idx].get());
1182 }
1183
MergeInputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1184 bool InferenceContext::MergeInputHandleShapesAndTypes(
1185 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1186 if (input_handle_shapes_and_types_[idx] == nullptr) {
1187 input_handle_shapes_and_types_[idx].reset(
1188 new std::vector<ShapeAndType>(shapes_and_types));
1189 return true;
1190 }
1191 return MergeHandleShapesAndTypes(shapes_and_types,
1192 input_handle_shapes_and_types_[idx].get());
1193 }
1194
RelaxHandleShapesAndMergeTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1195 bool InferenceContext::RelaxHandleShapesAndMergeTypes(
1196 const std::vector<ShapeAndType>& shapes_and_types,
1197 std::vector<ShapeAndType>* to_update) {
1198 if (shapes_and_types.size() != to_update->size()) {
1199 return false;
1200 }
1201 std::vector<ShapeAndType> new_values(shapes_and_types.size());
1202 for (int i = 0; i < shapes_and_types.size(); ++i) {
1203 const ShapeAndType& existing = (*to_update)[i];
1204 if (shapes_and_types[i].dtype == existing.dtype) {
1205 new_values[i].dtype = existing.dtype;
1206 } else {
1207 if (existing.dtype != DT_INVALID) {
1208 return false;
1209 } else {
1210 new_values[i].dtype = shapes_and_types[i].dtype;
1211 }
1212 }
1213 Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape);
1214 }
1215 to_update->swap(new_values);
1216 return true;
1217 }
1218
RelaxOutputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1219 bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes(
1220 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1221 if (output_handle_shapes_and_types_[idx] == nullptr) {
1222 output_handle_shapes_and_types_[idx].reset(
1223 new std::vector<ShapeAndType>(shapes_and_types));
1224 return true;
1225 }
1226 return RelaxHandleShapesAndMergeTypes(
1227 shapes_and_types, output_handle_shapes_and_types_[idx].get());
1228 }
1229
RelaxInputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1230 bool InferenceContext::RelaxInputHandleShapesAndMergeTypes(
1231 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1232 if (input_handle_shapes_and_types_[idx] == nullptr) {
1233 input_handle_shapes_and_types_[idx].reset(
1234 new std::vector<ShapeAndType>(shapes_and_types));
1235 return true;
1236 }
1237 return RelaxHandleShapesAndMergeTypes(
1238 shapes_and_types, input_handle_shapes_and_types_[idx].get());
1239 }
1240
1241 // -----------------------------------------------------------------------------
1242 // ShapeManager
1243 // -----------------------------------------------------------------------------
ShapeManager()1244 InferenceContext::ShapeManager::ShapeManager() {}
~ShapeManager()1245 InferenceContext::ShapeManager::~ShapeManager() {
1246 for (auto* s : all_shapes_) delete s;
1247 for (auto* d : all_dims_) delete d;
1248 }
1249
MakeShape(const std::vector<DimensionHandle> & dims)1250 ShapeHandle InferenceContext::ShapeManager::MakeShape(
1251 const std::vector<DimensionHandle>& dims) {
1252 all_shapes_.push_back(new Shape(dims));
1253 return all_shapes_.back();
1254 }
1255
UnknownShape()1256 ShapeHandle InferenceContext::ShapeManager::UnknownShape() {
1257 all_shapes_.push_back(new Shape());
1258 return all_shapes_.back();
1259 }
1260
1261 } // namespace shape_inference
1262 } // namespace tensorflow
1263