• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
17 
18 #include <vector>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/core/framework/full_type.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/gtl/inlined_vector.h"
27 #include "tensorflow/core/platform/macros.h"
28 
29 namespace tensorflow {
30 
31 namespace grappler {
32 class GraphProperties;
33 class SymbolicShapeManager;
34 }  // namespace grappler
35 
36 namespace shape_inference {
37 
38 struct DimensionOrConstant;
39 class InferenceContext;
40 
41 // Dimension values are accessed through InferenceContext.
42 class Dimension {
43  private:
44   Dimension();
45   Dimension(int64_t value);
~Dimension()46   ~Dimension() {}
47 
48   const int64 value_;
49 
50   friend class InferenceContext;
51   friend class ShapeManager;
52   TF_DISALLOW_COPY_AND_ASSIGN(Dimension);
53 };
54 
55 class DimensionHandle {
56  public:
DimensionHandle()57   DimensionHandle() {}
SameHandle(DimensionHandle d)58   bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; }
Handle()59   std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
60 
61  private:
DimensionHandle(const Dimension * dim)62   DimensionHandle(const Dimension* dim) { ptr_ = dim; }
63 
64   const Dimension* operator->() const { return ptr_; }
IsSet()65   bool IsSet() const { return ptr_ != nullptr; }
66 
67   const Dimension* ptr_ = nullptr;
68 
69   friend struct DimensionOrConstant;
70   friend class InferenceContext;
71   friend class ShapeInferenceTest;
72   friend class ShapeInferenceTestutil;
73   friend class ::tensorflow::grappler::GraphProperties;
74   friend class ::tensorflow::grappler::SymbolicShapeManager;
75 
76   // Intentionally copyable.
77 };
78 
79 // Shape rank and dimensions are accessed through InferenceContext.
80 class Shape {
81  private:
82   Shape();
83   Shape(const std::vector<DimensionHandle>& dims);
~Shape()84   ~Shape() {}
85 
86   const int32 rank_;
87   const std::vector<DimensionHandle> dims_;
88 
89   friend class InferenceContext;
90   friend class ::tensorflow::grappler::SymbolicShapeManager;
91 
92   TF_DISALLOW_COPY_AND_ASSIGN(Shape);
93 };
94 
95 class ShapeHandle {
96  public:
ShapeHandle()97   ShapeHandle() {}
SameHandle(ShapeHandle s)98   bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; }
Handle()99   std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
100 
101  private:
ShapeHandle(const Shape * shape)102   ShapeHandle(const Shape* shape) { ptr_ = shape; }
103   const Shape* operator->() const { return ptr_; }
IsSet()104   bool IsSet() const { return ptr_ != nullptr; }
105 
106   const Shape* ptr_ = nullptr;
107 
108   friend class InferenceContext;
109   friend class ShapeInferenceTest;
110   friend class ShapeInferenceTestutil;
111   friend class ::tensorflow::grappler::SymbolicShapeManager;
112 
113   // Intentionally copyable.
114 };
115 
116 // Struct used to allow functions to take DimensionHandle or a dimension value.
117 // Not meant to be constructed directly.
118 struct DimensionOrConstant {
119  public:
120   // Intentionally not explicit.
121   DimensionOrConstant(DimensionHandle dim);
122 
123   // val must be non-negative or InferenceContext::kUnknownDim.
124   DimensionOrConstant(int64_t val);
125 
126   // dim takes precedence. If dim != nullptr, val is ignored.
127   DimensionHandle dim;
128   int64 val;
129 
130  private:
131   DimensionOrConstant();
132 };
133 
134 struct ShapeAndType {
ShapeAndTypeShapeAndType135   ShapeAndType() {}
ShapeAndTypeShapeAndType136   ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
137   // TODO(mdan): Remove dtype from constructor, and use type_ instead.
138   // dtype is kept here for backward compatibiity. Its information should
139   // be redundant to that in type;
ShapeAndTypeShapeAndType140   ShapeAndType(ShapeHandle s, DataType t, FullTypeDef type_)
141       : shape(s), dtype(t), type(type_) {}
142 
143   ShapeHandle shape;
144   DataType dtype = DT_INVALID;
145   FullTypeDef type;
146 };
147 
148 // Shape inference functions registered on ops in REGISTER_OP implement
149 // their shape functions in terms of this InferenceContext.  An InferenceContext
150 // is created by the framework and passed to a shape inference function.  The
151 // shape inference function calls functions on the context, and should call
152 // set_output() to set the shape on all outputs.
153 //
154 // To infer shapes for user-defined functions see ShapeRefiner.
155 //
156 // All Shape* and Dimension* returned by functions of InferenceContext are owned
157 // by the InferenceContext.
158 class InferenceContext {
159  public:
160   static constexpr int64_t kUnknownDim = -1;
161   static constexpr int32_t kUnknownRank = -1;
162 
163   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
164   //
165   // Elements of <input_tensors_as_shapes> are used for when a shape function
166   // makes a call to MakeShapeFromShapeTensor; in particular, when the
167   // input_tensors[i] is nullptr but the shape represented by it is partially
168   // known from analysis of the graph.
169   // <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
170   // Values of <input_tensors_as_shapes> do not need to outlive the context.
171   InferenceContext(int graph_def_version, const AttrSlice& attrs,
172                    const OpDef& op_def,
173                    const std::vector<ShapeHandle>& input_shapes,
174                    const std::vector<const Tensor*>& input_tensors,
175                    const std::vector<ShapeHandle>& input_tensors_as_shapes,
176                    std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
177                        input_handle_shapes_and_types);
178 
179   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
180   //
181   // Elements of <input_tensors_as_shapes> are used for when a shape
182   // function makes a call to MakeShapeFromShapeTensor; in particular, when
183   // the input_tensors[i] is nullptr but the shape represented by it is
184   // partially known from analysis of the graph. <input_tensors_as_shapes>
185   // can have fewer elements than <input_shapes>. Values of
186   // <input_tensors_as_shapes> do not need to outlive the context.
187   InferenceContext(
188       int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
189       const std::vector<PartialTensorShape>& input_shapes,
190       const std::vector<const Tensor*>& input_tensors,
191       const std::vector<PartialTensorShape>& input_tensors_as_shapes,
192       const std::vector<std::unique_ptr<
193           std::vector<std::pair<PartialTensorShape, DataType>>>>&
194           input_handle_shapes_and_types);
195 
196   ~InferenceContext();
197 
198   // Runs the shape inference function 'fn' with 'this' as the
199   // argument, returns the status of the inference.
200   //
201   // On error, additional context is provided in the error message.
202   Status Run(
203       const std::function<Status(shape_inference::InferenceContext* c)>& fn);
204 
205   // Merge the stored shape of the input in position idx with <shape> according
206   // to the following rules:
207   //
208   // - If the ShapeHandles are the same or <shape> is unknown, there will be no
209   //   change. Otherwise if the stored shape is unknown, the new shape will be
210   //   <shape>.
211   // - If both shapes are known, then they must have the same rank.
212   // - For any one dimension, if the values for that dimension in both shapes
213   //   are known, then the values must match.
214   // - If one shape has equal or more information than the other shape in every
215   //   dimension, the new shape will become the shape with more information.
216   // - Example: merging [2,?] and [?,2] results in [2,2]
217   // - Example: [2,2] cannot be merged with [1,2]
218   //
219   // This requires idx to be in the [0, num_inputs) range. If the merge is
220   // successful, return true. Return false otherwise.
MergeInput(int idx,ShapeHandle shape)221   bool MergeInput(int idx, ShapeHandle shape) {
222     ShapeHandle new_shape;
223     if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
224     inputs_[idx] = new_shape;
225     return true;
226   }
227 
228   // Relax the stored shape of the input in position idx with <shape> according
229   // to the following rules:
230   //
231   // - If the ShapeHandles are the same then the stored shape will be returned.
232   // - If either of the ShapeHandles are unknown, then a new UnknownShape will
233   //   be returned. A new shape must be returned because we cannot claim that
234   //   the resulting shape is necessarily the same as either of the input
235   //   shapes.
236   // - If the shapes both have known ranks but their ranks are different, a new
237   //   UnknownShape will be returned.
238   // - For any one dimension, if the value for that dimension in either of the
239   //   shapes is unknown, a new shape will be returned with a new UnknownDim in
240   //   that dimension.
241   // - For any one dimension, if the values for that dimension in both shapes
242   //   are known but do not match, a new shape will be returned with a new
243   //   UnknownDim in that dimension.
244   // - If both shapes have the same known rank and match in every dimension,
245   //   the stored shape will be returned.
246   // - Example: relaxing [2,?] and [?,2] results in [?,?]
247   // - Example: relaxing [2,2] and [3,2] results in [?,2]
248   // - Example: relaxing [2,2] with [1,2,3] results in ?
249   //
250   // This requires idx to be in the [0, num_inputs) range. If the relax is
251   // successful and the new shape differs from the old one, store the new
252   // shape and return true. Return false otherwise.
RelaxInput(int idx,ShapeHandle shape)253   bool RelaxInput(int idx, ShapeHandle shape) {
254     ShapeHandle new_shape;
255     Relax(inputs_[idx], shape, &new_shape);
256     if (inputs_[idx].SameHandle(new_shape)) {
257       return false;
258     }
259     inputs_[idx] = new_shape;
260     return true;
261   }
262 
SetInput(int idx,ShapeHandle shape)263   void SetInput(int idx, ShapeHandle shape) { inputs_[idx] = shape; }
264 
input(int64_t idx)265   ShapeHandle input(int64_t idx) const { return inputs_[idx]; }
266   Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
num_inputs()267   int num_inputs() const { return inputs_.size(); }
268 
269   // Returns the input tensor at index <idx>, or nullptr if the input tensor is
270   // not available at the time of shape inference.
input_tensor(int idx)271   const Tensor* input_tensor(int idx) {
272     // Mark that this idx was requested.
273     request_input_tensor(idx);
274     return input_tensors_[idx];
275   }
276 
277   // Notifies the shape refiner that the value of the tensor at index <idx>
278   // is needed. The shape refiner tries to statically compute this tensor,
279   // and if successful re-runs the  shape function with this tensor available
280   // in the call to 'input_tensor(idx)'.
request_input_tensor(int idx)281   void request_input_tensor(int idx) { requested_input_tensor_[idx] = true; }
282 
283   // Returns true iff input_tensor(idx) was called by the shape function.
requested_input_tensor(int idx)284   bool requested_input_tensor(int idx) const {
285     return requested_input_tensor_[idx];
286   }
287 
288   // Notifies the shape refiner that the value of the tensor at index <idx>
289   // as a partial shape is needed. The shape refiner tries to statically compute
290   // this, and if successful re-runs the  shape function with the
291   // computed PartialTensorShape available in the call to
292   // 'MakeShapeFromShapeTensor(idx, handle)' or
293   // 'MakeShapeFromShapeTensorTreatScalarAsUnknownShape(idx, handle)'.
request_input_tensor_as_partial_shape(int idx)294   void request_input_tensor_as_partial_shape(int idx) {
295     requested_input_tensor_as_partial_shape_[idx] = true;
296   }
297 
298   // Returns true if MakeShapeFromInputTensor was called but the constant
299   // input_tensor was not present.
requested_input_tensor_as_partial_shape(int idx)300   bool requested_input_tensor_as_partial_shape(int idx) const {
301     return requested_input_tensor_as_partial_shape_[idx];
302   }
303 
set_input_tensors(const std::vector<const Tensor * > & input_tensors)304   void set_input_tensors(const std::vector<const Tensor*>& input_tensors) {
305     input_tensors_ = input_tensors;
306   }
307 
set_input_tensors_as_shapes(const std::vector<ShapeHandle> & input_tensors_as_shapes)308   void set_input_tensors_as_shapes(
309       const std::vector<ShapeHandle>& input_tensors_as_shapes) {
310     input_tensors_as_shapes_ = input_tensors_as_shapes;
311   }
312 
input_tensors_as_shapes()313   const std::vector<ShapeHandle>& input_tensors_as_shapes() const {
314     return input_tensors_as_shapes_;
315   }
316 
output(int64_t idx)317   ShapeHandle output(int64_t idx) const { return outputs_.at(idx); }
set_output(int idx,ShapeHandle shape)318   void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
319   Status set_output(StringPiece output_name,
320                     const std::vector<ShapeHandle>& shapes);
321 
num_outputs()322   int num_outputs() const { return outputs_.size(); }
output(int idx)323   ShapeHandle output(int idx) const { return outputs_.at(idx); }
324   Status output(StringPiece output_name,
325                 std::vector<ShapeHandle>* output) const;
326 
327   // Returns the value for attribute named `attr_name`.
GetAttr(StringPiece attr_name,const AttrValue ** attr_value)328   Status GetAttr(StringPiece attr_name, const AttrValue** attr_value) const {
329     return attrs_.Find(attr_name, attr_value);
330   }
GetAttr(StringPiece attr_name)331   const AttrValue* GetAttr(StringPiece attr_name) const {
332     return attrs_.Find(attr_name);
333   }
334 
ret_types()335   const FullTypeDef& ret_types() const { return ret_types_; }
336 
337   // idx can be negative for an offset from end of dimensions.
338   // idx must be in the range [-1 * s.rank, s.rank).
Dim(ShapeHandle s,int64_t idx)339   DimensionHandle Dim(ShapeHandle s, int64_t idx) {
340     if (!s.Handle() || s->rank_ == kUnknownRank) {
341       return UnknownDim();
342     }
343     return DimKnownRank(s, idx);
344   }
345   // As above, but asserts that the rank of the shape is known.
DimKnownRank(ShapeHandle s,int64_t idx)346   static DimensionHandle DimKnownRank(ShapeHandle s, int64_t idx) {
347     CHECK_NE(s->rank_, kUnknownRank);
348     if (idx < 0) {
349       return s->dims_[s->dims_.size() + idx];
350     }
351     return s->dims_[idx];
352   }
353 
Rank(ShapeHandle s)354   static int32 Rank(ShapeHandle s) {
355     return s.IsSet() ? s->rank_ : kUnknownRank;
356   }
RankKnown(ShapeHandle s)357   static bool RankKnown(ShapeHandle s) {
358     return (s.IsSet() && (Rank(s) != kUnknownRank));
359   }
Value(DimensionOrConstant d)360   static inline int64 Value(DimensionOrConstant d) {
361     return d.dim.IsSet() ? d.dim->value_ : d.val;
362   }
ValueKnown(DimensionOrConstant d)363   static inline bool ValueKnown(DimensionOrConstant d) {
364     return Value(d) != kUnknownDim;
365   }
366 
367   // Fills the output proto with the shape defined by the handle.
368   // "proto" is expected to be empty prior to the call.
369   void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto);
370 
371   // Returns true if the rank and all dimensions of the Shape are known.
372   bool FullyDefined(ShapeHandle s);
373 
374   // Returns the total number of elements, or an unknown dimension for an
375   // incomplete shape.
376   DimensionHandle NumElements(ShapeHandle s);
377 
378   std::string DebugString(ShapeHandle s);
379   std::string DebugString(DimensionHandle d);
380   std::string DebugString(const ShapeAndType& shape_and_type);
381   std::string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types);
382 
383   // Describes the whole context, for debugging purposes.
384   std::string DebugString() const;
385 
386   // If <shape> has rank <rank>, or its rank is unknown, return OK and return
387   // the shape with asserted rank in <*out>. Otherwise return an error.
388   //
389   // Note that <*out> may be set to <shape>.
390   Status WithRank(ShapeHandle shape, int64_t rank,
391                   ShapeHandle* out) TF_MUST_USE_RESULT;
392   Status WithRankAtLeast(ShapeHandle shape, int64_t rank,
393                          ShapeHandle* out) TF_MUST_USE_RESULT;
394   Status WithRankAtMost(ShapeHandle shape, int64_t rank,
395                         ShapeHandle* out) TF_MUST_USE_RESULT;
396 
397   // If <dim> has value <value>, or its value is unknown, returns OK and returns
398   // the dimension with asserted value in <*out>. Otherwise returns an error.
399   //
400   // Note that <*out> may be set to <dim>.
401   Status WithValue(DimensionHandle dim, int64_t value,
402                    DimensionHandle* out) TF_MUST_USE_RESULT;
403 
404   // Merges <s0> and <s1> and returns the merged shape in <*out>. See
405   // 'MergeInput' function for full details and examples.
406   Status Merge(ShapeHandle s0, ShapeHandle s1,
407                ShapeHandle* out) TF_MUST_USE_RESULT;
408 
409   // Asserts that <s>'s rank >= <prefix>'s rank, and the first
410   // <prefix.rank> dimensions of <s> are compatible with the dimensions of
411   // <prefix>.
412   // Returns the merged results in <*s_out> and <*prefix_out>.
413   Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out,
414                      ShapeHandle* prefix_out) TF_MUST_USE_RESULT;
415 
416   // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0>
417   // and <d1> have incompatible values, returns an error.
418   //
419   // Note that <*out> may be set to <d0> or <d1>.
420   Status Merge(DimensionHandle d0, DimensionHandle d1,
421                DimensionHandle* out) TF_MUST_USE_RESULT;
422 
423   // Returns in <*out> a sub-shape of <s> with dimensions [start:].
424   // <start> can be negative to index from the end of the shape. If <start> >
425   // rank of <s>, then an empty subshape is returned.
426   Status Subshape(ShapeHandle s, int64_t start,
427                   ShapeHandle* out) TF_MUST_USE_RESULT;
428 
429   // Returns in <*out> a sub-shape of <s>, with dimensions [start:end].
430   // <start> and <end> can be negative, to index from the end of the shape.
431   // <start> and <end> are set to the rank of <s> if > rank of <s>.
432   Status Subshape(ShapeHandle s, int64_t start, int64_t end,
433                   ShapeHandle* out) TF_MUST_USE_RESULT;
434 
435   // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride].
436   // <start> and <end> can be negative, to index from the end of the shape.
437   // <start> and <end> are set to the rank of <s> if > rank of <s>.
438   // <stride> can be negative, to reverse the <s>.
439   Status Subshape(ShapeHandle s, int64_t start, int64_t end, int64_t stride,
440                   ShapeHandle* out) TF_MUST_USE_RESULT;
441 
442   // Returns in <*out> the result of appending the dimensions of <s2> to those
443   // of <s1>.
444   Status Concatenate(ShapeHandle s1, ShapeHandle s2,
445                      ShapeHandle* out) TF_MUST_USE_RESULT;
446 
447   // Returns in <out> the shape from replacing <s.dim[dim_index]> with
448   // <new_dim>.
449   Status ReplaceDim(ShapeHandle s, int64_t dim_index, DimensionHandle new_dim,
450                     ShapeHandle* out) TF_MUST_USE_RESULT;
451 
452   // Returns a new shape with the given dims. The returned value is owned by
453   // this context.
454   ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
455   ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims);
456 
457   // Returns a new unknown shape.
458   ShapeHandle UnknownShape();
459 
460   // Returns a shape with specified rank but unknown dims.
461   ShapeHandle UnknownShapeOfRank(int64_t rank);
462 
463   // Returns a new shape of zero dimensions.
464   ShapeHandle Scalar();
465 
466   // Returns a new shape of one dimension.
467   ShapeHandle Vector(DimensionOrConstant dim);
468 
469   // Returns a new shape of two dimensions.
470   ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2);
471 
472   // Returns in <out> a new shape whose dimension sizes come from input tensor
473   // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor.  If
474   // the input tensor is NULL, then an unknown shape is returned.
475   Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
476 
477   // Like the function above, but treats scalar values as unknown
478   // shapes.  **NOTE** If the scalar is statically known, its value
479   // must be -1 or an error is returned.
480   Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,
481                                                            ShapeHandle* out);
482 
483   // Returns in <out> a new shape corresponding to <proto>.
484   Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
485                                  ShapeHandle* out);
486 
487   // Returns in <out> a new shape corresponding to <partial_shape>.
488   Status MakeShapeFromPartialTensorShape(
489       const PartialTensorShape& partial_shape, ShapeHandle* out);
490 
491   // Returns in <out> a new shape corresponding to <shape>.
492   Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out);
493 
494   // Returns a new dimension of the given size.  The returned value is owned by
495   // this context.
MakeDim(DimensionOrConstant d)496   inline DimensionHandle MakeDim(DimensionOrConstant d) {
497     return shape_manager_.MakeDim(d);
498   }
499 
UnknownDim()500   inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
501 
502   // Returns in <val> a scalar value from an input tensor <t>.  The input tensor
503   // must be a 0-dimensional int32 or int64 tensor.  Caller must ensure that the
504   // input tensor is not NULL.
505   Status GetScalarFromTensor(const Tensor* t, int64* val);
506 
507   // Returns in <val> a scalar value from a 1D input tensor <t> with int32 or
508   // int64 elements. Caller must ensure that the input tensor is not NULL.
509   Status GetScalarFromTensor(const Tensor* t, int64_t idx, int64* val);
510 
511   // Returns a new dimension whose value is given by a scalar input tensor.
512   // The input tensor must be in host memory, since it is dereferenced to get
513   // the value.
514   Status MakeDimForScalarInput(int idx, DimensionHandle* out);
515 
516   // Returns a new dimension whose value is given by a scalar input tensor.
517   // This allows for a negative input dimension given the rank of a separate
518   // tensor.  This rank can be negative if unknown.
519   // The input tensor must be in host memory, since it is dereferenced to get
520   // the value.
521   Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
522                                                    DimensionHandle* out);
523 
524   // Look up the attr being evaluated with name attr_name and set *value to its
525   // value. If no attr with attr_name is found in def(), or the attr does not
526   // have a matching type, a non-ok status will be returned.
527   template <class T>
528   Status GetAttr(StringPiece attr_name, T* value) const;
529 
530   // Returns in <out> the result of dividing <dividend> by <divisor>.
531   // Returns an error if <divisor>  is not positive or if <evenly_divisible>
532   // and <divisor> does not evenly divide <dividend>.
533   Status Divide(DimensionHandle dividend, DimensionOrConstant divisor,
534                 bool evenly_divisible, DimensionHandle* out);
535 
536   // Returns in <out> the sum of <first> and <second>.
537   Status Add(DimensionHandle first, DimensionOrConstant second,
538              DimensionHandle* out);
539 
540   // Returns in <out> the dimension that is <first> minus <second>.
541   Status Subtract(DimensionHandle first, DimensionOrConstant second,
542                   DimensionHandle* out);
543 
544   // Returns in <out> the product of <first> and <second>.
545   Status Multiply(DimensionHandle first, DimensionOrConstant second,
546                   DimensionHandle* out);
547 
548   // Returns in <out> the minimum of <first> and <second>. If either <first> or
549   // <second> is zero the results is zero. Otherwise, if either <first> or
550   // <second> is unknown the results is unknown.
551   Status Min(DimensionHandle first, DimensionOrConstant second,
552              DimensionHandle* out);
553 
554   // Returns in <out> the maximum of <first> and <second>. If either <first> or
555   // <second> is unknown the results is unknown.
556   Status Max(DimensionHandle first, DimensionOrConstant second,
557              DimensionHandle* out);
558 
construction_status()559   Status construction_status() const { return construction_status_; }
560 
561   // Methods to propagate shape and dtype on edges of handles. Handles are the
562   // dtype DT_RESOURCE which can be used to access state stored in a
563   // ResourceManager. When ops (such as variables) consume these handles to
564   // produce tensors they might need to know side-information about the shapes
565   // and dtypes of tensors which can be accessed via the handle. These methods
566   // propagate that information. Output handle dtypes and shapes are ignored if
567   // the output tensor is not of type DT_RESOURCE.
568 
569   // Merge the stored shapes and types corresponding to the input handle in
570   // position idx with the specified shapes and types. This requires idx to be
571   // in the [0, num_inputs) range.
572   //
573   // If the merge is successful and any of the new shapes differs from the old
574   // one, or any of the old dtypes was DT_INVALID, store the new shapes and
575   // return true.  Return false otherwise.
576   //
577   // See 'MergeInput' function for full details and examples.
578   bool MergeInputHandleShapesAndTypes(
579       int idx,
580       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
581 
582   // As MergeInputHandleShapesAndTypes, but for an output.
583   bool MergeOutputHandleShapesAndTypes(
584       int idx,
585       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
586 
587   // Relaxes the stored shapes and types corresponding to the input handle in
588   // position idx with the specified shapes and types. This requires idx to be
589   // in the [0, num_inputs) range.
590   //
591   // If the relax is successful (sizes are the same, old dtypes match new ones
592   // or are DT_INVALID), then store the relaxed shapes and return true.
593   // Return false otherwise.
594   //
595   // See 'RelaxInput' function for full details and examples.
596   bool RelaxInputHandleShapesAndMergeTypes(
597       int idx,
598       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
599 
600   // As RelaxInputHandleShapesAndTypes, but for an output.
601   bool RelaxOutputHandleShapesAndMergeTypes(
602       int idx,
603       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
604 
set_input_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)605   void set_input_handle_shapes_and_types(
606       int idx, const std::vector<ShapeAndType>& shapes_and_types) {
607     input_handle_shapes_and_types_[idx] =
608         absl::make_unique<std::vector<ShapeAndType>>(shapes_and_types);
609   }
610 
611   // Returns the output handle shapes and types, for the resource tensor output
612   // at index <idx>. Returns NULL if the shape and types were never set.
output_handle_shapes_and_types(int idx)613   const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
614     return output_handle_shapes_and_types_[idx].get();
615   }
616 
617   // Returns the inputs handle shapes and types, for the resource tensor output
618   // at index <idx>. Returns NULL if the shape and types were not available.
input_handle_shapes_and_types(int idx)619   const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
620     return input_handle_shapes_and_types_[idx].get();
621   }
622 
set_output_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)623   void set_output_handle_shapes_and_types(
624       int idx, const std::vector<ShapeAndType>& shapes_and_types) {
625     output_handle_shapes_and_types_[idx].reset(
626         new std::vector<ShapeAndType>(shapes_and_types));
627   }
628 
629   // Note that shape functions should usually call MakeShapeFromShapeTensor,
630   // as it does more analysis to provide partial shapes.
631   //
632   // Returns in <out> a new shape whose dimension sizes come from tensor <t>.
633   // The tensor must be a 1-dimensional int32 or int64 tensor.  If <t> is NULL,
634   // then an unknown shape is returned.
635   Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
636                              ShapeHandle* out);
637 
graph_def_version()638   int graph_def_version() const { return graph_def_version_; }
639 
MergedShapes()640   const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const {
641     return merged_shapes_;
642   }
MergedDims()643   const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims()
644       const {
645     return merged_dims_;
646   }
647 
648   // Adds new outputs; useful when mutating the graph.
649   Status ExpandOutputs(int new_output_size);
650 
651  private:
652   // Creates and stores shapes for use in InferenceContext.
653   class ShapeManager {
654    public:
655     ShapeManager();
656     ~ShapeManager();
657 
658     // Returns a new shape with the given dims. The returned value is owned by
659     // this class.
660     ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
661 
662     // Returns a new unknown shape.
663     ShapeHandle UnknownShape();
664 
665     // Returns a new dimension of the given size.  The returned value
666     // is owned by this class.
MakeDim(DimensionOrConstant d)667     inline DimensionHandle MakeDim(DimensionOrConstant d) {
668       if (d.dim.IsSet()) {
669         return d.dim;
670       } else {
671         all_dims_.push_back(new Dimension(d.val));
672         return all_dims_.back();
673       }
674     }
675 
676    private:
677     std::vector<Shape*> all_shapes_;    // values are owned.
678     std::vector<Dimension*> all_dims_;  // values are owned.
679   };
680 
681   friend class ::tensorflow::grappler::GraphProperties;
682 
683   friend class ShapeInferenceTest;      // For testing Relax functions.
684   friend class ShapeInferenceTestutil;  // For testing shapes.
685 
686   // Shared initialization across the two constructors.  Remove
687   // once we get rid of one of them.
688   void PreInputInit(const OpDef& op_def,
689                     const std::vector<const Tensor*>& input_tensors,
690                     const std::vector<ShapeHandle>& input_tensors_as_shapes);
691   void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
692                          input_handle_data);
693 
ReturnUnknownShape(ShapeHandle * out)694   Status ReturnUnknownShape(ShapeHandle* out) {
695     *out = UnknownShape();
696     return Status::OK();
697   }
ReturnCreatedShape(const std::vector<DimensionHandle> & dims,ShapeHandle * out)698   Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims,
699                             ShapeHandle* out) {
700     *out = MakeShape(dims);
701     return Status::OK();
702   }
703 
704   // Adds additional context to the given status.
705   Status AttachContext(const Status& status);
706 
707   // Relaxes an existing value <d_old> with a new value <d_new> and returns the
708   // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible
709   // values, returns an error.
710   //
711   // Note that <*out> may be set to <d_old> or <d_new>.
712   void Relax(DimensionHandle d_old, DimensionHandle d_new,
713              DimensionHandle* out);
714   // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the
715   // relaxed shape in <*out>. See 'RelaxInput' function for full details and
716   // examples.
717   void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out);
718 
719   // Used to implement MergeInputHandleShapesAndTypes and
720   // MergeOutputHandleShapesAndTypes.
721   bool MergeHandleShapesAndTypes(
722       const std::vector<ShapeAndType>& shapes_and_types,
723       std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
724   // Used to implement RelaxInputHandleShapesAndMergeTypes and
725   // RelaxOutputHandleShapesAndMergeTypes.
726   bool RelaxHandleShapesAndMergeTypes(
727       const std::vector<ShapeAndType>& shapes_and_types,
728       std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
729 
730   // Forget all the previous merged shapes and dims.
ForgetMerges()731   void ForgetMerges() {
732     merged_shapes_.clear();
733     merged_dims_.clear();
734   }
735 
736   // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor.
737   Status InternalMakeShapeFromTensor(
738       bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
739       ShapeHandle tensor_shape, ShapeHandle* out);
740 
741   ShapeManager shape_manager_;
742 
743   // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
744   // `shape_manager_`.
745   std::vector<ShapeHandle> inputs_;
746   std::vector<const Tensor*> input_tensors_;
747   std::vector<bool> requested_input_tensor_;
748   std::vector<ShapeHandle> outputs_;
749   // Can have fewer elements than inputs_.
750   std::vector<ShapeHandle> input_tensors_as_shapes_;
751   std::vector<bool> requested_input_tensor_as_partial_shape_;
752 
753   // input_handle_shapes_and_types_[i] is the list of shape/type pairs available
754   // through the resource handle passed along input i of the node.
755   //
756   // Values may be NULL.
757   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
758       input_handle_shapes_and_types_;
759 
760   // output_handle_shapes_and_types_[i] is the list of shape/type pairs
761   // available through the resource handle passed along output i of the node.
762   //
763   // Values may be NULL.
764   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
765       output_handle_shapes_and_types_;
766 
767   // Return types for the node this context is associated with. This information
768   // is to eventually consolidate all the dtype and shape info, allowing for
769   // output_handle_shapes_and_types_ to be removed.
770   FullTypeDef ret_types_;
771 
772   const int graph_def_version_;
773   AttrSlice attrs_;
774   NameRangeMap input_name_map_;
775   NameRangeMap output_name_map_;
776 
777   // An error set during construction. TODO(cwhipkey): remove when test
778   // constructor is removed.
779   Status construction_status_;
780 
781   // Pair of shape or dim handles that are equivalent, ie that represent the
782   // same underlying shape of dimension. Note that for each pair at least one of
783   // the handles must contain an unknown shape, since we don't keep track of
784   // known shapes or dims here.
785   std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_;
786   std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_;
787 
788   TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
789 };
790 
791 // -----------------------------------------------------------------------------
792 // Template and inline method implementations, please ignore
793 
Dimension()794 inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
Dimension(int64_t value)795 inline Dimension::Dimension(int64_t value) : value_(value) {
796   DCHECK(value >= 0 || value == InferenceContext::kUnknownDim)
797       << "Dimension must be non-negative or equal to "
798          "InferenceContext::kUnknownDim but got "
799       << value;
800 }
801 
Shape()802 inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
Shape(const std::vector<DimensionHandle> & dims)803 inline Shape::Shape(const std::vector<DimensionHandle>& dims)
804     : rank_(dims.size()), dims_(dims) {}
805 
DimensionOrConstant(DimensionHandle dim)806 inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim)
807     : dim(dim) {
808   DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension.";
809 }
810 
DimensionOrConstant(int64_t val)811 inline DimensionOrConstant::DimensionOrConstant(int64_t val) : val(val) {
812   DCHECK(val >= 0 || val == InferenceContext::kUnknownDim)
813       << "Dimension must be non-negative or equal to "
814          "InferenceContext::kUnknownDim but got "
815       << val;
816 }
817 
818 template <class T>
GetAttr(StringPiece attr_name,T * value)819 Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
820   return GetNodeAttr(attrs_, attr_name, value);
821 }
822 
823 }  // namespace shape_inference
824 }  // namespace tensorflow
825 
826 #endif  // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
827