1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16
17 #include "tensorflow/core/framework/node_def.pb_text.h"
18 #include "tensorflow/core/framework/partial_tensor_shape.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/kernels/bounds_check.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/strings/numbers.h"
23 #include "tensorflow/core/lib/strings/scanner.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25
26 namespace tensorflow {
27 namespace shape_inference {
28
29 constexpr int32 InferenceContext::kUnknownRank;
30 constexpr int64 InferenceContext::kUnknownDim;
31
InferenceContext(int graph_def_version,const NodeDef * node_def,const OpDef & op_def,const std::vector<TensorShapeProto> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<TensorShapeProto> & input_tensors_as_shapes,const std::vector<std::unique_ptr<std::vector<std::pair<TensorShapeProto,DataType>>>> & input_handle_shapes_and_types)32 InferenceContext::InferenceContext(
33 int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
34 const std::vector<TensorShapeProto>& input_shapes,
35 const std::vector<const Tensor*>& input_tensors,
36 const std::vector<TensorShapeProto>& input_tensors_as_shapes,
37 const std::vector<
38 std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
39 input_handle_shapes_and_types)
40 : graph_def_version_(graph_def_version),
41 node_def_(CHECK_NOTNULL(node_def)) {
42 std::vector<ShapeHandle> input_tensors_as_shape_handles;
43 for (const TensorShapeProto& p : input_tensors_as_shapes) {
44 ShapeHandle shape;
45 construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
46 if (!construction_status_.ok()) {
47 return;
48 }
49 input_tensors_as_shape_handles.push_back(shape);
50 }
51 PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
52 if (!construction_status_.ok()) return;
53 for (const TensorShapeProto& p : input_shapes) {
54 ShapeHandle shape;
55 construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
56 if (!construction_status_.ok()) {
57 return;
58 }
59 inputs_.push_back(shape);
60 }
61
62 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
63 input_shapes.size());
64 for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) {
65 const auto& v = input_handle_shapes_and_types[i];
66 if (v == nullptr) {
67 continue;
68 }
69 handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
70 auto& new_v = *handle_data[i];
71 for (int j = 0; j < v->size(); ++j) {
72 const auto& p = (*v)[j];
73 construction_status_.Update(
74 MakeShapeFromShapeProto(p.first, &new_v[j].shape));
75 if (!construction_status_.ok()) {
76 return;
77 }
78 new_v[j].dtype = p.second;
79 }
80 }
81 PostInputInit(std::move(handle_data));
82 }
83
84 // Same as above, but with PartialTensorShape instead of TensorShapeProto
InferenceContext(int graph_def_version,const NodeDef * node_def,const OpDef & op_def,const std::vector<PartialTensorShape> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<PartialTensorShape> & input_tensors_as_shapes,const std::vector<std::unique_ptr<std::vector<std::pair<PartialTensorShape,DataType>>>> & input_handle_shapes_and_types)85 InferenceContext::InferenceContext(
86 int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
87 const std::vector<PartialTensorShape>& input_shapes,
88 const std::vector<const Tensor*>& input_tensors,
89 const std::vector<PartialTensorShape>& input_tensors_as_shapes,
90 const std::vector<
91 std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
92 input_handle_shapes_and_types)
93 : graph_def_version_(graph_def_version),
94 node_def_(CHECK_NOTNULL(node_def)) {
95 std::vector<ShapeHandle> input_tensors_as_shape_handles;
96 for (const PartialTensorShape& p : input_tensors_as_shapes) {
97 ShapeHandle shape;
98 construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
99 if (!construction_status_.ok()) {
100 return;
101 }
102 input_tensors_as_shape_handles.push_back(shape);
103 }
104 PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
105 if (!construction_status_.ok()) return;
106 for (const PartialTensorShape& p : input_shapes) {
107 ShapeHandle shape;
108 construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
109 if (!construction_status_.ok()) {
110 return;
111 }
112 inputs_.push_back(shape);
113 }
114 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
115 input_shapes.size());
116 for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) {
117 const auto& v = input_handle_shapes_and_types[i];
118 if (v == nullptr) {
119 continue;
120 }
121 handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
122 auto& new_v = *handle_data[i];
123 for (int j = 0; j < v->size(); ++j) {
124 const auto& p = (*v)[j];
125 construction_status_.Update(
126 MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape));
127 if (!construction_status_.ok()) {
128 return;
129 }
130 new_v[j].dtype = p.second;
131 }
132 }
133 PostInputInit(std::move(handle_data));
134 }
135
InferenceContext(int graph_def_version,const NodeDef * node_def,const OpDef & op_def,const std::vector<ShapeHandle> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes,std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types)136 InferenceContext::InferenceContext(
137 int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
138 const std::vector<ShapeHandle>& input_shapes,
139 const std::vector<const Tensor*>& input_tensors,
140 const std::vector<ShapeHandle>& input_tensors_as_shapes,
141 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
142 input_handle_shapes_and_types)
143 : graph_def_version_(graph_def_version),
144 node_def_(CHECK_NOTNULL(node_def)) {
145 PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
146 if (!construction_status_.ok()) return;
147 inputs_ = input_shapes;
148
149 PostInputInit(std::move(input_handle_shapes_and_types));
150 }
151
~InferenceContext()152 InferenceContext::~InferenceContext() {}
153
Run(const std::function<Status (shape_inference::InferenceContext * c)> & fn)154 Status InferenceContext::Run(
155 const std::function<Status(shape_inference::InferenceContext* c)>& fn) {
156 Status s = fn(this);
157 if (!s.ok()) {
158 return AttachContext(s);
159 }
160 #ifndef NDEBUG
161 for (int i = 0; i < num_outputs(); ++i) {
162 DCHECK(output(i).IsSet())
163 << i << " for " << node_def_->name() << " of type " << node_def_->op();
164 }
165 #endif // NDEBUG
166 return s;
167 }
168
set_output(StringPiece output_name,const std::vector<ShapeHandle> & shapes)169 Status InferenceContext::set_output(StringPiece output_name,
170 const std::vector<ShapeHandle>& shapes) {
171 auto result = output_name_map_.find(output_name);
172 if (result == output_name_map_.end()) {
173 return errors::InvalidArgument("Unknown output name: ", output_name);
174 } else {
175 const int start = result->second.first;
176 const int size = result->second.second - start;
177 if (size != shapes.size()) {
178 return errors::InvalidArgument("Must have exactly ", shapes.size(),
179 " shapes.");
180 }
181 for (int i = 0; i < size; ++i) {
182 outputs_[i + start] = shapes[i];
183 }
184 }
185 return Status::OK();
186 }
187
input(StringPiece input_name,std::vector<ShapeHandle> * output) const188 Status InferenceContext::input(StringPiece input_name,
189 std::vector<ShapeHandle>* output) const {
190 const auto result = input_name_map_.find(input_name);
191 if (result == input_name_map_.end()) {
192 return errors::InvalidArgument("Unknown input name: ", input_name);
193 } else {
194 output->clear();
195 for (int i = result->second.first; i < result->second.second; ++i) {
196 output->push_back(inputs_[i]);
197 }
198 }
199 return Status::OK();
200 }
201
output(StringPiece output_name,std::vector<ShapeHandle> * output) const202 Status InferenceContext::output(StringPiece output_name,
203 std::vector<ShapeHandle>* output) const {
204 const auto result = output_name_map_.find(output_name);
205 if (result == output_name_map_.end()) {
206 return errors::InvalidArgument("Unknown output name: ", output_name);
207 } else {
208 output->clear();
209 for (int i = result->second.first; i < result->second.second; ++i) {
210 output->push_back(outputs_[i]);
211 }
212 }
213 return Status::OK();
214 }
215
op() const216 string InferenceContext::op() const { return node_def_->op(); }
217
PreInputInit(const OpDef & op_def,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes)218 void InferenceContext::PreInputInit(
219 const OpDef& op_def, const std::vector<const Tensor*>& input_tensors,
220 const std::vector<ShapeHandle>& input_tensors_as_shapes) {
221 input_tensors_ = input_tensors;
222 input_tensors_as_shapes_ = input_tensors_as_shapes;
223
224 construction_status_ = NameRangesForNode(*node_def_, op_def, &input_name_map_,
225 &output_name_map_);
226 if (!construction_status_.ok()) return;
227
228 int num_outputs = 0;
229 for (const auto& e : output_name_map_) {
230 num_outputs = std::max(num_outputs, e.second.second);
231 }
232 for (int i = 0; i < num_outputs; ++i) {
233 outputs_.push_back(nullptr);
234 }
235 output_handle_shapes_and_types_.resize(num_outputs);
236 }
237
PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data)238 void InferenceContext::PostInputInit(
239 std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
240 int num_inputs_from_node_def = 0;
241 for (const auto& e : input_name_map_) {
242 num_inputs_from_node_def =
243 std::max(num_inputs_from_node_def, e.second.second);
244 }
245
246 // Allow passing empty shapes/dtypes to avoid changing every single test.
247 if (input_handle_data.empty()) {
248 input_handle_shapes_and_types_.resize(inputs_.size());
249 } else {
250 if (input_handle_data.size() != inputs_.size()) {
251 construction_status_ = errors::InvalidArgument(
252 "Wrong number of handle shapes passed; expected ", inputs_.size(),
253 " got ", input_handle_data.size());
254 return;
255 }
256 input_handle_shapes_and_types_ = std::move(input_handle_data);
257 }
258
259 if (inputs_.size() != num_inputs_from_node_def) {
260 construction_status_ = errors::InvalidArgument(
261 "Wrong number of inputs passed: ", inputs_.size(), " while ",
262 num_inputs_from_node_def, " expected based on NodeDef");
263 return;
264 }
265
266 CHECK_LE(input_tensors_.size(), inputs_.size());
267 input_tensors_.resize(inputs_.size());
268 requested_input_tensor_.resize(inputs_.size());
269 requested_input_tensor_as_partial_shape_.resize(inputs_.size());
270 }
271
ShapeHandleToProto(ShapeHandle handle,TensorShapeProto * proto)272 void InferenceContext::ShapeHandleToProto(ShapeHandle handle,
273 TensorShapeProto* proto) {
274 if (!RankKnown(handle)) {
275 proto->set_unknown_rank(true);
276 return;
277 }
278
279 for (int32 i = 0; i < Rank(handle); ++i) {
280 DimensionHandle dim = Dim(handle, i);
281 auto* dim_shape = proto->add_dim();
282 if (ValueKnown(dim)) {
283 dim_shape->set_size(Value(dim));
284 } else {
285 dim_shape->set_size(-1);
286 }
287 }
288 }
289
FullyDefined(ShapeHandle s)290 bool InferenceContext::FullyDefined(ShapeHandle s) {
291 if (!RankKnown(s)) return false;
292 for (int i = 0; i < Rank(s); ++i) {
293 if (!ValueKnown(Dim(s, i))) return false;
294 }
295 return true;
296 }
297
NumElements(ShapeHandle s)298 DimensionHandle InferenceContext::NumElements(ShapeHandle s) {
299 const auto rank = Rank(s);
300 if (rank == kUnknownRank) return UnknownDim();
301 int64 size = 1;
302 for (int i = 0; i < rank; ++i) {
303 int64 dim_val = Value(Dim(s, i));
304 if (dim_val == kUnknownDim) return UnknownDim();
305 size *= dim_val;
306 }
307 return MakeDim(size);
308 }
309
DebugString(ShapeHandle s)310 string InferenceContext::DebugString(ShapeHandle s) {
311 if (RankKnown(s)) {
312 std::vector<string> vals;
313 for (auto d : s->dims_) vals.push_back(DebugString(d));
314 return strings::StrCat("[", str_util::Join(vals, ","), "]");
315 } else {
316 return "?";
317 }
318 }
319
DebugString(DimensionHandle d)320 string InferenceContext::DebugString(DimensionHandle d) {
321 return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
322 }
323
DebugString() const324 string InferenceContext::DebugString() const {
325 return strings::StrCat("InferenceContext for node: ",
326 ProtoDebugString(*node_def_));
327 }
328
WithRank(ShapeHandle shape,int64 rank,ShapeHandle * out)329 Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
330 ShapeHandle* out) {
331 if (rank > kint32max) {
332 return errors::InvalidArgument("Rank cannot exceed kint32max");
333 }
334 const int32 existing = Rank(shape);
335 if (existing == rank) {
336 *out = shape;
337 return Status::OK();
338 }
339 if (existing == kUnknownRank) {
340 std::vector<DimensionHandle> dims;
341 dims.reserve(rank);
342 for (int i = 0; i < rank; ++i) {
343 dims.push_back(UnknownDim());
344 }
345 ShapeHandle shp = shape_manager_.MakeShape(dims);
346 return Merge(shape, shp, out);
347 }
348 *out = nullptr;
349
350 return errors::InvalidArgument("Shape must be rank ", rank, " but is rank ",
351 existing);
352 }
353
WithRankAtLeast(ShapeHandle shape,int64 rank,ShapeHandle * out)354 Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank,
355 ShapeHandle* out) {
356 if (rank > kint32max) {
357 return errors::InvalidArgument("Rank cannot exceed kint32max");
358 }
359 const int32 existing = Rank(shape);
360 if (existing >= rank || existing == kUnknownRank) {
361 *out = shape;
362 return Status::OK();
363 }
364 *out = nullptr;
365 return errors::InvalidArgument("Shape must be at least rank ", rank,
366 " but is rank ", existing);
367 }
368
WithRankAtMost(ShapeHandle shape,int64 rank,ShapeHandle * out)369 Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank,
370 ShapeHandle* out) {
371 if (rank > kint32max) {
372 return errors::InvalidArgument("Rank cannot exceed kint32max");
373 }
374 const int32 existing = Rank(shape);
375 if (existing <= rank || existing == kUnknownRank) {
376 *out = shape;
377 return Status::OK();
378 }
379 *out = nullptr;
380 return errors::InvalidArgument("Shape must be at most rank ", rank,
381 " but is rank ", existing);
382 }
383
WithValue(DimensionHandle dim,int64 value,DimensionHandle * out)384 Status InferenceContext::WithValue(DimensionHandle dim, int64 value,
385 DimensionHandle* out) {
386 const int64 existing = Value(dim);
387 if (existing == value) {
388 *out = dim;
389 return Status::OK();
390 }
391 if (existing == kUnknownDim) {
392 DimensionHandle d = MakeDim(value);
393 return Merge(dim, d, out);
394 }
395 *out = nullptr;
396 return errors::InvalidArgument("Dimension must be ", value, " but is ",
397 existing);
398 }
399
Relax(DimensionHandle d_old,DimensionHandle d_new,DimensionHandle * out)400 void InferenceContext::Relax(DimensionHandle d_old, DimensionHandle d_new,
401 DimensionHandle* out) {
402 if (d_old.SameHandle(d_new)) {
403 *out = d_old;
404 } else if (!ValueKnown(d_old) && !ValueKnown(d_new)) {
405 // The node will be fed by the dimension d_new instead of d_old: any
406 // equality assertion between d_old and other input dimension on this node
407 // may not be true anymore, so forget them all.
408 ForgetMerges();
409 // Return the new shape handle to force the relaxation to propagate to the
410 // fanout of the context.
411 *out = d_new;
412 } else if (!ValueKnown(d_new)) {
413 ForgetMerges();
414 *out = d_new;
415 } else if (Value(d_old) == Value(d_new)) {
416 // Return the old shape handle. This will stop the relaxation in the fanout
417 // of the context.
418 *out = d_old;
419 } else {
420 // Return a new handle that encodes a different unknown dim.
421 ForgetMerges();
422 *out = UnknownDim();
423 }
424 }
425
Merge(DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)426 Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1,
427 DimensionHandle* out) {
428 if (d0.SameHandle(d1)) {
429 *out = d0;
430 return Status::OK();
431 } else if (!ValueKnown(d1)) {
432 *out = d0;
433 merged_dims_.emplace_back(d0, d1);
434 return Status::OK();
435 } else if (!ValueKnown(d0)) {
436 *out = d1;
437 merged_dims_.emplace_back(d0, d1);
438 return Status::OK();
439 } else if (Value(d0) == Value(d1)) {
440 *out = d0;
441 return Status::OK();
442 } else {
443 *out = nullptr;
444 return errors::InvalidArgument("Dimensions must be equal, but are ",
445 Value(d0), " and ", Value(d1));
446 }
447 }
448
MergePrefix(ShapeHandle s,ShapeHandle prefix,ShapeHandle * s_out,ShapeHandle * prefix_out)449 Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix,
450 ShapeHandle* s_out,
451 ShapeHandle* prefix_out) {
452 *s_out = *prefix_out = nullptr;
453 if (!RankKnown(prefix) || !RankKnown(s)) {
454 *s_out = s;
455 *prefix_out = prefix;
456 return Status::OK();
457 }
458 const int32 rank = Rank(prefix);
459 TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s));
460
461 // Merge the prefix dims and create the new output shapes.
462 std::vector<DimensionHandle> dims;
463 dims.resize(rank);
464 for (int i = 0; i < rank; ++i) {
465 TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i]));
466 }
467 *prefix_out = MakeShape(dims);
468 for (int i = rank; i < Rank(s); ++i) dims.push_back(Dim(s, i));
469 *s_out = MakeShape(dims);
470 return Status::OK();
471 }
472
Relax(ShapeHandle s_old,ShapeHandle s_new,ShapeHandle * out)473 void InferenceContext::Relax(ShapeHandle s_old, ShapeHandle s_new,
474 ShapeHandle* out) {
475 if (s_old.SameHandle(s_new)) {
476 *out = s_old;
477 return;
478 } else if (!RankKnown(s_new) || !s_old.IsSet()) {
479 ForgetMerges();
480 *out = s_new;
481 return;
482 }
483
484 const int32 rank = Rank(s_old);
485 if (rank != Rank(s_new)) {
486 ForgetMerges();
487 *out = UnknownShape();
488 return;
489 }
490
491 bool return_s_old = true;
492 for (int i = 0; i < rank; ++i) {
493 auto d0 = Dim(s_old, i);
494 auto d1 = Dim(s_new, i);
495 if (d0.SameHandle(d1)) continue;
496
497 auto v0 = Value(d0);
498 auto v1 = Value(d1);
499 if (v0 == kUnknownDim || v1 == kUnknownDim || v0 != v1) {
500 return_s_old = false;
501 break;
502 }
503 }
504 if (return_s_old) {
505 *out = s_old;
506 return;
507 }
508
509 // Relax dims.
510 std::vector<DimensionHandle> dims(rank);
511 for (int i = 0; i < rank; ++i) {
512 Relax(Dim(s_old, i), Dim(s_new, i), &dims[i]);
513 }
514 ForgetMerges();
515 *out = MakeShape(dims);
516 }
517
Merge(ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)518 Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
519 ShapeHandle* out) {
520 if (s0.SameHandle(s1)) {
521 *out = s0;
522 return Status::OK();
523 } else if (!RankKnown(s1)) {
524 *out = s0;
525 merged_shapes_.emplace_back(s0, s1);
526 return Status::OK();
527 } else if (!RankKnown(s0)) {
528 *out = s1;
529 merged_shapes_.emplace_back(s0, s1);
530 return Status::OK();
531 }
532
533 const int32 rank = Rank(s0);
534 if (rank != Rank(s1)) {
535 *out = nullptr;
536 return errors::InvalidArgument("Shapes must be equal rank, but are ", rank,
537 " and ", Rank(s1));
538 }
539
540 bool return_s0 = true;
541 bool return_s1 = true;
542 for (int i = 0; i < rank; ++i) {
543 auto d0 = Dim(s0, i);
544 auto d1 = Dim(s1, i);
545 if (d0.SameHandle(d1)) continue;
546
547 auto v0 = Value(d0);
548 auto v1 = Value(d1);
549 if (v0 == kUnknownDim) {
550 if (v1 != kUnknownDim) {
551 return_s0 = false;
552 }
553 } else if (v1 == kUnknownDim) {
554 return_s1 = false;
555 } else if (v0 != v1) {
556 *out = nullptr;
557 return errors::InvalidArgument(
558 "Dimension ", i, " in both shapes must be equal, but are ", Value(d0),
559 " and ", Value(d1), ". Shapes are ", DebugString(s0), " and ",
560 DebugString(s1), ".");
561 }
562 }
563
564 merged_shapes_.emplace_back(s0, s1);
565
566 if (return_s0 || return_s1) {
567 *out = return_s0 ? s0 : s1;
568 return Status::OK();
569 }
570
571 // Merge dims.
572 std::vector<DimensionHandle> dims(rank, nullptr);
573 for (int i = 0; i < rank; ++i) {
574 // Invariant for merge was checked earlier, so CHECK is ok.
575 TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i]));
576 }
577
578 Status s = ReturnCreatedShape(dims, out);
579 if (s.ok()) {
580 // Merge the new shape with s0. Since s0 and s1 are merged, this implies
581 // that s1 and out are also merged.
582 merged_shapes_.emplace_back(s0, *out);
583 }
584 return s;
585 }
586
Subshape(ShapeHandle s,int64 start,ShapeHandle * out)587 Status InferenceContext::Subshape(ShapeHandle s, int64 start,
588 ShapeHandle* out) {
589 return Subshape(s, start, std::numeric_limits<int64>::max() /* end */, out);
590 }
591
Subshape(ShapeHandle s,int64 start_in,int64 end_in,ShapeHandle * out)592 Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in,
593 ShapeHandle* out) {
594 int64 start = start_in;
595 int64 end = end_in;
596 const int32 rank = Rank(s);
597 if (start == 0 && ((RankKnown(s) && end >= rank) ||
598 end == std::numeric_limits<int64>::max())) {
599 *out = s;
600 return Status::OK();
601 }
602 if (!RankKnown(s)) {
603 return ReturnUnknownShape(out);
604 }
605
606 if (start > rank) start = rank;
607 if (end > rank) end = rank;
608 if (start < 0) {
609 start = rank + start;
610 if (start < 0) {
611 *out = nullptr;
612 return errors::InvalidArgument("Subshape start out of bounds: ", start_in,
613 ", for shape with rank ", rank);
614 }
615 }
616
617 if (end < 0) {
618 end = rank + end;
619 if (end < 0) {
620 *out = nullptr;
621 return errors::InvalidArgument("Subshape end out of bounds: ", end_in,
622 ", for shape with rank ", rank);
623 }
624 }
625 if (start > end) {
626 *out = nullptr;
627 return errors::InvalidArgument(
628 "Subshape must have computed start <= end, but is ", start, " and ",
629 end, " (computed from start ", start_in, " and end ", end_in,
630 " over shape with rank ", rank, ")");
631 }
632 std::vector<DimensionHandle> dims;
633 dims.reserve(end - start);
634 for (int i = start; i < end; ++i) {
635 dims.push_back(Dim(s, i));
636 }
637 return ReturnCreatedShape(dims, out);
638 }
639
Concatenate(ShapeHandle s1,ShapeHandle s2,ShapeHandle * out)640 Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2,
641 ShapeHandle* out) {
642 if (!RankKnown(s1) || !RankKnown(s2)) {
643 return ReturnUnknownShape(out);
644 }
645 const int32 s1_rank = Rank(s1);
646 const int32 s2_rank = Rank(s2);
647 const int32 rank = s1_rank + s2_rank;
648 std::vector<DimensionHandle> dims;
649 dims.reserve(rank);
650 for (int i = 0; i < s1_rank; ++i) dims.push_back(Dim(s1, i));
651 for (int i = 0; i < s2_rank; ++i) dims.push_back(Dim(s2, i));
652 return ReturnCreatedShape(dims, out);
653 }
654
ReplaceDim(ShapeHandle s,int64 dim_index_in,DimensionHandle new_dim,ShapeHandle * out)655 Status InferenceContext::ReplaceDim(ShapeHandle s, int64 dim_index_in,
656 DimensionHandle new_dim, ShapeHandle* out) {
657 if (!RankKnown(s)) {
658 return ReturnUnknownShape(out);
659 }
660 int64 dim_index = dim_index_in;
661 if (dim_index < 0) {
662 dim_index = s->dims_.size() + dim_index;
663 }
664 if (!FastBoundsCheck(dim_index, s->dims_.size())) {
665 *out = nullptr;
666 return errors::InvalidArgument("Out of range dim_index ", dim_index_in,
667 " for shape with ", s->dims_.size(),
668 " dimensions");
669 }
670 std::vector<DimensionHandle> dims(s->dims_);
671 dims[dim_index] = new_dim;
672 return ReturnCreatedShape(dims, out);
673 }
674
MakeShape(const std::vector<DimensionHandle> & dims)675 ShapeHandle InferenceContext::MakeShape(
676 const std::vector<DimensionHandle>& dims) {
677 return shape_manager_.MakeShape(dims);
678 }
679
MakeShape(std::initializer_list<DimensionOrConstant> dims)680 ShapeHandle InferenceContext::MakeShape(
681 std::initializer_list<DimensionOrConstant> dims) {
682 std::vector<DimensionHandle> dims_actual;
683 dims_actual.reserve(dims.size());
684 for (const DimensionOrConstant& d : dims) {
685 dims_actual.push_back(MakeDim(d));
686 }
687
688 return shape_manager_.MakeShape(dims_actual);
689 }
690
UnknownShape()691 ShapeHandle InferenceContext::UnknownShape() {
692 return shape_manager_.UnknownShape();
693 }
694
UnknownShapeOfRank(int64 rank)695 ShapeHandle InferenceContext::UnknownShapeOfRank(int64 rank) {
696 CHECK_LE(rank, kint32max) << "rank must be less than kint32max";
697 if (rank == kUnknownRank) {
698 return UnknownShape();
699 }
700 CHECK_GE(rank, 0) << "rank must not be negative";
701 std::vector<DimensionHandle> dims(rank);
702 for (int32 i = 0; i < rank; ++i) {
703 dims[i] = UnknownDim();
704 }
705 return MakeShape(dims);
706 }
707
Scalar()708 ShapeHandle InferenceContext::Scalar() { return MakeShape({}); }
709
Vector(DimensionOrConstant dim)710 ShapeHandle InferenceContext::Vector(DimensionOrConstant dim) {
711 return MakeShape({dim});
712 }
713
Matrix(DimensionOrConstant dim1,DimensionOrConstant dim2)714 ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
715 DimensionOrConstant dim2) {
716 return MakeShape({dim1, dim2});
717 }
718
MakeShapeFromShapeTensor(int input_idx,ShapeHandle * out)719 Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
720 ShapeHandle* out) {
721 ShapeHandle input_shape;
722 TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
723
724 requested_input_tensor_as_partial_shape_[input_idx] = true;
725 if (input_idx < input_tensors_as_shapes_.size() &&
726 input_tensors_as_shapes_[input_idx].IsSet() &&
727 RankKnown(input_tensors_as_shapes_[input_idx])) {
728 *out = input_tensors_as_shapes_[input_idx];
729 return Status::OK();
730 }
731
732 return MakeShapeFromTensor(input_tensor(input_idx), input_shape, out);
733 }
734
MakeShapeFromTensor(const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)735 Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
736 ShapeHandle tensor_shape,
737 ShapeHandle* out) {
738 if (t == nullptr) {
739 // Shape tensor is not known, but if the shape of the shape tensor is then
740 // the right number of unknown dims can be created.
741 DimensionHandle shape_dim = Dim(tensor_shape, 0);
742 if (!ValueKnown(shape_dim)) {
743 return ReturnUnknownShape(out);
744 }
745 const auto num_dims = Value(shape_dim);
746 std::vector<DimensionHandle> dims;
747 dims.reserve(num_dims);
748 for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim());
749 return ReturnCreatedShape(dims, out);
750 }
751
752 if (t->shape().dims() != 1) {
753 *out = nullptr;
754 return errors::InvalidArgument("Input tensor must be rank 1, but was rank ",
755 t->shape().dims());
756 }
757 std::vector<DimensionHandle> dims;
758 if (t->dtype() == DataType::DT_INT32) {
759 auto flat_t = t->flat<int32>();
760 for (int i = 0; i < flat_t.size(); ++i) {
761 const int32 val = flat_t(i);
762 if (val < -1) {
763 return errors::InvalidArgument(
764 "Invalid value in tensor used for shape: ", val);
765 }
766 // -1 will become an unknown dim.
767 dims.push_back(MakeDim(val));
768 }
769 } else if (t->dtype() == DataType::DT_INT64) {
770 auto flat_t = t->flat<int64>();
771 for (int i = 0; i < flat_t.size(); ++i) {
772 const int64 val = flat_t(i);
773 if (val < -1) {
774 return errors::InvalidArgument(
775 "Invalid value in tensor used for shape: ", val);
776 }
777 // -1 will become an unknown dim.
778 dims.push_back(MakeDim(val));
779 }
780 } else {
781 *out = nullptr;
782 return errors::InvalidArgument(
783 "Input tensor must be int32 or int64, but was ",
784 DataTypeString(t->dtype()));
785 }
786
787 return ReturnCreatedShape(dims, out);
788 }
789
MakeShapeFromPartialTensorShape(const PartialTensorShape & partial_shape,ShapeHandle * out)790 Status InferenceContext::MakeShapeFromPartialTensorShape(
791 const PartialTensorShape& partial_shape, ShapeHandle* out) {
792 *out = nullptr;
793 if (partial_shape.dims() == -1) {
794 return ReturnUnknownShape(out);
795 }
796 const int num_dims = partial_shape.dims();
797 std::vector<DimensionHandle> dims(num_dims);
798 for (int i = 0; i < num_dims; ++i) {
799 // -1 is unknown in PartialTensorShape and in InferenceContext, so this size
800 // can be passed directly to MakeDim.
801 dims[i] = MakeDim(partial_shape.dim_size(i));
802 }
803 return ReturnCreatedShape(dims, out);
804 }
805
MakeShapeFromTensorShape(const TensorShape & shape,ShapeHandle * out)806 Status InferenceContext::MakeShapeFromTensorShape(const TensorShape& shape,
807 ShapeHandle* out) {
808 return MakeShapeFromPartialTensorShape(PartialTensorShape(shape.dim_sizes()),
809 out);
810 }
811
MakeShapeFromShapeProto(const TensorShapeProto & proto,ShapeHandle * out)812 Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
813 ShapeHandle* out) {
814 *out = nullptr;
815 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto));
816 PartialTensorShape partial_shape(proto);
817 return MakeShapeFromPartialTensorShape(partial_shape, out);
818 }
819
GetScalarFromTensor(const Tensor * t,int64 * val)820 Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
821 // Caller must ensure that <t> is not NULL.
822 const int rank = t->dims();
823 if (rank != 0) {
824 return errors::InvalidArgument("Input must be scalar but has rank ", rank);
825 }
826
827 if (t->dtype() == DT_INT32) {
828 *val = t->scalar<int32>()();
829 return Status::OK();
830 } else if (t->dtype() == DT_INT64) {
831 *val = t->scalar<int64>()();
832 return Status::OK();
833 } else {
834 return errors::InvalidArgument(
835 "Scalar input for dim size must be int32 or int64");
836 }
837 }
838
839 // Returns a new dimension whose value is given by a scalar input tensor.
MakeDimForScalarInput(int idx,DimensionHandle * out)840 Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
841 int64 val;
842 const Tensor* t = input_tensor(idx);
843 if (t == nullptr) {
844 *out = UnknownDim();
845 return Status::OK();
846 }
847 TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
848 if (val < 0) {
849 return errors::InvalidArgument("Dimension size, given by scalar input ",
850 idx, ", must be non-negative but is ", val);
851 }
852 *out = MakeDim(val);
853 return Status::OK();
854 }
855
MakeDimForScalarInputWithNegativeIndexing(int idx,int input_rank,DimensionHandle * out)856 Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing(
857 int idx, int input_rank, DimensionHandle* out) {
858 int64 val;
859 const Tensor* t = input_tensor(idx);
860 if (t == nullptr) {
861 *out = UnknownDim();
862 return Status::OK();
863 }
864 TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
865 if (val < 0) {
866 if (input_rank < 0) {
867 *out = UnknownDim();
868 return Status::OK();
869 } else if (val + input_rank < 0) {
870 return errors::InvalidArgument("Dimension size, given by scalar input ",
871 val, " must be in range [-", input_rank,
872 ", ", input_rank, ")");
873 } else {
874 val += input_rank;
875 }
876 } else if (input_rank >= 0 && val >= input_rank) {
877 return errors::InvalidArgument("Dimension size, given by scalar input ",
878 val, " must be in range [-", input_rank,
879 ", ", input_rank, ")");
880 }
881 *out = MakeDim(val);
882 return Status::OK();
883 }
884
Divide(DimensionHandle dividend,DimensionOrConstant divisor,bool evenly_divisible,DimensionHandle * out)885 Status InferenceContext::Divide(DimensionHandle dividend,
886 DimensionOrConstant divisor,
887 bool evenly_divisible, DimensionHandle* out) {
888 const int64 divisor_value = Value(divisor);
889 if (divisor_value == 1) {
890 *out = dividend;
891 } else if (!ValueKnown(dividend) ||
892 (divisor.dim.IsSet() && !ValueKnown(divisor.dim))) {
893 *out = UnknownDim();
894 } else {
895 const int64 v = Value(dividend);
896 if (divisor_value <= 0) {
897 return errors::InvalidArgument("Divisor must be positive but is ",
898 divisor_value);
899 }
900 if (evenly_divisible && (v % divisor_value) != 0) {
901 return errors::InvalidArgument(
902 "Dimension size must be evenly divisible by ", divisor_value,
903 " but is ", v);
904 }
905 *out = MakeDim(v / divisor_value);
906 }
907 return Status::OK();
908 }
909
Add(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)910 Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second,
911 DimensionHandle* out) {
912 const int64 first_value = Value(first);
913 const int64 second_value = Value(second);
914 // Special cases.
915 if (first_value == 0) {
916 *out = MakeDim(second);
917 } else if (second_value == 0) {
918 *out = first;
919 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
920 *out = UnknownDim();
921 } else {
922 // Invariant: Both values are known and positive. Still in run-time we can
923 // get pair of values which cannot be store in output. Check below will
924 // report error. We still need to avoid undefined behavior of signed
925 // overflow and use unsigned addition.
926 const int64 sum = static_cast<uint64>(first_value) + second_value;
927 if (sum < 0) {
928 return errors::InvalidArgument("Dimension size overflow from adding ",
929 first_value, " and ", second_value);
930 }
931 *out = MakeDim(sum);
932 }
933 return Status::OK();
934 }
935
Subtract(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)936 Status InferenceContext::Subtract(DimensionHandle first,
937 DimensionOrConstant second,
938 DimensionHandle* out) {
939 const int64 first_value = Value(first);
940 const int64 second_value = Value(second);
941 // Special cases.
942 if (second_value == 0) {
943 *out = first;
944 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
945 *out = UnknownDim();
946 } else {
947 // Invariant: Both values are known, first_value is non-negative, and
948 // second_value is positive.
949 if (first_value < second_value) {
950 return errors::InvalidArgument(
951 "Negative dimension size caused by subtracting ", second_value,
952 " from ", first_value);
953 }
954 *out = MakeDim(first_value - second_value);
955 }
956 return Status::OK();
957 }
958
Multiply(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)959 Status InferenceContext::Multiply(DimensionHandle first,
960 DimensionOrConstant second,
961 DimensionHandle* out) {
962 const int64 first_value = Value(first);
963 const int64 second_value = Value(second);
964 // Special cases.
965 if (first_value == 0) {
966 *out = first;
967 } else if (second_value == 0) {
968 *out = MakeDim(second);
969 } else if (first_value == 1) {
970 *out = MakeDim(second);
971 } else if (second_value == 1) {
972 *out = first;
973 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
974 *out = UnknownDim();
975 } else {
976 // Invariant: Both values are known and greater than 1.
977 const int64 product = first_value * second_value;
978 if (product < 0) {
979 return errors::InvalidArgument(
980 "Negative dimension size caused by overflow when multiplying ",
981 first_value, " and ", second_value);
982 }
983 *out = MakeDim(product);
984 }
985 return Status::OK();
986 }
987
Min(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)988 Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second,
989 DimensionHandle* out) {
990 const int64 first_value = Value(first);
991 const int64 second_value = Value(second);
992 if (first_value == 0) {
993 *out = first;
994 } else if (second_value == 0) {
995 *out = MakeDim(second);
996 } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
997 *out = UnknownDim();
998 } else {
999 if (first_value <= second_value) {
1000 *out = first;
1001 } else {
1002 *out = MakeDim(second);
1003 }
1004 }
1005 return Status::OK();
1006 }
1007
Max(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1008 Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second,
1009 DimensionHandle* out) {
1010 const int64 first_value = Value(first);
1011 const int64 second_value = Value(second);
1012 if (first_value == kUnknownDim || second_value == kUnknownDim) {
1013 *out = UnknownDim();
1014 } else {
1015 if (first_value >= second_value) {
1016 *out = first;
1017 } else {
1018 *out = MakeDim(second);
1019 }
1020 }
1021 return Status::OK();
1022 }
1023
AttachContext(const Status & status)1024 Status InferenceContext::AttachContext(const Status& status) {
1025 std::vector<string> input_shapes;
1026 for (const ShapeHandle& input_shape : inputs_) {
1027 input_shapes.emplace_back(DebugString(input_shape));
1028 }
1029
1030 // Add information about the input tensors and partial tensor shapes used.
1031 std::vector<string> input_from_tensors_str;
1032 std::vector<string> input_from_tensors_as_shape_str;
1033 for (int i = 0; i < inputs_.size(); ++i) {
1034 if (requested_input_tensor_as_partial_shape_[i] &&
1035 i < input_tensors_as_shapes_.size() &&
1036 input_tensors_as_shapes_[i].IsSet() &&
1037 RankKnown(input_tensors_as_shapes_[i])) {
1038 input_from_tensors_as_shape_str.push_back(strings::StrCat(
1039 "input[", i, "] = ", DebugString(input_tensors_as_shapes_[i])));
1040 } else if (requested_input_tensor_[i] && i < input_tensors_.size() &&
1041 input_tensors_[i] != nullptr) {
1042 input_from_tensors_str.push_back(strings::StrCat(
1043 "input[", i, "] = <",
1044 input_tensors_[i]->SummarizeValue(256 /* max_values */), ">"));
1045 }
1046 }
1047
1048 string error_context = strings::StrCat(
1049 " for '", node_def_->name(), "' (op: '", node_def_->op(),
1050 "') with input shapes: ", str_util::Join(input_shapes, ", "));
1051 if (!input_from_tensors_str.empty()) {
1052 strings::StrAppend(&error_context, " and with computed input tensors: ",
1053 str_util::Join(input_from_tensors_str, ", "));
1054 }
1055 if (!input_from_tensors_as_shape_str.empty()) {
1056 strings::StrAppend(&error_context,
1057 " and with input tensors computed as partial shapes: ",
1058 str_util::Join(input_from_tensors_as_shape_str, ","));
1059 }
1060
1061 strings::StrAppend(&error_context, ".");
1062 return Status(status.code(),
1063 strings::StrCat(status.error_message(), error_context));
1064 }
1065
MergeHandleShapesAndTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1066 bool InferenceContext::MergeHandleShapesAndTypes(
1067 const std::vector<ShapeAndType>& shapes_and_types,
1068 std::vector<ShapeAndType>* to_update) {
1069 if (shapes_and_types.size() != to_update->size()) {
1070 return false;
1071 }
1072 std::vector<ShapeAndType> new_values(shapes_and_types.size());
1073 bool refined = false;
1074 for (int i = 0; i < shapes_and_types.size(); ++i) {
1075 const ShapeAndType& existing = (*to_update)[i];
1076 if (shapes_and_types[i].dtype == existing.dtype) {
1077 new_values[i].dtype = existing.dtype;
1078 } else {
1079 if (existing.dtype != DT_INVALID) {
1080 return false;
1081 } else {
1082 new_values[i].dtype = shapes_and_types[i].dtype;
1083 refined = true;
1084 }
1085 }
1086 if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape)
1087 .ok()) {
1088 // merge failed, ignore the new value.
1089 new_values[i].shape = existing.shape;
1090 }
1091 if (!existing.shape.SameHandle(new_values[i].shape)) {
1092 refined = true;
1093 }
1094 }
1095 if (!refined) {
1096 return false;
1097 }
1098 for (int i = 0; i < new_values.size(); ++i) {
1099 (*to_update)[i] = new_values[i];
1100 }
1101 return true;
1102 }
1103
MergeOutputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1104 bool InferenceContext::MergeOutputHandleShapesAndTypes(
1105 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1106 if (output_handle_shapes_and_types_[idx] == nullptr) {
1107 output_handle_shapes_and_types_[idx].reset(
1108 new std::vector<ShapeAndType>(shapes_and_types));
1109 return true;
1110 }
1111 return MergeHandleShapesAndTypes(shapes_and_types,
1112 output_handle_shapes_and_types_[idx].get());
1113 }
1114
MergeInputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1115 bool InferenceContext::MergeInputHandleShapesAndTypes(
1116 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1117 if (input_handle_shapes_and_types_[idx] == nullptr) {
1118 input_handle_shapes_and_types_[idx].reset(
1119 new std::vector<ShapeAndType>(shapes_and_types));
1120 return true;
1121 }
1122 return MergeHandleShapesAndTypes(shapes_and_types,
1123 input_handle_shapes_and_types_[idx].get());
1124 }
1125
RelaxHandleShapesAndMergeTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1126 bool InferenceContext::RelaxHandleShapesAndMergeTypes(
1127 const std::vector<ShapeAndType>& shapes_and_types,
1128 std::vector<ShapeAndType>* to_update) {
1129 if (shapes_and_types.size() != to_update->size()) {
1130 return false;
1131 }
1132 std::vector<ShapeAndType> new_values(shapes_and_types.size());
1133 bool refined = false;
1134 for (int i = 0; i < shapes_and_types.size(); ++i) {
1135 const ShapeAndType& existing = (*to_update)[i];
1136 if (shapes_and_types[i].dtype == existing.dtype) {
1137 new_values[i].dtype = existing.dtype;
1138 } else {
1139 if (existing.dtype != DT_INVALID) {
1140 return false;
1141 } else {
1142 new_values[i].dtype = shapes_and_types[i].dtype;
1143 refined = true;
1144 }
1145 }
1146 Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape);
1147 if (!existing.shape.SameHandle(new_values[i].shape)) {
1148 refined = true;
1149 }
1150 }
1151 if (!refined) {
1152 return false;
1153 }
1154 for (int i = 0; i < new_values.size(); ++i) {
1155 (*to_update)[i] = new_values[i];
1156 }
1157 return true;
1158 }
1159
RelaxOutputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1160 bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes(
1161 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1162 if (output_handle_shapes_and_types_[idx] == nullptr) {
1163 output_handle_shapes_and_types_[idx].reset(
1164 new std::vector<ShapeAndType>(shapes_and_types));
1165 return true;
1166 }
1167 return RelaxHandleShapesAndMergeTypes(
1168 shapes_and_types, output_handle_shapes_and_types_[idx].get());
1169 }
1170
RelaxInputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1171 bool InferenceContext::RelaxInputHandleShapesAndMergeTypes(
1172 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1173 if (input_handle_shapes_and_types_[idx] == nullptr) {
1174 input_handle_shapes_and_types_[idx].reset(
1175 new std::vector<ShapeAndType>(shapes_and_types));
1176 return true;
1177 }
1178 return RelaxHandleShapesAndMergeTypes(
1179 shapes_and_types, input_handle_shapes_and_types_[idx].get());
1180 }
1181
1182 // -----------------------------------------------------------------------------
1183 // ShapeManager
1184 // -----------------------------------------------------------------------------
ShapeManager()1185 InferenceContext::ShapeManager::ShapeManager() {}
~ShapeManager()1186 InferenceContext::ShapeManager::~ShapeManager() {
1187 for (auto* s : all_shapes_) delete s;
1188 for (auto* d : all_dims_) delete d;
1189 }
1190
MakeShape(const std::vector<DimensionHandle> & dims)1191 ShapeHandle InferenceContext::ShapeManager::MakeShape(
1192 const std::vector<DimensionHandle>& dims) {
1193 all_shapes_.push_back(new Shape(dims));
1194 return all_shapes_.back();
1195 }
1196
UnknownShape()1197 ShapeHandle InferenceContext::ShapeManager::UnknownShape() {
1198 all_shapes_.push_back(new Shape());
1199 return all_shapes_.back();
1200 }
1201
1202 } // namespace shape_inference
1203 } // namespace tensorflow
1204