• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16 
17 #include "tensorflow/core/framework/bounds_check.h"
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/strings/numbers.h"
23 #include "tensorflow/core/lib/strings/scanner.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 
26 namespace tensorflow {
27 namespace shape_inference {
28 
29 constexpr int32 InferenceContext::kUnknownRank;
30 constexpr int64 InferenceContext::kUnknownDim;
31 
32 // Same as above, but with PartialTensorShape instead of TensorShapeProto
InferenceContext(int graph_def_version,const AttrSlice & attrs,const OpDef & op_def,const std::vector<PartialTensorShape> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<PartialTensorShape> & input_tensors_as_shapes,const std::vector<std::unique_ptr<std::vector<std::pair<PartialTensorShape,DataType>>>> & input_handle_shapes_and_types)33 InferenceContext::InferenceContext(
34     int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
35     const std::vector<PartialTensorShape>& input_shapes,
36     const std::vector<const Tensor*>& input_tensors,
37     const std::vector<PartialTensorShape>& input_tensors_as_shapes,
38     const std::vector<
39         std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
40         input_handle_shapes_and_types)
41     : graph_def_version_(graph_def_version), attrs_(attrs) {
42   std::vector<ShapeHandle> input_tensors_as_shape_handles;
43   input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
44   for (const PartialTensorShape& p : input_tensors_as_shapes) {
45     ShapeHandle shape;
46     construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
47     if (!construction_status_.ok()) {
48       return;
49     }
50     input_tensors_as_shape_handles.push_back(shape);
51   }
52   PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
53   if (!construction_status_.ok()) return;
54   inputs_.reserve(input_shapes.size());
55   for (const PartialTensorShape& p : input_shapes) {
56     ShapeHandle shape;
57     construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
58     if (!construction_status_.ok()) {
59       return;
60     }
61     inputs_.push_back(shape);
62   }
63   std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
64       input_shapes.size());
65   for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) {
66     const auto& v = input_handle_shapes_and_types[i];
67     if (v == nullptr) {
68       continue;
69     }
70     handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
71     auto& new_v = *handle_data[i];
72     for (int j = 0; j < v->size(); ++j) {
73       const auto& p = (*v)[j];
74       construction_status_.Update(
75           MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape));
76       if (!construction_status_.ok()) {
77         return;
78       }
79       new_v[j].dtype = p.second;
80     }
81   }
82   PostInputInit(std::move(handle_data));
83 }
84 
InferenceContext(int graph_def_version,const AttrSlice & attrs,const OpDef & op_def,const std::vector<ShapeHandle> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes,std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types)85 InferenceContext::InferenceContext(
86     int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
87     const std::vector<ShapeHandle>& input_shapes,
88     const std::vector<const Tensor*>& input_tensors,
89     const std::vector<ShapeHandle>& input_tensors_as_shapes,
90     std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
91         input_handle_shapes_and_types)
92     : graph_def_version_(graph_def_version), attrs_(attrs) {
93   PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
94   if (!construction_status_.ok()) return;
95   inputs_ = input_shapes;
96 
97   PostInputInit(std::move(input_handle_shapes_and_types));
98 }
99 
~InferenceContext()100 InferenceContext::~InferenceContext() {}
101 
Run(const std::function<Status (shape_inference::InferenceContext * c)> & fn)102 Status InferenceContext::Run(
103     const std::function<Status(shape_inference::InferenceContext* c)>& fn) {
104   ForgetMerges();
105   Status s = fn(this);
106   if (!s.ok()) {
107     ForgetMerges();
108     return AttachContext(s);
109   }
110 #ifndef NDEBUG
111   for (int i = 0; i < num_outputs(); ++i) {
112     DCHECK(output(i).IsSet()) << i << " for " << attrs_.SummarizeNode();
113   }
114 #endif  // NDEBUG
115   return s;
116 }
117 
set_output(StringPiece output_name,const std::vector<ShapeHandle> & shapes)118 Status InferenceContext::set_output(StringPiece output_name,
119                                     const std::vector<ShapeHandle>& shapes) {
120   auto result = output_name_map_.find(output_name);
121   if (result == output_name_map_.end()) {
122     return errors::InvalidArgument("Unknown output name: ", output_name);
123   } else {
124     const int start = result->second.first;
125     const int size = result->second.second - start;
126     if (size != shapes.size()) {
127       return errors::InvalidArgument("Must have exactly ", shapes.size(),
128                                      " shapes.");
129     }
130     for (int i = 0; i < size; ++i) {
131       outputs_[i + start] = shapes[i];
132     }
133   }
134   return Status::OK();
135 }
136 
input(StringPiece input_name,std::vector<ShapeHandle> * output) const137 Status InferenceContext::input(StringPiece input_name,
138                                std::vector<ShapeHandle>* output) const {
139   const auto result = input_name_map_.find(input_name);
140   if (result == input_name_map_.end()) {
141     return errors::InvalidArgument("Unknown input name: ", input_name);
142   } else {
143     output->clear();
144     for (int i = result->second.first; i < result->second.second; ++i) {
145       output->push_back(inputs_[i]);
146     }
147   }
148   return Status::OK();
149 }
150 
output(StringPiece output_name,std::vector<ShapeHandle> * output) const151 Status InferenceContext::output(StringPiece output_name,
152                                 std::vector<ShapeHandle>* output) const {
153   const auto result = output_name_map_.find(output_name);
154   if (result == output_name_map_.end()) {
155     return errors::InvalidArgument("Unknown output name: ", output_name);
156   } else {
157     output->clear();
158     for (int i = result->second.first; i < result->second.second; ++i) {
159       output->push_back(outputs_[i]);
160     }
161   }
162   return Status::OK();
163 }
164 
PreInputInit(const OpDef & op_def,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes)165 void InferenceContext::PreInputInit(
166     const OpDef& op_def, const std::vector<const Tensor*>& input_tensors,
167     const std::vector<ShapeHandle>& input_tensors_as_shapes) {
168   input_tensors_ = input_tensors;
169   input_tensors_as_shapes_ = input_tensors_as_shapes;
170 
171   construction_status_ =
172       NameRangesForNode(attrs_, op_def, &input_name_map_, &output_name_map_);
173   if (!construction_status_.ok()) return;
174 
175   int num_outputs = 0;
176   for (const auto& e : output_name_map_) {
177     num_outputs = std::max(num_outputs, e.second.second);
178   }
179   outputs_.assign(num_outputs, nullptr);
180   output_handle_shapes_and_types_.resize(num_outputs);
181 }
182 
ExpandOutputs(int new_output_size)183 Status InferenceContext::ExpandOutputs(int new_output_size) {
184   if (new_output_size < outputs_.size()) {
185     return errors::InvalidArgument("Trying to reduce number of outputs of op.");
186   }
187   outputs_.resize(new_output_size, nullptr);
188   output_handle_shapes_and_types_.resize(new_output_size);
189   return Status::OK();
190 }
191 
PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data)192 void InferenceContext::PostInputInit(
193     std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
194   int num_inputs_from_node_def = 0;
195   for (const auto& e : input_name_map_) {
196     num_inputs_from_node_def =
197         std::max(num_inputs_from_node_def, e.second.second);
198   }
199 
200   // Allow passing empty shapes/dtypes to avoid changing every single test.
201   if (input_handle_data.empty()) {
202     input_handle_shapes_and_types_.resize(inputs_.size());
203   } else {
204     if (input_handle_data.size() != inputs_.size()) {
205       construction_status_ = errors::InvalidArgument(
206           "Wrong number of handle shapes passed; expected ", inputs_.size(),
207           " got ", input_handle_data.size());
208       return;
209     }
210     input_handle_shapes_and_types_ = std::move(input_handle_data);
211   }
212 
213   if (inputs_.size() != num_inputs_from_node_def) {
214     construction_status_ = errors::InvalidArgument(
215         "Wrong number of inputs passed: ", inputs_.size(), " while ",
216         num_inputs_from_node_def, " expected based on NodeDef");
217     return;
218   }
219 
220   CHECK_LE(input_tensors_.size(), inputs_.size());
221   input_tensors_.resize(inputs_.size());
222   requested_input_tensor_.resize(inputs_.size());
223   requested_input_tensor_as_partial_shape_.resize(inputs_.size());
224 }
225 
ShapeHandleToProto(ShapeHandle handle,TensorShapeProto * proto)226 void InferenceContext::ShapeHandleToProto(ShapeHandle handle,
227                                           TensorShapeProto* proto) {
228   if (!RankKnown(handle)) {
229     proto->set_unknown_rank(true);
230     return;
231   }
232 
233   for (int32 i = 0; i < Rank(handle); ++i) {
234     DimensionHandle dim = Dim(handle, i);
235     auto* dim_shape = proto->add_dim();
236     if (ValueKnown(dim)) {
237       dim_shape->set_size(Value(dim));
238     } else {
239       dim_shape->set_size(-1);
240     }
241   }
242 }
243 
FullyDefined(ShapeHandle s)244 bool InferenceContext::FullyDefined(ShapeHandle s) {
245   if (!RankKnown(s)) return false;
246   for (int i = 0; i < Rank(s); ++i) {
247     if (!ValueKnown(Dim(s, i))) return false;
248   }
249   return true;
250 }
251 
NumElements(ShapeHandle s)252 DimensionHandle InferenceContext::NumElements(ShapeHandle s) {
253   const auto rank = Rank(s);
254   if (rank == kUnknownRank) return UnknownDim();
255   bool found_unknown = false;
256   int64 size = 1;
257   for (int i = 0; i < rank; ++i) {
258     int64 dim_val = Value(Dim(s, i));
259     if (dim_val == kUnknownDim) {
260       found_unknown = true;
261     } else if (dim_val == 0) {
262       return MakeDim(0);
263     } else {
264       size *= dim_val;
265     }
266   }
267   if (found_unknown) {
268     return UnknownDim();
269   } else {
270     return MakeDim(size);
271   }
272 }
273 
DebugString(ShapeHandle s)274 string InferenceContext::DebugString(ShapeHandle s) {
275   if (RankKnown(s)) {
276     std::vector<string> vals;
277     for (auto d : s->dims_) vals.push_back(DebugString(d));
278     return strings::StrCat("[", absl::StrJoin(vals, ","), "]");
279   } else {
280     return "?";
281   }
282 }
283 
DebugString(DimensionHandle d)284 string InferenceContext::DebugString(DimensionHandle d) {
285   return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
286 }
287 
DebugString() const288 string InferenceContext::DebugString() const {
289   return strings::StrCat("InferenceContext for node: ", attrs_.SummarizeNode());
290 }
291 
DebugString(const ShapeAndType & shape_and_type)292 string InferenceContext::DebugString(const ShapeAndType& shape_and_type) {
293   return strings::StrCat(DebugString(shape_and_type.shape), ":",
294                          DataTypeString(shape_and_type.dtype));
295 }
296 
DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types)297 string InferenceContext::DebugString(
298     gtl::ArraySlice<ShapeAndType> shape_and_types) {
299   std::vector<string> pieces;
300   for (const ShapeAndType& s : shape_and_types) {
301     pieces.push_back(DebugString(s));
302   }
303   return strings::StrCat("[", absl::StrJoin(pieces, ","), "]");
304 }
305 
WithRank(ShapeHandle shape,int64 rank,ShapeHandle * out)306 Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
307                                   ShapeHandle* out) {
308   if (rank > kint32max) {
309     return errors::InvalidArgument("Rank cannot exceed kint32max");
310   }
311   const int32 existing = Rank(shape);
312   if (existing == rank) {
313     *out = shape;
314     return Status::OK();
315   }
316   if (existing == kUnknownRank) {
317     std::vector<DimensionHandle> dims;
318     dims.reserve(rank);
319     for (int i = 0; i < rank; ++i) {
320       dims.push_back(UnknownDim());
321     }
322     ShapeHandle shp = shape_manager_.MakeShape(dims);
323     return Merge(shape, shp, out);
324   }
325   *out = nullptr;
326 
327   return errors::InvalidArgument("Shape must be rank ", rank, " but is rank ",
328                                  existing);
329 }
330 
WithRankAtLeast(ShapeHandle shape,int64 rank,ShapeHandle * out)331 Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank,
332                                          ShapeHandle* out) {
333   if (rank > kint32max) {
334     return errors::InvalidArgument("Rank cannot exceed kint32max");
335   }
336   const int32 existing = Rank(shape);
337   if (existing >= rank || existing == kUnknownRank) {
338     *out = shape;
339     return Status::OK();
340   }
341   *out = nullptr;
342   return errors::InvalidArgument("Shape must be at least rank ", rank,
343                                  " but is rank ", existing);
344 }
345 
WithRankAtMost(ShapeHandle shape,int64 rank,ShapeHandle * out)346 Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank,
347                                         ShapeHandle* out) {
348   if (rank > kint32max) {
349     return errors::InvalidArgument("Rank cannot exceed kint32max");
350   }
351   const int32 existing = Rank(shape);
352   if (existing <= rank || existing == kUnknownRank) {
353     *out = shape;
354     return Status::OK();
355   }
356   *out = nullptr;
357   return errors::InvalidArgument("Shape must be at most rank ", rank,
358                                  " but is rank ", existing);
359 }
360 
WithValue(DimensionHandle dim,int64 value,DimensionHandle * out)361 Status InferenceContext::WithValue(DimensionHandle dim, int64 value,
362                                    DimensionHandle* out) {
363   const int64 existing = Value(dim);
364   if (existing == value) {
365     *out = dim;
366     return Status::OK();
367   }
368   if (existing == kUnknownDim) {
369     DimensionHandle d = MakeDim(value);
370     return Merge(dim, d, out);
371   }
372   *out = nullptr;
373   return errors::InvalidArgument("Dimension must be ", value, " but is ",
374                                  existing);
375 }
376 
Relax(DimensionHandle d_old,DimensionHandle d_new,DimensionHandle * out)377 void InferenceContext::Relax(DimensionHandle d_old, DimensionHandle d_new,
378                              DimensionHandle* out) {
379   if (d_old.SameHandle(d_new)) {
380     *out = d_old;
381   } else if (!ValueKnown(d_old) && !ValueKnown(d_new)) {
382     // The node will be fed by the dimension d_new instead of d_old: any
383     // equality assertion between d_old and other input dimension on this node
384     // may not be true anymore, so forget them all.
385     ForgetMerges();
386     // Return the new shape handle to force the relaxation to propagate to the
387     // fanout of the context.
388     *out = d_new;
389   } else if (!ValueKnown(d_new)) {
390     ForgetMerges();
391     *out = d_new;
392   } else if (Value(d_old) == Value(d_new)) {
393     // Return the old shape handle. This will stop the relaxation in the fanout
394     // of the context.
395     *out = d_old;
396   } else {
397     // Return a new handle that encodes a different unknown dim.
398     ForgetMerges();
399     *out = UnknownDim();
400   }
401 }
402 
Merge(DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)403 Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1,
404                                DimensionHandle* out) {
405   if (d0.SameHandle(d1)) {
406     *out = d0;
407     return Status::OK();
408   } else if (!ValueKnown(d1)) {
409     *out = d0;
410     merged_dims_.emplace_back(d0, d1);
411     return Status::OK();
412   } else if (!ValueKnown(d0)) {
413     *out = d1;
414     merged_dims_.emplace_back(d0, d1);
415     return Status::OK();
416   } else if (Value(d0) == Value(d1)) {
417     *out = d0;
418     return Status::OK();
419   } else {
420     *out = nullptr;
421     return errors::InvalidArgument("Dimensions must be equal, but are ",
422                                    Value(d0), " and ", Value(d1));
423   }
424 }
425 
MergePrefix(ShapeHandle s,ShapeHandle prefix,ShapeHandle * s_out,ShapeHandle * prefix_out)426 Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix,
427                                      ShapeHandle* s_out,
428                                      ShapeHandle* prefix_out) {
429   *s_out = *prefix_out = nullptr;
430   if (!RankKnown(prefix) || !RankKnown(s)) {
431     *s_out = s;
432     *prefix_out = prefix;
433     return Status::OK();
434   }
435   const int32 rank = Rank(prefix);
436   TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s));
437 
438   // Merge the prefix dims and create the new output shapes.
439   const int32 rank_s = Rank(s);
440   std::vector<DimensionHandle> dims;
441   dims.reserve(std::max(rank, rank_s));
442   dims.resize(rank);
443   for (int i = 0; i < rank; ++i) {
444     TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i]));
445   }
446   *prefix_out = MakeShape(dims);
447   for (int i = rank; i < rank_s; ++i) dims.push_back(Dim(s, i));
448   *s_out = MakeShape(dims);
449   return Status::OK();
450 }
451 
Relax(ShapeHandle s_old,ShapeHandle s_new,ShapeHandle * out)452 void InferenceContext::Relax(ShapeHandle s_old, ShapeHandle s_new,
453                              ShapeHandle* out) {
454   if (s_old.SameHandle(s_new)) {
455     *out = s_old;
456     return;
457   } else if (!RankKnown(s_new) || !s_old.IsSet()) {
458     ForgetMerges();
459     *out = s_new;
460     return;
461   }
462 
463   const int32 rank = Rank(s_old);
464   if (rank != Rank(s_new)) {
465     ForgetMerges();
466     *out = UnknownShape();
467     return;
468   }
469 
470   bool return_s_old = true;
471   for (int i = 0; i < rank; ++i) {
472     auto d0 = Dim(s_old, i);
473     auto d1 = Dim(s_new, i);
474     if (d0.SameHandle(d1)) continue;
475 
476     auto v0 = Value(d0);
477     auto v1 = Value(d1);
478     if (v0 == kUnknownDim || v1 == kUnknownDim || v0 != v1) {
479       return_s_old = false;
480       break;
481     }
482   }
483   if (return_s_old) {
484     *out = s_old;
485     return;
486   }
487 
488   // Relax dims.
489   std::vector<DimensionHandle> dims(rank);
490   for (int i = 0; i < rank; ++i) {
491     Relax(Dim(s_old, i), Dim(s_new, i), &dims[i]);
492   }
493   ForgetMerges();
494   *out = MakeShape(dims);
495 }
496 
Merge(ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)497 Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
498                                ShapeHandle* out) {
499   if (s0.SameHandle(s1)) {
500     *out = s0;
501     return Status::OK();
502   } else if (!RankKnown(s1)) {
503     *out = s0;
504     merged_shapes_.emplace_back(s0, s1);
505     return Status::OK();
506   } else if (!RankKnown(s0)) {
507     *out = s1;
508     merged_shapes_.emplace_back(s0, s1);
509     return Status::OK();
510   }
511 
512   const int32 rank = Rank(s0);
513   if (rank != Rank(s1)) {
514     *out = nullptr;
515     return errors::InvalidArgument("Shapes must be equal rank, but are ", rank,
516                                    " and ", Rank(s1));
517   }
518 
519   bool return_s0 = true;
520   bool return_s1 = true;
521   for (int i = 0; i < rank; ++i) {
522     auto d0 = Dim(s0, i);
523     auto d1 = Dim(s1, i);
524     if (d0.SameHandle(d1)) continue;
525 
526     auto v0 = Value(d0);
527     auto v1 = Value(d1);
528     if (v0 == kUnknownDim) {
529       if (v1 != kUnknownDim) {
530         return_s0 = false;
531       }
532     } else if (v1 == kUnknownDim) {
533       return_s1 = false;
534     } else if (v0 != v1) {
535       *out = nullptr;
536       return errors::InvalidArgument(
537           "Dimension ", i, " in both shapes must be equal, but are ", Value(d0),
538           " and ", Value(d1), ". Shapes are ", DebugString(s0), " and ",
539           DebugString(s1), ".");
540     }
541   }
542 
543   merged_shapes_.emplace_back(s0, s1);
544 
545   if (return_s0 || return_s1) {
546     *out = return_s0 ? s0 : s1;
547     return Status::OK();
548   }
549 
550   // Merge dims.
551   std::vector<DimensionHandle> dims(rank, nullptr);
552   for (int i = 0; i < rank; ++i) {
553     // Invariant for merge was checked earlier, so CHECK is ok.
554     TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i]));
555   }
556 
557   Status s = ReturnCreatedShape(dims, out);
558   if (s.ok()) {
559     // Merge the new shape with s0. Since s0 and s1 are merged, this implies
560     // that s1 and out are also merged.
561     merged_shapes_.emplace_back(s0, *out);
562   }
563   return s;
564 }
565 
Subshape(ShapeHandle s,int64 start,ShapeHandle * out)566 Status InferenceContext::Subshape(ShapeHandle s, int64 start,
567                                   ShapeHandle* out) {
568   return Subshape(s, start, std::numeric_limits<int64>::max() /* end */, out);
569 }
570 
Subshape(ShapeHandle s,int64 start,int64 end,ShapeHandle * out)571 Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
572                                   ShapeHandle* out) {
573   return Subshape(s, start, end, 1 /* stride */, out);
574 }
575 
Subshape(ShapeHandle s,int64 start,int64 end,int64 stride,ShapeHandle * out)576 Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
577                                   int64 stride, ShapeHandle* out) {
578   int64 start_in = start;
579   int64 end_in = end;
580 
581   const int32 rank = Rank(s);
582   if (start == 0 && stride == 1 &&
583       ((RankKnown(s) && end >= rank) ||
584        end == std::numeric_limits<int64>::max())) {
585     *out = s;
586     return Status::OK();
587   }
588   if (!RankKnown(s)) {
589     return ReturnUnknownShape(out);
590   }
591 
592   if (start > rank) start = rank;
593   if (end > rank) end = rank;
594 
595   if (stride < 0 && start == rank) --start;
596 
597   if (start < 0) {
598     start = rank + start;
599     if (start < 0) {
600       *out = nullptr;
601       return errors::InvalidArgument("Subshape start out of bounds: ", start_in,
602                                      ", for shape with rank ", rank);
603     }
604   }
605 
606   if (end < 0) {
607     end = rank + end;
608     if (end < 0) {
609       *out = nullptr;
610       return errors::InvalidArgument("Subshape end out of bounds: ", end_in,
611                                      ", for shape with rank ", rank);
612     }
613   }
614   if (stride > 0 && start > end) {
615     *out = nullptr;
616     return errors::InvalidArgument(
617         "Subshape must have computed start <= end, but is ", start, " and ",
618         end, " (computed from start ", start_in, " and end ", end_in,
619         " over shape with rank ", rank, ")");
620   } else if (stride < 0 && start < end) {
621     *out = nullptr;
622     return errors::InvalidArgument(
623         "Subshape must have computed start >= end since stride is negative, "
624         "but is ",
625         start, " and ", end, " (computed from start ", start_in, " and end ",
626         end_in, " over shape with rank ", rank, " and stride", stride, ")");
627   }
628 
629   std::vector<DimensionHandle> dims;
630   for (int i = start; stride > 0 ? i < end : i > end; i += stride) {
631     dims.push_back(Dim(s, i));
632   }
633   return ReturnCreatedShape(dims, out);
634 }
635 
Concatenate(ShapeHandle s1,ShapeHandle s2,ShapeHandle * out)636 Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2,
637                                      ShapeHandle* out) {
638   if (!RankKnown(s1) || !RankKnown(s2)) {
639     return ReturnUnknownShape(out);
640   }
641   const int32 s1_rank = Rank(s1);
642   const int32 s2_rank = Rank(s2);
643   const int32 rank = s1_rank + s2_rank;
644   std::vector<DimensionHandle> dims;
645   dims.reserve(rank);
646   for (int i = 0; i < s1_rank; ++i) dims.push_back(Dim(s1, i));
647   for (int i = 0; i < s2_rank; ++i) dims.push_back(Dim(s2, i));
648   return ReturnCreatedShape(dims, out);
649 }
650 
ReplaceDim(ShapeHandle s,int64 dim_index_in,DimensionHandle new_dim,ShapeHandle * out)651 Status InferenceContext::ReplaceDim(ShapeHandle s, int64 dim_index_in,
652                                     DimensionHandle new_dim, ShapeHandle* out) {
653   if (!RankKnown(s)) {
654     return ReturnUnknownShape(out);
655   }
656   int64 dim_index = dim_index_in;
657   if (dim_index < 0) {
658     dim_index = s->dims_.size() + dim_index;
659   }
660   if (!FastBoundsCheck(dim_index, s->dims_.size())) {
661     *out = nullptr;
662     return errors::InvalidArgument("Out of range dim_index ", dim_index_in,
663                                    " for shape with ", s->dims_.size(),
664                                    " dimensions");
665   }
666   std::vector<DimensionHandle> dims(s->dims_);
667   dims[dim_index] = new_dim;
668   return ReturnCreatedShape(dims, out);
669 }
670 
MakeShape(const std::vector<DimensionHandle> & dims)671 ShapeHandle InferenceContext::MakeShape(
672     const std::vector<DimensionHandle>& dims) {
673   return shape_manager_.MakeShape(dims);
674 }
675 
MakeShape(std::initializer_list<DimensionOrConstant> dims)676 ShapeHandle InferenceContext::MakeShape(
677     std::initializer_list<DimensionOrConstant> dims) {
678   std::vector<DimensionHandle> dims_actual;
679   dims_actual.reserve(dims.size());
680   for (const DimensionOrConstant& d : dims) {
681     dims_actual.push_back(MakeDim(d));
682   }
683 
684   return shape_manager_.MakeShape(dims_actual);
685 }
686 
UnknownShape()687 ShapeHandle InferenceContext::UnknownShape() {
688   return shape_manager_.UnknownShape();
689 }
690 
UnknownShapeOfRank(int64 rank)691 ShapeHandle InferenceContext::UnknownShapeOfRank(int64 rank) {
692   CHECK_LE(rank, kint32max) << "rank must be less than kint32max";
693   if (rank == kUnknownRank) {
694     return UnknownShape();
695   }
696   CHECK_GE(rank, 0) << "rank must not be negative";
697   std::vector<DimensionHandle> dims(rank);
698   for (int32 i = 0; i < rank; ++i) {
699     dims[i] = UnknownDim();
700   }
701   return MakeShape(dims);
702 }
703 
Scalar()704 ShapeHandle InferenceContext::Scalar() { return MakeShape({}); }
705 
Vector(DimensionOrConstant dim)706 ShapeHandle InferenceContext::Vector(DimensionOrConstant dim) {
707   return MakeShape({dim});
708 }
709 
Matrix(DimensionOrConstant dim1,DimensionOrConstant dim2)710 ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
711                                      DimensionOrConstant dim2) {
712   return MakeShape({dim1, dim2});
713 }
714 
MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,ShapeHandle * out)715 Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
716     int input_idx, ShapeHandle* out) {
717   ShapeHandle input_shape;
718   TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
719 
720   requested_input_tensor_as_partial_shape_[input_idx] = true;
721   if (input_idx < input_tensors_as_shapes_.size() &&
722       input_tensors_as_shapes_[input_idx].IsSet() &&
723       RankKnown(input_tensors_as_shapes_[input_idx])) {
724     *out = input_tensors_as_shapes_[input_idx];
725     return Status::OK();
726   }
727 
728   return InternalMakeShapeFromTensor(
729       true /* treat_unknown_scalar_tensor_as_unknown_shape */,
730       input_tensor(input_idx), input_shape, out);
731 }
732 
MakeShapeFromShapeTensor(int input_idx,ShapeHandle * out)733 Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
734                                                   ShapeHandle* out) {
735   ShapeHandle input_shape;
736   TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
737 
738   requested_input_tensor_as_partial_shape_[input_idx] = true;
739   if (input_idx < input_tensors_as_shapes_.size() &&
740       input_tensors_as_shapes_[input_idx].IsSet() &&
741       RankKnown(input_tensors_as_shapes_[input_idx])) {
742     *out = input_tensors_as_shapes_[input_idx];
743     return Status::OK();
744   }
745 
746   return InternalMakeShapeFromTensor(
747       false /* treat_unknown_scalar_tensor_as_unknown_shape */,
748       input_tensor(input_idx), input_shape, out);
749 }
750 
MakeShapeFromTensor(const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)751 Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
752                                              ShapeHandle tensor_shape,
753                                              ShapeHandle* out) {
754   return InternalMakeShapeFromTensor(
755       false /* treat_unknown_scalar_tensor_as_unknown_shape */, t, tensor_shape,
756       out);
757 }
758 
InternalMakeShapeFromTensor(bool treat_unknown_scalar_tensor_as_unknown_shape,const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)759 Status InferenceContext::InternalMakeShapeFromTensor(
760     bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
761     ShapeHandle tensor_shape, ShapeHandle* out) {
762   // Only callers who have set
763   if (!treat_unknown_scalar_tensor_as_unknown_shape) {
764     TF_RETURN_IF_ERROR(WithRank(tensor_shape, 1, &tensor_shape));
765   }
766   if (t == nullptr) {
767     // This is guarded by the check above.
768     if (Rank(tensor_shape) == 0) {
769       return ReturnUnknownShape(out);
770     }
771     // Shape tensor is not known, but if the shape of the shape tensor is then
772     // the right number of unknown dims can be created.
773     DimensionHandle shape_dim = Dim(tensor_shape, 0);
774     if (!ValueKnown(shape_dim)) {
775       return ReturnUnknownShape(out);
776     }
777     const auto num_dims = Value(shape_dim);
778     std::vector<DimensionHandle> dims;
779     dims.reserve(num_dims);
780     for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim());
781     return ReturnCreatedShape(dims, out);
782   }
783 
784   if (t->shape().dims() == 0) {
785     if (t->dtype() == DataType::DT_INT32) {
786       auto flat_t = t->scalar<int32>();
787       if (flat_t() != -1) {
788         *out = nullptr;
789         return errors::InvalidArgument(
790             "Input tensor must be rank 1, or if its rank 0 it must have value "
791             "-1 "
792             "(representing an unknown shape).  Saw value: ",
793             flat_t());
794       }
795       return ReturnUnknownShape(out);
796     } else if (t->dtype() == DataType::DT_INT64) {
797       auto flat_t = t->scalar<int64>();
798       if (flat_t() != -1) {
799         *out = nullptr;
800         return errors::InvalidArgument(
801             "Input tensor must be rank 1, or if its rank 0 it must have value "
802             "-1 "
803             "(representing an unknown shape).  Saw value: ",
804             flat_t());
805       }
806       return ReturnUnknownShape(out);
807     } else {
808       *out = nullptr;
809       return errors::InvalidArgument(
810           "Input tensor must be int32 or int64, but was ",
811           DataTypeString(t->dtype()));
812     }
813   }
814 
815   if (t->shape().dims() != 1) {
816     *out = nullptr;
817     return errors::InvalidArgument(
818         "Input tensor must be rank 1, but was rank ", t->shape().dims(), ".",
819         ((t->shape().dims() == 0)
820              ? "If it is rank 0 rank 0 it must have statically known value -1 "
821                "(representing an unknown shape). "
822              : " "),
823         "Saw tensor shape ", t->shape().DebugString());
824   }
825   std::vector<DimensionHandle> dims;
826   if (t->dtype() == DataType::DT_INT32) {
827     auto flat_t = t->flat<int32>();
828     for (int i = 0; i < flat_t.size(); ++i) {
829       const int32 val = flat_t(i);
830       if (val < -1) {
831         return errors::InvalidArgument(
832             "Invalid value in tensor used for shape: ", val);
833       }
834       // -1 will become an unknown dim.
835       dims.push_back(MakeDim(val));
836     }
837   } else if (t->dtype() == DataType::DT_INT64) {
838     auto flat_t = t->flat<int64>();
839     for (int i = 0; i < flat_t.size(); ++i) {
840       const int64 val = flat_t(i);
841       if (val < -1) {
842         return errors::InvalidArgument(
843             "Invalid value in tensor used for shape: ", val);
844       }
845       // -1 will become an unknown dim.
846       dims.push_back(MakeDim(val));
847     }
848   } else {
849     *out = nullptr;
850     return errors::InvalidArgument(
851         "Input tensor must be int32 or int64, but was ",
852         DataTypeString(t->dtype()));
853   }
854 
855   return ReturnCreatedShape(dims, out);
856 }
857 
MakeShapeFromPartialTensorShape(const PartialTensorShape & partial_shape,ShapeHandle * out)858 Status InferenceContext::MakeShapeFromPartialTensorShape(
859     const PartialTensorShape& partial_shape, ShapeHandle* out) {
860   *out = nullptr;
861   if (partial_shape.dims() == -1) {
862     return ReturnUnknownShape(out);
863   }
864   const int num_dims = partial_shape.dims();
865   std::vector<DimensionHandle> dims(num_dims);
866   for (int i = 0; i < num_dims; ++i) {
867     // -1 is unknown in PartialTensorShape and in InferenceContext, so this size
868     // can be passed directly to MakeDim.
869     dims[i] = MakeDim(partial_shape.dim_size(i));
870   }
871   return ReturnCreatedShape(dims, out);
872 }
873 
MakeShapeFromTensorShape(const TensorShape & shape,ShapeHandle * out)874 Status InferenceContext::MakeShapeFromTensorShape(const TensorShape& shape,
875                                                   ShapeHandle* out) {
876   return MakeShapeFromPartialTensorShape(PartialTensorShape(shape.dim_sizes()),
877                                          out);
878 }
879 
MakeShapeFromShapeProto(const TensorShapeProto & proto,ShapeHandle * out)880 Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
881                                                  ShapeHandle* out) {
882   *out = nullptr;
883   TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto));
884   PartialTensorShape partial_shape(proto);
885   return MakeShapeFromPartialTensorShape(partial_shape, out);
886 }
887 
GetScalarFromTensor(const Tensor * t,int64 * val)888 Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
889   // Caller must ensure that <t> is not NULL.
890   const int rank = t->dims();
891   if (rank != 0) {
892     return errors::InvalidArgument("Input must be scalar but has rank ", rank);
893   }
894 
895   if (t->dtype() == DT_INT32) {
896     *val = t->scalar<int32>()();
897     return Status::OK();
898   } else if (t->dtype() == DT_INT64) {
899     *val = t->scalar<int64>()();
900     return Status::OK();
901   } else {
902     return errors::InvalidArgument("Scalar input must be int32 or int64.");
903   }
904 }
905 
906 // Returns a new dimension whose value is given by a scalar input tensor.
MakeDimForScalarInput(int idx,DimensionHandle * out)907 Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
908   int64 val;
909   const Tensor* t = input_tensor(idx);
910   if (t == nullptr) {
911     *out = UnknownDim();
912     return Status::OK();
913   }
914   TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
915   if (val < 0) {
916     return errors::InvalidArgument("Dimension size, given by scalar input ",
917                                    idx, ", must be non-negative but is ", val);
918   }
919   *out = MakeDim(val);
920   return Status::OK();
921 }
922 
MakeDimForScalarInputWithNegativeIndexing(int idx,int input_rank,DimensionHandle * out)923 Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing(
924     int idx, int input_rank, DimensionHandle* out) {
925   int64 val;
926   const Tensor* t = input_tensor(idx);
927   if (t == nullptr) {
928     *out = UnknownDim();
929     return Status::OK();
930   }
931   TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
932   if (val < 0) {
933     if (input_rank < 0) {
934       *out = UnknownDim();
935       return Status::OK();
936     } else if (val + input_rank < 0) {
937       return errors::InvalidArgument("Dimension size, given by scalar input ",
938                                      val, " must be in range [-", input_rank,
939                                      ", ", input_rank, ")");
940     } else {
941       val += input_rank;
942     }
943   } else if (input_rank >= 0 && val >= input_rank) {
944     return errors::InvalidArgument("Dimension size, given by scalar input ",
945                                    val, " must be in range [-", input_rank,
946                                    ", ", input_rank, ")");
947   }
948   *out = MakeDim(val);
949   return Status::OK();
950 }
951 
Divide(DimensionHandle dividend,DimensionOrConstant divisor,bool evenly_divisible,DimensionHandle * out)952 Status InferenceContext::Divide(DimensionHandle dividend,
953                                 DimensionOrConstant divisor,
954                                 bool evenly_divisible, DimensionHandle* out) {
955   const int64 divisor_value = Value(divisor);
956   if (divisor_value == 1) {
957     *out = dividend;
958   } else if (!ValueKnown(dividend) ||
959              (divisor.dim.IsSet() && !ValueKnown(divisor.dim))) {
960     *out = UnknownDim();
961   } else {
962     const int64 v = Value(dividend);
963     if (divisor_value <= 0) {
964       return errors::InvalidArgument("Divisor must be positive but is ",
965                                      divisor_value);
966     }
967     if (evenly_divisible && (v % divisor_value) != 0) {
968       return errors::InvalidArgument(
969           "Dimension size must be evenly divisible by ", divisor_value,
970           " but is ", v);
971     }
972     *out = MakeDim(v / divisor_value);
973   }
974   return Status::OK();
975 }
976 
Add(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)977 Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second,
978                              DimensionHandle* out) {
979   const int64 first_value = Value(first);
980   const int64 second_value = Value(second);
981   // Special cases.
982   if (first_value == 0) {
983     *out = MakeDim(second);
984   } else if (second_value == 0) {
985     *out = first;
986   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
987     *out = UnknownDim();
988   } else {
989     // Invariant: Both values are known and positive. Still in run-time we can
990     // get pair of values which cannot be store in output. Check below will
991     // report error. We still need to avoid undefined behavior of signed
992     // overflow and use unsigned addition.
993     const int64 sum = static_cast<uint64>(first_value) + second_value;
994     if (sum < 0) {
995       return errors::InvalidArgument("Dimension size overflow from adding ",
996                                      first_value, " and ", second_value);
997     }
998     *out = MakeDim(sum);
999   }
1000   return Status::OK();
1001 }
1002 
Subtract(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1003 Status InferenceContext::Subtract(DimensionHandle first,
1004                                   DimensionOrConstant second,
1005                                   DimensionHandle* out) {
1006   const int64 first_value = Value(first);
1007   const int64 second_value = Value(second);
1008   // Special cases.
1009   if (second_value == 0) {
1010     *out = first;
1011   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1012     *out = UnknownDim();
1013   } else {
1014     // Invariant: Both values are known, first_value is non-negative, and
1015     // second_value is positive.
1016     if (first_value < second_value) {
1017       return errors::InvalidArgument(
1018           "Negative dimension size caused by subtracting ", second_value,
1019           " from ", first_value);
1020     }
1021     *out = MakeDim(first_value - second_value);
1022   }
1023   return Status::OK();
1024 }
1025 
Multiply(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1026 Status InferenceContext::Multiply(DimensionHandle first,
1027                                   DimensionOrConstant second,
1028                                   DimensionHandle* out) {
1029   const int64 first_value = Value(first);
1030   const int64 second_value = Value(second);
1031   // Special cases.
1032   if (first_value == 0) {
1033     *out = first;
1034   } else if (second_value == 0) {
1035     *out = MakeDim(second);
1036   } else if (first_value == 1) {
1037     *out = MakeDim(second);
1038   } else if (second_value == 1) {
1039     *out = first;
1040   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1041     *out = UnknownDim();
1042   } else {
1043     // Invariant: Both values are known and greater than 1.
1044     const int64 product = first_value * second_value;
1045     if (product < 0) {
1046       return errors::InvalidArgument(
1047           "Negative dimension size caused by overflow when multiplying ",
1048           first_value, " and ", second_value);
1049     }
1050     *out = MakeDim(product);
1051   }
1052   return Status::OK();
1053 }
1054 
Min(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1055 Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second,
1056                              DimensionHandle* out) {
1057   const int64 first_value = Value(first);
1058   const int64 second_value = Value(second);
1059   if (first_value == 0) {
1060     *out = first;
1061   } else if (second_value == 0) {
1062     *out = MakeDim(second);
1063   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1064     *out = UnknownDim();
1065   } else {
1066     if (first_value <= second_value) {
1067       *out = first;
1068     } else {
1069       *out = MakeDim(second);
1070     }
1071   }
1072   return Status::OK();
1073 }
1074 
Max(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1075 Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second,
1076                              DimensionHandle* out) {
1077   const int64 first_value = Value(first);
1078   const int64 second_value = Value(second);
1079   if (first_value == kUnknownDim || second_value == kUnknownDim) {
1080     *out = UnknownDim();
1081   } else {
1082     if (first_value >= second_value) {
1083       *out = first;
1084     } else {
1085       *out = MakeDim(second);
1086     }
1087   }
1088   return Status::OK();
1089 }
1090 
AttachContext(const Status & status)1091 Status InferenceContext::AttachContext(const Status& status) {
1092   std::vector<string> input_shapes;
1093   input_shapes.reserve(inputs_.size());
1094   for (const ShapeHandle& input_shape : inputs_) {
1095     input_shapes.emplace_back(DebugString(input_shape));
1096   }
1097 
1098   // Add information about the input tensors and partial tensor shapes used.
1099   std::vector<string> input_from_tensors_str;
1100   std::vector<string> input_from_tensors_as_shape_str;
1101   input_from_tensors_as_shape_str.reserve(inputs_.size());
1102   for (int i = 0; i < inputs_.size(); ++i) {
1103     if (requested_input_tensor_as_partial_shape_[i] &&
1104         i < input_tensors_as_shapes_.size() &&
1105         input_tensors_as_shapes_[i].IsSet() &&
1106         RankKnown(input_tensors_as_shapes_[i])) {
1107       input_from_tensors_as_shape_str.push_back(strings::StrCat(
1108           "input[", i, "] = ", DebugString(input_tensors_as_shapes_[i])));
1109     } else if (requested_input_tensor_[i] && i < input_tensors_.size() &&
1110                input_tensors_[i] != nullptr) {
1111       input_from_tensors_str.push_back(strings::StrCat(
1112           "input[", i, "] = <",
1113           input_tensors_[i]->SummarizeValue(256 /* max_values */), ">"));
1114     }
1115   }
1116 
1117   string error_context = strings::StrCat(
1118       " for '", attrs_.SummarizeNode(),
1119       "' with input shapes: ", absl::StrJoin(input_shapes, ", "));
1120   if (!input_from_tensors_str.empty()) {
1121     strings::StrAppend(&error_context, " and with computed input tensors: ",
1122                        absl::StrJoin(input_from_tensors_str, ", "));
1123   }
1124   if (!input_from_tensors_as_shape_str.empty()) {
1125     strings::StrAppend(&error_context,
1126                        " and with input tensors computed as partial shapes: ",
1127                        absl::StrJoin(input_from_tensors_as_shape_str, ","));
1128   }
1129 
1130   strings::StrAppend(&error_context, ".");
1131   return Status(status.code(),
1132                 strings::StrCat(status.error_message(), error_context));
1133 }
1134 
MergeHandleShapesAndTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1135 bool InferenceContext::MergeHandleShapesAndTypes(
1136     const std::vector<ShapeAndType>& shapes_and_types,
1137     std::vector<ShapeAndType>* to_update) {
1138   if (shapes_and_types.size() != to_update->size()) {
1139     return false;
1140   }
1141   std::vector<ShapeAndType> new_values(shapes_and_types.size());
1142   bool refined = false;
1143   for (int i = 0; i < shapes_and_types.size(); ++i) {
1144     const ShapeAndType& existing = (*to_update)[i];
1145     if (shapes_and_types[i].dtype == existing.dtype) {
1146       new_values[i].dtype = existing.dtype;
1147     } else {
1148       if (existing.dtype != DT_INVALID) {
1149         return false;
1150       } else {
1151         new_values[i].dtype = shapes_and_types[i].dtype;
1152         refined = true;
1153       }
1154     }
1155     if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape)
1156              .ok()) {
1157       // merge failed, ignore the new value.
1158       new_values[i].shape = existing.shape;
1159     }
1160     if (!existing.shape.SameHandle(new_values[i].shape)) {
1161       refined = true;
1162     }
1163   }
1164   if (!refined) {
1165     return false;
1166   }
1167   for (int i = 0; i < new_values.size(); ++i) {
1168     (*to_update)[i] = new_values[i];
1169   }
1170   return true;
1171 }
1172 
MergeOutputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1173 bool InferenceContext::MergeOutputHandleShapesAndTypes(
1174     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1175   if (output_handle_shapes_and_types_[idx] == nullptr) {
1176     output_handle_shapes_and_types_[idx].reset(
1177         new std::vector<ShapeAndType>(shapes_and_types));
1178     return true;
1179   }
1180   return MergeHandleShapesAndTypes(shapes_and_types,
1181                                    output_handle_shapes_and_types_[idx].get());
1182 }
1183 
MergeInputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1184 bool InferenceContext::MergeInputHandleShapesAndTypes(
1185     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1186   if (input_handle_shapes_and_types_[idx] == nullptr) {
1187     input_handle_shapes_and_types_[idx].reset(
1188         new std::vector<ShapeAndType>(shapes_and_types));
1189     return true;
1190   }
1191   return MergeHandleShapesAndTypes(shapes_and_types,
1192                                    input_handle_shapes_and_types_[idx].get());
1193 }
1194 
RelaxHandleShapesAndMergeTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1195 bool InferenceContext::RelaxHandleShapesAndMergeTypes(
1196     const std::vector<ShapeAndType>& shapes_and_types,
1197     std::vector<ShapeAndType>* to_update) {
1198   if (shapes_and_types.size() != to_update->size()) {
1199     return false;
1200   }
1201   std::vector<ShapeAndType> new_values(shapes_and_types.size());
1202   for (int i = 0; i < shapes_and_types.size(); ++i) {
1203     const ShapeAndType& existing = (*to_update)[i];
1204     if (shapes_and_types[i].dtype == existing.dtype) {
1205       new_values[i].dtype = existing.dtype;
1206     } else {
1207       if (existing.dtype != DT_INVALID) {
1208         return false;
1209       } else {
1210         new_values[i].dtype = shapes_and_types[i].dtype;
1211       }
1212     }
1213     Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape);
1214   }
1215   to_update->swap(new_values);
1216   return true;
1217 }
1218 
RelaxOutputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1219 bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes(
1220     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1221   if (output_handle_shapes_and_types_[idx] == nullptr) {
1222     output_handle_shapes_and_types_[idx].reset(
1223         new std::vector<ShapeAndType>(shapes_and_types));
1224     return true;
1225   }
1226   return RelaxHandleShapesAndMergeTypes(
1227       shapes_and_types, output_handle_shapes_and_types_[idx].get());
1228 }
1229 
RelaxInputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1230 bool InferenceContext::RelaxInputHandleShapesAndMergeTypes(
1231     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1232   if (input_handle_shapes_and_types_[idx] == nullptr) {
1233     input_handle_shapes_and_types_[idx].reset(
1234         new std::vector<ShapeAndType>(shapes_and_types));
1235     return true;
1236   }
1237   return RelaxHandleShapesAndMergeTypes(
1238       shapes_and_types, input_handle_shapes_and_types_[idx].get());
1239 }
1240 
1241 // -----------------------------------------------------------------------------
1242 // ShapeManager
1243 // -----------------------------------------------------------------------------
ShapeManager()1244 InferenceContext::ShapeManager::ShapeManager() {}
~ShapeManager()1245 InferenceContext::ShapeManager::~ShapeManager() {
1246   for (auto* s : all_shapes_) delete s;
1247   for (auto* d : all_dims_) delete d;
1248 }
1249 
MakeShape(const std::vector<DimensionHandle> & dims)1250 ShapeHandle InferenceContext::ShapeManager::MakeShape(
1251     const std::vector<DimensionHandle>& dims) {
1252   all_shapes_.push_back(new Shape(dims));
1253   return all_shapes_.back();
1254 }
1255 
UnknownShape()1256 ShapeHandle InferenceContext::ShapeManager::UnknownShape() {
1257   all_shapes_.push_back(new Shape());
1258   return all_shapes_.back();
1259 }
1260 
1261 }  // namespace shape_inference
1262 }  // namespace tensorflow
1263