1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16
17 #include "tensorflow/core/framework/bounds_check.h"
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/strings/numbers.h"
23 #include "tensorflow/core/lib/strings/scanner.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25
26 namespace tensorflow {
27 namespace shape_inference {
28
29 constexpr int32 InferenceContext::kUnknownRank;
30 constexpr int64 InferenceContext::kUnknownDim;
31
32 // Same as above, but with PartialTensorShape instead of TensorShapeProto
InferenceContext(int graph_def_version,const AttrSlice & attrs,const OpDef & op_def,const std::vector<PartialTensorShape> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<PartialTensorShape> & input_tensors_as_shapes,const std::vector<std::unique_ptr<std::vector<std::pair<PartialTensorShape,DataType>>>> & input_handle_shapes_and_types)33 InferenceContext::InferenceContext(
34 int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
35 const std::vector<PartialTensorShape>& input_shapes,
36 const std::vector<const Tensor*>& input_tensors,
37 const std::vector<PartialTensorShape>& input_tensors_as_shapes,
38 const std::vector<
39 std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
40 input_handle_shapes_and_types)
41 : graph_def_version_(graph_def_version), attrs_(attrs) {
42 std::vector<ShapeHandle> input_tensors_as_shape_handles;
43 input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
44 for (const PartialTensorShape& p : input_tensors_as_shapes) {
45 ShapeHandle shape;
46 construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
47 if (!construction_status_.ok()) {
48 return;
49 }
50 input_tensors_as_shape_handles.push_back(shape);
51 }
52 PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
53 if (!construction_status_.ok()) return;
54 inputs_.reserve(input_shapes.size());
55 for (const PartialTensorShape& p : input_shapes) {
56 ShapeHandle shape;
57 construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
58 if (!construction_status_.ok()) {
59 return;
60 }
61 inputs_.push_back(shape);
62 }
63 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
64 input_shapes.size());
65 for (int i = 0, end = input_handle_shapes_and_types.size(); i < end; ++i) {
66 const auto& v = input_handle_shapes_and_types[i];
67 if (v == nullptr) {
68 continue;
69 }
70 handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
71 auto& new_v = *handle_data[i];
72 for (int j = 0, end = v->size(); j < end; ++j) {
73 const auto& p = (*v)[j];
74 construction_status_.Update(
75 MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape));
76 if (!construction_status_.ok()) {
77 return;
78 }
79 new_v[j].dtype = p.second;
80 }
81 }
82 PostInputInit(std::move(handle_data));
83 }
84
InferenceContext(int graph_def_version,const AttrSlice & attrs,const OpDef & op_def,const std::vector<ShapeHandle> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes,std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types)85 InferenceContext::InferenceContext(
86 int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
87 const std::vector<ShapeHandle>& input_shapes,
88 const std::vector<const Tensor*>& input_tensors,
89 const std::vector<ShapeHandle>& input_tensors_as_shapes,
90 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
91 input_handle_shapes_and_types)
92 : graph_def_version_(graph_def_version), attrs_(attrs) {
93 PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
94 if (!construction_status_.ok()) return;
95 inputs_ = input_shapes;
96
97 PostInputInit(std::move(input_handle_shapes_and_types));
98 }
99
~InferenceContext()100 InferenceContext::~InferenceContext() {}
101
Run(const std::function<Status (shape_inference::InferenceContext * c)> & fn)102 Status InferenceContext::Run(
103 const std::function<Status(shape_inference::InferenceContext* c)>& fn) {
104 ForgetMerges();
105 Status s = fn(this);
106 if (!s.ok()) {
107 ForgetMerges();
108 return AttachContext(s);
109 }
110 #ifndef NDEBUG
111 for (int i = 0; i < num_outputs(); ++i) {
112 DCHECK(output(i).IsSet()) << i << " for " << attrs_.SummarizeNode();
113 }
114 #endif // NDEBUG
115 return s;
116 }
117
set_output(StringPiece output_name,const std::vector<ShapeHandle> & shapes)118 Status InferenceContext::set_output(StringPiece output_name,
119 const std::vector<ShapeHandle>& shapes) {
120 auto result = output_name_map_.find(output_name);
121 if (result == output_name_map_.end()) {
122 return errors::InvalidArgument("Unknown output name: ", output_name);
123 } else {
124 const int start = result->second.first;
125 const int size = result->second.second - start;
126 const int shapes_size = shapes.size();
127 if (size != shapes_size) {
128 return errors::InvalidArgument("Must have exactly ", shapes.size(),
129 " shapes.");
130 }
131 for (int i = 0; i < shapes_size; ++i) {
132 outputs_[i + start] = shapes[i];
133 }
134 }
135 return Status::OK();
136 }
137
input(StringPiece input_name,std::vector<ShapeHandle> * output) const138 Status InferenceContext::input(StringPiece input_name,
139 std::vector<ShapeHandle>* output) const {
140 const auto result = input_name_map_.find(input_name);
141 if (result == input_name_map_.end()) {
142 return errors::InvalidArgument("Unknown input name: ", input_name);
143 } else {
144 output->clear();
145 for (int i = result->second.first; i < result->second.second; ++i) {
146 output->push_back(inputs_[i]);
147 }
148 }
149 return Status::OK();
150 }
151
output(StringPiece output_name,std::vector<ShapeHandle> * output) const152 Status InferenceContext::output(StringPiece output_name,
153 std::vector<ShapeHandle>* output) const {
154 const auto result = output_name_map_.find(output_name);
155 if (result == output_name_map_.end()) {
156 return errors::InvalidArgument("Unknown output name: ", output_name);
157 } else {
158 output->clear();
159 for (int i = result->second.first; i < result->second.second; ++i) {
160 output->push_back(outputs_[i]);
161 }
162 }
163 return Status::OK();
164 }
165
PreInputInit(const OpDef & op_def,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes)166 void InferenceContext::PreInputInit(
167 const OpDef& op_def, const std::vector<const Tensor*>& input_tensors,
168 const std::vector<ShapeHandle>& input_tensors_as_shapes) {
169 input_tensors_ = input_tensors;
170 input_tensors_as_shapes_ = input_tensors_as_shapes;
171
172 construction_status_ =
173 NameRangesForNode(attrs_, op_def, &input_name_map_, &output_name_map_);
174 if (!construction_status_.ok()) return;
175
176 int num_outputs = 0;
177 for (const auto& e : output_name_map_) {
178 num_outputs = std::max(num_outputs, e.second.second);
179 }
180 outputs_.assign(num_outputs, nullptr);
181 output_handle_shapes_and_types_.resize(num_outputs);
182 }
183
ExpandOutputs(int new_output_size)184 Status InferenceContext::ExpandOutputs(int new_output_size) {
185 const int outputs_size = outputs_.size();
186 if (new_output_size < outputs_size) {
187 return errors::InvalidArgument("Trying to reduce number of outputs of op.");
188 }
189 outputs_.resize(new_output_size, nullptr);
190 output_handle_shapes_and_types_.resize(new_output_size);
191 return Status::OK();
192 }
193
PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data)194 void InferenceContext::PostInputInit(
195 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
196 int num_inputs_from_node_def = 0;
197 for (const auto& e : input_name_map_) {
198 num_inputs_from_node_def =
199 std::max(num_inputs_from_node_def, e.second.second);
200 }
201
202 // Allow passing empty shapes/dtypes to avoid changing every single test.
203 if (input_handle_data.empty()) {
204 input_handle_shapes_and_types_.resize(inputs_.size());
205 } else {
206 if (input_handle_data.size() != inputs_.size()) {
207 construction_status_ = errors::InvalidArgument(
208 "Wrong number of handle shapes passed; expected ", inputs_.size(),
209 " got ", input_handle_data.size());
210 return;
211 }
212 input_handle_shapes_and_types_ = std::move(input_handle_data);
213 }
214 const int inputs_size = inputs_.size();
215 if (inputs_size != num_inputs_from_node_def) {
216 construction_status_ = errors::InvalidArgument(
217 "Wrong number of inputs passed: ", inputs_.size(), " while ",
218 num_inputs_from_node_def, " expected based on NodeDef");
219 return;
220 }
221
222 CHECK_LE(input_tensors_.size(), inputs_.size());
223 input_tensors_.resize(inputs_.size());
224 requested_input_tensor_.resize(inputs_.size());
225 requested_input_tensor_as_partial_shape_.resize(inputs_.size());
226 }
227
ShapeHandleToProto(ShapeHandle handle,TensorShapeProto * proto)228 void InferenceContext::ShapeHandleToProto(ShapeHandle handle,
229 TensorShapeProto* proto) {
230 if (!RankKnown(handle)) {
231 proto->set_unknown_rank(true);
232 return;
233 }
234
235 for (int32 i = 0; i < Rank(handle); ++i) {
236 DimensionHandle dim = Dim(handle, i);
237 auto* dim_shape = proto->add_dim();
238 if (ValueKnown(dim)) {
239 dim_shape->set_size(Value(dim));
240 } else {
241 dim_shape->set_size(-1);
242 }
243 }
244 }
245
FullyDefined(ShapeHandle s)246 bool InferenceContext::FullyDefined(ShapeHandle s) {
247 if (!RankKnown(s)) return false;
248 for (int i = 0; i < Rank(s); ++i) {
249 if (!ValueKnown(Dim(s, i))) return false;
250 }
251 return true;
252 }
253
NumElements(ShapeHandle s)254 DimensionHandle InferenceContext::NumElements(ShapeHandle s) {
255 const auto rank = Rank(s);
256 if (rank == kUnknownRank) return UnknownDim();
257 bool found_unknown = false;
258 int64 size = 1;
259 for (int i = 0; i < rank; ++i) {
260 int64 dim_val = Value(Dim(s, i));
261 if (dim_val == kUnknownDim) {
262 found_unknown = true;
263 } else if (dim_val == 0) {
264 return MakeDim(0);
265 } else {
266 size *= dim_val;
267 }
268 }
269 if (found_unknown) {
270 return UnknownDim();
271 } else {
272 return MakeDim(size);
273 }
274 }
275
DebugString(ShapeHandle s)276 string InferenceContext::DebugString(ShapeHandle s) {
277 if (RankKnown(s)) {
278 std::vector<string> vals;
279 for (auto d : s->dims_) vals.push_back(DebugString(d));
280 return strings::StrCat("[", absl::StrJoin(vals, ","), "]");
281 } else {
282 return "?";
283 }
284 }
285
DebugString(DimensionHandle d)286 string InferenceContext::DebugString(DimensionHandle d) {
287 return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
288 }
289
DebugString() const290 string InferenceContext::DebugString() const {
291 return strings::StrCat("InferenceContext for node: ", attrs_.SummarizeNode());
292 }
293
DebugString(const ShapeAndType & shape_and_type)294 string InferenceContext::DebugString(const ShapeAndType& shape_and_type) {
295 return strings::StrCat(DebugString(shape_and_type.shape), ":",
296 DataTypeString(shape_and_type.dtype));
297 }
298
DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types)299 string InferenceContext::DebugString(
300 gtl::ArraySlice<ShapeAndType> shape_and_types) {
301 std::vector<string> pieces;
302 for (const ShapeAndType& s : shape_and_types) {
303 pieces.push_back(DebugString(s));
304 }
305 return strings::StrCat("[", absl::StrJoin(pieces, ","), "]");
306 }
307
WithRank(ShapeHandle shape,int64 rank,ShapeHandle * out)308 Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
309 ShapeHandle* out) {
310 if (rank > kint32max) {
311 return errors::InvalidArgument("Rank cannot exceed kint32max");
312 }
313 const int32 existing = Rank(shape);
314 if (existing == rank) {
315 *out = shape;
316 return Status::OK();
317 }
318 if (existing == kUnknownRank) {
319 std::vector<DimensionHandle> dims;
320 dims.reserve(rank);
321 for (int i = 0; i < rank; ++i) {
322 dims.push_back(UnknownDim());
323 }
324 ShapeHandle shp = shape_manager_.MakeShape(dims);
325 return Merge(shape, shp, out);
326 }
327 *out = nullptr;
328
329 return errors::InvalidArgument("Shape must be rank ", rank, " but is rank ",
330 existing);
331 }
332
WithRankAtLeast(ShapeHandle shape,int64 rank,ShapeHandle * out)333 Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank,
334 ShapeHandle* out) {
335 if (rank > kint32max) {
336 return errors::InvalidArgument("Rank cannot exceed kint32max");
337 }
338 const int32 existing = Rank(shape);
339 if (existing >= rank || existing == kUnknownRank) {
340 *out = shape;
341 return Status::OK();
342 }
343 *out = nullptr;
344 return errors::InvalidArgument("Shape must be at least rank ", rank,
345 " but is rank ", existing);
346 }
347
WithRankAtMost(ShapeHandle shape,int64 rank,ShapeHandle * out)348 Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank,
349 ShapeHandle* out) {
350 if (rank > kint32max) {
351 return errors::InvalidArgument("Rank cannot exceed kint32max");
352 }
353 const int32 existing = Rank(shape);
354 if (existing <= rank || existing == kUnknownRank) {
355 *out = shape;
356 return Status::OK();
357 }
358 *out = nullptr;
359 return errors::InvalidArgument("Shape must be at most rank ", rank,
360 " but is rank ", existing);
361 }
362
WithValue(DimensionHandle dim,int64 value,DimensionHandle * out)363 Status InferenceContext::WithValue(DimensionHandle dim, int64 value,
364 DimensionHandle* out) {
365 const int64 existing = Value(dim);
366 if (existing == value) {
367 *out = dim;
368 return Status::OK();
369 }
370 if (existing == kUnknownDim) {
371 DimensionHandle d = MakeDim(value);
372 return Merge(dim, d, out);
373 }
374 *out = nullptr;
375 return errors::InvalidArgument("Dimension must be ", value, " but is ",
376 existing);
377 }
378
Relax(DimensionHandle d_old,DimensionHandle d_new,DimensionHandle * out)379 void InferenceContext::Relax(DimensionHandle d_old, DimensionHandle d_new,
380 DimensionHandle* out) {
381 if (d_old.SameHandle(d_new)) {
382 *out = d_old;
383 } else if (!ValueKnown(d_old) && !ValueKnown(d_new)) {
384 // The node will be fed by the dimension d_new instead of d_old: any
385 // equality assertion between d_old and other input dimension on this node
386 // may not be true anymore, so forget them all.
387 ForgetMerges();
388 // Return the new shape handle to force the relaxation to propagate to the
389 // fanout of the context.
390 *out = d_new;
391 } else if (!ValueKnown(d_new)) {
392 ForgetMerges();
393 *out = d_new;
394 } else if (Value(d_old) == Value(d_new)) {
395 // Return the old shape handle. This will stop the relaxation in the fanout
396 // of the context.
397 *out = d_old;
398 } else {
399 // Return a new handle that encodes a different unknown dim.
400 ForgetMerges();
401 *out = UnknownDim();
402 }
403 }
404
Merge(DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)405 Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1,
406 DimensionHandle* out) {
407 if (d0.SameHandle(d1)) {
408 *out = d0;
409 return Status::OK();
410 } else if (!ValueKnown(d1)) {
411 *out = d0;
412 merged_dims_.emplace_back(d0, d1);
413 return Status::OK();
414 } else if (!ValueKnown(d0)) {
415 *out = d1;
416 merged_dims_.emplace_back(d0, d1);
417 return Status::OK();
418 } else if (Value(d0) == Value(d1)) {
419 *out = d0;
420 return Status::OK();
421 } else {
422 *out = nullptr;
423 return errors::InvalidArgument("Dimensions must be equal, but are ",
424 Value(d0), " and ", Value(d1));
425 }
426 }
427
MergePrefix(ShapeHandle s,ShapeHandle prefix,ShapeHandle * s_out,ShapeHandle * prefix_out)428 Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix,
429 ShapeHandle* s_out,
430 ShapeHandle* prefix_out) {
431 *s_out = *prefix_out = nullptr;
432 if (!RankKnown(prefix) || !RankKnown(s)) {
433 *s_out = s;
434 *prefix_out = prefix;
435 return Status::OK();
436 }
437 const int32 rank = Rank(prefix);
438 TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s));
439
440 // Merge the prefix dims and create the new output shapes.
441 const int32 rank_s = Rank(s);
442 std::vector<DimensionHandle> dims;
443 dims.reserve(std::max(rank, rank_s));
444 dims.resize(rank);
445 for (int i = 0; i < rank; ++i) {
446 TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i]));
447 }
448 *prefix_out = MakeShape(dims);
449 for (int i = rank; i < rank_s; ++i) dims.push_back(Dim(s, i));
450 *s_out = MakeShape(dims);
451 return Status::OK();
452 }
453
Relax(ShapeHandle s_old,ShapeHandle s_new,ShapeHandle * out)454 void InferenceContext::Relax(ShapeHandle s_old, ShapeHandle s_new,
455 ShapeHandle* out) {
456 if (s_old.SameHandle(s_new)) {
457 *out = s_old;
458 return;
459 } else if (!RankKnown(s_new) || !s_old.IsSet()) {
460 ForgetMerges();
461 *out = s_new;
462 return;
463 }
464
465 const int32 rank = Rank(s_old);
466 if (rank != Rank(s_new)) {
467 ForgetMerges();
468 *out = UnknownShape();
469 return;
470 }
471
472 bool return_s_old = true;
473 for (int i = 0; i < rank; ++i) {
474 auto d0 = Dim(s_old, i);
475 auto d1 = Dim(s_new, i);
476 if (d0.SameHandle(d1)) continue;
477
478 auto v0 = Value(d0);
479 auto v1 = Value(d1);
480 if (v0 == kUnknownDim || v1 == kUnknownDim || v0 != v1) {
481 return_s_old = false;
482 break;
483 }
484 }
485 if (return_s_old) {
486 *out = s_old;
487 return;
488 }
489
490 // Relax dims.
491 std::vector<DimensionHandle> dims(rank);
492 for (int i = 0; i < rank; ++i) {
493 Relax(Dim(s_old, i), Dim(s_new, i), &dims[i]);
494 }
495 ForgetMerges();
496 *out = MakeShape(dims);
497 }
498
Merge(ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)499 Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
500 ShapeHandle* out) {
501 if (s0.SameHandle(s1)) {
502 *out = s0;
503 return Status::OK();
504 } else if (!RankKnown(s1)) {
505 *out = s0;
506 merged_shapes_.emplace_back(s0, s1);
507 return Status::OK();
508 } else if (!RankKnown(s0)) {
509 *out = s1;
510 merged_shapes_.emplace_back(s0, s1);
511 return Status::OK();
512 }
513
514 const int32 rank = Rank(s0);
515 if (rank != Rank(s1)) {
516 *out = nullptr;
517 return errors::InvalidArgument("Shapes must be equal rank, but are ", rank,
518 " and ", Rank(s1));
519 }
520
521 bool return_s0 = true;
522 bool return_s1 = true;
523 for (int i = 0; i < rank; ++i) {
524 auto d0 = Dim(s0, i);
525 auto d1 = Dim(s1, i);
526 if (d0.SameHandle(d1)) continue;
527
528 auto v0 = Value(d0);
529 auto v1 = Value(d1);
530 if (v0 == kUnknownDim) {
531 if (v1 != kUnknownDim) {
532 return_s0 = false;
533 }
534 } else if (v1 == kUnknownDim) {
535 return_s1 = false;
536 } else if (v0 != v1) {
537 *out = nullptr;
538 return errors::InvalidArgument(
539 "Dimension ", i, " in both shapes must be equal, but are ", Value(d0),
540 " and ", Value(d1), ". Shapes are ", DebugString(s0), " and ",
541 DebugString(s1), ".");
542 }
543 }
544
545 merged_shapes_.emplace_back(s0, s1);
546
547 if (return_s0 || return_s1) {
548 *out = return_s0 ? s0 : s1;
549 return Status::OK();
550 }
551
552 // Merge dims.
553 std::vector<DimensionHandle> dims(rank, nullptr);
554 for (int i = 0; i < rank; ++i) {
555 // Invariant for merge was checked earlier, so CHECK is ok.
556 TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i]));
557 }
558
559 Status s = ReturnCreatedShape(dims, out);
560 if (s.ok()) {
561 // Merge the new shape with s0. Since s0 and s1 are merged, this implies
562 // that s1 and out are also merged.
563 merged_shapes_.emplace_back(s0, *out);
564 }
565 return s;
566 }
567
Subshape(ShapeHandle s,int64 start,ShapeHandle * out)568 Status InferenceContext::Subshape(ShapeHandle s, int64 start,
569 ShapeHandle* out) {
570 return Subshape(s, start, std::numeric_limits<int64>::max() /* end */, out);
571 }
572
Subshape(ShapeHandle s,int64 start,int64 end,ShapeHandle * out)573 Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
574 ShapeHandle* out) {
575 return Subshape(s, start, end, 1 /* stride */, out);
576 }
577
Subshape(ShapeHandle s,int64 start,int64 end,int64 stride,ShapeHandle * out)578 Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
579 int64 stride, ShapeHandle* out) {
580 int64 start_in = start;
581 int64 end_in = end;
582
583 const int32 rank = Rank(s);
584 if (start == 0 && stride == 1 &&
585 ((RankKnown(s) && end >= rank) ||
586 end == std::numeric_limits<int64>::max())) {
587 *out = s;
588 return Status::OK();
589 }
590 if (!RankKnown(s)) {
591 return ReturnUnknownShape(out);
592 }
593
594 if (start > rank) start = rank;
595 if (end > rank) end = rank;
596
597 if (stride < 0 && start == rank) --start;
598
599 if (start < 0) {
600 start = rank + start;
601 if (start < 0) {
602 *out = nullptr;
603 return errors::InvalidArgument("Subshape start out of bounds: ", start_in,
604 ", for shape with rank ", rank);
605 }
606 }
607
608 if (end < 0) {
609 end = rank + end;
610 if (end < 0) {
611 *out = nullptr;
612 return errors::InvalidArgument("Subshape end out of bounds: ", end_in,
613 ", for shape with rank ", rank);
614 }
615 }
616 if (stride > 0 && start > end) {
617 *out = nullptr;
618 return errors::InvalidArgument(
619 "Subshape must have computed start <= end, but is ", start, " and ",
620 end, " (computed from start ", start_in, " and end ", end_in,
621 " over shape with rank ", rank, ")");
622 } else if (stride < 0 && start < end) {
623 *out = nullptr;
624 return errors::InvalidArgument(
625 "Subshape must have computed start >= end since stride is negative, "
626 "but is ",
627 start, " and ", end, " (computed from start ", start_in, " and end ",
628 end_in, " over shape with rank ", rank, " and stride", stride, ")");
629 }
630
631 std::vector<DimensionHandle> dims;
632 for (int i = start; stride > 0 ? i < end : i > end; i += stride) {
633 dims.push_back(Dim(s, i));
634 }
635 return ReturnCreatedShape(dims, out);
636 }
637
Concatenate(ShapeHandle s1,ShapeHandle s2,ShapeHandle * out)638 Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2,
639 ShapeHandle* out) {
640 if (!RankKnown(s1) || !RankKnown(s2)) {
641 return ReturnUnknownShape(out);
642 }
643 const int32 s1_rank = Rank(s1);
644 const int32 s2_rank = Rank(s2);
645 const int32 rank = s1_rank + s2_rank;
646 std::vector<DimensionHandle> dims;
647 dims.reserve(rank);
648 for (int i = 0; i < s1_rank; ++i) dims.push_back(Dim(s1, i));
649 for (int i = 0; i < s2_rank; ++i) dims.push_back(Dim(s2, i));
650 return ReturnCreatedShape(dims, out);
651 }
652
ReplaceDim(ShapeHandle s,int64 dim_index_in,DimensionHandle new_dim,ShapeHandle * out)653 Status InferenceContext::ReplaceDim(ShapeHandle s, int64 dim_index_in,
654 DimensionHandle new_dim, ShapeHandle* out) {
655 if (!RankKnown(s)) {
656 return ReturnUnknownShape(out);
657 }
658 int64 dim_index = dim_index_in;
659 if (dim_index < 0) {
660 dim_index = s->dims_.size() + dim_index;
661 }
662 if (!FastBoundsCheck(dim_index, s->dims_.size())) {
663 *out = nullptr;
664 return errors::InvalidArgument("Out of range dim_index ", dim_index_in,
665 " for shape with ", s->dims_.size(),
666 " dimensions");
667 }
668 std::vector<DimensionHandle> dims(s->dims_);
669 dims[dim_index] = new_dim;
670 return ReturnCreatedShape(dims, out);
671 }
672
MakeShape(const std::vector<DimensionHandle> & dims)673 ShapeHandle InferenceContext::MakeShape(
674 const std::vector<DimensionHandle>& dims) {
675 return shape_manager_.MakeShape(dims);
676 }
677
MakeShape(std::initializer_list<DimensionOrConstant> dims)678 ShapeHandle InferenceContext::MakeShape(
679 std::initializer_list<DimensionOrConstant> dims) {
680 std::vector<DimensionHandle> dims_actual;
681 dims_actual.reserve(dims.size());
682 for (const DimensionOrConstant& d : dims) {
683 dims_actual.push_back(MakeDim(d));
684 }
685
686 return shape_manager_.MakeShape(dims_actual);
687 }
688
UnknownShape()689 ShapeHandle InferenceContext::UnknownShape() {
690 return shape_manager_.UnknownShape();
691 }
692
UnknownShapeOfRank(int64 rank)693 ShapeHandle InferenceContext::UnknownShapeOfRank(int64 rank) {
694 CHECK_LE(rank, kint32max) << "rank must be less than kint32max";
695 if (rank == kUnknownRank) {
696 return UnknownShape();
697 }
698 CHECK_GE(rank, 0) << "rank must not be negative";
699 std::vector<DimensionHandle> dims(rank);
700 for (int32 i = 0; i < rank; ++i) {
701 dims[i] = UnknownDim();
702 }
703 return MakeShape(dims);
704 }
705
Scalar()706 ShapeHandle InferenceContext::Scalar() { return MakeShape({}); }
707
Vector(DimensionOrConstant dim)708 ShapeHandle InferenceContext::Vector(DimensionOrConstant dim) {
709 return MakeShape({dim});
710 }
711
Matrix(DimensionOrConstant dim1,DimensionOrConstant dim2)712 ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
713 DimensionOrConstant dim2) {
714 return MakeShape({dim1, dim2});
715 }
716
MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,ShapeHandle * out)717 Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
718 int input_idx, ShapeHandle* out) {
719 ShapeHandle input_shape;
720 TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
721
722 request_input_tensor_as_partial_shape(input_idx);
723 const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
724 if (input_idx < input_tensors_as_shapes_size &&
725 input_tensors_as_shapes_[input_idx].IsSet() &&
726 RankKnown(input_tensors_as_shapes_[input_idx])) {
727 *out = input_tensors_as_shapes_[input_idx];
728 return Status::OK();
729 }
730
731 return InternalMakeShapeFromTensor(
732 true /* treat_unknown_scalar_tensor_as_unknown_shape */,
733 input_tensor(input_idx), input_shape, out);
734 }
735
MakeShapeFromShapeTensor(int input_idx,ShapeHandle * out)736 Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
737 ShapeHandle* out) {
738 ShapeHandle input_shape;
739 TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
740
741 request_input_tensor_as_partial_shape(input_idx);
742 const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
743 if (input_idx < input_tensors_as_shapes_size &&
744 input_tensors_as_shapes_[input_idx].IsSet() &&
745 RankKnown(input_tensors_as_shapes_[input_idx])) {
746 *out = input_tensors_as_shapes_[input_idx];
747 return Status::OK();
748 }
749
750 return InternalMakeShapeFromTensor(
751 false /* treat_unknown_scalar_tensor_as_unknown_shape */,
752 input_tensor(input_idx), input_shape, out);
753 }
754
MakeShapeFromTensor(const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)755 Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
756 ShapeHandle tensor_shape,
757 ShapeHandle* out) {
758 return InternalMakeShapeFromTensor(
759 false /* treat_unknown_scalar_tensor_as_unknown_shape */, t, tensor_shape,
760 out);
761 }
762
InternalMakeShapeFromTensor(bool treat_unknown_scalar_tensor_as_unknown_shape,const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)763 Status InferenceContext::InternalMakeShapeFromTensor(
764 bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
765 ShapeHandle tensor_shape, ShapeHandle* out) {
766 // Only callers who have set
767 if (!treat_unknown_scalar_tensor_as_unknown_shape) {
768 TF_RETURN_IF_ERROR(WithRank(tensor_shape, 1, &tensor_shape));
769 }
770 if (t == nullptr) {
771 // This is guarded by the check above.
772 if (Rank(tensor_shape) == 0) {
773 return ReturnUnknownShape(out);
774 }
775 // Shape tensor is not known, but if the shape of the shape tensor is then
776 // the right number of unknown dims can be created.
777 DimensionHandle shape_dim = Dim(tensor_shape, 0);
778 if (!ValueKnown(shape_dim)) {
779 return ReturnUnknownShape(out);
780 }
781 const auto num_dims = Value(shape_dim);
782 std::vector<DimensionHandle> dims;
783 dims.reserve(num_dims);
784 for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim());
785 return ReturnCreatedShape(dims, out);
786 }
787
788 if (t->shape().dims() == 0) {
789 if (t->dtype() == DataType::DT_INT32) {
790 auto flat_t = t->scalar<int32>();
791 if (flat_t() != -1) {
792 *out = nullptr;
793 return errors::InvalidArgument(
794 "Input tensor must be rank 1, or if its rank 0 it must have value "
795 "-1 "
796 "(representing an unknown shape). Saw value: ",
797 flat_t());
798 }
799 return ReturnUnknownShape(out);
800 } else if (t->dtype() == DataType::DT_INT64) {
801 auto flat_t = t->scalar<int64>();
802 if (flat_t() != -1) {
803 *out = nullptr;
804 return errors::InvalidArgument(
805 "Input tensor must be rank 1, or if its rank 0 it must have value "
806 "-1 "
807 "(representing an unknown shape). Saw value: ",
808 flat_t());
809 }
810 return ReturnUnknownShape(out);
811 } else {
812 *out = nullptr;
813 return errors::InvalidArgument(
814 "Input tensor must be int32 or int64, but was ",
815 DataTypeString(t->dtype()));
816 }
817 }
818
819 if (t->shape().dims() != 1) {
820 *out = nullptr;
821 return errors::InvalidArgument(
822 "Input tensor must be rank 1, but was rank ", t->shape().dims(), ".",
823 ((t->shape().dims() == 0)
824 ? "If it is rank 0 rank 0 it must have statically known value -1 "
825 "(representing an unknown shape). "
826 : " "),
827 "Saw tensor shape ", t->shape().DebugString());
828 }
829 std::vector<DimensionHandle> dims;
830 if (t->dtype() == DataType::DT_INT32) {
831 auto flat_t = t->flat<int32>();
832 for (int i = 0; i < flat_t.size(); ++i) {
833 const int32 val = flat_t(i);
834 if (val < -1) {
835 return errors::InvalidArgument(
836 "Invalid value in tensor used for shape: ", val);
837 }
838 // -1 will become an unknown dim.
839 dims.push_back(MakeDim(val));
840 }
841 } else if (t->dtype() == DataType::DT_INT64) {
842 auto flat_t = t->flat<int64>();
843 for (int i = 0; i < flat_t.size(); ++i) {
844 const int64 val = flat_t(i);
845 if (val < -1) {
846 return errors::InvalidArgument(
847 "Invalid value in tensor used for shape: ", val);
848 }
849 // -1 will become an unknown dim.
850 dims.push_back(MakeDim(val));
851 }
852 } else {
853 *out = nullptr;
854 return errors::InvalidArgument(
855 "Input tensor must be int32 or int64, but was ",
856 DataTypeString(t->dtype()));
857 }
858
859 return ReturnCreatedShape(dims, out);
860 }
861
MakeShapeFromPartialTensorShape(const PartialTensorShape & partial_shape,ShapeHandle * out)862 Status InferenceContext::MakeShapeFromPartialTensorShape(
863 const PartialTensorShape& partial_shape, ShapeHandle* out) {
864 *out = nullptr;
865 if (partial_shape.dims() == -1) {
866 return ReturnUnknownShape(out);
867 }
868 const int num_dims = partial_shape.dims();
869 std::vector<DimensionHandle> dims(num_dims);
870 for (int i = 0; i < num_dims; ++i) {
871 // -1 is unknown in PartialTensorShape and in InferenceContext, so this size
872 // can be passed directly to MakeDim.
873 dims[i] = MakeDim(partial_shape.dim_size(i));
874 }
875 return ReturnCreatedShape(dims, out);
876 }
877
MakeShapeFromTensorShape(const TensorShape & shape,ShapeHandle * out)878 Status InferenceContext::MakeShapeFromTensorShape(const TensorShape& shape,
879 ShapeHandle* out) {
880 return MakeShapeFromPartialTensorShape(PartialTensorShape(shape.dim_sizes()),
881 out);
882 }
883
MakeShapeFromShapeProto(const TensorShapeProto & proto,ShapeHandle * out)884 Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
885 ShapeHandle* out) {
886 *out = nullptr;
887 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto));
888 PartialTensorShape partial_shape(proto);
889 return MakeShapeFromPartialTensorShape(partial_shape, out);
890 }
891
GetScalarFromTensor(const Tensor * t,int64 * val)892 Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
893 // Caller must ensure that <t> is not NULL.
894 const int rank = t->dims();
895 if (rank != 0) {
896 return errors::InvalidArgument("Input must be scalar but has rank ", rank);
897 }
898
899 if (t->dtype() == DataType::DT_INT32) {
900 *val = t->scalar<int32>()();
901 return Status::OK();
902 } else if (t->dtype() == DataType::DT_INT64) {
903 *val = t->scalar<int64>()();
904 return Status::OK();
905 } else {
906 return errors::InvalidArgument("Scalar input must be int32 or int64.");
907 }
908 }
909
GetScalarFromTensor(const Tensor * t,int64 idx,int64 * val)910 Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64 idx,
911 int64* val) {
912 // Caller must ensure that <t> is not NULL.
913 const int rank = t->dims();
914 if (rank != 1) {
915 return errors::InvalidArgument("Input must be 1D but has rank ", rank);
916 }
917
918 if (t->dtype() == DataType::DT_INT32) {
919 auto flat_t = t->flat<int32>();
920 if (idx < 0 || idx >= flat_t.size()) {
921 return errors::InvalidArgument("Invalid index ", idx,
922 " for Tensor of size ", flat_t.size());
923 }
924 *val = flat_t(idx);
925 return Status::OK();
926 } else if (t->dtype() == DataType::DT_INT64) {
927 auto flat_t = t->flat<int64>();
928 if (idx < 0 || idx >= flat_t.size()) {
929 return errors::InvalidArgument("Invalid index ", idx,
930 " for Tensor of size ", flat_t.size());
931 }
932 *val = flat_t(idx);
933 return Status::OK();
934 } else {
935 return errors::InvalidArgument("Tensor input must be int32 or int64.");
936 }
937 }
938
939 // Returns a new dimension whose value is given by a scalar input tensor.
MakeDimForScalarInput(int idx,DimensionHandle * out)940 Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
941 int64 val;
942 const Tensor* t = input_tensor(idx);
943 if (t == nullptr) {
944 *out = UnknownDim();
945 return Status::OK();
946 }
947 TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
948 if (val < 0) {
949 return errors::InvalidArgument("Dimension size, given by scalar input ",
950 idx, ", must be non-negative but is ", val);
951 }
952 *out = MakeDim(val);
953 return Status::OK();
954 }
955
MakeDimForScalarInputWithNegativeIndexing(int idx,int input_rank,DimensionHandle * out)956 Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing(
957 int idx, int input_rank, DimensionHandle* out) {
958 int64 val;
959 const Tensor* t = input_tensor(idx);
960 if (t == nullptr) {
961 *out = UnknownDim();
962 return Status::OK();
963 }
964 TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
965 if (val < 0) {
966 if (input_rank < 0) {
967 *out = UnknownDim();
968 return Status::OK();
969 } else if (val + input_rank < 0) {
970 return errors::InvalidArgument("Dimension size, given by scalar input ",
971 val, " must be in range [-", input_rank,
972 ", ", input_rank, ")");
973 } else {
974 val += input_rank;
975 }
976 } else if (input_rank >= 0 && val >= input_rank) {
977 return errors::InvalidArgument("Dimension size, given by scalar input ",
978 val, " must be in range [-", input_rank,
979 ", ", input_rank, ")");
980 }
981 *out = MakeDim(val);
982 return Status::OK();
983 }
984
Divide(DimensionHandle dividend,DimensionOrConstant divisor,bool evenly_divisible,DimensionHandle * out)985 Status InferenceContext::Divide(DimensionHandle dividend,
986 DimensionOrConstant divisor,
987 bool evenly_divisible, DimensionHandle* out) {
988 const int64 divisor_value = Value(divisor);
989 if (divisor_value == 1) {
990 *out = dividend;
991 } else if (!ValueKnown(dividend) ||
992 (divisor.dim.IsSet() && !ValueKnown(divisor.dim))) {
993 *out = UnknownDim();
994 } else {
995 const int64 v = Value(dividend);
996 if (divisor_value <= 0) {
997 return errors::InvalidArgument("Divisor must be positive but is ",
998 divisor_value);
999 }
1000 if (evenly_divisible && (v % divisor_value) != 0) {
1001 return errors::InvalidArgument(
1002 "Dimension size must be evenly divisible by ", divisor_value,
1003 " but is ", v);
1004 }
1005 *out = MakeDim(v / divisor_value);
1006 }
1007 return Status::OK();
1008 }
1009
Add(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1010 Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second,
1011 DimensionHandle* out) {
1012 const int64 first_value = Value(first);
1013 const int64 second_value = Value(second);
1014 // Special cases.
1015 if (first_value == 0) {
1016 *out = MakeDim(second);
1017 } else if (second_value == 0) {
1018 *out = first;
1019 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1020 *out = UnknownDim();
1021 } else {
1022 // Invariant: Both values are known and positive. Still in run-time we can
1023 // get pair of values which cannot be store in output. Check below will
1024 // report error. We still need to avoid undefined behavior of signed
1025 // overflow and use unsigned addition.
1026 const int64 sum = static_cast<uint64>(first_value) + second_value;
1027 if (sum < 0) {
1028 return errors::InvalidArgument("Dimension size overflow from adding ",
1029 first_value, " and ", second_value);
1030 }
1031 *out = MakeDim(sum);
1032 }
1033 return Status::OK();
1034 }
1035
Subtract(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1036 Status InferenceContext::Subtract(DimensionHandle first,
1037 DimensionOrConstant second,
1038 DimensionHandle* out) {
1039 const int64 first_value = Value(first);
1040 const int64 second_value = Value(second);
1041 // Special cases.
1042 if (second_value == 0) {
1043 *out = first;
1044 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1045 *out = UnknownDim();
1046 } else {
1047 // Invariant: Both values are known, first_value is non-negative, and
1048 // second_value is positive.
1049 if (first_value < second_value) {
1050 return errors::InvalidArgument(
1051 "Negative dimension size caused by subtracting ", second_value,
1052 " from ", first_value);
1053 }
1054 *out = MakeDim(first_value - second_value);
1055 }
1056 return Status::OK();
1057 }
1058
Multiply(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1059 Status InferenceContext::Multiply(DimensionHandle first,
1060 DimensionOrConstant second,
1061 DimensionHandle* out) {
1062 const int64 first_value = Value(first);
1063 const int64 second_value = Value(second);
1064 // Special cases.
1065 if (first_value == 0) {
1066 *out = first;
1067 } else if (second_value == 0) {
1068 *out = MakeDim(second);
1069 } else if (first_value == 1) {
1070 *out = MakeDim(second);
1071 } else if (second_value == 1) {
1072 *out = first;
1073 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1074 *out = UnknownDim();
1075 } else {
1076 // Invariant: Both values are known and greater than 1.
1077 const int64 product = first_value * second_value;
1078 if (product < 0) {
1079 return errors::InvalidArgument(
1080 "Negative dimension size caused by overflow when multiplying ",
1081 first_value, " and ", second_value);
1082 }
1083 *out = MakeDim(product);
1084 }
1085 return Status::OK();
1086 }
1087
Min(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1088 Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second,
1089 DimensionHandle* out) {
1090 const int64 first_value = Value(first);
1091 const int64 second_value = Value(second);
1092 if (first_value == 0) {
1093 *out = first;
1094 } else if (second_value == 0) {
1095 *out = MakeDim(second);
1096 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1097 *out = UnknownDim();
1098 } else {
1099 if (first_value <= second_value) {
1100 *out = first;
1101 } else {
1102 *out = MakeDim(second);
1103 }
1104 }
1105 return Status::OK();
1106 }
1107
Max(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1108 Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second,
1109 DimensionHandle* out) {
1110 const int64 first_value = Value(first);
1111 const int64 second_value = Value(second);
1112 if (first_value == kUnknownDim || second_value == kUnknownDim) {
1113 *out = UnknownDim();
1114 } else {
1115 if (first_value >= second_value) {
1116 *out = first;
1117 } else {
1118 *out = MakeDim(second);
1119 }
1120 }
1121 return Status::OK();
1122 }
1123
AttachContext(const Status & status)1124 Status InferenceContext::AttachContext(const Status& status) {
1125 std::vector<string> input_shapes;
1126 input_shapes.reserve(inputs_.size());
1127 for (const ShapeHandle& input_shape : inputs_) {
1128 input_shapes.emplace_back(DebugString(input_shape));
1129 }
1130
1131 // Add information about the input tensors and partial tensor shapes used.
1132 std::vector<string> input_from_tensors_str;
1133 std::vector<string> input_from_tensors_as_shape_str;
1134 input_from_tensors_as_shape_str.reserve(inputs_.size());
1135 for (int i = 0, end = inputs_.size(); i < end; ++i) {
1136 const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
1137 const int input_tensors_size = input_tensors_.size();
1138 if (requested_input_tensor_as_partial_shape_[i] &&
1139 i < input_tensors_as_shapes_size &&
1140 input_tensors_as_shapes_[i].IsSet() &&
1141 RankKnown(input_tensors_as_shapes_[i])) {
1142 input_from_tensors_as_shape_str.push_back(strings::StrCat(
1143 "input[", i, "] = ", DebugString(input_tensors_as_shapes_[i])));
1144 } else if (requested_input_tensor_[i] && i < input_tensors_size &&
1145 input_tensors_[i] != nullptr) {
1146 input_from_tensors_str.push_back(strings::StrCat(
1147 "input[", i, "] = <",
1148 input_tensors_[i]->SummarizeValue(256 /* max_values */), ">"));
1149 }
1150 }
1151
1152 string error_context = strings::StrCat(
1153 " for '", attrs_.SummarizeNode(),
1154 "' with input shapes: ", absl::StrJoin(input_shapes, ", "));
1155 if (!input_from_tensors_str.empty()) {
1156 strings::StrAppend(&error_context, " and with computed input tensors: ",
1157 absl::StrJoin(input_from_tensors_str, ", "));
1158 }
1159 if (!input_from_tensors_as_shape_str.empty()) {
1160 strings::StrAppend(&error_context,
1161 " and with input tensors computed as partial shapes: ",
1162 absl::StrJoin(input_from_tensors_as_shape_str, ","));
1163 }
1164
1165 strings::StrAppend(&error_context, ".");
1166 return Status(status.code(),
1167 strings::StrCat(status.error_message(), error_context));
1168 }
1169
MergeHandleShapesAndTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1170 bool InferenceContext::MergeHandleShapesAndTypes(
1171 const std::vector<ShapeAndType>& shapes_and_types,
1172 std::vector<ShapeAndType>* to_update) {
1173 if (shapes_and_types.size() != to_update->size()) {
1174 return false;
1175 }
1176 std::vector<ShapeAndType> new_values(shapes_and_types.size());
1177 bool refined = false;
1178 for (int i = 0, end = shapes_and_types.size(); i < end; ++i) {
1179 const ShapeAndType& existing = (*to_update)[i];
1180 if (shapes_and_types[i].dtype == existing.dtype) {
1181 new_values[i].dtype = existing.dtype;
1182 } else {
1183 if (existing.dtype != DT_INVALID) {
1184 return false;
1185 } else {
1186 new_values[i].dtype = shapes_and_types[i].dtype;
1187 refined = true;
1188 }
1189 }
1190 if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape)
1191 .ok()) {
1192 // merge failed, ignore the new value.
1193 new_values[i].shape = existing.shape;
1194 }
1195 if (!existing.shape.SameHandle(new_values[i].shape)) {
1196 refined = true;
1197 }
1198 }
1199 if (!refined) {
1200 return false;
1201 }
1202 for (int i = 0, end = new_values.size(); i < end; ++i) {
1203 (*to_update)[i] = new_values[i];
1204 }
1205 return true;
1206 }
1207
MergeOutputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1208 bool InferenceContext::MergeOutputHandleShapesAndTypes(
1209 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1210 if (output_handle_shapes_and_types_[idx] == nullptr) {
1211 output_handle_shapes_and_types_[idx].reset(
1212 new std::vector<ShapeAndType>(shapes_and_types));
1213 return true;
1214 }
1215 return MergeHandleShapesAndTypes(shapes_and_types,
1216 output_handle_shapes_and_types_[idx].get());
1217 }
1218
MergeInputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1219 bool InferenceContext::MergeInputHandleShapesAndTypes(
1220 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1221 if (input_handle_shapes_and_types_[idx] == nullptr) {
1222 input_handle_shapes_and_types_[idx].reset(
1223 new std::vector<ShapeAndType>(shapes_and_types));
1224 return true;
1225 }
1226 return MergeHandleShapesAndTypes(shapes_and_types,
1227 input_handle_shapes_and_types_[idx].get());
1228 }
1229
RelaxHandleShapesAndMergeTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1230 bool InferenceContext::RelaxHandleShapesAndMergeTypes(
1231 const std::vector<ShapeAndType>& shapes_and_types,
1232 std::vector<ShapeAndType>* to_update) {
1233 if (shapes_and_types.size() != to_update->size()) {
1234 return false;
1235 }
1236 std::vector<ShapeAndType> new_values(shapes_and_types.size());
1237 for (int i = 0, end = shapes_and_types.size(); i < end; ++i) {
1238 const ShapeAndType& existing = (*to_update)[i];
1239 if (shapes_and_types[i].dtype == existing.dtype) {
1240 new_values[i].dtype = existing.dtype;
1241 } else {
1242 if (existing.dtype != DT_INVALID) {
1243 return false;
1244 } else {
1245 new_values[i].dtype = shapes_and_types[i].dtype;
1246 }
1247 }
1248 Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape);
1249 }
1250 to_update->swap(new_values);
1251 return true;
1252 }
1253
RelaxOutputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1254 bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes(
1255 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1256 if (output_handle_shapes_and_types_[idx] == nullptr) {
1257 output_handle_shapes_and_types_[idx].reset(
1258 new std::vector<ShapeAndType>(shapes_and_types));
1259 return true;
1260 }
1261 return RelaxHandleShapesAndMergeTypes(
1262 shapes_and_types, output_handle_shapes_and_types_[idx].get());
1263 }
1264
RelaxInputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1265 bool InferenceContext::RelaxInputHandleShapesAndMergeTypes(
1266 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1267 if (input_handle_shapes_and_types_[idx] == nullptr) {
1268 input_handle_shapes_and_types_[idx].reset(
1269 new std::vector<ShapeAndType>(shapes_and_types));
1270 return true;
1271 }
1272 return RelaxHandleShapesAndMergeTypes(
1273 shapes_and_types, input_handle_shapes_and_types_[idx].get());
1274 }
1275
1276 // -----------------------------------------------------------------------------
1277 // ShapeManager
1278 // -----------------------------------------------------------------------------
ShapeManager()1279 InferenceContext::ShapeManager::ShapeManager() {}
~ShapeManager()1280 InferenceContext::ShapeManager::~ShapeManager() {
1281 for (auto* s : all_shapes_) delete s;
1282 for (auto* d : all_dims_) delete d;
1283 }
1284
MakeShape(const std::vector<DimensionHandle> & dims)1285 ShapeHandle InferenceContext::ShapeManager::MakeShape(
1286 const std::vector<DimensionHandle>& dims) {
1287 all_shapes_.push_back(new Shape(dims));
1288 return all_shapes_.back();
1289 }
1290
UnknownShape()1291 ShapeHandle InferenceContext::ShapeManager::UnknownShape() {
1292 all_shapes_.push_back(new Shape());
1293 return all_shapes_.back();
1294 }
1295
1296 } // namespace shape_inference
1297 } // namespace tensorflow
1298