• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
17 
18 #include <vector>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/gtl/inlined_vector.h"
26 #include "tensorflow/core/platform/macros.h"
27 
28 namespace tensorflow {
29 
30 namespace grappler {
31 class GraphProperties;
32 class SymbolicShapeManager;
33 }  // namespace grappler
34 
35 namespace shape_inference {
36 
37 struct DimensionOrConstant;
38 class InferenceContext;
39 
40 // Dimension values are accessed through InferenceContext.
41 class Dimension {
42  private:
43   Dimension();
44   Dimension(int64 value);
~Dimension()45   ~Dimension() {}
46 
47   const int64 value_;
48 
49   friend class InferenceContext;
50   friend class ShapeManager;
51   TF_DISALLOW_COPY_AND_ASSIGN(Dimension);
52 };
53 
54 class DimensionHandle {
55  public:
DimensionHandle()56   DimensionHandle() {}
SameHandle(DimensionHandle d)57   bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; }
Handle()58   std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
59 
60  private:
DimensionHandle(const Dimension * dim)61   DimensionHandle(const Dimension* dim) { ptr_ = dim; }
62 
63   const Dimension* operator->() const { return ptr_; }
IsSet()64   bool IsSet() const { return ptr_ != nullptr; }
65 
66   const Dimension* ptr_ = nullptr;
67 
68   friend struct DimensionOrConstant;
69   friend class InferenceContext;
70   friend class ShapeInferenceTest;
71   friend class ShapeInferenceTestutil;
72   friend class ::tensorflow::grappler::GraphProperties;
73   friend class ::tensorflow::grappler::SymbolicShapeManager;
74 
75   // Intentionally copyable.
76 };
77 
78 // Shape rank and dimensions are accessed through InferenceContext.
79 class Shape {
80  private:
81   Shape();
82   Shape(const std::vector<DimensionHandle>& dims);
~Shape()83   ~Shape() {}
84 
85   const int32 rank_;
86   const std::vector<DimensionHandle> dims_;
87 
88   friend class InferenceContext;
89   friend class ::tensorflow::grappler::SymbolicShapeManager;
90 
91   TF_DISALLOW_COPY_AND_ASSIGN(Shape);
92 };
93 
94 class ShapeHandle {
95  public:
ShapeHandle()96   ShapeHandle() {}
SameHandle(ShapeHandle s)97   bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; }
Handle()98   std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
99 
100  private:
ShapeHandle(const Shape * shape)101   ShapeHandle(const Shape* shape) { ptr_ = shape; }
102   const Shape* operator->() const { return ptr_; }
IsSet()103   bool IsSet() const { return ptr_ != nullptr; }
104 
105   const Shape* ptr_ = nullptr;
106 
107   friend class InferenceContext;
108   friend class ShapeInferenceTest;
109   friend class ShapeInferenceTestutil;
110   friend class ::tensorflow::grappler::SymbolicShapeManager;
111 
112   // Intentionally copyable.
113 };
114 
115 // Struct used to allow functions to take DimensionHandle or a dimension value.
116 // Not meant to be constructed directly.
117 struct DimensionOrConstant {
118  public:
119   // Intentionally not explicit.
120   DimensionOrConstant(DimensionHandle dim);
121 
122   // val must be non-negative or InferenceContext::kUnknownDim.
123   DimensionOrConstant(int64 val);
124 
125   // dim takes precedence. If dim != nullptr, val is ignored.
126   DimensionHandle dim;
127   int64 val;
128 
129  private:
130   DimensionOrConstant();
131 };
132 
133 struct ShapeAndType {
ShapeAndTypeShapeAndType134   ShapeAndType() {}
ShapeAndTypeShapeAndType135   ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
ShapeAndTypeShapeAndType136   ShapeAndType(ShapeHandle s, DataType t, SpecializedType specialized_t)
137       : shape(s), dtype(t), specialized_type(specialized_t) {}
138 
139   ShapeHandle shape;
140   DataType dtype = DT_INVALID;
141   // The type of a variant-dtype tensor sometimes affects graph building
142   // (e.g. for vectorization), and needs to be know statically in such cases.
143   SpecializedType specialized_type = ST_INVALID;
144 };
145 
146 // Shape inference functions registered on ops in REGISTER_OP implement
147 // their shape functions in terms of this InferenceContext.  An InferenceContext
148 // is created by the framework and passed to a shape inference function.  The
149 // shape inference function calls functions on the context, and should call
150 // set_output() to set the shape on all outputs.
151 //
152 // To infer shapes for user-defined functions see ShapeRefiner.
153 //
154 // All Shape* and Dimension* returned by functions of InferenceContext are owned
155 // by the InferenceContext.
156 class InferenceContext {
157  public:
158   static constexpr int64 kUnknownDim = -1;
159   static constexpr int32 kUnknownRank = -1;
160 
161   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
162   //
163   // Elements of <input_tensors_as_shapes> are used for when a shape function
164   // makes a call to MakeShapeFromShapeTensor; in particular, when the
165   // input_tensors[i] is nullptr but the shape represented by it is partially
166   // known from analysis of the graph.
167   // <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
168   // Values of <input_tensors_as_shapes> do not need to outlive the context.
169   InferenceContext(int graph_def_version, const AttrSlice& attrs,
170                    const OpDef& op_def,
171                    const std::vector<ShapeHandle>& input_shapes,
172                    const std::vector<const Tensor*>& input_tensors,
173                    const std::vector<ShapeHandle>& input_tensors_as_shapes,
174                    std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
175                        input_handle_shapes_and_types);
176 
177   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
178   //
179   // Elements of <input_tensors_as_shapes> are used for when a shape
180   // function makes a call to MakeShapeFromShapeTensor; in particular, when
181   // the input_tensors[i] is nullptr but the shape represented by it is
182   // partially known from analysis of the graph. <input_tensors_as_shapes>
183   // can have fewer elements than <input_shapes>. Values of
184   // <input_tensors_as_shapes> do not need to outlive the context.
185   InferenceContext(
186       int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
187       const std::vector<PartialTensorShape>& input_shapes,
188       const std::vector<const Tensor*>& input_tensors,
189       const std::vector<PartialTensorShape>& input_tensors_as_shapes,
190       const std::vector<std::unique_ptr<
191           std::vector<std::pair<PartialTensorShape, DataType>>>>&
192           input_handle_shapes_and_types);
193 
194   ~InferenceContext();
195 
196   // Runs the shape inference function 'fn' with 'this' as the
197   // argument, returns the status of the inference.
198   //
199   // On error, additional context is provided in the error message.
200   Status Run(
201       const std::function<Status(shape_inference::InferenceContext* c)>& fn);
202 
203   // Merge the stored shape of the input in position idx with <shape> according
204   // to the following rules:
205   //
206   // - If the ShapeHandles are the same or <shape> is unknown, there will be no
207   //   change. Otherwise if the stored shape is unknown, the new shape will be
208   //   <shape>.
209   // - If both shapes are known, then they must have the same rank.
210   // - For any one dimension, if the values for that dimension in both shapes
211   //   are known, then the values must match.
212   // - If one shape has equal or more information than the other shape in every
213   //   dimension, the new shape will become the shape with more information.
214   // - Example: merging [2,?] and [?,2] results in [2,2]
215   // - Example: [2,2] cannot be merged with [1,2]
216   //
217   // This requires idx to be in the [0, num_inputs) range. If the merge is
218   // successful, return true. Return false otherwise.
MergeInput(int idx,ShapeHandle shape)219   bool MergeInput(int idx, ShapeHandle shape) {
220     ShapeHandle new_shape;
221     if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
222     inputs_[idx] = new_shape;
223     return true;
224   }
225 
226   // Relax the stored shape of the input in position idx with <shape> according
227   // to the following rules:
228   //
229   // - If the ShapeHandles are the same then the stored shape will be returned.
230   // - If either of the ShapeHandles are unknown, then a new UnknownShape will
231   //   be returned. A new shape must be returned because we cannot claim that
232   //   the resulting shape is necessarily the same as either of the input
233   //   shapes.
234   // - If the shapes both have known ranks but their ranks are different, a new
235   //   UnknownShape will be returned.
236   // - For any one dimension, if the value for that dimension in either of the
237   //   shapes is unknown, a new shape will be returned with a new UnknownDim in
238   //   that dimension.
239   // - For any one dimension, if the values for that dimension in both shapes
240   //   are known but do not match, a new shape will be returned with a new
241   //   UnknownDim in that dimension.
242   // - If both shapes have the same known rank and match in every dimension,
243   //   the stored shape will be returned.
244   // - Example: relaxing [2,?] and [?,2] results in [?,?]
245   // - Example: relaxing [2,2] and [3,2] results in [?,2]
246   // - Example: relaxing [2,2] with [1,2,3] results in ?
247   //
248   // This requires idx to be in the [0, num_inputs) range. If the relax is
249   // successful and the new shape differs from the old one, store the new
250   // shape and return true. Return false otherwise.
RelaxInput(int idx,ShapeHandle shape)251   bool RelaxInput(int idx, ShapeHandle shape) {
252     ShapeHandle new_shape;
253     Relax(inputs_[idx], shape, &new_shape);
254     if (inputs_[idx].SameHandle(new_shape)) {
255       return false;
256     }
257     inputs_[idx] = new_shape;
258     return true;
259   }
260 
SetInput(int idx,ShapeHandle shape)261   void SetInput(int idx, ShapeHandle shape) { inputs_[idx] = shape; }
262 
input(int64 idx)263   ShapeHandle input(int64 idx) const { return inputs_[idx]; }
264   Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
num_inputs()265   int num_inputs() const { return inputs_.size(); }
266 
267   // Returns the input tensor at index <idx>, or nullptr if the input tensor is
268   // not available at the time of shape inference.
input_tensor(int idx)269   const Tensor* input_tensor(int idx) {
270     // Mark that this idx was requested.
271     request_input_tensor(idx);
272     return input_tensors_[idx];
273   }
274 
275   // Notifies the shape refiner that the value of the tensor at index <idx>
276   // is needed. The shape refiner tries to statically compute this tensor,
277   // and if successful re-runs the  shape function with this tensor available
278   // in the call to 'input_tensor(idx)'.
request_input_tensor(int idx)279   void request_input_tensor(int idx) { requested_input_tensor_[idx] = true; }
280 
281   // Returns true iff input_tensor(idx) was called by the shape function.
requested_input_tensor(int idx)282   bool requested_input_tensor(int idx) const {
283     return requested_input_tensor_[idx];
284   }
285 
286   // Notifies the shape refiner that the value of the tensor at index <idx>
287   // as a partial shape is needed. The shape refiner tries to statically compute
288   // this, and if successful re-runs the  shape function with the
289   // computed PartialTensorShape available in the call to
290   // 'MakeShapeFromShapeTensor(idx, handle)' or
291   // 'MakeShapeFromShapeTensorTreatScalarAsUnknownShape(idx, handle)'.
request_input_tensor_as_partial_shape(int idx)292   void request_input_tensor_as_partial_shape(int idx) {
293     requested_input_tensor_as_partial_shape_[idx] = true;
294   }
295 
296   // Returns true if MakeShapeFromInputTensor was called but the constant
297   // input_tensor was not present.
requested_input_tensor_as_partial_shape(int idx)298   bool requested_input_tensor_as_partial_shape(int idx) const {
299     return requested_input_tensor_as_partial_shape_[idx];
300   }
301 
set_input_tensors(const std::vector<const Tensor * > & input_tensors)302   void set_input_tensors(const std::vector<const Tensor*>& input_tensors) {
303     input_tensors_ = input_tensors;
304   }
305 
set_input_tensors_as_shapes(const std::vector<ShapeHandle> & input_tensors_as_shapes)306   void set_input_tensors_as_shapes(
307       const std::vector<ShapeHandle>& input_tensors_as_shapes) {
308     input_tensors_as_shapes_ = input_tensors_as_shapes;
309   }
310 
input_tensors_as_shapes()311   const std::vector<ShapeHandle>& input_tensors_as_shapes() const {
312     return input_tensors_as_shapes_;
313   }
314 
output(int64 idx)315   ShapeHandle output(int64 idx) const { return outputs_.at(idx); }
set_output(int idx,ShapeHandle shape)316   void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
317   Status set_output(StringPiece output_name,
318                     const std::vector<ShapeHandle>& shapes);
319 
num_outputs()320   int num_outputs() const { return outputs_.size(); }
output(int idx)321   ShapeHandle output(int idx) const { return outputs_.at(idx); }
322   Status output(StringPiece output_name,
323                 std::vector<ShapeHandle>* output) const;
324 
attrs()325   const AttrSlice& attrs() const { return attrs_; }
326 
327   // idx can be negative for an offset from end of dimensions.
328   // idx must be in the range [-1 * s.rank, s.rank).
Dim(ShapeHandle s,int64 idx)329   DimensionHandle Dim(ShapeHandle s, int64 idx) {
330     if (!s.Handle() || s->rank_ == kUnknownRank) {
331       return UnknownDim();
332     }
333     return DimKnownRank(s, idx);
334   }
335   // As above, but asserts that the rank of the shape is known.
DimKnownRank(ShapeHandle s,int64 idx)336   static DimensionHandle DimKnownRank(ShapeHandle s, int64 idx) {
337     CHECK_NE(s->rank_, kUnknownRank);
338     if (idx < 0) {
339       return s->dims_[s->dims_.size() + idx];
340     }
341     return s->dims_[idx];
342   }
343 
Rank(ShapeHandle s)344   static int32 Rank(ShapeHandle s) {
345     return s.IsSet() ? s->rank_ : kUnknownRank;
346   }
RankKnown(ShapeHandle s)347   static bool RankKnown(ShapeHandle s) {
348     return (s.IsSet() && (Rank(s) != kUnknownRank));
349   }
Value(DimensionOrConstant d)350   static inline int64 Value(DimensionOrConstant d) {
351     return d.dim.IsSet() ? d.dim->value_ : d.val;
352   }
ValueKnown(DimensionOrConstant d)353   static inline bool ValueKnown(DimensionOrConstant d) {
354     return Value(d) != kUnknownDim;
355   }
356 
357   // Fills the output proto with the shape defined by the handle.
358   // "proto" is expected to be empty prior to the call.
359   void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto);
360 
361   // Returns true if the rank and all dimensions of the Shape are known.
362   bool FullyDefined(ShapeHandle s);
363 
364   // Returns the total number of elements, or an unknown dimension for an
365   // incomplete shape.
366   DimensionHandle NumElements(ShapeHandle s);
367 
368   std::string DebugString(ShapeHandle s);
369   std::string DebugString(DimensionHandle d);
370   std::string DebugString(const ShapeAndType& shape_and_type);
371   std::string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types);
372 
373   // Describes the whole context, for debugging purposes.
374   std::string DebugString() const;
375 
376   // If <shape> has rank <rank>, or its rank is unknown, return OK and return
377   // the shape with asserted rank in <*out>. Otherwise return an error.
378   //
379   // Note that <*out> may be set to <shape>.
380   Status WithRank(ShapeHandle shape, int64 rank,
381                   ShapeHandle* out) TF_MUST_USE_RESULT;
382   Status WithRankAtLeast(ShapeHandle shape, int64 rank,
383                          ShapeHandle* out) TF_MUST_USE_RESULT;
384   Status WithRankAtMost(ShapeHandle shape, int64 rank,
385                         ShapeHandle* out) TF_MUST_USE_RESULT;
386 
387   // If <dim> has value <value>, or its value is unknown, returns OK and returns
388   // the dimension with asserted value in <*out>. Otherwise returns an error.
389   //
390   // Note that <*out> may be set to <dim>.
391   Status WithValue(DimensionHandle dim, int64 value,
392                    DimensionHandle* out) TF_MUST_USE_RESULT;
393 
394   // Merges <s0> and <s1> and returns the merged shape in <*out>. See
395   // 'MergeInput' function for full details and examples.
396   Status Merge(ShapeHandle s0, ShapeHandle s1,
397                ShapeHandle* out) TF_MUST_USE_RESULT;
398 
399   // Asserts that <s>'s rank >= <prefix>'s rank, and the first
400   // <prefix.rank> dimensions of <s> are compatible with the dimensions of
401   // <prefix>.
402   // Returns the merged results in <*s_out> and <*prefix_out>.
403   Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out,
404                      ShapeHandle* prefix_out) TF_MUST_USE_RESULT;
405 
406   // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0>
407   // and <d1> have incompatible values, returns an error.
408   //
409   // Note that <*out> may be set to <d0> or <d1>.
410   Status Merge(DimensionHandle d0, DimensionHandle d1,
411                DimensionHandle* out) TF_MUST_USE_RESULT;
412 
413   // Returns in <*out> a sub-shape of <s> with dimensions [start:].
414   // <start> can be negative to index from the end of the shape. If <start> >
415   // rank of <s>, then an empty subshape is returned.
416   Status Subshape(ShapeHandle s, int64 start,
417                   ShapeHandle* out) TF_MUST_USE_RESULT;
418 
419   // Returns in <*out> a sub-shape of <s>, with dimensions [start:end].
420   // <start> and <end> can be negative, to index from the end of the shape.
421   // <start> and <end> are set to the rank of <s> if > rank of <s>.
422   Status Subshape(ShapeHandle s, int64 start, int64 end,
423                   ShapeHandle* out) TF_MUST_USE_RESULT;
424 
425   // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride].
426   // <start> and <end> can be negative, to index from the end of the shape.
427   // <start> and <end> are set to the rank of <s> if > rank of <s>.
428   // <stride> can be negative, to reverse the <s>.
429   Status Subshape(ShapeHandle s, int64 start, int64 end, int64 stride,
430                   ShapeHandle* out) TF_MUST_USE_RESULT;
431 
432   // Returns in <*out> the result of appending the dimensions of <s2> to those
433   // of <s1>.
434   Status Concatenate(ShapeHandle s1, ShapeHandle s2,
435                      ShapeHandle* out) TF_MUST_USE_RESULT;
436 
437   // Returns in <out> the shape from replacing <s.dim[dim_index]> with
438   // <new_dim>.
439   Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim,
440                     ShapeHandle* out) TF_MUST_USE_RESULT;
441 
442   // Returns a new shape with the given dims. The returned value is owned by
443   // this context.
444   ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
445   ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims);
446 
447   // Returns a new unknown shape.
448   ShapeHandle UnknownShape();
449 
450   // Returns a shape with specified rank but unknown dims.
451   ShapeHandle UnknownShapeOfRank(int64 rank);
452 
453   // Returns a new shape of zero dimensions.
454   ShapeHandle Scalar();
455 
456   // Returns a new shape of one dimension.
457   ShapeHandle Vector(DimensionOrConstant dim);
458 
459   // Returns a new shape of two dimensions.
460   ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2);
461 
462   // Returns in <out> a new shape whose dimension sizes come from input tensor
463   // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor.  If
464   // the input tensor is NULL, then an unknown shape is returned.
465   Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
466 
467   // Like the function above, but treats scalar values as unknown
468   // shapes.  **NOTE** If the scalar is statically known, its value
469   // must be -1 or an error is returned.
470   Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,
471                                                            ShapeHandle* out);
472 
473   // Returns in <out> a new shape corresponding to <proto>.
474   Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
475                                  ShapeHandle* out);
476 
477   // Returns in <out> a new shape corresponding to <partial_shape>.
478   Status MakeShapeFromPartialTensorShape(
479       const PartialTensorShape& partial_shape, ShapeHandle* out);
480 
481   // Returns in <out> a new shape corresponding to <shape>.
482   Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out);
483 
484   // Returns a new dimension of the given size.  The returned value is owned by
485   // this context.
MakeDim(DimensionOrConstant d)486   inline DimensionHandle MakeDim(DimensionOrConstant d) {
487     return shape_manager_.MakeDim(d);
488   }
489 
UnknownDim()490   inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
491 
492   // Returns in <val> a scalar value from an input tensor <t>.  The input tensor
493   // must be a 0-dimensional int32 or int64 tensor.  Caller must ensure that the
494   // input tensor is not NULL.
495   Status GetScalarFromTensor(const Tensor* t, int64* val);
496 
497   // Returns in <val> a scalar value from a 1D input tensor <t> with int32 or
498   // int64 elements. Caller must ensure that the input tensor is not NULL.
499   Status GetScalarFromTensor(const Tensor* t, int64 idx, int64* val);
500 
501   // Returns a new dimension whose value is given by a scalar input tensor.
502   // The input tensor must be in host memory, since it is dereferenced to get
503   // the value.
504   Status MakeDimForScalarInput(int idx, DimensionHandle* out);
505 
506   // Returns a new dimension whose value is given by a scalar input tensor.
507   // This allows for a negative input dimension given the rank of a separate
508   // tensor.  This rank can be negative if unknown.
509   // The input tensor must be in host memory, since it is dereferenced to get
510   // the value.
511   Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
512                                                    DimensionHandle* out);
513 
514   // Look up the attr being evaluated with name attr_name and set *value to its
515   // value. If no attr with attr_name is found in def(), or the attr does not
516   // have a matching type, a non-ok status will be returned.
517   template <class T>
518   Status GetAttr(StringPiece attr_name, T* value) const;
519 
520   // Returns in <out> the result of dividing <dividend> by <divisor>.
521   // Returns an error if <divisor>  is not positive or if <evenly_divisible>
522   // and <divisor> does not evenly divide <dividend>.
523   Status Divide(DimensionHandle dividend, DimensionOrConstant divisor,
524                 bool evenly_divisible, DimensionHandle* out);
525 
526   // Returns in <out> the sum of <first> and <second>.
527   Status Add(DimensionHandle first, DimensionOrConstant second,
528              DimensionHandle* out);
529 
530   // Returns in <out> the dimension that is <first> minus <second>.
531   Status Subtract(DimensionHandle first, DimensionOrConstant second,
532                   DimensionHandle* out);
533 
534   // Returns in <out> the product of <first> and <second>.
535   Status Multiply(DimensionHandle first, DimensionOrConstant second,
536                   DimensionHandle* out);
537 
538   // Returns in <out> the minimum of <first> and <second>. If either <first> or
539   // <second> is zero the results is zero. Otherwise, if either <first> or
540   // <second> is unknown the results is unknown.
541   Status Min(DimensionHandle first, DimensionOrConstant second,
542              DimensionHandle* out);
543 
544   // Returns in <out> the maximum of <first> and <second>. If either <first> or
545   // <second> is unknown the results is unknown.
546   Status Max(DimensionHandle first, DimensionOrConstant second,
547              DimensionHandle* out);
548 
construction_status()549   Status construction_status() const { return construction_status_; }
550 
551   // Methods to propagate shape and dtype on edges of handles. Handles are the
552   // dtype DT_RESOURCE which can be used to access state stored in a
553   // ResourceManager. When ops (such as variables) consume these handles to
554   // produce tensors they might need to know side-information about the shapes
555   // and dtypes of tensors which can be accessed via the handle. These methods
556   // propagate that information. Output handle dtypes and shapes are ignored if
557   // the output tensor is not of type DT_RESOURCE.
558 
559   // Merge the stored shapes and types corresponding to the input handle in
560   // position idx with the specified shapes and types. This requires idx to be
561   // in the [0, num_inputs) range.
562   //
563   // If the merge is successful and any of the new shapes differs from the old
564   // one, or any of the old dtypes was DT_INVALID, store the new shapes and
565   // return true.  Return false otherwise.
566   //
567   // See 'MergeInput' function for full details and examples.
568   bool MergeInputHandleShapesAndTypes(
569       int idx,
570       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
571 
572   // As MergeInputHandleShapesAndTypes, but for an output.
573   bool MergeOutputHandleShapesAndTypes(
574       int idx,
575       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
576 
577   // Relaxes the stored shapes and types corresponding to the input handle in
578   // position idx with the specified shapes and types. This requires idx to be
579   // in the [0, num_inputs) range.
580   //
581   // If the relax is successful (sizes are the same, old dtypes match new ones
582   // or are DT_INVALID), then store the relaxed shapes and return true.
583   // Return false otherwise.
584   //
585   // See 'RelaxInput' function for full details and examples.
586   bool RelaxInputHandleShapesAndMergeTypes(
587       int idx,
588       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
589 
590   // As RelaxInputHandleShapesAndTypes, but for an output.
591   bool RelaxOutputHandleShapesAndMergeTypes(
592       int idx,
593       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
594 
set_input_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)595   void set_input_handle_shapes_and_types(
596       int idx, const std::vector<ShapeAndType>& shapes_and_types) {
597     input_handle_shapes_and_types_[idx] =
598         absl::make_unique<std::vector<ShapeAndType>>(shapes_and_types);
599   }
600 
601   // Returns the output handle shapes and types, for the resource tensor output
602   // at index <idx>. Returns NULL if the shape and types were never set.
output_handle_shapes_and_types(int idx)603   const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
604     return output_handle_shapes_and_types_[idx].get();
605   }
606 
607   // Returns the inputs handle shapes and types, for the resource tensor output
608   // at index <idx>. Returns NULL if the shape and types were not available.
input_handle_shapes_and_types(int idx)609   const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
610     return input_handle_shapes_and_types_[idx].get();
611   }
612 
set_output_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)613   void set_output_handle_shapes_and_types(
614       int idx, const std::vector<ShapeAndType>& shapes_and_types) {
615     output_handle_shapes_and_types_[idx].reset(
616         new std::vector<ShapeAndType>(shapes_and_types));
617   }
618 
619   // Note that shape functions should usually call MakeShapeFromShapeTensor,
620   // as it does more analysis to provide partial shapes.
621   //
622   // Returns in <out> a new shape whose dimension sizes come from tensor <t>.
623   // The tensor must be a 1-dimensional int32 or int64 tensor.  If <t> is NULL,
624   // then an unknown shape is returned.
625   Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
626                              ShapeHandle* out);
627 
graph_def_version()628   int graph_def_version() const { return graph_def_version_; }
629 
MergedShapes()630   const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const {
631     return merged_shapes_;
632   }
MergedDims()633   const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims()
634       const {
635     return merged_dims_;
636   }
637 
638   // Adds new outputs; useful when mutating the graph.
639   Status ExpandOutputs(int new_output_size);
640 
641  private:
642   // Creates and stores shapes for use in InferenceContext.
643   class ShapeManager {
644    public:
645     ShapeManager();
646     ~ShapeManager();
647 
648     // Returns a new shape with the given dims. The returned value is owned by
649     // this class.
650     ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
651 
652     // Returns a new unknown shape.
653     ShapeHandle UnknownShape();
654 
655     // Returns a new dimension of the given size.  The returned value
656     // is owned by this class.
MakeDim(DimensionOrConstant d)657     inline DimensionHandle MakeDim(DimensionOrConstant d) {
658       if (d.dim.IsSet()) {
659         return d.dim;
660       } else {
661         all_dims_.push_back(new Dimension(d.val));
662         return all_dims_.back();
663       }
664     }
665 
666    private:
667     std::vector<Shape*> all_shapes_;    // values are owned.
668     std::vector<Dimension*> all_dims_;  // values are owned.
669   };
670 
671   friend class ::tensorflow::grappler::GraphProperties;
672 
673   friend class ShapeInferenceTest;      // For testing Relax functions.
674   friend class ShapeInferenceTestutil;  // For testing shapes.
675 
676   // Shared initialization across the two constructors.  Remove
677   // once we get rid of one of them.
678   void PreInputInit(const OpDef& op_def,
679                     const std::vector<const Tensor*>& input_tensors,
680                     const std::vector<ShapeHandle>& input_tensors_as_shapes);
681   void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
682                          input_handle_data);
683 
ReturnUnknownShape(ShapeHandle * out)684   Status ReturnUnknownShape(ShapeHandle* out) {
685     *out = UnknownShape();
686     return Status::OK();
687   }
ReturnCreatedShape(const std::vector<DimensionHandle> & dims,ShapeHandle * out)688   Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims,
689                             ShapeHandle* out) {
690     *out = MakeShape(dims);
691     return Status::OK();
692   }
693 
694   // Adds additional context to the given status.
695   Status AttachContext(const Status& status);
696 
697   // Relaxes an existing value <d_old> with a new value <d_new> and returns the
698   // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible
699   // values, returns an error.
700   //
701   // Note that <*out> may be set to <d_old> or <d_new>.
702   void Relax(DimensionHandle d_old, DimensionHandle d_new,
703              DimensionHandle* out);
704   // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the
705   // relaxed shape in <*out>. See 'RelaxInput' function for full details and
706   // examples.
707   void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out);
708 
709   // Used to implement MergeInputHandleShapesAndTypes and
710   // MergeOutputHandleShapesAndTypes.
711   bool MergeHandleShapesAndTypes(
712       const std::vector<ShapeAndType>& shapes_and_types,
713       std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
714   // Used to implement RelaxInputHandleShapesAndMergeTypes and
715   // RelaxOutputHandleShapesAndMergeTypes.
716   bool RelaxHandleShapesAndMergeTypes(
717       const std::vector<ShapeAndType>& shapes_and_types,
718       std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
719 
720   // Forget all the previous merged shapes and dims.
ForgetMerges()721   void ForgetMerges() {
722     merged_shapes_.clear();
723     merged_dims_.clear();
724   }
725 
726   // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor.
727   Status InternalMakeShapeFromTensor(
728       bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
729       ShapeHandle tensor_shape, ShapeHandle* out);
730 
731   ShapeManager shape_manager_;
732 
733   // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
734   // `shape_manager_`.
735   std::vector<ShapeHandle> inputs_;
736   std::vector<const Tensor*> input_tensors_;
737   std::vector<bool> requested_input_tensor_;
738   std::vector<ShapeHandle> outputs_;
739   // Can have fewer elements than inputs_.
740   std::vector<ShapeHandle> input_tensors_as_shapes_;
741   std::vector<bool> requested_input_tensor_as_partial_shape_;
742 
743   // input_handle_shapes_and_types_[i] is the list of shape/type pairs available
744   // through the resource handle passed along input i of the node.
745   //
746   // Values may be NULL.
747   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
748       input_handle_shapes_and_types_;
749 
750   // output_handle_shapes_and_types_[i] is the list of shape/type pairs
751   // available through the resource handle passed along output i of the node.
752   //
753   // Values may be NULL.
754   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
755       output_handle_shapes_and_types_;
756 
757   const int graph_def_version_;
758   AttrSlice attrs_;
759   NameRangeMap input_name_map_;
760   NameRangeMap output_name_map_;
761 
762   // An error set during construction. TODO(cwhipkey): remove when test
763   // constructor is removed.
764   Status construction_status_;
765 
766   // Pair of shape or dim handles that are equivalent, ie that represent the
767   // same underlying shape of dimension. Note that for each pair at least one of
768   // the handles must contain an unknown shape, since we don't keep track of
769   // known shapes or dims here.
770   std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_;
771   std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_;
772 
773   TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
774 };
775 
776 // -----------------------------------------------------------------------------
777 // Template and inline method implementations, please ignore
778 
Dimension()779 inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
Dimension(int64 value)780 inline Dimension::Dimension(int64 value) : value_(value) {
781   DCHECK(value >= 0 || value == InferenceContext::kUnknownDim)
782       << "Dimension must be non-negative or equal to "
783          "InferenceContext::kUnknownDim but got "
784       << value;
785 }
786 
Shape()787 inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
Shape(const std::vector<DimensionHandle> & dims)788 inline Shape::Shape(const std::vector<DimensionHandle>& dims)
789     : rank_(dims.size()), dims_(dims) {}
790 
DimensionOrConstant(DimensionHandle dim)791 inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim)
792     : dim(dim) {
793   DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension.";
794 }
795 
DimensionOrConstant(int64 val)796 inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) {
797   DCHECK(val >= 0 || val == InferenceContext::kUnknownDim)
798       << "Dimension must be non-negative or equal to "
799          "InferenceContext::kUnknownDim but got "
800       << val;
801 }
802 
803 template <class T>
GetAttr(StringPiece attr_name,T * value)804 Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
805   return GetNodeAttr(attrs_, attr_name, value);
806 }
807 
808 }  // namespace shape_inference
809 }  // namespace tensorflow
810 
811 #endif  // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
812