• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16 
17 #include "tensorflow/core/framework/bounds_check.h"
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/strings/numbers.h"
23 #include "tensorflow/core/lib/strings/scanner.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 
26 namespace tensorflow {
27 namespace shape_inference {
28 
29 constexpr int32 InferenceContext::kUnknownRank;
30 constexpr int64 InferenceContext::kUnknownDim;
31 
32 // Same as above, but with PartialTensorShape instead of TensorShapeProto
InferenceContext(int graph_def_version,const AttrSlice & attrs,const OpDef & op_def,const std::vector<PartialTensorShape> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<PartialTensorShape> & input_tensors_as_shapes,const std::vector<std::unique_ptr<std::vector<std::pair<PartialTensorShape,DataType>>>> & input_handle_shapes_and_types)33 InferenceContext::InferenceContext(
34     int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
35     const std::vector<PartialTensorShape>& input_shapes,
36     const std::vector<const Tensor*>& input_tensors,
37     const std::vector<PartialTensorShape>& input_tensors_as_shapes,
38     const std::vector<
39         std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
40         input_handle_shapes_and_types)
41     : graph_def_version_(graph_def_version), attrs_(attrs) {
42   std::vector<ShapeHandle> input_tensors_as_shape_handles;
43   input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
44   for (const PartialTensorShape& p : input_tensors_as_shapes) {
45     ShapeHandle shape;
46     construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
47     if (!construction_status_.ok()) {
48       return;
49     }
50     input_tensors_as_shape_handles.push_back(shape);
51   }
52   PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
53   if (!construction_status_.ok()) return;
54   inputs_.reserve(input_shapes.size());
55   for (const PartialTensorShape& p : input_shapes) {
56     ShapeHandle shape;
57     construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
58     if (!construction_status_.ok()) {
59       return;
60     }
61     inputs_.push_back(shape);
62   }
63   std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
64       input_shapes.size());
65   for (int i = 0, end = input_handle_shapes_and_types.size(); i < end; ++i) {
66     const auto& v = input_handle_shapes_and_types[i];
67     if (v == nullptr) {
68       continue;
69     }
70     handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
71     auto& new_v = *handle_data[i];
72     for (int j = 0, end = v->size(); j < end; ++j) {
73       const auto& p = (*v)[j];
74       construction_status_.Update(
75           MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape));
76       if (!construction_status_.ok()) {
77         return;
78       }
79       new_v[j].dtype = p.second;
80     }
81   }
82   PostInputInit(std::move(handle_data));
83 }
84 
InferenceContext(int graph_def_version,const AttrSlice & attrs,const OpDef & op_def,const std::vector<ShapeHandle> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes,std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types)85 InferenceContext::InferenceContext(
86     int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
87     const std::vector<ShapeHandle>& input_shapes,
88     const std::vector<const Tensor*>& input_tensors,
89     const std::vector<ShapeHandle>& input_tensors_as_shapes,
90     std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
91         input_handle_shapes_and_types)
92     : graph_def_version_(graph_def_version), attrs_(attrs) {
93   PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
94   if (!construction_status_.ok()) return;
95   inputs_ = input_shapes;
96 
97   PostInputInit(std::move(input_handle_shapes_and_types));
98 }
99 
~InferenceContext()100 InferenceContext::~InferenceContext() {}
101 
Run(const std::function<Status (shape_inference::InferenceContext * c)> & fn)102 Status InferenceContext::Run(
103     const std::function<Status(shape_inference::InferenceContext* c)>& fn) {
104   ForgetMerges();
105   Status s = fn(this);
106   if (!s.ok()) {
107     ForgetMerges();
108     return AttachContext(s);
109   }
110 #ifndef NDEBUG
111   for (int i = 0; i < num_outputs(); ++i) {
112     DCHECK(output(i).IsSet()) << i << " for " << attrs_.SummarizeNode();
113   }
114 #endif  // NDEBUG
115   return s;
116 }
117 
set_output(StringPiece output_name,const std::vector<ShapeHandle> & shapes)118 Status InferenceContext::set_output(StringPiece output_name,
119                                     const std::vector<ShapeHandle>& shapes) {
120   auto result = output_name_map_.find(output_name);
121   if (result == output_name_map_.end()) {
122     return errors::InvalidArgument("Unknown output name: ", output_name);
123   } else {
124     const int start = result->second.first;
125     const int size = result->second.second - start;
126     const int shapes_size = shapes.size();
127     if (size != shapes_size) {
128       return errors::InvalidArgument("Must have exactly ", shapes.size(),
129                                      " shapes.");
130     }
131     for (int i = 0; i < shapes_size; ++i) {
132       outputs_[i + start] = shapes[i];
133     }
134   }
135   return Status::OK();
136 }
137 
input(StringPiece input_name,std::vector<ShapeHandle> * output) const138 Status InferenceContext::input(StringPiece input_name,
139                                std::vector<ShapeHandle>* output) const {
140   const auto result = input_name_map_.find(input_name);
141   if (result == input_name_map_.end()) {
142     return errors::InvalidArgument("Unknown input name: ", input_name);
143   } else {
144     output->clear();
145     for (int i = result->second.first; i < result->second.second; ++i) {
146       output->push_back(inputs_[i]);
147     }
148   }
149   return Status::OK();
150 }
151 
output(StringPiece output_name,std::vector<ShapeHandle> * output) const152 Status InferenceContext::output(StringPiece output_name,
153                                 std::vector<ShapeHandle>* output) const {
154   const auto result = output_name_map_.find(output_name);
155   if (result == output_name_map_.end()) {
156     return errors::InvalidArgument("Unknown output name: ", output_name);
157   } else {
158     output->clear();
159     for (int i = result->second.first; i < result->second.second; ++i) {
160       output->push_back(outputs_[i]);
161     }
162   }
163   return Status::OK();
164 }
165 
PreInputInit(const OpDef & op_def,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes)166 void InferenceContext::PreInputInit(
167     const OpDef& op_def, const std::vector<const Tensor*>& input_tensors,
168     const std::vector<ShapeHandle>& input_tensors_as_shapes) {
169   input_tensors_ = input_tensors;
170   input_tensors_as_shapes_ = input_tensors_as_shapes;
171 
172   construction_status_ =
173       NameRangesForNode(attrs_, op_def, &input_name_map_, &output_name_map_);
174   if (!construction_status_.ok()) return;
175 
176   int num_outputs = 0;
177   for (const auto& e : output_name_map_) {
178     num_outputs = std::max(num_outputs, e.second.second);
179   }
180   outputs_.assign(num_outputs, nullptr);
181   output_handle_shapes_and_types_.resize(num_outputs);
182 }
183 
ExpandOutputs(int new_output_size)184 Status InferenceContext::ExpandOutputs(int new_output_size) {
185   const int outputs_size = outputs_.size();
186   if (new_output_size < outputs_size) {
187     return errors::InvalidArgument("Trying to reduce number of outputs of op.");
188   }
189   outputs_.resize(new_output_size, nullptr);
190   output_handle_shapes_and_types_.resize(new_output_size);
191   return Status::OK();
192 }
193 
PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data)194 void InferenceContext::PostInputInit(
195     std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
196   int num_inputs_from_node_def = 0;
197   for (const auto& e : input_name_map_) {
198     num_inputs_from_node_def =
199         std::max(num_inputs_from_node_def, e.second.second);
200   }
201 
202   // Allow passing empty shapes/dtypes to avoid changing every single test.
203   if (input_handle_data.empty()) {
204     input_handle_shapes_and_types_.resize(inputs_.size());
205   } else {
206     if (input_handle_data.size() != inputs_.size()) {
207       construction_status_ = errors::InvalidArgument(
208           "Wrong number of handle shapes passed; expected ", inputs_.size(),
209           " got ", input_handle_data.size());
210       return;
211     }
212     input_handle_shapes_and_types_ = std::move(input_handle_data);
213   }
214   const int inputs_size = inputs_.size();
215   if (inputs_size != num_inputs_from_node_def) {
216     construction_status_ = errors::InvalidArgument(
217         "Wrong number of inputs passed: ", inputs_.size(), " while ",
218         num_inputs_from_node_def, " expected based on NodeDef");
219     return;
220   }
221 
222   CHECK_LE(input_tensors_.size(), inputs_.size());
223   input_tensors_.resize(inputs_.size());
224   requested_input_tensor_.resize(inputs_.size());
225   requested_input_tensor_as_partial_shape_.resize(inputs_.size());
226 }
227 
ShapeHandleToProto(ShapeHandle handle,TensorShapeProto * proto)228 void InferenceContext::ShapeHandleToProto(ShapeHandle handle,
229                                           TensorShapeProto* proto) {
230   if (!RankKnown(handle)) {
231     proto->set_unknown_rank(true);
232     return;
233   }
234 
235   for (int32 i = 0; i < Rank(handle); ++i) {
236     DimensionHandle dim = Dim(handle, i);
237     auto* dim_shape = proto->add_dim();
238     if (ValueKnown(dim)) {
239       dim_shape->set_size(Value(dim));
240     } else {
241       dim_shape->set_size(-1);
242     }
243   }
244 }
245 
FullyDefined(ShapeHandle s)246 bool InferenceContext::FullyDefined(ShapeHandle s) {
247   if (!RankKnown(s)) return false;
248   for (int i = 0; i < Rank(s); ++i) {
249     if (!ValueKnown(Dim(s, i))) return false;
250   }
251   return true;
252 }
253 
NumElements(ShapeHandle s)254 DimensionHandle InferenceContext::NumElements(ShapeHandle s) {
255   const auto rank = Rank(s);
256   if (rank == kUnknownRank) return UnknownDim();
257   bool found_unknown = false;
258   int64 size = 1;
259   for (int i = 0; i < rank; ++i) {
260     int64 dim_val = Value(Dim(s, i));
261     if (dim_val == kUnknownDim) {
262       found_unknown = true;
263     } else if (dim_val == 0) {
264       return MakeDim(0);
265     } else {
266       size *= dim_val;
267     }
268   }
269   if (found_unknown) {
270     return UnknownDim();
271   } else {
272     return MakeDim(size);
273   }
274 }
275 
DebugString(ShapeHandle s)276 string InferenceContext::DebugString(ShapeHandle s) {
277   if (RankKnown(s)) {
278     std::vector<string> vals;
279     for (auto d : s->dims_) vals.push_back(DebugString(d));
280     return strings::StrCat("[", absl::StrJoin(vals, ","), "]");
281   } else {
282     return "?";
283   }
284 }
285 
DebugString(DimensionHandle d)286 string InferenceContext::DebugString(DimensionHandle d) {
287   return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
288 }
289 
DebugString() const290 string InferenceContext::DebugString() const {
291   return strings::StrCat("InferenceContext for node: ", attrs_.SummarizeNode());
292 }
293 
DebugString(const ShapeAndType & shape_and_type)294 string InferenceContext::DebugString(const ShapeAndType& shape_and_type) {
295   return strings::StrCat(DebugString(shape_and_type.shape), ":",
296                          DataTypeString(shape_and_type.dtype));
297 }
298 
DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types)299 string InferenceContext::DebugString(
300     gtl::ArraySlice<ShapeAndType> shape_and_types) {
301   std::vector<string> pieces;
302   for (const ShapeAndType& s : shape_and_types) {
303     pieces.push_back(DebugString(s));
304   }
305   return strings::StrCat("[", absl::StrJoin(pieces, ","), "]");
306 }
307 
WithRank(ShapeHandle shape,int64 rank,ShapeHandle * out)308 Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
309                                   ShapeHandle* out) {
310   if (rank > kint32max) {
311     return errors::InvalidArgument("Rank cannot exceed kint32max");
312   }
313   const int32 existing = Rank(shape);
314   if (existing == rank) {
315     *out = shape;
316     return Status::OK();
317   }
318   if (existing == kUnknownRank) {
319     std::vector<DimensionHandle> dims;
320     dims.reserve(rank);
321     for (int i = 0; i < rank; ++i) {
322       dims.push_back(UnknownDim());
323     }
324     ShapeHandle shp = shape_manager_.MakeShape(dims);
325     return Merge(shape, shp, out);
326   }
327   *out = nullptr;
328 
329   return errors::InvalidArgument("Shape must be rank ", rank, " but is rank ",
330                                  existing);
331 }
332 
WithRankAtLeast(ShapeHandle shape,int64 rank,ShapeHandle * out)333 Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank,
334                                          ShapeHandle* out) {
335   if (rank > kint32max) {
336     return errors::InvalidArgument("Rank cannot exceed kint32max");
337   }
338   const int32 existing = Rank(shape);
339   if (existing >= rank || existing == kUnknownRank) {
340     *out = shape;
341     return Status::OK();
342   }
343   *out = nullptr;
344   return errors::InvalidArgument("Shape must be at least rank ", rank,
345                                  " but is rank ", existing);
346 }
347 
WithRankAtMost(ShapeHandle shape,int64 rank,ShapeHandle * out)348 Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank,
349                                         ShapeHandle* out) {
350   if (rank > kint32max) {
351     return errors::InvalidArgument("Rank cannot exceed kint32max");
352   }
353   const int32 existing = Rank(shape);
354   if (existing <= rank || existing == kUnknownRank) {
355     *out = shape;
356     return Status::OK();
357   }
358   *out = nullptr;
359   return errors::InvalidArgument("Shape must be at most rank ", rank,
360                                  " but is rank ", existing);
361 }
362 
WithValue(DimensionHandle dim,int64 value,DimensionHandle * out)363 Status InferenceContext::WithValue(DimensionHandle dim, int64 value,
364                                    DimensionHandle* out) {
365   const int64 existing = Value(dim);
366   if (existing == value) {
367     *out = dim;
368     return Status::OK();
369   }
370   if (existing == kUnknownDim) {
371     DimensionHandle d = MakeDim(value);
372     return Merge(dim, d, out);
373   }
374   *out = nullptr;
375   return errors::InvalidArgument("Dimension must be ", value, " but is ",
376                                  existing);
377 }
378 
Relax(DimensionHandle d_old,DimensionHandle d_new,DimensionHandle * out)379 void InferenceContext::Relax(DimensionHandle d_old, DimensionHandle d_new,
380                              DimensionHandle* out) {
381   if (d_old.SameHandle(d_new)) {
382     *out = d_old;
383   } else if (!ValueKnown(d_old) && !ValueKnown(d_new)) {
384     // The node will be fed by the dimension d_new instead of d_old: any
385     // equality assertion between d_old and other input dimension on this node
386     // may not be true anymore, so forget them all.
387     ForgetMerges();
388     // Return the new shape handle to force the relaxation to propagate to the
389     // fanout of the context.
390     *out = d_new;
391   } else if (!ValueKnown(d_new)) {
392     ForgetMerges();
393     *out = d_new;
394   } else if (Value(d_old) == Value(d_new)) {
395     // Return the old shape handle. This will stop the relaxation in the fanout
396     // of the context.
397     *out = d_old;
398   } else {
399     // Return a new handle that encodes a different unknown dim.
400     ForgetMerges();
401     *out = UnknownDim();
402   }
403 }
404 
Merge(DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)405 Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1,
406                                DimensionHandle* out) {
407   if (d0.SameHandle(d1)) {
408     *out = d0;
409     return Status::OK();
410   } else if (!ValueKnown(d1)) {
411     *out = d0;
412     merged_dims_.emplace_back(d0, d1);
413     return Status::OK();
414   } else if (!ValueKnown(d0)) {
415     *out = d1;
416     merged_dims_.emplace_back(d0, d1);
417     return Status::OK();
418   } else if (Value(d0) == Value(d1)) {
419     *out = d0;
420     return Status::OK();
421   } else {
422     *out = nullptr;
423     return errors::InvalidArgument("Dimensions must be equal, but are ",
424                                    Value(d0), " and ", Value(d1));
425   }
426 }
427 
MergePrefix(ShapeHandle s,ShapeHandle prefix,ShapeHandle * s_out,ShapeHandle * prefix_out)428 Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix,
429                                      ShapeHandle* s_out,
430                                      ShapeHandle* prefix_out) {
431   *s_out = *prefix_out = nullptr;
432   if (!RankKnown(prefix) || !RankKnown(s)) {
433     *s_out = s;
434     *prefix_out = prefix;
435     return Status::OK();
436   }
437   const int32 rank = Rank(prefix);
438   TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s));
439 
440   // Merge the prefix dims and create the new output shapes.
441   const int32 rank_s = Rank(s);
442   std::vector<DimensionHandle> dims;
443   dims.reserve(std::max(rank, rank_s));
444   dims.resize(rank);
445   for (int i = 0; i < rank; ++i) {
446     TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i]));
447   }
448   *prefix_out = MakeShape(dims);
449   for (int i = rank; i < rank_s; ++i) dims.push_back(Dim(s, i));
450   *s_out = MakeShape(dims);
451   return Status::OK();
452 }
453 
Relax(ShapeHandle s_old,ShapeHandle s_new,ShapeHandle * out)454 void InferenceContext::Relax(ShapeHandle s_old, ShapeHandle s_new,
455                              ShapeHandle* out) {
456   if (s_old.SameHandle(s_new)) {
457     *out = s_old;
458     return;
459   } else if (!RankKnown(s_new) || !s_old.IsSet()) {
460     ForgetMerges();
461     *out = s_new;
462     return;
463   }
464 
465   const int32 rank = Rank(s_old);
466   if (rank != Rank(s_new)) {
467     ForgetMerges();
468     *out = UnknownShape();
469     return;
470   }
471 
472   bool return_s_old = true;
473   for (int i = 0; i < rank; ++i) {
474     auto d0 = Dim(s_old, i);
475     auto d1 = Dim(s_new, i);
476     if (d0.SameHandle(d1)) continue;
477 
478     auto v0 = Value(d0);
479     auto v1 = Value(d1);
480     if (v0 == kUnknownDim || v1 == kUnknownDim || v0 != v1) {
481       return_s_old = false;
482       break;
483     }
484   }
485   if (return_s_old) {
486     *out = s_old;
487     return;
488   }
489 
490   // Relax dims.
491   std::vector<DimensionHandle> dims(rank);
492   for (int i = 0; i < rank; ++i) {
493     Relax(Dim(s_old, i), Dim(s_new, i), &dims[i]);
494   }
495   ForgetMerges();
496   *out = MakeShape(dims);
497 }
498 
Merge(ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)499 Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
500                                ShapeHandle* out) {
501   if (s0.SameHandle(s1)) {
502     *out = s0;
503     return Status::OK();
504   } else if (!RankKnown(s1)) {
505     *out = s0;
506     merged_shapes_.emplace_back(s0, s1);
507     return Status::OK();
508   } else if (!RankKnown(s0)) {
509     *out = s1;
510     merged_shapes_.emplace_back(s0, s1);
511     return Status::OK();
512   }
513 
514   const int32 rank = Rank(s0);
515   if (rank != Rank(s1)) {
516     *out = nullptr;
517     return errors::InvalidArgument("Shapes must be equal rank, but are ", rank,
518                                    " and ", Rank(s1));
519   }
520 
521   bool return_s0 = true;
522   bool return_s1 = true;
523   for (int i = 0; i < rank; ++i) {
524     auto d0 = Dim(s0, i);
525     auto d1 = Dim(s1, i);
526     if (d0.SameHandle(d1)) continue;
527 
528     auto v0 = Value(d0);
529     auto v1 = Value(d1);
530     if (v0 == kUnknownDim) {
531       if (v1 != kUnknownDim) {
532         return_s0 = false;
533       }
534     } else if (v1 == kUnknownDim) {
535       return_s1 = false;
536     } else if (v0 != v1) {
537       *out = nullptr;
538       return errors::InvalidArgument(
539           "Dimension ", i, " in both shapes must be equal, but are ", Value(d0),
540           " and ", Value(d1), ". Shapes are ", DebugString(s0), " and ",
541           DebugString(s1), ".");
542     }
543   }
544 
545   merged_shapes_.emplace_back(s0, s1);
546 
547   if (return_s0 || return_s1) {
548     *out = return_s0 ? s0 : s1;
549     return Status::OK();
550   }
551 
552   // Merge dims.
553   std::vector<DimensionHandle> dims(rank, nullptr);
554   for (int i = 0; i < rank; ++i) {
555     // Invariant for merge was checked earlier, so CHECK is ok.
556     TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i]));
557   }
558 
559   Status s = ReturnCreatedShape(dims, out);
560   if (s.ok()) {
561     // Merge the new shape with s0. Since s0 and s1 are merged, this implies
562     // that s1 and out are also merged.
563     merged_shapes_.emplace_back(s0, *out);
564   }
565   return s;
566 }
567 
Subshape(ShapeHandle s,int64 start,ShapeHandle * out)568 Status InferenceContext::Subshape(ShapeHandle s, int64 start,
569                                   ShapeHandle* out) {
570   return Subshape(s, start, std::numeric_limits<int64>::max() /* end */, out);
571 }
572 
Subshape(ShapeHandle s,int64 start,int64 end,ShapeHandle * out)573 Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
574                                   ShapeHandle* out) {
575   return Subshape(s, start, end, 1 /* stride */, out);
576 }
577 
Subshape(ShapeHandle s,int64 start,int64 end,int64 stride,ShapeHandle * out)578 Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
579                                   int64 stride, ShapeHandle* out) {
580   int64 start_in = start;
581   int64 end_in = end;
582 
583   const int32 rank = Rank(s);
584   if (start == 0 && stride == 1 &&
585       ((RankKnown(s) && end >= rank) ||
586        end == std::numeric_limits<int64>::max())) {
587     *out = s;
588     return Status::OK();
589   }
590   if (!RankKnown(s)) {
591     return ReturnUnknownShape(out);
592   }
593 
594   if (start > rank) start = rank;
595   if (end > rank) end = rank;
596 
597   if (stride < 0 && start == rank) --start;
598 
599   if (start < 0) {
600     start = rank + start;
601     if (start < 0) {
602       *out = nullptr;
603       return errors::InvalidArgument("Subshape start out of bounds: ", start_in,
604                                      ", for shape with rank ", rank);
605     }
606   }
607 
608   if (end < 0) {
609     end = rank + end;
610     if (end < 0) {
611       *out = nullptr;
612       return errors::InvalidArgument("Subshape end out of bounds: ", end_in,
613                                      ", for shape with rank ", rank);
614     }
615   }
616   if (stride > 0 && start > end) {
617     *out = nullptr;
618     return errors::InvalidArgument(
619         "Subshape must have computed start <= end, but is ", start, " and ",
620         end, " (computed from start ", start_in, " and end ", end_in,
621         " over shape with rank ", rank, ")");
622   } else if (stride < 0 && start < end) {
623     *out = nullptr;
624     return errors::InvalidArgument(
625         "Subshape must have computed start >= end since stride is negative, "
626         "but is ",
627         start, " and ", end, " (computed from start ", start_in, " and end ",
628         end_in, " over shape with rank ", rank, " and stride", stride, ")");
629   }
630 
631   std::vector<DimensionHandle> dims;
632   for (int i = start; stride > 0 ? i < end : i > end; i += stride) {
633     dims.push_back(Dim(s, i));
634   }
635   return ReturnCreatedShape(dims, out);
636 }
637 
Concatenate(ShapeHandle s1,ShapeHandle s2,ShapeHandle * out)638 Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2,
639                                      ShapeHandle* out) {
640   if (!RankKnown(s1) || !RankKnown(s2)) {
641     return ReturnUnknownShape(out);
642   }
643   const int32 s1_rank = Rank(s1);
644   const int32 s2_rank = Rank(s2);
645   const int32 rank = s1_rank + s2_rank;
646   std::vector<DimensionHandle> dims;
647   dims.reserve(rank);
648   for (int i = 0; i < s1_rank; ++i) dims.push_back(Dim(s1, i));
649   for (int i = 0; i < s2_rank; ++i) dims.push_back(Dim(s2, i));
650   return ReturnCreatedShape(dims, out);
651 }
652 
ReplaceDim(ShapeHandle s,int64 dim_index_in,DimensionHandle new_dim,ShapeHandle * out)653 Status InferenceContext::ReplaceDim(ShapeHandle s, int64 dim_index_in,
654                                     DimensionHandle new_dim, ShapeHandle* out) {
655   if (!RankKnown(s)) {
656     return ReturnUnknownShape(out);
657   }
658   int64 dim_index = dim_index_in;
659   if (dim_index < 0) {
660     dim_index = s->dims_.size() + dim_index;
661   }
662   if (!FastBoundsCheck(dim_index, s->dims_.size())) {
663     *out = nullptr;
664     return errors::InvalidArgument("Out of range dim_index ", dim_index_in,
665                                    " for shape with ", s->dims_.size(),
666                                    " dimensions");
667   }
668   std::vector<DimensionHandle> dims(s->dims_);
669   dims[dim_index] = new_dim;
670   return ReturnCreatedShape(dims, out);
671 }
672 
MakeShape(const std::vector<DimensionHandle> & dims)673 ShapeHandle InferenceContext::MakeShape(
674     const std::vector<DimensionHandle>& dims) {
675   return shape_manager_.MakeShape(dims);
676 }
677 
MakeShape(std::initializer_list<DimensionOrConstant> dims)678 ShapeHandle InferenceContext::MakeShape(
679     std::initializer_list<DimensionOrConstant> dims) {
680   std::vector<DimensionHandle> dims_actual;
681   dims_actual.reserve(dims.size());
682   for (const DimensionOrConstant& d : dims) {
683     dims_actual.push_back(MakeDim(d));
684   }
685 
686   return shape_manager_.MakeShape(dims_actual);
687 }
688 
UnknownShape()689 ShapeHandle InferenceContext::UnknownShape() {
690   return shape_manager_.UnknownShape();
691 }
692 
UnknownShapeOfRank(int64 rank)693 ShapeHandle InferenceContext::UnknownShapeOfRank(int64 rank) {
694   CHECK_LE(rank, kint32max) << "rank must be less than kint32max";
695   if (rank == kUnknownRank) {
696     return UnknownShape();
697   }
698   CHECK_GE(rank, 0) << "rank must not be negative";
699   std::vector<DimensionHandle> dims(rank);
700   for (int32 i = 0; i < rank; ++i) {
701     dims[i] = UnknownDim();
702   }
703   return MakeShape(dims);
704 }
705 
Scalar()706 ShapeHandle InferenceContext::Scalar() { return MakeShape({}); }
707 
Vector(DimensionOrConstant dim)708 ShapeHandle InferenceContext::Vector(DimensionOrConstant dim) {
709   return MakeShape({dim});
710 }
711 
Matrix(DimensionOrConstant dim1,DimensionOrConstant dim2)712 ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
713                                      DimensionOrConstant dim2) {
714   return MakeShape({dim1, dim2});
715 }
716 
MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,ShapeHandle * out)717 Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
718     int input_idx, ShapeHandle* out) {
719   ShapeHandle input_shape;
720   TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
721 
722   request_input_tensor_as_partial_shape(input_idx);
723   const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
724   if (input_idx < input_tensors_as_shapes_size &&
725       input_tensors_as_shapes_[input_idx].IsSet() &&
726       RankKnown(input_tensors_as_shapes_[input_idx])) {
727     *out = input_tensors_as_shapes_[input_idx];
728     return Status::OK();
729   }
730 
731   return InternalMakeShapeFromTensor(
732       true /* treat_unknown_scalar_tensor_as_unknown_shape */,
733       input_tensor(input_idx), input_shape, out);
734 }
735 
MakeShapeFromShapeTensor(int input_idx,ShapeHandle * out)736 Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
737                                                   ShapeHandle* out) {
738   ShapeHandle input_shape;
739   TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
740 
741   request_input_tensor_as_partial_shape(input_idx);
742   const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
743   if (input_idx < input_tensors_as_shapes_size &&
744       input_tensors_as_shapes_[input_idx].IsSet() &&
745       RankKnown(input_tensors_as_shapes_[input_idx])) {
746     *out = input_tensors_as_shapes_[input_idx];
747     return Status::OK();
748   }
749 
750   return InternalMakeShapeFromTensor(
751       false /* treat_unknown_scalar_tensor_as_unknown_shape */,
752       input_tensor(input_idx), input_shape, out);
753 }
754 
MakeShapeFromTensor(const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)755 Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
756                                              ShapeHandle tensor_shape,
757                                              ShapeHandle* out) {
758   return InternalMakeShapeFromTensor(
759       false /* treat_unknown_scalar_tensor_as_unknown_shape */, t, tensor_shape,
760       out);
761 }
762 
InternalMakeShapeFromTensor(bool treat_unknown_scalar_tensor_as_unknown_shape,const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)763 Status InferenceContext::InternalMakeShapeFromTensor(
764     bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
765     ShapeHandle tensor_shape, ShapeHandle* out) {
766   // Only callers who have set
767   if (!treat_unknown_scalar_tensor_as_unknown_shape) {
768     TF_RETURN_IF_ERROR(WithRank(tensor_shape, 1, &tensor_shape));
769   }
770   if (t == nullptr) {
771     // This is guarded by the check above.
772     if (Rank(tensor_shape) == 0) {
773       return ReturnUnknownShape(out);
774     }
775     // Shape tensor is not known, but if the shape of the shape tensor is then
776     // the right number of unknown dims can be created.
777     DimensionHandle shape_dim = Dim(tensor_shape, 0);
778     if (!ValueKnown(shape_dim)) {
779       return ReturnUnknownShape(out);
780     }
781     const auto num_dims = Value(shape_dim);
782     std::vector<DimensionHandle> dims;
783     dims.reserve(num_dims);
784     for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim());
785     return ReturnCreatedShape(dims, out);
786   }
787 
788   if (t->shape().dims() == 0) {
789     if (t->dtype() == DataType::DT_INT32) {
790       auto flat_t = t->scalar<int32>();
791       if (flat_t() != -1) {
792         *out = nullptr;
793         return errors::InvalidArgument(
794             "Input tensor must be rank 1, or if its rank 0 it must have value "
795             "-1 "
796             "(representing an unknown shape).  Saw value: ",
797             flat_t());
798       }
799       return ReturnUnknownShape(out);
800     } else if (t->dtype() == DataType::DT_INT64) {
801       auto flat_t = t->scalar<int64>();
802       if (flat_t() != -1) {
803         *out = nullptr;
804         return errors::InvalidArgument(
805             "Input tensor must be rank 1, or if its rank 0 it must have value "
806             "-1 "
807             "(representing an unknown shape).  Saw value: ",
808             flat_t());
809       }
810       return ReturnUnknownShape(out);
811     } else {
812       *out = nullptr;
813       return errors::InvalidArgument(
814           "Input tensor must be int32 or int64, but was ",
815           DataTypeString(t->dtype()));
816     }
817   }
818 
819   if (t->shape().dims() != 1) {
820     *out = nullptr;
821     return errors::InvalidArgument(
822         "Input tensor must be rank 1, but was rank ", t->shape().dims(), ".",
823         ((t->shape().dims() == 0)
824              ? "If it is rank 0 rank 0 it must have statically known value -1 "
825                "(representing an unknown shape). "
826              : " "),
827         "Saw tensor shape ", t->shape().DebugString());
828   }
829   std::vector<DimensionHandle> dims;
830   if (t->dtype() == DataType::DT_INT32) {
831     auto flat_t = t->flat<int32>();
832     for (int i = 0; i < flat_t.size(); ++i) {
833       const int32 val = flat_t(i);
834       if (val < -1) {
835         return errors::InvalidArgument(
836             "Invalid value in tensor used for shape: ", val);
837       }
838       // -1 will become an unknown dim.
839       dims.push_back(MakeDim(val));
840     }
841   } else if (t->dtype() == DataType::DT_INT64) {
842     auto flat_t = t->flat<int64>();
843     for (int i = 0; i < flat_t.size(); ++i) {
844       const int64 val = flat_t(i);
845       if (val < -1) {
846         return errors::InvalidArgument(
847             "Invalid value in tensor used for shape: ", val);
848       }
849       // -1 will become an unknown dim.
850       dims.push_back(MakeDim(val));
851     }
852   } else {
853     *out = nullptr;
854     return errors::InvalidArgument(
855         "Input tensor must be int32 or int64, but was ",
856         DataTypeString(t->dtype()));
857   }
858 
859   return ReturnCreatedShape(dims, out);
860 }
861 
MakeShapeFromPartialTensorShape(const PartialTensorShape & partial_shape,ShapeHandle * out)862 Status InferenceContext::MakeShapeFromPartialTensorShape(
863     const PartialTensorShape& partial_shape, ShapeHandle* out) {
864   *out = nullptr;
865   if (partial_shape.dims() == -1) {
866     return ReturnUnknownShape(out);
867   }
868   const int num_dims = partial_shape.dims();
869   std::vector<DimensionHandle> dims(num_dims);
870   for (int i = 0; i < num_dims; ++i) {
871     // -1 is unknown in PartialTensorShape and in InferenceContext, so this size
872     // can be passed directly to MakeDim.
873     dims[i] = MakeDim(partial_shape.dim_size(i));
874   }
875   return ReturnCreatedShape(dims, out);
876 }
877 
MakeShapeFromTensorShape(const TensorShape & shape,ShapeHandle * out)878 Status InferenceContext::MakeShapeFromTensorShape(const TensorShape& shape,
879                                                   ShapeHandle* out) {
880   return MakeShapeFromPartialTensorShape(PartialTensorShape(shape.dim_sizes()),
881                                          out);
882 }
883 
MakeShapeFromShapeProto(const TensorShapeProto & proto,ShapeHandle * out)884 Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
885                                                  ShapeHandle* out) {
886   *out = nullptr;
887   TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto));
888   PartialTensorShape partial_shape(proto);
889   return MakeShapeFromPartialTensorShape(partial_shape, out);
890 }
891 
GetScalarFromTensor(const Tensor * t,int64 * val)892 Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
893   // Caller must ensure that <t> is not NULL.
894   const int rank = t->dims();
895   if (rank != 0) {
896     return errors::InvalidArgument("Input must be scalar but has rank ", rank);
897   }
898 
899   if (t->dtype() == DataType::DT_INT32) {
900     *val = t->scalar<int32>()();
901     return Status::OK();
902   } else if (t->dtype() == DataType::DT_INT64) {
903     *val = t->scalar<int64>()();
904     return Status::OK();
905   } else {
906     return errors::InvalidArgument("Scalar input must be int32 or int64.");
907   }
908 }
909 
GetScalarFromTensor(const Tensor * t,int64 idx,int64 * val)910 Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64 idx,
911                                              int64* val) {
912   // Caller must ensure that <t> is not NULL.
913   const int rank = t->dims();
914   if (rank != 1) {
915     return errors::InvalidArgument("Input must be 1D but has rank ", rank);
916   }
917 
918   if (t->dtype() == DataType::DT_INT32) {
919     auto flat_t = t->flat<int32>();
920     if (idx < 0 || idx >= flat_t.size()) {
921       return errors::InvalidArgument("Invalid index ", idx,
922                                      " for Tensor of size ", flat_t.size());
923     }
924     *val = flat_t(idx);
925     return Status::OK();
926   } else if (t->dtype() == DataType::DT_INT64) {
927     auto flat_t = t->flat<int64>();
928     if (idx < 0 || idx >= flat_t.size()) {
929       return errors::InvalidArgument("Invalid index ", idx,
930                                      " for Tensor of size ", flat_t.size());
931     }
932     *val = flat_t(idx);
933     return Status::OK();
934   } else {
935     return errors::InvalidArgument("Tensor input must be int32 or int64.");
936   }
937 }
938 
939 // Returns a new dimension whose value is given by a scalar input tensor.
MakeDimForScalarInput(int idx,DimensionHandle * out)940 Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
941   int64 val;
942   const Tensor* t = input_tensor(idx);
943   if (t == nullptr) {
944     *out = UnknownDim();
945     return Status::OK();
946   }
947   TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
948   if (val < 0) {
949     return errors::InvalidArgument("Dimension size, given by scalar input ",
950                                    idx, ", must be non-negative but is ", val);
951   }
952   *out = MakeDim(val);
953   return Status::OK();
954 }
955 
MakeDimForScalarInputWithNegativeIndexing(int idx,int input_rank,DimensionHandle * out)956 Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing(
957     int idx, int input_rank, DimensionHandle* out) {
958   int64 val;
959   const Tensor* t = input_tensor(idx);
960   if (t == nullptr) {
961     *out = UnknownDim();
962     return Status::OK();
963   }
964   TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
965   if (val < 0) {
966     if (input_rank < 0) {
967       *out = UnknownDim();
968       return Status::OK();
969     } else if (val + input_rank < 0) {
970       return errors::InvalidArgument("Dimension size, given by scalar input ",
971                                      val, " must be in range [-", input_rank,
972                                      ", ", input_rank, ")");
973     } else {
974       val += input_rank;
975     }
976   } else if (input_rank >= 0 && val >= input_rank) {
977     return errors::InvalidArgument("Dimension size, given by scalar input ",
978                                    val, " must be in range [-", input_rank,
979                                    ", ", input_rank, ")");
980   }
981   *out = MakeDim(val);
982   return Status::OK();
983 }
984 
Divide(DimensionHandle dividend,DimensionOrConstant divisor,bool evenly_divisible,DimensionHandle * out)985 Status InferenceContext::Divide(DimensionHandle dividend,
986                                 DimensionOrConstant divisor,
987                                 bool evenly_divisible, DimensionHandle* out) {
988   const int64 divisor_value = Value(divisor);
989   if (divisor_value == 1) {
990     *out = dividend;
991   } else if (!ValueKnown(dividend) ||
992              (divisor.dim.IsSet() && !ValueKnown(divisor.dim))) {
993     *out = UnknownDim();
994   } else {
995     const int64 v = Value(dividend);
996     if (divisor_value <= 0) {
997       return errors::InvalidArgument("Divisor must be positive but is ",
998                                      divisor_value);
999     }
1000     if (evenly_divisible && (v % divisor_value) != 0) {
1001       return errors::InvalidArgument(
1002           "Dimension size must be evenly divisible by ", divisor_value,
1003           " but is ", v);
1004     }
1005     *out = MakeDim(v / divisor_value);
1006   }
1007   return Status::OK();
1008 }
1009 
Add(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1010 Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second,
1011                              DimensionHandle* out) {
1012   const int64 first_value = Value(first);
1013   const int64 second_value = Value(second);
1014   // Special cases.
1015   if (first_value == 0) {
1016     *out = MakeDim(second);
1017   } else if (second_value == 0) {
1018     *out = first;
1019   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1020     *out = UnknownDim();
1021   } else {
1022     // Invariant: Both values are known and positive. Still in run-time we can
1023     // get pair of values which cannot be store in output. Check below will
1024     // report error. We still need to avoid undefined behavior of signed
1025     // overflow and use unsigned addition.
1026     const int64 sum = static_cast<uint64>(first_value) + second_value;
1027     if (sum < 0) {
1028       return errors::InvalidArgument("Dimension size overflow from adding ",
1029                                      first_value, " and ", second_value);
1030     }
1031     *out = MakeDim(sum);
1032   }
1033   return Status::OK();
1034 }
1035 
Subtract(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1036 Status InferenceContext::Subtract(DimensionHandle first,
1037                                   DimensionOrConstant second,
1038                                   DimensionHandle* out) {
1039   const int64 first_value = Value(first);
1040   const int64 second_value = Value(second);
1041   // Special cases.
1042   if (second_value == 0) {
1043     *out = first;
1044   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1045     *out = UnknownDim();
1046   } else {
1047     // Invariant: Both values are known, first_value is non-negative, and
1048     // second_value is positive.
1049     if (first_value < second_value) {
1050       return errors::InvalidArgument(
1051           "Negative dimension size caused by subtracting ", second_value,
1052           " from ", first_value);
1053     }
1054     *out = MakeDim(first_value - second_value);
1055   }
1056   return Status::OK();
1057 }
1058 
Multiply(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1059 Status InferenceContext::Multiply(DimensionHandle first,
1060                                   DimensionOrConstant second,
1061                                   DimensionHandle* out) {
1062   const int64 first_value = Value(first);
1063   const int64 second_value = Value(second);
1064   // Special cases.
1065   if (first_value == 0) {
1066     *out = first;
1067   } else if (second_value == 0) {
1068     *out = MakeDim(second);
1069   } else if (first_value == 1) {
1070     *out = MakeDim(second);
1071   } else if (second_value == 1) {
1072     *out = first;
1073   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1074     *out = UnknownDim();
1075   } else {
1076     // Invariant: Both values are known and greater than 1.
1077     const int64 product = first_value * second_value;
1078     if (product < 0) {
1079       return errors::InvalidArgument(
1080           "Negative dimension size caused by overflow when multiplying ",
1081           first_value, " and ", second_value);
1082     }
1083     *out = MakeDim(product);
1084   }
1085   return Status::OK();
1086 }
1087 
Min(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1088 Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second,
1089                              DimensionHandle* out) {
1090   const int64 first_value = Value(first);
1091   const int64 second_value = Value(second);
1092   if (first_value == 0) {
1093     *out = first;
1094   } else if (second_value == 0) {
1095     *out = MakeDim(second);
1096   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1097     *out = UnknownDim();
1098   } else {
1099     if (first_value <= second_value) {
1100       *out = first;
1101     } else {
1102       *out = MakeDim(second);
1103     }
1104   }
1105   return Status::OK();
1106 }
1107 
Max(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1108 Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second,
1109                              DimensionHandle* out) {
1110   const int64 first_value = Value(first);
1111   const int64 second_value = Value(second);
1112   if (first_value == kUnknownDim || second_value == kUnknownDim) {
1113     *out = UnknownDim();
1114   } else {
1115     if (first_value >= second_value) {
1116       *out = first;
1117     } else {
1118       *out = MakeDim(second);
1119     }
1120   }
1121   return Status::OK();
1122 }
1123 
AttachContext(const Status & status)1124 Status InferenceContext::AttachContext(const Status& status) {
1125   std::vector<string> input_shapes;
1126   input_shapes.reserve(inputs_.size());
1127   for (const ShapeHandle& input_shape : inputs_) {
1128     input_shapes.emplace_back(DebugString(input_shape));
1129   }
1130 
1131   // Add information about the input tensors and partial tensor shapes used.
1132   std::vector<string> input_from_tensors_str;
1133   std::vector<string> input_from_tensors_as_shape_str;
1134   input_from_tensors_as_shape_str.reserve(inputs_.size());
1135   for (int i = 0, end = inputs_.size(); i < end; ++i) {
1136     const int input_tensors_as_shapes_size = input_tensors_as_shapes_.size();
1137     const int input_tensors_size = input_tensors_.size();
1138     if (requested_input_tensor_as_partial_shape_[i] &&
1139         i < input_tensors_as_shapes_size &&
1140         input_tensors_as_shapes_[i].IsSet() &&
1141         RankKnown(input_tensors_as_shapes_[i])) {
1142       input_from_tensors_as_shape_str.push_back(strings::StrCat(
1143           "input[", i, "] = ", DebugString(input_tensors_as_shapes_[i])));
1144     } else if (requested_input_tensor_[i] && i < input_tensors_size &&
1145                input_tensors_[i] != nullptr) {
1146       input_from_tensors_str.push_back(strings::StrCat(
1147           "input[", i, "] = <",
1148           input_tensors_[i]->SummarizeValue(256 /* max_values */), ">"));
1149     }
1150   }
1151 
1152   string error_context = strings::StrCat(
1153       " for '", attrs_.SummarizeNode(),
1154       "' with input shapes: ", absl::StrJoin(input_shapes, ", "));
1155   if (!input_from_tensors_str.empty()) {
1156     strings::StrAppend(&error_context, " and with computed input tensors: ",
1157                        absl::StrJoin(input_from_tensors_str, ", "));
1158   }
1159   if (!input_from_tensors_as_shape_str.empty()) {
1160     strings::StrAppend(&error_context,
1161                        " and with input tensors computed as partial shapes: ",
1162                        absl::StrJoin(input_from_tensors_as_shape_str, ","));
1163   }
1164 
1165   strings::StrAppend(&error_context, ".");
1166   return Status(status.code(),
1167                 strings::StrCat(status.error_message(), error_context));
1168 }
1169 
MergeHandleShapesAndTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1170 bool InferenceContext::MergeHandleShapesAndTypes(
1171     const std::vector<ShapeAndType>& shapes_and_types,
1172     std::vector<ShapeAndType>* to_update) {
1173   if (shapes_and_types.size() != to_update->size()) {
1174     return false;
1175   }
1176   std::vector<ShapeAndType> new_values(shapes_and_types.size());
1177   bool refined = false;
1178   for (int i = 0, end = shapes_and_types.size(); i < end; ++i) {
1179     const ShapeAndType& existing = (*to_update)[i];
1180     if (shapes_and_types[i].dtype == existing.dtype) {
1181       new_values[i].dtype = existing.dtype;
1182     } else {
1183       if (existing.dtype != DT_INVALID) {
1184         return false;
1185       } else {
1186         new_values[i].dtype = shapes_and_types[i].dtype;
1187         refined = true;
1188       }
1189     }
1190     if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape)
1191              .ok()) {
1192       // merge failed, ignore the new value.
1193       new_values[i].shape = existing.shape;
1194     }
1195     if (!existing.shape.SameHandle(new_values[i].shape)) {
1196       refined = true;
1197     }
1198   }
1199   if (!refined) {
1200     return false;
1201   }
1202   for (int i = 0, end = new_values.size(); i < end; ++i) {
1203     (*to_update)[i] = new_values[i];
1204   }
1205   return true;
1206 }
1207 
MergeOutputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1208 bool InferenceContext::MergeOutputHandleShapesAndTypes(
1209     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1210   if (output_handle_shapes_and_types_[idx] == nullptr) {
1211     output_handle_shapes_and_types_[idx].reset(
1212         new std::vector<ShapeAndType>(shapes_and_types));
1213     return true;
1214   }
1215   return MergeHandleShapesAndTypes(shapes_and_types,
1216                                    output_handle_shapes_and_types_[idx].get());
1217 }
1218 
MergeInputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1219 bool InferenceContext::MergeInputHandleShapesAndTypes(
1220     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1221   if (input_handle_shapes_and_types_[idx] == nullptr) {
1222     input_handle_shapes_and_types_[idx].reset(
1223         new std::vector<ShapeAndType>(shapes_and_types));
1224     return true;
1225   }
1226   return MergeHandleShapesAndTypes(shapes_and_types,
1227                                    input_handle_shapes_and_types_[idx].get());
1228 }
1229 
RelaxHandleShapesAndMergeTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1230 bool InferenceContext::RelaxHandleShapesAndMergeTypes(
1231     const std::vector<ShapeAndType>& shapes_and_types,
1232     std::vector<ShapeAndType>* to_update) {
1233   if (shapes_and_types.size() != to_update->size()) {
1234     return false;
1235   }
1236   std::vector<ShapeAndType> new_values(shapes_and_types.size());
1237   for (int i = 0, end = shapes_and_types.size(); i < end; ++i) {
1238     const ShapeAndType& existing = (*to_update)[i];
1239     if (shapes_and_types[i].dtype == existing.dtype) {
1240       new_values[i].dtype = existing.dtype;
1241     } else {
1242       if (existing.dtype != DT_INVALID) {
1243         return false;
1244       } else {
1245         new_values[i].dtype = shapes_and_types[i].dtype;
1246       }
1247     }
1248     Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape);
1249   }
1250   to_update->swap(new_values);
1251   return true;
1252 }
1253 
RelaxOutputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1254 bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes(
1255     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1256   if (output_handle_shapes_and_types_[idx] == nullptr) {
1257     output_handle_shapes_and_types_[idx].reset(
1258         new std::vector<ShapeAndType>(shapes_and_types));
1259     return true;
1260   }
1261   return RelaxHandleShapesAndMergeTypes(
1262       shapes_and_types, output_handle_shapes_and_types_[idx].get());
1263 }
1264 
RelaxInputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1265 bool InferenceContext::RelaxInputHandleShapesAndMergeTypes(
1266     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1267   if (input_handle_shapes_and_types_[idx] == nullptr) {
1268     input_handle_shapes_and_types_[idx].reset(
1269         new std::vector<ShapeAndType>(shapes_and_types));
1270     return true;
1271   }
1272   return RelaxHandleShapesAndMergeTypes(
1273       shapes_and_types, input_handle_shapes_and_types_[idx].get());
1274 }
1275 
1276 // -----------------------------------------------------------------------------
1277 // ShapeManager
1278 // -----------------------------------------------------------------------------
ShapeManager()1279 InferenceContext::ShapeManager::ShapeManager() {}
~ShapeManager()1280 InferenceContext::ShapeManager::~ShapeManager() {
1281   for (auto* s : all_shapes_) delete s;
1282   for (auto* d : all_dims_) delete d;
1283 }
1284 
MakeShape(const std::vector<DimensionHandle> & dims)1285 ShapeHandle InferenceContext::ShapeManager::MakeShape(
1286     const std::vector<DimensionHandle>& dims) {
1287   all_shapes_.push_back(new Shape(dims));
1288   return all_shapes_.back();
1289 }
1290 
UnknownShape()1291 ShapeHandle InferenceContext::ShapeManager::UnknownShape() {
1292   all_shapes_.push_back(new Shape());
1293   return all_shapes_.back();
1294 }
1295 
1296 }  // namespace shape_inference
1297 }  // namespace tensorflow
1298