• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16 
17 #include "tensorflow/core/framework/node_def.pb_text.h"
18 #include "tensorflow/core/framework/partial_tensor_shape.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/kernels/bounds_check.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/strings/numbers.h"
23 #include "tensorflow/core/lib/strings/scanner.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 
26 namespace tensorflow {
27 namespace shape_inference {
28 
29 constexpr int32 InferenceContext::kUnknownRank;
30 constexpr int64 InferenceContext::kUnknownDim;
31 
InferenceContext(int graph_def_version,const NodeDef * node_def,const OpDef & op_def,const std::vector<TensorShapeProto> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<TensorShapeProto> & input_tensors_as_shapes,const std::vector<std::unique_ptr<std::vector<std::pair<TensorShapeProto,DataType>>>> & input_handle_shapes_and_types)32 InferenceContext::InferenceContext(
33     int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
34     const std::vector<TensorShapeProto>& input_shapes,
35     const std::vector<const Tensor*>& input_tensors,
36     const std::vector<TensorShapeProto>& input_tensors_as_shapes,
37     const std::vector<
38         std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
39         input_handle_shapes_and_types)
40     : graph_def_version_(graph_def_version),
41       node_def_(CHECK_NOTNULL(node_def)) {
42   std::vector<ShapeHandle> input_tensors_as_shape_handles;
43   for (const TensorShapeProto& p : input_tensors_as_shapes) {
44     ShapeHandle shape;
45     construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
46     if (!construction_status_.ok()) {
47       return;
48     }
49     input_tensors_as_shape_handles.push_back(shape);
50   }
51   PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
52   if (!construction_status_.ok()) return;
53   for (const TensorShapeProto& p : input_shapes) {
54     ShapeHandle shape;
55     construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
56     if (!construction_status_.ok()) {
57       return;
58     }
59     inputs_.push_back(shape);
60   }
61 
62   std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
63       input_shapes.size());
64   for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) {
65     const auto& v = input_handle_shapes_and_types[i];
66     if (v == nullptr) {
67       continue;
68     }
69     handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
70     auto& new_v = *handle_data[i];
71     for (int j = 0; j < v->size(); ++j) {
72       const auto& p = (*v)[j];
73       construction_status_.Update(
74           MakeShapeFromShapeProto(p.first, &new_v[j].shape));
75       if (!construction_status_.ok()) {
76         return;
77       }
78       new_v[j].dtype = p.second;
79     }
80   }
81   PostInputInit(std::move(handle_data));
82 }
83 
84 // Same as above, but with PartialTensorShape instead of TensorShapeProto
InferenceContext(int graph_def_version,const NodeDef * node_def,const OpDef & op_def,const std::vector<PartialTensorShape> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<PartialTensorShape> & input_tensors_as_shapes,const std::vector<std::unique_ptr<std::vector<std::pair<PartialTensorShape,DataType>>>> & input_handle_shapes_and_types)85 InferenceContext::InferenceContext(
86     int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
87     const std::vector<PartialTensorShape>& input_shapes,
88     const std::vector<const Tensor*>& input_tensors,
89     const std::vector<PartialTensorShape>& input_tensors_as_shapes,
90     const std::vector<
91         std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
92         input_handle_shapes_and_types)
93     : graph_def_version_(graph_def_version),
94       node_def_(CHECK_NOTNULL(node_def)) {
95   std::vector<ShapeHandle> input_tensors_as_shape_handles;
96   for (const PartialTensorShape& p : input_tensors_as_shapes) {
97     ShapeHandle shape;
98     construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
99     if (!construction_status_.ok()) {
100       return;
101     }
102     input_tensors_as_shape_handles.push_back(shape);
103   }
104   PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
105   if (!construction_status_.ok()) return;
106   for (const PartialTensorShape& p : input_shapes) {
107     ShapeHandle shape;
108     construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
109     if (!construction_status_.ok()) {
110       return;
111     }
112     inputs_.push_back(shape);
113   }
114   std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
115       input_shapes.size());
116   for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) {
117     const auto& v = input_handle_shapes_and_types[i];
118     if (v == nullptr) {
119       continue;
120     }
121     handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
122     auto& new_v = *handle_data[i];
123     for (int j = 0; j < v->size(); ++j) {
124       const auto& p = (*v)[j];
125       construction_status_.Update(
126           MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape));
127       if (!construction_status_.ok()) {
128         return;
129       }
130       new_v[j].dtype = p.second;
131     }
132   }
133   PostInputInit(std::move(handle_data));
134 }
135 
InferenceContext(int graph_def_version,const NodeDef * node_def,const OpDef & op_def,const std::vector<ShapeHandle> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes,std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types)136 InferenceContext::InferenceContext(
137     int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
138     const std::vector<ShapeHandle>& input_shapes,
139     const std::vector<const Tensor*>& input_tensors,
140     const std::vector<ShapeHandle>& input_tensors_as_shapes,
141     std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
142         input_handle_shapes_and_types)
143     : graph_def_version_(graph_def_version),
144       node_def_(CHECK_NOTNULL(node_def)) {
145   PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
146   if (!construction_status_.ok()) return;
147   inputs_ = input_shapes;
148 
149   PostInputInit(std::move(input_handle_shapes_and_types));
150 }
151 
~InferenceContext()152 InferenceContext::~InferenceContext() {}
153 
Run(const std::function<Status (shape_inference::InferenceContext * c)> & fn)154 Status InferenceContext::Run(
155     const std::function<Status(shape_inference::InferenceContext* c)>& fn) {
156   Status s = fn(this);
157   if (!s.ok()) {
158     return AttachContext(s);
159   }
160 #ifndef NDEBUG
161   for (int i = 0; i < num_outputs(); ++i) {
162     DCHECK(output(i).IsSet())
163         << i << " for " << node_def_->name() << " of type " << node_def_->op();
164   }
165 #endif  // NDEBUG
166   return s;
167 }
168 
set_output(StringPiece output_name,const std::vector<ShapeHandle> & shapes)169 Status InferenceContext::set_output(StringPiece output_name,
170                                     const std::vector<ShapeHandle>& shapes) {
171   auto result = output_name_map_.find(output_name);
172   if (result == output_name_map_.end()) {
173     return errors::InvalidArgument("Unknown output name: ", output_name);
174   } else {
175     const int start = result->second.first;
176     const int size = result->second.second - start;
177     if (size != shapes.size()) {
178       return errors::InvalidArgument("Must have exactly ", shapes.size(),
179                                      " shapes.");
180     }
181     for (int i = 0; i < size; ++i) {
182       outputs_[i + start] = shapes[i];
183     }
184   }
185   return Status::OK();
186 }
187 
input(StringPiece input_name,std::vector<ShapeHandle> * output) const188 Status InferenceContext::input(StringPiece input_name,
189                                std::vector<ShapeHandle>* output) const {
190   const auto result = input_name_map_.find(input_name);
191   if (result == input_name_map_.end()) {
192     return errors::InvalidArgument("Unknown input name: ", input_name);
193   } else {
194     output->clear();
195     for (int i = result->second.first; i < result->second.second; ++i) {
196       output->push_back(inputs_[i]);
197     }
198   }
199   return Status::OK();
200 }
201 
output(StringPiece output_name,std::vector<ShapeHandle> * output) const202 Status InferenceContext::output(StringPiece output_name,
203                                 std::vector<ShapeHandle>* output) const {
204   const auto result = output_name_map_.find(output_name);
205   if (result == output_name_map_.end()) {
206     return errors::InvalidArgument("Unknown output name: ", output_name);
207   } else {
208     output->clear();
209     for (int i = result->second.first; i < result->second.second; ++i) {
210       output->push_back(outputs_[i]);
211     }
212   }
213   return Status::OK();
214 }
215 
op() const216 string InferenceContext::op() const { return node_def_->op(); }
217 
PreInputInit(const OpDef & op_def,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes)218 void InferenceContext::PreInputInit(
219     const OpDef& op_def, const std::vector<const Tensor*>& input_tensors,
220     const std::vector<ShapeHandle>& input_tensors_as_shapes) {
221   input_tensors_ = input_tensors;
222   input_tensors_as_shapes_ = input_tensors_as_shapes;
223 
224   construction_status_ = NameRangesForNode(*node_def_, op_def, &input_name_map_,
225                                            &output_name_map_);
226   if (!construction_status_.ok()) return;
227 
228   int num_outputs = 0;
229   for (const auto& e : output_name_map_) {
230     num_outputs = std::max(num_outputs, e.second.second);
231   }
232   for (int i = 0; i < num_outputs; ++i) {
233     outputs_.push_back(nullptr);
234   }
235   output_handle_shapes_and_types_.resize(num_outputs);
236 }
237 
PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data)238 void InferenceContext::PostInputInit(
239     std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
240   int num_inputs_from_node_def = 0;
241   for (const auto& e : input_name_map_) {
242     num_inputs_from_node_def =
243         std::max(num_inputs_from_node_def, e.second.second);
244   }
245 
246   // Allow passing empty shapes/dtypes to avoid changing every single test.
247   if (input_handle_data.empty()) {
248     input_handle_shapes_and_types_.resize(inputs_.size());
249   } else {
250     if (input_handle_data.size() != inputs_.size()) {
251       construction_status_ = errors::InvalidArgument(
252           "Wrong number of handle shapes passed; expected ", inputs_.size(),
253           " got ", input_handle_data.size());
254       return;
255     }
256     input_handle_shapes_and_types_ = std::move(input_handle_data);
257   }
258 
259   if (inputs_.size() != num_inputs_from_node_def) {
260     construction_status_ = errors::InvalidArgument(
261         "Wrong number of inputs passed: ", inputs_.size(), " while ",
262         num_inputs_from_node_def, " expected based on NodeDef");
263     return;
264   }
265 
266   CHECK_LE(input_tensors_.size(), inputs_.size());
267   input_tensors_.resize(inputs_.size());
268   requested_input_tensor_.resize(inputs_.size());
269   requested_input_tensor_as_partial_shape_.resize(inputs_.size());
270 }
271 
ShapeHandleToProto(ShapeHandle handle,TensorShapeProto * proto)272 void InferenceContext::ShapeHandleToProto(ShapeHandle handle,
273                                           TensorShapeProto* proto) {
274   if (!RankKnown(handle)) {
275     proto->set_unknown_rank(true);
276     return;
277   }
278 
279   for (int32 i = 0; i < Rank(handle); ++i) {
280     DimensionHandle dim = Dim(handle, i);
281     auto* dim_shape = proto->add_dim();
282     if (ValueKnown(dim)) {
283       dim_shape->set_size(Value(dim));
284     } else {
285       dim_shape->set_size(-1);
286     }
287   }
288 }
289 
FullyDefined(ShapeHandle s)290 bool InferenceContext::FullyDefined(ShapeHandle s) {
291   if (!RankKnown(s)) return false;
292   for (int i = 0; i < Rank(s); ++i) {
293     if (!ValueKnown(Dim(s, i))) return false;
294   }
295   return true;
296 }
297 
NumElements(ShapeHandle s)298 DimensionHandle InferenceContext::NumElements(ShapeHandle s) {
299   const auto rank = Rank(s);
300   if (rank == kUnknownRank) return UnknownDim();
301   int64 size = 1;
302   for (int i = 0; i < rank; ++i) {
303     int64 dim_val = Value(Dim(s, i));
304     if (dim_val == kUnknownDim) return UnknownDim();
305     size *= dim_val;
306   }
307   return MakeDim(size);
308 }
309 
DebugString(ShapeHandle s)310 string InferenceContext::DebugString(ShapeHandle s) {
311   if (RankKnown(s)) {
312     std::vector<string> vals;
313     for (auto d : s->dims_) vals.push_back(DebugString(d));
314     return strings::StrCat("[", str_util::Join(vals, ","), "]");
315   } else {
316     return "?";
317   }
318 }
319 
DebugString(DimensionHandle d)320 string InferenceContext::DebugString(DimensionHandle d) {
321   return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
322 }
323 
DebugString() const324 string InferenceContext::DebugString() const {
325   return strings::StrCat("InferenceContext for node: ",
326                          ProtoDebugString(*node_def_));
327 }
328 
WithRank(ShapeHandle shape,int64 rank,ShapeHandle * out)329 Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
330                                   ShapeHandle* out) {
331   if (rank > kint32max) {
332     return errors::InvalidArgument("Rank cannot exceed kint32max");
333   }
334   const int32 existing = Rank(shape);
335   if (existing == rank) {
336     *out = shape;
337     return Status::OK();
338   }
339   if (existing == kUnknownRank) {
340     std::vector<DimensionHandle> dims;
341     dims.reserve(rank);
342     for (int i = 0; i < rank; ++i) {
343       dims.push_back(UnknownDim());
344     }
345     ShapeHandle shp = shape_manager_.MakeShape(dims);
346     return Merge(shape, shp, out);
347   }
348   *out = nullptr;
349 
350   return errors::InvalidArgument("Shape must be rank ", rank, " but is rank ",
351                                  existing);
352 }
353 
WithRankAtLeast(ShapeHandle shape,int64 rank,ShapeHandle * out)354 Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank,
355                                          ShapeHandle* out) {
356   if (rank > kint32max) {
357     return errors::InvalidArgument("Rank cannot exceed kint32max");
358   }
359   const int32 existing = Rank(shape);
360   if (existing >= rank || existing == kUnknownRank) {
361     *out = shape;
362     return Status::OK();
363   }
364   *out = nullptr;
365   return errors::InvalidArgument("Shape must be at least rank ", rank,
366                                  " but is rank ", existing);
367 }
368 
WithRankAtMost(ShapeHandle shape,int64 rank,ShapeHandle * out)369 Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank,
370                                         ShapeHandle* out) {
371   if (rank > kint32max) {
372     return errors::InvalidArgument("Rank cannot exceed kint32max");
373   }
374   const int32 existing = Rank(shape);
375   if (existing <= rank || existing == kUnknownRank) {
376     *out = shape;
377     return Status::OK();
378   }
379   *out = nullptr;
380   return errors::InvalidArgument("Shape must be at most rank ", rank,
381                                  " but is rank ", existing);
382 }
383 
WithValue(DimensionHandle dim,int64 value,DimensionHandle * out)384 Status InferenceContext::WithValue(DimensionHandle dim, int64 value,
385                                    DimensionHandle* out) {
386   const int64 existing = Value(dim);
387   if (existing == value) {
388     *out = dim;
389     return Status::OK();
390   }
391   if (existing == kUnknownDim) {
392     DimensionHandle d = MakeDim(value);
393     return Merge(dim, d, out);
394   }
395   *out = nullptr;
396   return errors::InvalidArgument("Dimension must be ", value, " but is ",
397                                  existing);
398 }
399 
Relax(DimensionHandle d_old,DimensionHandle d_new,DimensionHandle * out)400 void InferenceContext::Relax(DimensionHandle d_old, DimensionHandle d_new,
401                              DimensionHandle* out) {
402   if (d_old.SameHandle(d_new)) {
403     *out = d_old;
404   } else if (!ValueKnown(d_old) && !ValueKnown(d_new)) {
405     // The node will be fed by the dimension d_new instead of d_old: any
406     // equality assertion between d_old and other input dimension on this node
407     // may not be true anymore, so forget them all.
408     ForgetMerges();
409     // Return the new shape handle to force the relaxation to propagate to the
410     // fanout of the context.
411     *out = d_new;
412   } else if (!ValueKnown(d_new)) {
413     ForgetMerges();
414     *out = d_new;
415   } else if (Value(d_old) == Value(d_new)) {
416     // Return the old shape handle. This will stop the relaxation in the fanout
417     // of the context.
418     *out = d_old;
419   } else {
420     // Return a new handle that encodes a different unknown dim.
421     ForgetMerges();
422     *out = UnknownDim();
423   }
424 }
425 
Merge(DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)426 Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1,
427                                DimensionHandle* out) {
428   if (d0.SameHandle(d1)) {
429     *out = d0;
430     return Status::OK();
431   } else if (!ValueKnown(d1)) {
432     *out = d0;
433     merged_dims_.emplace_back(d0, d1);
434     return Status::OK();
435   } else if (!ValueKnown(d0)) {
436     *out = d1;
437     merged_dims_.emplace_back(d0, d1);
438     return Status::OK();
439   } else if (Value(d0) == Value(d1)) {
440     *out = d0;
441     return Status::OK();
442   } else {
443     *out = nullptr;
444     return errors::InvalidArgument("Dimensions must be equal, but are ",
445                                    Value(d0), " and ", Value(d1));
446   }
447 }
448 
MergePrefix(ShapeHandle s,ShapeHandle prefix,ShapeHandle * s_out,ShapeHandle * prefix_out)449 Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix,
450                                      ShapeHandle* s_out,
451                                      ShapeHandle* prefix_out) {
452   *s_out = *prefix_out = nullptr;
453   if (!RankKnown(prefix) || !RankKnown(s)) {
454     *s_out = s;
455     *prefix_out = prefix;
456     return Status::OK();
457   }
458   const int32 rank = Rank(prefix);
459   TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s));
460 
461   // Merge the prefix dims and create the new output shapes.
462   std::vector<DimensionHandle> dims;
463   dims.resize(rank);
464   for (int i = 0; i < rank; ++i) {
465     TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i]));
466   }
467   *prefix_out = MakeShape(dims);
468   for (int i = rank; i < Rank(s); ++i) dims.push_back(Dim(s, i));
469   *s_out = MakeShape(dims);
470   return Status::OK();
471 }
472 
Relax(ShapeHandle s_old,ShapeHandle s_new,ShapeHandle * out)473 void InferenceContext::Relax(ShapeHandle s_old, ShapeHandle s_new,
474                              ShapeHandle* out) {
475   if (s_old.SameHandle(s_new)) {
476     *out = s_old;
477     return;
478   } else if (!RankKnown(s_new) || !s_old.IsSet()) {
479     ForgetMerges();
480     *out = s_new;
481     return;
482   }
483 
484   const int32 rank = Rank(s_old);
485   if (rank != Rank(s_new)) {
486     ForgetMerges();
487     *out = UnknownShape();
488     return;
489   }
490 
491   bool return_s_old = true;
492   for (int i = 0; i < rank; ++i) {
493     auto d0 = Dim(s_old, i);
494     auto d1 = Dim(s_new, i);
495     if (d0.SameHandle(d1)) continue;
496 
497     auto v0 = Value(d0);
498     auto v1 = Value(d1);
499     if (v0 == kUnknownDim || v1 == kUnknownDim || v0 != v1) {
500       return_s_old = false;
501       break;
502     }
503   }
504   if (return_s_old) {
505     *out = s_old;
506     return;
507   }
508 
509   // Relax dims.
510   std::vector<DimensionHandle> dims(rank);
511   for (int i = 0; i < rank; ++i) {
512     Relax(Dim(s_old, i), Dim(s_new, i), &dims[i]);
513   }
514   ForgetMerges();
515   *out = MakeShape(dims);
516 }
517 
Merge(ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)518 Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
519                                ShapeHandle* out) {
520   if (s0.SameHandle(s1)) {
521     *out = s0;
522     return Status::OK();
523   } else if (!RankKnown(s1)) {
524     *out = s0;
525     merged_shapes_.emplace_back(s0, s1);
526     return Status::OK();
527   } else if (!RankKnown(s0)) {
528     *out = s1;
529     merged_shapes_.emplace_back(s0, s1);
530     return Status::OK();
531   }
532 
533   const int32 rank = Rank(s0);
534   if (rank != Rank(s1)) {
535     *out = nullptr;
536     return errors::InvalidArgument("Shapes must be equal rank, but are ", rank,
537                                    " and ", Rank(s1));
538   }
539 
540   bool return_s0 = true;
541   bool return_s1 = true;
542   for (int i = 0; i < rank; ++i) {
543     auto d0 = Dim(s0, i);
544     auto d1 = Dim(s1, i);
545     if (d0.SameHandle(d1)) continue;
546 
547     auto v0 = Value(d0);
548     auto v1 = Value(d1);
549     if (v0 == kUnknownDim) {
550       if (v1 != kUnknownDim) {
551         return_s0 = false;
552       }
553     } else if (v1 == kUnknownDim) {
554       return_s1 = false;
555     } else if (v0 != v1) {
556       *out = nullptr;
557       return errors::InvalidArgument(
558           "Dimension ", i, " in both shapes must be equal, but are ", Value(d0),
559           " and ", Value(d1), ". Shapes are ", DebugString(s0), " and ",
560           DebugString(s1), ".");
561     }
562   }
563 
564   merged_shapes_.emplace_back(s0, s1);
565 
566   if (return_s0 || return_s1) {
567     *out = return_s0 ? s0 : s1;
568     return Status::OK();
569   }
570 
571   // Merge dims.
572   std::vector<DimensionHandle> dims(rank, nullptr);
573   for (int i = 0; i < rank; ++i) {
574     // Invariant for merge was checked earlier, so CHECK is ok.
575     TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i]));
576   }
577 
578   Status s = ReturnCreatedShape(dims, out);
579   if (s.ok()) {
580     // Merge the new shape with s0. Since s0 and s1 are merged, this implies
581     // that s1 and out are also merged.
582     merged_shapes_.emplace_back(s0, *out);
583   }
584   return s;
585 }
586 
Subshape(ShapeHandle s,int64 start,ShapeHandle * out)587 Status InferenceContext::Subshape(ShapeHandle s, int64 start,
588                                   ShapeHandle* out) {
589   return Subshape(s, start, std::numeric_limits<int64>::max() /* end */, out);
590 }
591 
Subshape(ShapeHandle s,int64 start_in,int64 end_in,ShapeHandle * out)592 Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in,
593                                   ShapeHandle* out) {
594   int64 start = start_in;
595   int64 end = end_in;
596   const int32 rank = Rank(s);
597   if (start == 0 && ((RankKnown(s) && end >= rank) ||
598                      end == std::numeric_limits<int64>::max())) {
599     *out = s;
600     return Status::OK();
601   }
602   if (!RankKnown(s)) {
603     return ReturnUnknownShape(out);
604   }
605 
606   if (start > rank) start = rank;
607   if (end > rank) end = rank;
608   if (start < 0) {
609     start = rank + start;
610     if (start < 0) {
611       *out = nullptr;
612       return errors::InvalidArgument("Subshape start out of bounds: ", start_in,
613                                      ", for shape with rank ", rank);
614     }
615   }
616 
617   if (end < 0) {
618     end = rank + end;
619     if (end < 0) {
620       *out = nullptr;
621       return errors::InvalidArgument("Subshape end out of bounds: ", end_in,
622                                      ", for shape with rank ", rank);
623     }
624   }
625   if (start > end) {
626     *out = nullptr;
627     return errors::InvalidArgument(
628         "Subshape must have computed start <= end, but is ", start, " and ",
629         end, " (computed from start ", start_in, " and end ", end_in,
630         " over shape with rank ", rank, ")");
631   }
632   std::vector<DimensionHandle> dims;
633   dims.reserve(end - start);
634   for (int i = start; i < end; ++i) {
635     dims.push_back(Dim(s, i));
636   }
637   return ReturnCreatedShape(dims, out);
638 }
639 
Concatenate(ShapeHandle s1,ShapeHandle s2,ShapeHandle * out)640 Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2,
641                                      ShapeHandle* out) {
642   if (!RankKnown(s1) || !RankKnown(s2)) {
643     return ReturnUnknownShape(out);
644   }
645   const int32 s1_rank = Rank(s1);
646   const int32 s2_rank = Rank(s2);
647   const int32 rank = s1_rank + s2_rank;
648   std::vector<DimensionHandle> dims;
649   dims.reserve(rank);
650   for (int i = 0; i < s1_rank; ++i) dims.push_back(Dim(s1, i));
651   for (int i = 0; i < s2_rank; ++i) dims.push_back(Dim(s2, i));
652   return ReturnCreatedShape(dims, out);
653 }
654 
ReplaceDim(ShapeHandle s,int64 dim_index_in,DimensionHandle new_dim,ShapeHandle * out)655 Status InferenceContext::ReplaceDim(ShapeHandle s, int64 dim_index_in,
656                                     DimensionHandle new_dim, ShapeHandle* out) {
657   if (!RankKnown(s)) {
658     return ReturnUnknownShape(out);
659   }
660   int64 dim_index = dim_index_in;
661   if (dim_index < 0) {
662     dim_index = s->dims_.size() + dim_index;
663   }
664   if (!FastBoundsCheck(dim_index, s->dims_.size())) {
665     *out = nullptr;
666     return errors::InvalidArgument("Out of range dim_index ", dim_index_in,
667                                    " for shape with ", s->dims_.size(),
668                                    " dimensions");
669   }
670   std::vector<DimensionHandle> dims(s->dims_);
671   dims[dim_index] = new_dim;
672   return ReturnCreatedShape(dims, out);
673 }
674 
MakeShape(const std::vector<DimensionHandle> & dims)675 ShapeHandle InferenceContext::MakeShape(
676     const std::vector<DimensionHandle>& dims) {
677   return shape_manager_.MakeShape(dims);
678 }
679 
MakeShape(std::initializer_list<DimensionOrConstant> dims)680 ShapeHandle InferenceContext::MakeShape(
681     std::initializer_list<DimensionOrConstant> dims) {
682   std::vector<DimensionHandle> dims_actual;
683   dims_actual.reserve(dims.size());
684   for (const DimensionOrConstant& d : dims) {
685     dims_actual.push_back(MakeDim(d));
686   }
687 
688   return shape_manager_.MakeShape(dims_actual);
689 }
690 
UnknownShape()691 ShapeHandle InferenceContext::UnknownShape() {
692   return shape_manager_.UnknownShape();
693 }
694 
UnknownShapeOfRank(int64 rank)695 ShapeHandle InferenceContext::UnknownShapeOfRank(int64 rank) {
696   CHECK_LE(rank, kint32max) << "rank must be less than kint32max";
697   if (rank == kUnknownRank) {
698     return UnknownShape();
699   }
700   CHECK_GE(rank, 0) << "rank must not be negative";
701   std::vector<DimensionHandle> dims(rank);
702   for (int32 i = 0; i < rank; ++i) {
703     dims[i] = UnknownDim();
704   }
705   return MakeShape(dims);
706 }
707 
Scalar()708 ShapeHandle InferenceContext::Scalar() { return MakeShape({}); }
709 
Vector(DimensionOrConstant dim)710 ShapeHandle InferenceContext::Vector(DimensionOrConstant dim) {
711   return MakeShape({dim});
712 }
713 
Matrix(DimensionOrConstant dim1,DimensionOrConstant dim2)714 ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
715                                      DimensionOrConstant dim2) {
716   return MakeShape({dim1, dim2});
717 }
718 
MakeShapeFromShapeTensor(int input_idx,ShapeHandle * out)719 Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
720                                                   ShapeHandle* out) {
721   ShapeHandle input_shape;
722   TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
723 
724   requested_input_tensor_as_partial_shape_[input_idx] = true;
725   if (input_idx < input_tensors_as_shapes_.size() &&
726       input_tensors_as_shapes_[input_idx].IsSet() &&
727       RankKnown(input_tensors_as_shapes_[input_idx])) {
728     *out = input_tensors_as_shapes_[input_idx];
729     return Status::OK();
730   }
731 
732   return MakeShapeFromTensor(input_tensor(input_idx), input_shape, out);
733 }
734 
MakeShapeFromTensor(const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)735 Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
736                                              ShapeHandle tensor_shape,
737                                              ShapeHandle* out) {
738   if (t == nullptr) {
739     // Shape tensor is not known, but if the shape of the shape tensor is then
740     // the right number of unknown dims can be created.
741     DimensionHandle shape_dim = Dim(tensor_shape, 0);
742     if (!ValueKnown(shape_dim)) {
743       return ReturnUnknownShape(out);
744     }
745     const auto num_dims = Value(shape_dim);
746     std::vector<DimensionHandle> dims;
747     dims.reserve(num_dims);
748     for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim());
749     return ReturnCreatedShape(dims, out);
750   }
751 
752   if (t->shape().dims() != 1) {
753     *out = nullptr;
754     return errors::InvalidArgument("Input tensor must be rank 1, but was rank ",
755                                    t->shape().dims());
756   }
757   std::vector<DimensionHandle> dims;
758   if (t->dtype() == DataType::DT_INT32) {
759     auto flat_t = t->flat<int32>();
760     for (int i = 0; i < flat_t.size(); ++i) {
761       const int32 val = flat_t(i);
762       if (val < -1) {
763         return errors::InvalidArgument(
764             "Invalid value in tensor used for shape: ", val);
765       }
766       // -1 will become an unknown dim.
767       dims.push_back(MakeDim(val));
768     }
769   } else if (t->dtype() == DataType::DT_INT64) {
770     auto flat_t = t->flat<int64>();
771     for (int i = 0; i < flat_t.size(); ++i) {
772       const int64 val = flat_t(i);
773       if (val < -1) {
774         return errors::InvalidArgument(
775             "Invalid value in tensor used for shape: ", val);
776       }
777       // -1 will become an unknown dim.
778       dims.push_back(MakeDim(val));
779     }
780   } else {
781     *out = nullptr;
782     return errors::InvalidArgument(
783         "Input tensor must be int32 or int64, but was ",
784         DataTypeString(t->dtype()));
785   }
786 
787   return ReturnCreatedShape(dims, out);
788 }
789 
MakeShapeFromPartialTensorShape(const PartialTensorShape & partial_shape,ShapeHandle * out)790 Status InferenceContext::MakeShapeFromPartialTensorShape(
791     const PartialTensorShape& partial_shape, ShapeHandle* out) {
792   *out = nullptr;
793   if (partial_shape.dims() == -1) {
794     return ReturnUnknownShape(out);
795   }
796   const int num_dims = partial_shape.dims();
797   std::vector<DimensionHandle> dims(num_dims);
798   for (int i = 0; i < num_dims; ++i) {
799     // -1 is unknown in PartialTensorShape and in InferenceContext, so this size
800     // can be passed directly to MakeDim.
801     dims[i] = MakeDim(partial_shape.dim_size(i));
802   }
803   return ReturnCreatedShape(dims, out);
804 }
805 
MakeShapeFromTensorShape(const TensorShape & shape,ShapeHandle * out)806 Status InferenceContext::MakeShapeFromTensorShape(const TensorShape& shape,
807                                                   ShapeHandle* out) {
808   return MakeShapeFromPartialTensorShape(PartialTensorShape(shape.dim_sizes()),
809                                          out);
810 }
811 
MakeShapeFromShapeProto(const TensorShapeProto & proto,ShapeHandle * out)812 Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
813                                                  ShapeHandle* out) {
814   *out = nullptr;
815   TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto));
816   PartialTensorShape partial_shape(proto);
817   return MakeShapeFromPartialTensorShape(partial_shape, out);
818 }
819 
GetScalarFromTensor(const Tensor * t,int64 * val)820 Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
821   // Caller must ensure that <t> is not NULL.
822   const int rank = t->dims();
823   if (rank != 0) {
824     return errors::InvalidArgument("Input must be scalar but has rank ", rank);
825   }
826 
827   if (t->dtype() == DT_INT32) {
828     *val = t->scalar<int32>()();
829     return Status::OK();
830   } else if (t->dtype() == DT_INT64) {
831     *val = t->scalar<int64>()();
832     return Status::OK();
833   } else {
834     return errors::InvalidArgument(
835         "Scalar input for dim size must be int32 or int64");
836   }
837 }
838 
839 // Returns a new dimension whose value is given by a scalar input tensor.
MakeDimForScalarInput(int idx,DimensionHandle * out)840 Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
841   int64 val;
842   const Tensor* t = input_tensor(idx);
843   if (t == nullptr) {
844     *out = UnknownDim();
845     return Status::OK();
846   }
847   TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
848   if (val < 0) {
849     return errors::InvalidArgument("Dimension size, given by scalar input ",
850                                    idx, ", must be non-negative but is ", val);
851   }
852   *out = MakeDim(val);
853   return Status::OK();
854 }
855 
MakeDimForScalarInputWithNegativeIndexing(int idx,int input_rank,DimensionHandle * out)856 Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing(
857     int idx, int input_rank, DimensionHandle* out) {
858   int64 val;
859   const Tensor* t = input_tensor(idx);
860   if (t == nullptr) {
861     *out = UnknownDim();
862     return Status::OK();
863   }
864   TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
865   if (val < 0) {
866     if (input_rank < 0) {
867       *out = UnknownDim();
868       return Status::OK();
869     } else if (val + input_rank < 0) {
870       return errors::InvalidArgument("Dimension size, given by scalar input ",
871                                      val, " must be in range [-", input_rank,
872                                      ", ", input_rank, ")");
873     } else {
874       val += input_rank;
875     }
876   } else if (input_rank >= 0 && val >= input_rank) {
877     return errors::InvalidArgument("Dimension size, given by scalar input ",
878                                    val, " must be in range [-", input_rank,
879                                    ", ", input_rank, ")");
880   }
881   *out = MakeDim(val);
882   return Status::OK();
883 }
884 
Divide(DimensionHandle dividend,DimensionOrConstant divisor,bool evenly_divisible,DimensionHandle * out)885 Status InferenceContext::Divide(DimensionHandle dividend,
886                                 DimensionOrConstant divisor,
887                                 bool evenly_divisible, DimensionHandle* out) {
888   const int64 divisor_value = Value(divisor);
889   if (divisor_value == 1) {
890     *out = dividend;
891   } else if (!ValueKnown(dividend) ||
892              (divisor.dim.IsSet() && !ValueKnown(divisor.dim))) {
893     *out = UnknownDim();
894   } else {
895     const int64 v = Value(dividend);
896     if (divisor_value <= 0) {
897       return errors::InvalidArgument("Divisor must be positive but is ",
898                                      divisor_value);
899     }
900     if (evenly_divisible && (v % divisor_value) != 0) {
901       return errors::InvalidArgument(
902           "Dimension size must be evenly divisible by ", divisor_value,
903           " but is ", v);
904     }
905     *out = MakeDim(v / divisor_value);
906   }
907   return Status::OK();
908 }
909 
Add(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)910 Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second,
911                              DimensionHandle* out) {
912   const int64 first_value = Value(first);
913   const int64 second_value = Value(second);
914   // Special cases.
915   if (first_value == 0) {
916     *out = MakeDim(second);
917   } else if (second_value == 0) {
918     *out = first;
919   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
920     *out = UnknownDim();
921   } else {
922     // Invariant: Both values are known and positive. Still in run-time we can
923     // get pair of values which cannot be store in output. Check below will
924     // report error. We still need to avoid undefined behavior of signed
925     // overflow and use unsigned addition.
926     const int64 sum = static_cast<uint64>(first_value) + second_value;
927     if (sum < 0) {
928       return errors::InvalidArgument("Dimension size overflow from adding ",
929                                      first_value, " and ", second_value);
930     }
931     *out = MakeDim(sum);
932   }
933   return Status::OK();
934 }
935 
Subtract(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)936 Status InferenceContext::Subtract(DimensionHandle first,
937                                   DimensionOrConstant second,
938                                   DimensionHandle* out) {
939   const int64 first_value = Value(first);
940   const int64 second_value = Value(second);
941   // Special cases.
942   if (second_value == 0) {
943     *out = first;
944   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
945     *out = UnknownDim();
946   } else {
947     // Invariant: Both values are known, first_value is non-negative, and
948     // second_value is positive.
949     if (first_value < second_value) {
950       return errors::InvalidArgument(
951           "Negative dimension size caused by subtracting ", second_value,
952           " from ", first_value);
953     }
954     *out = MakeDim(first_value - second_value);
955   }
956   return Status::OK();
957 }
958 
Multiply(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)959 Status InferenceContext::Multiply(DimensionHandle first,
960                                   DimensionOrConstant second,
961                                   DimensionHandle* out) {
962   const int64 first_value = Value(first);
963   const int64 second_value = Value(second);
964   // Special cases.
965   if (first_value == 0) {
966     *out = first;
967   } else if (second_value == 0) {
968     *out = MakeDim(second);
969   } else if (first_value == 1) {
970     *out = MakeDim(second);
971   } else if (second_value == 1) {
972     *out = first;
973   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
974     *out = UnknownDim();
975   } else {
976     // Invariant: Both values are known and greater than 1.
977     const int64 product = first_value * second_value;
978     if (product < 0) {
979       return errors::InvalidArgument(
980           "Negative dimension size caused by overflow when multiplying ",
981           first_value, " and ", second_value);
982     }
983     *out = MakeDim(product);
984   }
985   return Status::OK();
986 }
987 
Min(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)988 Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second,
989                              DimensionHandle* out) {
990   const int64 first_value = Value(first);
991   const int64 second_value = Value(second);
992   if (first_value == 0) {
993     *out = first;
994   } else if (second_value == 0) {
995     *out = MakeDim(second);
996   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
997     *out = UnknownDim();
998   } else {
999     if (first_value <= second_value) {
1000       *out = first;
1001     } else {
1002       *out = MakeDim(second);
1003     }
1004   }
1005   return Status::OK();
1006 }
1007 
Max(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1008 Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second,
1009                              DimensionHandle* out) {
1010   const int64 first_value = Value(first);
1011   const int64 second_value = Value(second);
1012   if (first_value == kUnknownDim || second_value == kUnknownDim) {
1013     *out = UnknownDim();
1014   } else {
1015     if (first_value >= second_value) {
1016       *out = first;
1017     } else {
1018       *out = MakeDim(second);
1019     }
1020   }
1021   return Status::OK();
1022 }
1023 
AttachContext(const Status & status)1024 Status InferenceContext::AttachContext(const Status& status) {
1025   std::vector<string> input_shapes;
1026   for (const ShapeHandle& input_shape : inputs_) {
1027     input_shapes.emplace_back(DebugString(input_shape));
1028   }
1029 
1030   // Add information about the input tensors and partial tensor shapes used.
1031   std::vector<string> input_from_tensors_str;
1032   std::vector<string> input_from_tensors_as_shape_str;
1033   for (int i = 0; i < inputs_.size(); ++i) {
1034     if (requested_input_tensor_as_partial_shape_[i] &&
1035         i < input_tensors_as_shapes_.size() &&
1036         input_tensors_as_shapes_[i].IsSet() &&
1037         RankKnown(input_tensors_as_shapes_[i])) {
1038       input_from_tensors_as_shape_str.push_back(strings::StrCat(
1039           "input[", i, "] = ", DebugString(input_tensors_as_shapes_[i])));
1040     } else if (requested_input_tensor_[i] && i < input_tensors_.size() &&
1041                input_tensors_[i] != nullptr) {
1042       input_from_tensors_str.push_back(strings::StrCat(
1043           "input[", i, "] = <",
1044           input_tensors_[i]->SummarizeValue(256 /* max_values */), ">"));
1045     }
1046   }
1047 
1048   string error_context = strings::StrCat(
1049       " for '", node_def_->name(), "' (op: '", node_def_->op(),
1050       "') with input shapes: ", str_util::Join(input_shapes, ", "));
1051   if (!input_from_tensors_str.empty()) {
1052     strings::StrAppend(&error_context, " and with computed input tensors: ",
1053                        str_util::Join(input_from_tensors_str, ", "));
1054   }
1055   if (!input_from_tensors_as_shape_str.empty()) {
1056     strings::StrAppend(&error_context,
1057                        " and with input tensors computed as partial shapes: ",
1058                        str_util::Join(input_from_tensors_as_shape_str, ","));
1059   }
1060 
1061   strings::StrAppend(&error_context, ".");
1062   return Status(status.code(),
1063                 strings::StrCat(status.error_message(), error_context));
1064 }
1065 
MergeHandleShapesAndTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1066 bool InferenceContext::MergeHandleShapesAndTypes(
1067     const std::vector<ShapeAndType>& shapes_and_types,
1068     std::vector<ShapeAndType>* to_update) {
1069   if (shapes_and_types.size() != to_update->size()) {
1070     return false;
1071   }
1072   std::vector<ShapeAndType> new_values(shapes_and_types.size());
1073   bool refined = false;
1074   for (int i = 0; i < shapes_and_types.size(); ++i) {
1075     const ShapeAndType& existing = (*to_update)[i];
1076     if (shapes_and_types[i].dtype == existing.dtype) {
1077       new_values[i].dtype = existing.dtype;
1078     } else {
1079       if (existing.dtype != DT_INVALID) {
1080         return false;
1081       } else {
1082         new_values[i].dtype = shapes_and_types[i].dtype;
1083         refined = true;
1084       }
1085     }
1086     if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape)
1087              .ok()) {
1088       // merge failed, ignore the new value.
1089       new_values[i].shape = existing.shape;
1090     }
1091     if (!existing.shape.SameHandle(new_values[i].shape)) {
1092       refined = true;
1093     }
1094   }
1095   if (!refined) {
1096     return false;
1097   }
1098   for (int i = 0; i < new_values.size(); ++i) {
1099     (*to_update)[i] = new_values[i];
1100   }
1101   return true;
1102 }
1103 
MergeOutputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1104 bool InferenceContext::MergeOutputHandleShapesAndTypes(
1105     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1106   if (output_handle_shapes_and_types_[idx] == nullptr) {
1107     output_handle_shapes_and_types_[idx].reset(
1108         new std::vector<ShapeAndType>(shapes_and_types));
1109     return true;
1110   }
1111   return MergeHandleShapesAndTypes(shapes_and_types,
1112                                    output_handle_shapes_and_types_[idx].get());
1113 }
1114 
MergeInputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1115 bool InferenceContext::MergeInputHandleShapesAndTypes(
1116     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1117   if (input_handle_shapes_and_types_[idx] == nullptr) {
1118     input_handle_shapes_and_types_[idx].reset(
1119         new std::vector<ShapeAndType>(shapes_and_types));
1120     return true;
1121   }
1122   return MergeHandleShapesAndTypes(shapes_and_types,
1123                                    input_handle_shapes_and_types_[idx].get());
1124 }
1125 
RelaxHandleShapesAndMergeTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1126 bool InferenceContext::RelaxHandleShapesAndMergeTypes(
1127     const std::vector<ShapeAndType>& shapes_and_types,
1128     std::vector<ShapeAndType>* to_update) {
1129   if (shapes_and_types.size() != to_update->size()) {
1130     return false;
1131   }
1132   std::vector<ShapeAndType> new_values(shapes_and_types.size());
1133   bool refined = false;
1134   for (int i = 0; i < shapes_and_types.size(); ++i) {
1135     const ShapeAndType& existing = (*to_update)[i];
1136     if (shapes_and_types[i].dtype == existing.dtype) {
1137       new_values[i].dtype = existing.dtype;
1138     } else {
1139       if (existing.dtype != DT_INVALID) {
1140         return false;
1141       } else {
1142         new_values[i].dtype = shapes_and_types[i].dtype;
1143         refined = true;
1144       }
1145     }
1146     Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape);
1147     if (!existing.shape.SameHandle(new_values[i].shape)) {
1148       refined = true;
1149     }
1150   }
1151   if (!refined) {
1152     return false;
1153   }
1154   for (int i = 0; i < new_values.size(); ++i) {
1155     (*to_update)[i] = new_values[i];
1156   }
1157   return true;
1158 }
1159 
RelaxOutputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1160 bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes(
1161     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1162   if (output_handle_shapes_and_types_[idx] == nullptr) {
1163     output_handle_shapes_and_types_[idx].reset(
1164         new std::vector<ShapeAndType>(shapes_and_types));
1165     return true;
1166   }
1167   return RelaxHandleShapesAndMergeTypes(
1168       shapes_and_types, output_handle_shapes_and_types_[idx].get());
1169 }
1170 
RelaxInputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1171 bool InferenceContext::RelaxInputHandleShapesAndMergeTypes(
1172     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1173   if (input_handle_shapes_and_types_[idx] == nullptr) {
1174     input_handle_shapes_and_types_[idx].reset(
1175         new std::vector<ShapeAndType>(shapes_and_types));
1176     return true;
1177   }
1178   return RelaxHandleShapesAndMergeTypes(
1179       shapes_and_types, input_handle_shapes_and_types_[idx].get());
1180 }
1181 
1182 // -----------------------------------------------------------------------------
1183 // ShapeManager
1184 // -----------------------------------------------------------------------------
ShapeManager()1185 InferenceContext::ShapeManager::ShapeManager() {}
~ShapeManager()1186 InferenceContext::ShapeManager::~ShapeManager() {
1187   for (auto* s : all_shapes_) delete s;
1188   for (auto* d : all_dims_) delete d;
1189 }
1190 
MakeShape(const std::vector<DimensionHandle> & dims)1191 ShapeHandle InferenceContext::ShapeManager::MakeShape(
1192     const std::vector<DimensionHandle>& dims) {
1193   all_shapes_.push_back(new Shape(dims));
1194   return all_shapes_.back();
1195 }
1196 
UnknownShape()1197 ShapeHandle InferenceContext::ShapeManager::UnknownShape() {
1198   all_shapes_.push_back(new Shape());
1199   return all_shapes_.back();
1200 }
1201 
1202 }  // namespace shape_inference
1203 }  // namespace tensorflow
1204