• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16 
17 #include <cstdint>
18 
19 #include "tensorflow/core/framework/bounds_check.h"
20 #include "tensorflow/core/framework/full_type_util.h"
21 #include "tensorflow/core/framework/op_def.pb.h"
22 #include "tensorflow/core/framework/partial_tensor_shape.h"
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/util/overflow.h"
26 
27 namespace tensorflow {
28 namespace shape_inference {
29 
30 constexpr int32_t InferenceContext::kUnknownRank;
31 constexpr int64_t InferenceContext::kUnknownDim;
32 
33 // Same as above, but with PartialTensorShape instead of TensorShapeProto
InferenceContext(int graph_def_version,const AttrSlice & attrs,const OpDef & op_def,const std::vector<PartialTensorShape> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<PartialTensorShape> & input_tensors_as_shapes,const std::vector<std::unique_ptr<std::vector<std::pair<PartialTensorShape,DataType>>>> & input_handle_shapes_and_types)34 InferenceContext::InferenceContext(
35     int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
36     const std::vector<PartialTensorShape>& input_shapes,
37     const std::vector<const Tensor*>& input_tensors,
38     const std::vector<PartialTensorShape>& input_tensors_as_shapes,
39     const std::vector<
40         std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
41         input_handle_shapes_and_types)
42     : graph_def_version_(graph_def_version), attrs_(attrs) {
43   std::vector<ShapeHandle> input_tensors_as_shape_handles;
44   input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
45   for (const PartialTensorShape& p : input_tensors_as_shapes) {
46     ShapeHandle shape;
47     construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
48     if (!construction_status_.ok()) {
49       return;
50     }
51     input_tensors_as_shape_handles.push_back(shape);
52   }
53   PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
54   if (!construction_status_.ok()) return;
55   inputs_.reserve(input_shapes.size());
56   for (const PartialTensorShape& p : input_shapes) {
57     ShapeHandle shape;
58     construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
59     if (!construction_status_.ok()) {
60       return;
61     }
62     inputs_.push_back(shape);
63   }
64   std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
65       input_shapes.size());
66   for (int i = 0, end = input_handle_shapes_and_types.size(); i < end; ++i) {
67     const auto& v = input_handle_shapes_and_types[i];
68     if (v == nullptr) {
69       continue;
70     }
71     handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
72     auto& new_v = *handle_data[i];
73     for (int j = 0, end = v->size(); j < end; ++j) {
74       const auto& p = (*v)[j];
75       construction_status_.Update(
76           MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape));
77       if (!construction_status_.ok()) {
78         return;
79       }
80       new_v[j].dtype = p.second;
81     }
82   }
83   PostInputInit(std::move(handle_data));
84 }
85 
InferenceContext(int graph_def_version,const AttrSlice & attrs,const OpDef & op_def,const std::vector<ShapeHandle> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes,std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types)86 InferenceContext::InferenceContext(
87     int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
88     const std::vector<ShapeHandle>& input_shapes,
89     const std::vector<const Tensor*>& input_tensors,
90     const std::vector<ShapeHandle>& input_tensors_as_shapes,
91     std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
92         input_handle_shapes_and_types)
93     : graph_def_version_(graph_def_version), attrs_(attrs) {
94   PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
95   if (!construction_status_.ok()) return;
96   inputs_ = input_shapes;
97 
98   PostInputInit(std::move(input_handle_shapes_and_types));
99 }
100 
~InferenceContext()101 InferenceContext::~InferenceContext() {}
102 
Run(const std::function<Status (shape_inference::InferenceContext * c)> & fn)103 Status InferenceContext::Run(
104     const std::function<Status(shape_inference::InferenceContext* c)>& fn) {
105   ForgetMerges();
106   Status s = fn(this);
107   if (!s.ok()) {
108     ForgetMerges();
109     return AttachContext(s);
110   }
111 #ifndef NDEBUG
112   for (int i = 0; i < num_outputs(); ++i) {
113     DCHECK(output(i).IsSet()) << i << " for " << attrs_.SummarizeNode();
114   }
115 #endif  // NDEBUG
116   return s;
117 }
118 
set_output(StringPiece output_name,const std::vector<ShapeHandle> & shapes)119 Status InferenceContext::set_output(StringPiece output_name,
120                                     const std::vector<ShapeHandle>& shapes) {
121   auto result = output_name_map_.find(output_name);
122   if (result == output_name_map_.end()) {
123     return errors::InvalidArgument("Unknown output name: ", output_name);
124   } else {
125     const int start = result->second.first;
126     const int size = result->second.second - start;
127     const int shapes_size = shapes.size();
128     if (size != shapes_size) {
129       return errors::InvalidArgument("Must provide exactly ", size, " shapes.");
130     }
131     for (int i = 0; i < shapes_size; ++i) {
132       outputs_[i + start] = shapes[i];
133     }
134   }
135   return OkStatus();
136 }
137 
input(StringPiece input_name,std::vector<ShapeHandle> * output) const138 Status InferenceContext::input(StringPiece input_name,
139                                std::vector<ShapeHandle>* output) const {
140   const auto result = input_name_map_.find(input_name);
141   if (result == input_name_map_.end()) {
142     return errors::InvalidArgument("Unknown input name: ", input_name);
143   } else {
144     output->clear();
145     for (int i = result->second.first; i < result->second.second; ++i) {
146       output->push_back(inputs_[i]);
147     }
148   }
149   return OkStatus();
150 }
151 
output(StringPiece output_name,std::vector<ShapeHandle> * output) const152 Status InferenceContext::output(StringPiece output_name,
153                                 std::vector<ShapeHandle>* output) const {
154   const auto result = output_name_map_.find(output_name);
155   if (result == output_name_map_.end()) {
156     return errors::InvalidArgument("Unknown output name: ", output_name);
157   } else {
158     output->clear();
159     for (int i = result->second.first; i < result->second.second; ++i) {
160       output->push_back(outputs_[i]);
161     }
162   }
163   return OkStatus();
164 }
165 
PreInputInit(const OpDef & op_def,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes)166 void InferenceContext::PreInputInit(
167     const OpDef& op_def, const std::vector<const Tensor*>& input_tensors,
168     const std::vector<ShapeHandle>& input_tensors_as_shapes) {
169   // TODO(mdan): This is also done at graph construction. Run only here instead?
170   Status s = full_type::SpecializeType(attrs_, op_def, ret_types_);
171   if (!s.ok()) {
172     construction_status_ = s;
173     return;
174   }
175 
176   input_tensors_ = input_tensors;
177   input_tensors_as_shapes_ = input_tensors_as_shapes;
178 
179   construction_status_ =
180       NameRangesForNode(attrs_, op_def, &input_name_map_, &output_name_map_);
181   if (!construction_status_.ok()) return;
182 
183   int num_outputs = 0;
184   for (const auto& e : output_name_map_) {
185     num_outputs = std::max(num_outputs, e.second.second);
186   }
187   outputs_.assign(num_outputs, nullptr);
188   output_handle_shapes_and_types_.resize(num_outputs);
189 }
190 
ExpandOutputs(int new_output_size)191 Status InferenceContext::ExpandOutputs(int new_output_size) {
192   const int outputs_size = outputs_.size();
193   if (new_output_size < outputs_size) {
194     return errors::InvalidArgument("Trying to reduce number of outputs of op.");
195   }
196   outputs_.resize(new_output_size, nullptr);
197   output_handle_shapes_and_types_.resize(new_output_size);
198   return OkStatus();
199 }
200 
PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data)201 void InferenceContext::PostInputInit(
202     std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
203   int num_inputs_from_node_def = 0;
204   for (const auto& e : input_name_map_) {
205     num_inputs_from_node_def =
206         std::max(num_inputs_from_node_def, e.second.second);
207   }
208 
209   // Allow passing empty shapes/dtypes to avoid changing every single test.
210   if (input_handle_data.empty()) {
211     input_handle_shapes_and_types_.resize(inputs_.size());
212   } else {
213     if (input_handle_data.size() != inputs_.size()) {
214       construction_status_ = errors::InvalidArgument(
215           "Wrong number of handle shapes passed; expected ", inputs_.size(),
216           " got ", input_handle_data.size());
217       return;
218     }
219     input_handle_shapes_and_types_ = std::move(input_handle_data);
220   }
221   const int inputs_size = inputs_.size();
222   if (inputs_size != num_inputs_from_node_def) {
223     construction_status_ = errors::InvalidArgument(
224         "Wrong number of inputs passed: ", inputs_.size(), " while ",
225         num_inputs_from_node_def, " expected based on NodeDef");
226     return;
227   }
228 
229   CHECK_LE(input_tensors_.size(), inputs_.size());
230   input_tensors_.resize(inputs_.size());
231   requested_input_tensor_.resize(inputs_.size());
232   requested_input_tensor_as_partial_shape_.resize(inputs_.size());
233 }
234 
ShapeHandleToProto(ShapeHandle handle,TensorShapeProto * proto)235 void InferenceContext::ShapeHandleToProto(ShapeHandle handle,
236                                           TensorShapeProto* proto) {
237   if (!RankKnown(handle)) {
238     proto->set_unknown_rank(true);
239     return;
240   }
241 
242   for (int32_t i = 0; i < Rank(handle); ++i) {
243     DimensionHandle dim = Dim(handle, i);
244     auto* dim_shape = proto->add_dim();
245     if (ValueKnown(dim)) {
246       dim_shape->set_size(Value(dim));
247     } else {
248       dim_shape->set_size(-1);
249     }
250   }
251 }
252 
FullyDefined(ShapeHandle s)253 bool InferenceContext::FullyDefined(ShapeHandle s) {
254   if (!RankKnown(s)) return false;
255   for (int i = 0; i < Rank(s); ++i) {
256     if (!ValueKnown(Dim(s, i))) return false;
257   }
258   return true;
259 }
260 
NumElements(ShapeHandle s)261 DimensionHandle InferenceContext::NumElements(ShapeHandle s) {
262   const auto rank = Rank(s);
263   if (rank == kUnknownRank) return UnknownDim();
264   bool found_unknown = false;
265   int64_t size = 1;
266   for (int i = 0; i < rank; ++i) {
267     int64_t dim_val = Value(Dim(s, i));
268     if (dim_val == kUnknownDim) {
269       found_unknown = true;
270     } else if (dim_val == 0) {
271       return MakeDim(0);
272     } else {
273       size *= dim_val;
274     }
275   }
276   if (found_unknown) {
277     return UnknownDim();
278   } else {
279     return MakeDim(size);
280   }
281 }
282 
DebugString(ShapeHandle s)283 string InferenceContext::DebugString(ShapeHandle s) {
284   if (RankKnown(s)) {
285     std::vector<string> vals;
286     for (auto d : s->dims_) vals.push_back(DebugString(d));
287     return strings::StrCat("[", absl::StrJoin(vals, ","), "]");
288   } else {
289     return "?";
290   }
291 }
292 
DebugString(DimensionHandle d)293 string InferenceContext::DebugString(DimensionHandle d) {
294   return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
295 }
296 
DebugString() const297 string InferenceContext::DebugString() const {
298   return strings::StrCat("InferenceContext for node: ", attrs_.SummarizeNode());
299 }
300 
DebugString(const ShapeAndType & shape_and_type)301 string InferenceContext::DebugString(const ShapeAndType& shape_and_type) {
302   return strings::StrCat(DebugString(shape_and_type.shape), ":",
303                          DataTypeString(shape_and_type.dtype));
304 }
305 
DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types)306 string InferenceContext::DebugString(
307     gtl::ArraySlice<ShapeAndType> shape_and_types) {
308   std::vector<string> pieces;
309   for (const ShapeAndType& s : shape_and_types) {
310     pieces.push_back(DebugString(s));
311   }
312   return strings::StrCat("[", absl::StrJoin(pieces, ","), "]");
313 }
314 
WithRank(ShapeHandle shape,int64_t rank,ShapeHandle * out)315 Status InferenceContext::WithRank(ShapeHandle shape, int64_t rank,
316                                   ShapeHandle* out) {
317   if (rank > kint32max) {
318     return errors::InvalidArgument("Rank cannot exceed kint32max");
319   }
320   const int32_t existing = Rank(shape);
321   if (existing == rank) {
322     *out = shape;
323     return OkStatus();
324   }
325   if (existing == kUnknownRank) {
326     std::vector<DimensionHandle> dims;
327     dims.reserve(rank);
328     for (int i = 0; i < rank; ++i) {
329       dims.push_back(UnknownDim());
330     }
331     ShapeHandle shp = shape_manager_.MakeShape(dims);
332     return Merge(shape, shp, out);
333   }
334   *out = nullptr;
335 
336   return errors::InvalidArgument("Shape must be rank ", rank, " but is rank ",
337                                  existing);
338 }
339 
WithRankAtLeast(ShapeHandle shape,int64_t rank,ShapeHandle * out)340 Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64_t rank,
341                                          ShapeHandle* out) {
342   if (rank > kint32max) {
343     return errors::InvalidArgument("Rank cannot exceed kint32max");
344   }
345   const int32_t existing = Rank(shape);
346   if (existing >= rank || existing == kUnknownRank) {
347     *out = shape;
348     return OkStatus();
349   }
350   *out = nullptr;
351   return errors::InvalidArgument("Shape must be at least rank ", rank,
352                                  " but is rank ", existing);
353 }
354 
WithRankAtMost(ShapeHandle shape,int64_t rank,ShapeHandle * out)355 Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64_t rank,
356                                         ShapeHandle* out) {
357   if (rank > kint32max) {
358     return errors::InvalidArgument("Rank cannot exceed kint32max");
359   }
360   const int32_t existing = Rank(shape);
361   if (existing <= rank || existing == kUnknownRank) {
362     *out = shape;
363     return OkStatus();
364   }
365   *out = nullptr;
366   return errors::InvalidArgument("Shape must be at most rank ", rank,
367                                  " but is rank ", existing);
368 }
369 
WithValue(DimensionHandle dim,int64_t value,DimensionHandle * out)370 Status InferenceContext::WithValue(DimensionHandle dim, int64_t value,
371                                    DimensionHandle* out) {
372   const int64_t existing = Value(dim);
373   if (existing == value) {
374     *out = dim;
375     return OkStatus();
376   }
377   if (existing == kUnknownDim) {
378     DimensionHandle d = MakeDim(value);
379     return Merge(dim, d, out);
380   }
381   *out = nullptr;
382   return errors::InvalidArgument("Dimension must be ", value, " but is ",
383                                  existing);
384 }
385 
Relax(DimensionHandle d_old,DimensionHandle d_new,DimensionHandle * out)386 void InferenceContext::Relax(DimensionHandle d_old, DimensionHandle d_new,
387                              DimensionHandle* out) {
388   if (d_old.SameHandle(d_new)) {
389     *out = d_old;
390   } else if (!ValueKnown(d_old) && !ValueKnown(d_new)) {
391     // The node will be fed by the dimension d_new instead of d_old: any
392     // equality assertion between d_old and other input dimension on this node
393     // may not be true anymore, so forget them all.
394     ForgetMerges();
395     // Return the new shape handle to force the relaxation to propagate to the
396     // fanout of the context.
397     *out = d_new;
398   } else if (!ValueKnown(d_new)) {
399     ForgetMerges();
400     *out = d_new;
401   } else if (Value(d_old) == Value(d_new)) {
402     // Return the old shape handle. This will stop the relaxation in the fanout
403     // of the context.
404     *out = d_old;
405   } else {
406     // Return a new handle that encodes a different unknown dim.
407     ForgetMerges();
408     *out = UnknownDim();
409   }
410 }
411 
Merge(DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)412 Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1,
413                                DimensionHandle* out) {
414   if (d0.SameHandle(d1)) {
415     *out = d0;
416     return OkStatus();
417   } else if (!ValueKnown(d1)) {
418     *out = d0;
419     merged_dims_.emplace_back(d0, d1);
420     return OkStatus();
421   } else if (!ValueKnown(d0)) {
422     *out = d1;
423     merged_dims_.emplace_back(d0, d1);
424     return OkStatus();
425   } else if (Value(d0) == Value(d1)) {
426     *out = d0;
427     return OkStatus();
428   } else {
429     *out = nullptr;
430     return errors::InvalidArgument("Dimensions must be equal, but are ",
431                                    Value(d0), " and ", Value(d1));
432   }
433 }
434 
MergePrefix(ShapeHandle s,ShapeHandle prefix,ShapeHandle * s_out,ShapeHandle * prefix_out)435 Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix,
436                                      ShapeHandle* s_out,
437                                      ShapeHandle* prefix_out) {
438   *s_out = *prefix_out = nullptr;
439   if (!RankKnown(prefix) || !RankKnown(s)) {
440     *s_out = s;
441     *prefix_out = prefix;
442     return OkStatus();
443   }
444   const int32_t rank = Rank(prefix);
445   TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s));
446 
447   // Merge the prefix dims and create the new output shapes.
448   const int32_t rank_s = Rank(s);
449   std::vector<DimensionHandle> dims;
450   dims.reserve(std::max(rank, rank_s));
451   dims.resize(rank);
452   for (int i = 0; i < rank; ++i) {
453     TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i]));
454   }
455   *prefix_out = MakeShape(dims);
456   for (int i = rank; i < rank_s; ++i) dims.push_back(Dim(s, i));
457   *s_out = MakeShape(dims);
458   return OkStatus();
459 }
460 
Relax(ShapeHandle s_old,ShapeHandle s_new,ShapeHandle * out)461 void InferenceContext::Relax(ShapeHandle s_old, ShapeHandle s_new,
462                              ShapeHandle* out) {
463   if (s_old.SameHandle(s_new)) {
464     *out = s_old;
465     return;
466   } else if (!RankKnown(s_new) || !s_old.IsSet()) {
467     ForgetMerges();
468     *out = s_new;
469     return;
470   }
471 
472   const int32_t rank = Rank(s_old);
473   if (rank != Rank(s_new)) {
474     ForgetMerges();
475     *out = UnknownShape();
476     return;
477   }
478 
479   bool return_s_old = true;
480   for (int i = 0; i < rank; ++i) {
481     auto d0 = Dim(s_old, i);
482     auto d1 = Dim(s_new, i);
483     if (d0.SameHandle(d1)) continue;
484 
485     auto v0 = Value(d0);
486     auto v1 = Value(d1);
487     if (v0 == kUnknownDim || v1 == kUnknownDim || v0 != v1) {
488       return_s_old = false;
489       break;
490     }
491   }
492   if (return_s_old) {
493     *out = s_old;
494     return;
495   }
496 
497   // Relax dims.
498   std::vector<DimensionHandle> dims(rank);
499   for (int i = 0; i < rank; ++i) {
500     Relax(Dim(s_old, i), Dim(s_new, i), &dims[i]);
501   }
502   ForgetMerges();
503   *out = MakeShape(dims);
504 }
505 
Merge(ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)506 Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
507                                ShapeHandle* out) {
508   if (s0.SameHandle(s1)) {
509     *out = s0;
510     return OkStatus();
511   } else if (!RankKnown(s1)) {
512     *out = s0;
513     merged_shapes_.emplace_back(s0, s1);
514     return OkStatus();
515   } else if (!RankKnown(s0)) {
516     *out = s1;
517     merged_shapes_.emplace_back(s0, s1);
518     return OkStatus();
519   }
520 
521   const int32_t rank = Rank(s0);
522   if (rank != Rank(s1)) {
523     *out = nullptr;
524     return errors::InvalidArgument("Shapes must be equal rank, but are ", rank,
525                                    " and ", Rank(s1));
526   }
527 
528   bool return_s0 = true;
529   bool return_s1 = true;
530   for (int i = 0; i < rank; ++i) {
531     auto d0 = Dim(s0, i);
532     auto d1 = Dim(s1, i);
533     if (d0.SameHandle(d1)) continue;
534 
535     auto v0 = Value(d0);
536     auto v1 = Value(d1);
537     if (v0 == kUnknownDim) {
538       if (v1 != kUnknownDim) {
539         return_s0 = false;
540       }
541     } else if (v1 == kUnknownDim) {
542       return_s1 = false;
543     } else if (v0 != v1) {
544       *out = nullptr;
545       return errors::InvalidArgument(
546           "Dimension ", i, " in both shapes must be equal, but are ", Value(d0),
547           " and ", Value(d1), ". Shapes are ", DebugString(s0), " and ",
548           DebugString(s1), ".");
549     }
550   }
551 
552   merged_shapes_.emplace_back(s0, s1);
553 
554   if (return_s0 || return_s1) {
555     *out = return_s0 ? s0 : s1;
556     return OkStatus();
557   }
558 
559   // Merge dims.
560   std::vector<DimensionHandle> dims(rank, nullptr);
561   for (int i = 0; i < rank; ++i) {
562     // Invariant for merge was checked earlier, so CHECK is ok.
563     TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i]));
564   }
565 
566   Status s = ReturnCreatedShape(dims, out);
567   if (s.ok()) {
568     // Merge the new shape with s0. Since s0 and s1 are merged, this implies
569     // that s1 and out are also merged.
570     merged_shapes_.emplace_back(s0, *out);
571   }
572   return s;
573 }
574 
Subshape(ShapeHandle s,int64_t start,ShapeHandle * out)575 Status InferenceContext::Subshape(ShapeHandle s, int64_t start,
576                                   ShapeHandle* out) {
577   return Subshape(s, start, std::numeric_limits<int64_t>::max() /* end */, out);
578 }
579 
Subshape(ShapeHandle s,int64_t start,int64_t end,ShapeHandle * out)580 Status InferenceContext::Subshape(ShapeHandle s, int64_t start, int64_t end,
581                                   ShapeHandle* out) {
582   return Subshape(s, start, end, 1 /* stride */, out);
583 }
584 
Subshape(ShapeHandle s,int64_t start,int64_t end,int64_t stride,ShapeHandle * out)585 Status InferenceContext::Subshape(ShapeHandle s, int64_t start, int64_t end,
586                                   int64_t stride, ShapeHandle* out) {
587   int64_t start_in = start;
588   int64_t end_in = end;
589 
590   const int32_t rank = Rank(s);
591   if (start == 0 && stride == 1 &&
592       ((RankKnown(s) && end >= rank) ||
593        end == std::numeric_limits<int64_t>::max())) {
594     *out = s;
595     return OkStatus();
596   }
597   if (!RankKnown(s)) {
598     return ReturnUnknownShape(out);
599   }
600 
601   if (start > rank) start = rank;
602   if (end > rank) end = rank;
603 
604   if (stride < 0 && start == rank) --start;
605 
606   if (start < 0) {
607     start = rank + start;
608     if (start < 0) {
609       *out = nullptr;
610       return errors::InvalidArgument("Subshape start out of bounds: ", start_in,
611                                      ", for shape with rank ", rank);
612     }
613   }
614 
615   if (end < 0) {
616     end = rank + end;
617     if (end < 0) {
618       *out = nullptr;
619       return errors::InvalidArgument("Subshape end out of bounds: ", end_in,
620                                      ", for shape with rank ", rank);
621     }
622   }
623   if (stride > 0 && start > end) {
624     *out = nullptr;
625     return errors::InvalidArgument(
626         "Subshape must have computed start <= end, but is ", start, " and ",
627         end, " (computed from start ", start_in, " and end ", end_in,
628         " over shape with rank ", rank, ")");
629   } else if (stride < 0 && start < end) {
630     *out = nullptr;
631     return errors::InvalidArgument(
632         "Subshape must have computed start >= end since stride is negative, "
633         "but is ",
634         start, " and ", end, " (computed from start ", start_in, " and end ",
635         end_in, " over shape with rank ", rank, " and stride", stride, ")");
636   }
637 
638   std::vector<DimensionHandle> dims;
639   for (int i = start; stride > 0 ? i < end : i > end; i += stride) {
640     dims.push_back(Dim(s, i));
641   }
642   return ReturnCreatedShape(dims, out);
643 }
644 
Concatenate(ShapeHandle s1,ShapeHandle s2,ShapeHandle * out)645 Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2,
646                                      ShapeHandle* out) {
647   if (!RankKnown(s1) || !RankKnown(s2)) {
648     return ReturnUnknownShape(out);
649   }
650   const int32_t s1_rank = Rank(s1);
651   const int32_t s2_rank = Rank(s2);
652   const int32_t rank = s1_rank + s2_rank;
653   std::vector<DimensionHandle> dims;
654   dims.reserve(rank);
655   for (int i = 0; i < s1_rank; ++i) dims.push_back(Dim(s1, i));
656   for (int i = 0; i < s2_rank; ++i) dims.push_back(Dim(s2, i));
657   return ReturnCreatedShape(dims, out);
658 }
659 
ReplaceDim(ShapeHandle s,int64_t dim_index_in,DimensionHandle new_dim,ShapeHandle * out)660 Status InferenceContext::ReplaceDim(ShapeHandle s, int64_t dim_index_in,
661                                     DimensionHandle new_dim, ShapeHandle* out) {
662   if (!RankKnown(s)) {
663     return ReturnUnknownShape(out);
664   }
665   int64_t dim_index = dim_index_in;
666   if (dim_index < 0) {
667     dim_index = s->dims_.size() + dim_index;
668   }
669   if (!FastBoundsCheck(dim_index, s->dims_.size())) {
670     *out = nullptr;
671     return errors::InvalidArgument("Out of range dim_index ", dim_index_in,
672                                    " for shape with ", s->dims_.size(),
673                                    " dimensions");
674   }
675   std::vector<DimensionHandle> dims(s->dims_);
676   dims[dim_index] = new_dim;
677   return ReturnCreatedShape(dims, out);
678 }
679 
MakeShape(const std::vector<DimensionHandle> & dims)680 ShapeHandle InferenceContext::MakeShape(
681     const std::vector<DimensionHandle>& dims) {
682   return shape_manager_.MakeShape(dims);
683 }
684 
MakeShape(std::initializer_list<DimensionOrConstant> dims)685 ShapeHandle InferenceContext::MakeShape(
686     std::initializer_list<DimensionOrConstant> dims) {
687   std::vector<DimensionHandle> dims_actual;
688   dims_actual.reserve(dims.size());
689   for (const DimensionOrConstant& d : dims) {
690     dims_actual.push_back(MakeDim(d));
691   }
692 
693   return shape_manager_.MakeShape(dims_actual);
694 }
695 
UnknownShape()696 ShapeHandle InferenceContext::UnknownShape() {
697   return shape_manager_.UnknownShape();
698 }
699 
UnknownShapeOfRank(int64_t rank)700 ShapeHandle InferenceContext::UnknownShapeOfRank(int64_t rank) {
701   CHECK_LE(rank, kint32max) << "rank must be less than kint32max";
702   if (rank == kUnknownRank) {
703     return UnknownShape();
704   }
705   CHECK_GE(rank, 0) << "rank must not be negative";
706   std::vector<DimensionHandle> dims(rank);
707   for (int32_t i = 0; i < rank; ++i) {
708     dims[i] = UnknownDim();
709   }
710   return MakeShape(dims);
711 }
712 
Scalar()713 ShapeHandle InferenceContext::Scalar() { return MakeShape({}); }
714 
Vector(DimensionOrConstant dim)715 ShapeHandle InferenceContext::Vector(DimensionOrConstant dim) {
716   return MakeShape({dim});
717 }
718 
Matrix(DimensionOrConstant dim1,DimensionOrConstant dim2)719 ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
720                                      DimensionOrConstant dim2) {
721   return MakeShape({dim1, dim2});
722 }
723 
MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,ShapeHandle * out)724 Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
725     int input_idx, ShapeHandle* out) {
726   ShapeHandle input_shape;
727   TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
728 
729   request_input_tensor_as_partial_shape(input_idx);
730   const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
731   if (input_idx < input_tensors_as_shapes_size &&
732       input_tensors_as_shapes_[input_idx].IsSet() &&
733       RankKnown(input_tensors_as_shapes_[input_idx])) {
734     *out = input_tensors_as_shapes_[input_idx];
735     return OkStatus();
736   }
737 
738   return InternalMakeShapeFromTensor(
739       true /* treat_unknown_scalar_tensor_as_unknown_shape */,
740       input_tensor(input_idx), input_shape, out);
741 }
742 
MakeShapeFromShapeTensor(int input_idx,ShapeHandle * out)743 Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
744                                                   ShapeHandle* out) {
745   ShapeHandle input_shape;
746   TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
747 
748   request_input_tensor_as_partial_shape(input_idx);
749   const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
750   if (input_idx < input_tensors_as_shapes_size &&
751       input_tensors_as_shapes_[input_idx].IsSet() &&
752       RankKnown(input_tensors_as_shapes_[input_idx])) {
753     *out = input_tensors_as_shapes_[input_idx];
754     return OkStatus();
755   }
756 
757   return InternalMakeShapeFromTensor(
758       false /* treat_unknown_scalar_tensor_as_unknown_shape */,
759       input_tensor(input_idx), input_shape, out);
760 }
761 
MakeShapeFromTensor(const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)762 Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
763                                              ShapeHandle tensor_shape,
764                                              ShapeHandle* out) {
765   return InternalMakeShapeFromTensor(
766       false /* treat_unknown_scalar_tensor_as_unknown_shape */, t, tensor_shape,
767       out);
768 }
769 
InternalMakeShapeFromTensor(bool treat_unknown_scalar_tensor_as_unknown_shape,const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)770 Status InferenceContext::InternalMakeShapeFromTensor(
771     bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
772     ShapeHandle tensor_shape, ShapeHandle* out) {
773   // Only callers who have set
774   if (!treat_unknown_scalar_tensor_as_unknown_shape) {
775     TF_RETURN_IF_ERROR(WithRank(tensor_shape, 1, &tensor_shape));
776   }
777   if (t == nullptr) {
778     // This is guarded by the check above.
779     if (Rank(tensor_shape) == 0) {
780       return ReturnUnknownShape(out);
781     }
782     // Shape tensor is not known, but if the shape of the shape tensor is then
783     // the right number of unknown dims can be created.
784     DimensionHandle shape_dim = Dim(tensor_shape, 0);
785     if (!ValueKnown(shape_dim)) {
786       return ReturnUnknownShape(out);
787     }
788     const auto num_dims = Value(shape_dim);
789     // Note: This should be `TensorShape::MaxDimensions()` as we are not able to
790     // materialize shapes with more than this number of dimensions but then
791     // shape inference would fail for operations such as `tf.range`/`tf.ones`,
792     // etc. where the shape is not really materialized, only used during the
793     // inference. Hence, just prevent doing a `reserve` with a very large
794     // argument.
795     const int64_t max_dimensions = 1 << 25;
796     if (num_dims >= max_dimensions) {
797       return errors::Internal(
798           "Cannot create a tensor with ", num_dims,
799           " dimensions, as these would be more than maximum of ",
800           max_dimensions);
801     }
802     std::vector<DimensionHandle> dims;
803     dims.reserve(num_dims);
804     for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim());
805     return ReturnCreatedShape(dims, out);
806   }
807 
808   if (t->shape().dims() == 0) {
809     if (t->dtype() == DataType::DT_INT32) {
810       auto flat_t = t->scalar<int32>();
811       if (flat_t() != -1) {
812         *out = nullptr;
813         return errors::InvalidArgument(
814             "Input tensor must be rank 1, or if its rank 0 it must have value "
815             "-1 "
816             "(representing an unknown shape).  Saw value: ",
817             flat_t());
818       }
819       return ReturnUnknownShape(out);
820     } else if (t->dtype() == DataType::DT_INT64) {
821       auto flat_t = t->scalar<int64_t>();
822       if (flat_t() != -1) {
823         *out = nullptr;
824         return errors::InvalidArgument(
825             "Input tensor must be rank 1, or if its rank 0 it must have value "
826             "-1 "
827             "(representing an unknown shape).  Saw value: ",
828             flat_t());
829       }
830       return ReturnUnknownShape(out);
831     } else {
832       *out = nullptr;
833       return errors::InvalidArgument(
834           "Input tensor must be int32 or int64, but was ",
835           DataTypeString(t->dtype()));
836     }
837   }
838 
839   if (t->shape().dims() != 1) {
840     *out = nullptr;
841     return errors::InvalidArgument(
842         "Input tensor must be rank 1, but was rank ", t->shape().dims(), ".",
843         ((t->shape().dims() == 0)
844              ? "If it is rank 0 it must have statically known value -1 "
845                "(representing an unknown shape). "
846              : " "),
847         "Saw tensor shape ", t->shape().DebugString());
848   }
849   std::vector<DimensionHandle> dims;
850   if (t->dtype() == DataType::DT_INT32) {
851     auto flat_t = t->flat<int32>();
852     for (int i = 0; i < flat_t.size(); ++i) {
853       const int32_t val = flat_t(i);
854       if (val < -1) {
855         return errors::InvalidArgument(
856             "Invalid value in tensor used for shape: ", val);
857       }
858       // -1 will become an unknown dim.
859       dims.push_back(MakeDim(val));
860     }
861   } else if (t->dtype() == DataType::DT_INT64) {
862     auto flat_t = t->flat<int64_t>();
863     for (int i = 0; i < flat_t.size(); ++i) {
864       const int64_t val = flat_t(i);
865       if (val < -1) {
866         return errors::InvalidArgument(
867             "Invalid value in tensor used for shape: ", val);
868       }
869       // -1 will become an unknown dim.
870       dims.push_back(MakeDim(val));
871     }
872   } else {
873     *out = nullptr;
874     return errors::InvalidArgument(
875         "Input tensor must be int32 or int64, but was ",
876         DataTypeString(t->dtype()));
877   }
878 
879   return ReturnCreatedShape(dims, out);
880 }
881 
MakeShapeFromPartialTensorShape(const PartialTensorShape & partial_shape,ShapeHandle * out)882 Status InferenceContext::MakeShapeFromPartialTensorShape(
883     const PartialTensorShape& partial_shape, ShapeHandle* out) {
884   *out = nullptr;
885   if (partial_shape.dims() == -1) {
886     return ReturnUnknownShape(out);
887   }
888   const int num_dims = partial_shape.dims();
889   std::vector<DimensionHandle> dims(num_dims);
890   for (int i = 0; i < num_dims; ++i) {
891     // -1 is unknown in PartialTensorShape and in InferenceContext, so this size
892     // can be passed directly to MakeDim.
893     dims[i] = MakeDim(partial_shape.dim_size(i));
894   }
895   return ReturnCreatedShape(dims, out);
896 }
897 
MakeShapeFromTensorShape(const TensorShape & shape,ShapeHandle * out)898 Status InferenceContext::MakeShapeFromTensorShape(const TensorShape& shape,
899                                                   ShapeHandle* out) {
900   return MakeShapeFromPartialTensorShape(PartialTensorShape(shape.dim_sizes()),
901                                          out);
902 }
903 
MakeShapeFromShapeTensor(const TensorShape & shape)904 StatusOr<ShapeHandle> InferenceContext::MakeShapeFromShapeTensor(
905     const TensorShape& shape) {
906   ShapeHandle out;
907   TF_RETURN_IF_ERROR(MakeShapeFromTensorShape(shape, &out));
908   return out;
909 }
910 
ShapeHandleToProto(ShapeHandle handle)911 TensorShapeProto InferenceContext::ShapeHandleToProto(ShapeHandle handle) {
912   TensorShapeProto out;
913   ShapeHandleToProto(handle, &out);
914   return out;
915 }
916 
MakeShapeFromShapeProto(const TensorShapeProto & proto,ShapeHandle * out)917 Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
918                                                  ShapeHandle* out) {
919   *out = nullptr;
920   TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto));
921   PartialTensorShape partial_shape(proto);
922   return MakeShapeFromPartialTensorShape(partial_shape, out);
923 }
924 
GetScalarFromTensor(const Tensor * t,int64_t * val)925 Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64_t* val) {
926   // Caller must ensure that <t> is not NULL.
927   const int rank = t->dims();
928   if (rank != 0) {
929     return errors::InvalidArgument("Input must be scalar but has rank ", rank);
930   }
931 
932   if (t->dtype() == DataType::DT_INT32) {
933     *val = t->scalar<int32>()();
934     return OkStatus();
935   } else if (t->dtype() == DataType::DT_INT64) {
936     *val = t->scalar<int64_t>()();
937     return OkStatus();
938   } else {
939     return errors::InvalidArgument("Scalar input must be int32 or int64.");
940   }
941 }
942 
GetScalarFromTensor(const Tensor * t,int64_t idx,int64_t * val)943 Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64_t idx,
944                                              int64_t* val) {
945   // Caller must ensure that <t> is not NULL.
946   const int rank = t->dims();
947   if (rank != 1) {
948     return errors::InvalidArgument("Input must be 1D but has rank ", rank);
949   }
950 
951   if (t->dtype() == DataType::DT_INT32) {
952     auto flat_t = t->flat<int32>();
953     if (idx < 0 || idx >= flat_t.size()) {
954       return errors::InvalidArgument("Invalid index ", idx,
955                                      " for Tensor of size ", flat_t.size());
956     }
957     *val = flat_t(idx);
958     return OkStatus();
959   } else if (t->dtype() == DataType::DT_INT64) {
960     auto flat_t = t->flat<int64_t>();
961     if (idx < 0 || idx >= flat_t.size()) {
962       return errors::InvalidArgument("Invalid index ", idx,
963                                      " for Tensor of size ", flat_t.size());
964     }
965     *val = flat_t(idx);
966     return OkStatus();
967   } else {
968     return errors::InvalidArgument("Tensor input must be int32 or int64.");
969   }
970 }
971 
972 // Returns a new dimension whose value is given by a scalar input tensor.
MakeDimForScalarInput(int idx,DimensionHandle * out)973 Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
974   int64_t val;
975   const Tensor* t = input_tensor(idx);
976   if (t == nullptr) {
977     *out = UnknownDim();
978     return OkStatus();
979   }
980   TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
981   if (val < 0) {
982     return errors::InvalidArgument("Dimension size, given by scalar input ",
983                                    idx, ", must be non-negative but is ", val);
984   }
985   *out = MakeDim(val);
986   return OkStatus();
987 }
988 
MakeDimForScalarInputWithNegativeIndexing(int idx,int input_rank,DimensionHandle * out)989 Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing(
990     int idx, int input_rank, DimensionHandle* out) {
991   int64_t val;
992   const Tensor* t = input_tensor(idx);
993   if (t == nullptr) {
994     *out = UnknownDim();
995     return OkStatus();
996   }
997   TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
998   if (val < 0) {
999     if (input_rank < 0) {
1000       *out = UnknownDim();
1001       return OkStatus();
1002     } else if (val + input_rank < 0) {
1003       return errors::InvalidArgument("Dimension size, given by scalar input ",
1004                                      val, " must be in range [-", input_rank,
1005                                      ", ", input_rank, ")");
1006     } else {
1007       val += input_rank;
1008     }
1009   } else if (input_rank >= 0 && val >= input_rank) {
1010     return errors::InvalidArgument("Dimension size, given by scalar input ",
1011                                    val, " must be in range [-", input_rank,
1012                                    ", ", input_rank, ")");
1013   }
1014   *out = MakeDim(val);
1015   return OkStatus();
1016 }
1017 
Divide(DimensionHandle dividend,DimensionOrConstant divisor,bool evenly_divisible,DimensionHandle * out)1018 Status InferenceContext::Divide(DimensionHandle dividend,
1019                                 DimensionOrConstant divisor,
1020                                 bool evenly_divisible, DimensionHandle* out) {
1021   const int64_t divisor_value = Value(divisor);
1022   if (divisor_value == 1) {
1023     *out = dividend;
1024   } else if (!ValueKnown(dividend) ||
1025              (divisor.dim.IsSet() && !ValueKnown(divisor.dim))) {
1026     *out = UnknownDim();
1027   } else {
1028     const int64_t v = Value(dividend);
1029     if (divisor_value <= 0) {
1030       return errors::InvalidArgument("Divisor must be positive but is ",
1031                                      divisor_value);
1032     }
1033     if (evenly_divisible && (v % divisor_value) != 0) {
1034       return errors::InvalidArgument(
1035           "Dimension size must be evenly divisible by ", divisor_value,
1036           " but is ", v);
1037     }
1038     *out = MakeDim(v / divisor_value);
1039   }
1040   return OkStatus();
1041 }
1042 
Add(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1043 Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second,
1044                              DimensionHandle* out) {
1045   const int64_t first_value = Value(first);
1046   const int64_t second_value = Value(second);
1047   // Special cases.
1048   if (first_value == 0) {
1049     *out = MakeDim(second);
1050   } else if (second_value == 0) {
1051     *out = first;
1052   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1053     *out = UnknownDim();
1054   } else {
1055     // Invariant: Both values are known and positive. Still in run-time we can
1056     // get pair of values which cannot be store in output. Check below will
1057     // report error. We still need to avoid undefined behavior of signed
1058     // overflow and use unsigned addition.
1059     const int64_t sum = static_cast<uint64>(first_value) + second_value;
1060     if (sum < 0) {
1061       return errors::InvalidArgument("Dimension size overflow from adding ",
1062                                      first_value, " and ", second_value);
1063     }
1064     *out = MakeDim(sum);
1065   }
1066   return OkStatus();
1067 }
1068 
Subtract(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1069 Status InferenceContext::Subtract(DimensionHandle first,
1070                                   DimensionOrConstant second,
1071                                   DimensionHandle* out) {
1072   const int64_t first_value = Value(first);
1073   const int64_t second_value = Value(second);
1074   // Special cases.
1075   if (second_value == 0) {
1076     *out = first;
1077   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1078     *out = UnknownDim();
1079   } else {
1080     // Invariant: Both values are known, first_value is non-negative, and
1081     // second_value is positive.
1082     if (first_value < second_value) {
1083       return errors::InvalidArgument(
1084           "Negative dimension size caused by subtracting ", second_value,
1085           " from ", first_value);
1086     }
1087     *out = MakeDim(first_value - second_value);
1088   }
1089   return OkStatus();
1090 }
1091 
Multiply(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1092 Status InferenceContext::Multiply(DimensionHandle first,
1093                                   DimensionOrConstant second,
1094                                   DimensionHandle* out) {
1095   const int64_t first_value = Value(first);
1096   const int64_t second_value = Value(second);
1097   // Special cases.
1098   if (first_value == 0) {
1099     *out = first;
1100   } else if (second_value == 0) {
1101     *out = MakeDim(second);
1102   } else if (first_value == 1) {
1103     *out = MakeDim(second);
1104   } else if (second_value == 1) {
1105     *out = first;
1106   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1107     *out = UnknownDim();
1108   } else {
1109     // Invariant: Both values are known and greater than 1.
1110     const int64_t product = MultiplyWithoutOverflow(first_value, second_value);
1111     if (product < 0) {
1112       return errors::InvalidArgument(
1113           "Negative dimension size caused by overflow when multiplying ",
1114           first_value, " and ", second_value);
1115     }
1116     *out = MakeDim(product);
1117   }
1118   return OkStatus();
1119 }
1120 
Min(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1121 Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second,
1122                              DimensionHandle* out) {
1123   const int64_t first_value = Value(first);
1124   const int64_t second_value = Value(second);
1125   if (first_value == 0) {
1126     *out = first;
1127   } else if (second_value == 0) {
1128     *out = MakeDim(second);
1129   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1130     *out = UnknownDim();
1131   } else {
1132     if (first_value <= second_value) {
1133       *out = first;
1134     } else {
1135       *out = MakeDim(second);
1136     }
1137   }
1138   return OkStatus();
1139 }
1140 
Max(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1141 Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second,
1142                              DimensionHandle* out) {
1143   const int64_t first_value = Value(first);
1144   const int64_t second_value = Value(second);
1145   if (first_value == kUnknownDim || second_value == kUnknownDim) {
1146     *out = UnknownDim();
1147   } else {
1148     if (first_value >= second_value) {
1149       *out = first;
1150     } else {
1151       *out = MakeDim(second);
1152     }
1153   }
1154   return OkStatus();
1155 }
1156 
AttachContext(const Status & status)1157 Status InferenceContext::AttachContext(const Status& status) {
1158   std::vector<string> input_shapes;
1159   input_shapes.reserve(inputs_.size());
1160   for (const ShapeHandle& input_shape : inputs_) {
1161     input_shapes.emplace_back(DebugString(input_shape));
1162   }
1163 
1164   // Add information about the input tensors and partial tensor shapes used.
1165   std::vector<string> input_from_tensors_str;
1166   std::vector<string> input_from_tensors_as_shape_str;
1167   input_from_tensors_as_shape_str.reserve(inputs_.size());
1168   for (int i = 0, end = inputs_.size(); i < end; ++i) {
1169     const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
1170     const int input_tensors_size = input_tensors_.size();
1171     if (requested_input_tensor_as_partial_shape_[i] &&
1172         i < input_tensors_as_shapes_size &&
1173         input_tensors_as_shapes_[i].IsSet() &&
1174         RankKnown(input_tensors_as_shapes_[i])) {
1175       input_from_tensors_as_shape_str.push_back(strings::StrCat(
1176           "input[", i, "] = ", DebugString(input_tensors_as_shapes_[i])));
1177     } else if (requested_input_tensor_[i] && i < input_tensors_size &&
1178                input_tensors_[i] != nullptr) {
1179       input_from_tensors_str.push_back(strings::StrCat(
1180           "input[", i, "] = <",
1181           input_tensors_[i]->SummarizeValue(256 /* max_values */), ">"));
1182     }
1183   }
1184 
1185   string error_context = strings::StrCat(
1186       " for '", attrs_.SummarizeNode(),
1187       "' with input shapes: ", absl::StrJoin(input_shapes, ", "));
1188   if (!input_from_tensors_str.empty()) {
1189     strings::StrAppend(&error_context, " and with computed input tensors: ",
1190                        absl::StrJoin(input_from_tensors_str, ", "));
1191   }
1192   if (!input_from_tensors_as_shape_str.empty()) {
1193     strings::StrAppend(&error_context,
1194                        " and with input tensors computed as partial shapes: ",
1195                        absl::StrJoin(input_from_tensors_as_shape_str, ","));
1196   }
1197 
1198   strings::StrAppend(&error_context, ".");
1199   return errors::CreateWithUpdatedMessage(
1200       status, strings::StrCat(status.error_message(), error_context));
1201 }
1202 
MergeHandleShapesAndTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1203 bool InferenceContext::MergeHandleShapesAndTypes(
1204     const std::vector<ShapeAndType>& shapes_and_types,
1205     std::vector<ShapeAndType>* to_update) {
1206   if (shapes_and_types.size() != to_update->size()) {
1207     return false;
1208   }
1209   std::vector<ShapeAndType> new_values(shapes_and_types.size());
1210   bool refined = false;
1211   for (int i = 0, end = shapes_and_types.size(); i < end; ++i) {
1212     const ShapeAndType& existing = (*to_update)[i];
1213     if (shapes_and_types[i].dtype == existing.dtype) {
1214       new_values[i].dtype = existing.dtype;
1215     } else {
1216       if (existing.dtype != DT_INVALID) {
1217         return false;
1218       } else {
1219         new_values[i].dtype = shapes_and_types[i].dtype;
1220         refined = true;
1221       }
1222     }
1223     if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape)
1224              .ok()) {
1225       // merge failed, ignore the new value.
1226       new_values[i].shape = existing.shape;
1227     }
1228     if (!existing.shape.SameHandle(new_values[i].shape)) {
1229       refined = true;
1230     }
1231   }
1232   if (!refined) {
1233     return false;
1234   }
1235   to_update->swap(new_values);
1236   return true;
1237 }
1238 
MergeOutputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1239 bool InferenceContext::MergeOutputHandleShapesAndTypes(
1240     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1241   if (output_handle_shapes_and_types_[idx] == nullptr) {
1242     output_handle_shapes_and_types_[idx].reset(
1243         new std::vector<ShapeAndType>(shapes_and_types));
1244     return true;
1245   }
1246   return MergeHandleShapesAndTypes(shapes_and_types,
1247                                    output_handle_shapes_and_types_[idx].get());
1248 }
1249 
MergeInputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1250 bool InferenceContext::MergeInputHandleShapesAndTypes(
1251     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1252   if (input_handle_shapes_and_types_[idx] == nullptr) {
1253     input_handle_shapes_and_types_[idx].reset(
1254         new std::vector<ShapeAndType>(shapes_and_types));
1255     return true;
1256   }
1257   return MergeHandleShapesAndTypes(shapes_and_types,
1258                                    input_handle_shapes_and_types_[idx].get());
1259 }
1260 
RelaxHandleShapesAndMergeTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1261 bool InferenceContext::RelaxHandleShapesAndMergeTypes(
1262     const std::vector<ShapeAndType>& shapes_and_types,
1263     std::vector<ShapeAndType>* to_update) {
1264   if (shapes_and_types.size() != to_update->size()) {
1265     return false;
1266   }
1267   std::vector<ShapeAndType> new_values(shapes_and_types.size());
1268   for (int i = 0, end = shapes_and_types.size(); i < end; ++i) {
1269     const ShapeAndType& existing = (*to_update)[i];
1270     if (shapes_and_types[i].dtype == existing.dtype) {
1271       new_values[i].dtype = existing.dtype;
1272     } else {
1273       if (existing.dtype != DT_INVALID) {
1274         return false;
1275       } else {
1276         new_values[i].dtype = shapes_and_types[i].dtype;
1277       }
1278     }
1279     Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape);
1280   }
1281   to_update->swap(new_values);
1282   return true;
1283 }
1284 
RelaxOutputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1285 bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes(
1286     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1287   if (output_handle_shapes_and_types_[idx] == nullptr) {
1288     output_handle_shapes_and_types_[idx].reset(
1289         new std::vector<ShapeAndType>(shapes_and_types));
1290     return true;
1291   }
1292   return RelaxHandleShapesAndMergeTypes(
1293       shapes_and_types, output_handle_shapes_and_types_[idx].get());
1294 }
1295 
RelaxInputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1296 bool InferenceContext::RelaxInputHandleShapesAndMergeTypes(
1297     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1298   if (input_handle_shapes_and_types_[idx] == nullptr) {
1299     input_handle_shapes_and_types_[idx].reset(
1300         new std::vector<ShapeAndType>(shapes_and_types));
1301     return true;
1302   }
1303   return RelaxHandleShapesAndMergeTypes(
1304       shapes_and_types, input_handle_shapes_and_types_[idx].get());
1305 }
1306 
1307 // -----------------------------------------------------------------------------
1308 // ShapeManager
1309 // -----------------------------------------------------------------------------
ShapeManager()1310 InferenceContext::ShapeManager::ShapeManager() {}
~ShapeManager()1311 InferenceContext::ShapeManager::~ShapeManager() {
1312   for (auto* s : all_shapes_) delete s;
1313   for (auto* d : all_dims_) delete d;
1314 }
1315 
MakeShape(const std::vector<DimensionHandle> & dims)1316 ShapeHandle InferenceContext::ShapeManager::MakeShape(
1317     const std::vector<DimensionHandle>& dims) {
1318   all_shapes_.push_back(new Shape(dims));
1319   return all_shapes_.back();
1320 }
1321 
UnknownShape()1322 ShapeHandle InferenceContext::ShapeManager::UnknownShape() {
1323   all_shapes_.push_back(new Shape());
1324   return all_shapes_.back();
1325 }
1326 
1327 }  // namespace shape_inference
1328 }  // namespace tensorflow
1329