• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
17 
18 #include <vector>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/core/framework/full_type.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/macros.h"
26 
27 namespace tensorflow {
28 
29 namespace grappler {
30 class GraphProperties;
31 class SymbolicShapeManager;
32 }  // namespace grappler
33 
34 namespace shape_inference {
35 
36 struct DimensionOrConstant;
37 class InferenceContext;
38 
39 // This header contains the InferenceContext that is used to infer the shape of
40 // the results of an operation or flag an operation with invalid inputs (e.g.,
41 // mismatched shapes for elementwise operation) by ShapeRefiner. The shape of an
42 // operation is computed using the OpShapeInferenceFn set via SetShapeFn in op
43 // registration. The OpShapeInferenceFn uses a per op InferenceContext populated
44 // with input shapes to compute resultant shape (including resource shapes).
45 //
46 // The shapes created in the InferenceContext are bound to the lifetime of the
47 // InferenceContext in which it was created. E.g., in
48 //
49 // ```c++
50 //  InferenceContext c;
51 //  // Below a ShapeHandle is returned by MakeShape, while UnknownDim returns a
52 //  // DimensionHandle.
53 //  ShapeHandle in0 = c.MakeShape({10, c.UnknownDim()});
54 // ```
55 //
56 // the ShapeHandle `in0` (and the nested unknown dim inside) is only valid while
57 // `c` is in scope, as ShapeHandle and DimensionHandle are effectively
58 // wrappers around pointers stored inside the context with the lifetime of the
59 // value pointed to managed by the context. The result from one operation's
60 // inference context will be passed as input to the inference of consumer
61 // operations. Hence it is possible for ShapeHandles produced by inference on a
62 // node to consist of ShapeHandles owned by different InferenceContexts. While
63 // inferring the shapes of a Graph, the InferenceContext of all nodes/operations
64 // in the Graph remain resident for the lifetime of the Graph (e.g, there is a
65 // map from each node to its InferenceContext, technically its
66 // ExtendedInferencContext which additionally stores the element types of inputs
67 // & outputs, which remains resident).
68 //
69 // For functions, the body of the function is instantiated as a Graph while
70 // inferring the result shapes of a function call node. The rules above apply
71 // while the function's shape is being inferred, but the contexts associated
72 // with nodes in the function body are released once the function call's
73 // resultant shapes are inferred. The shapes of results returned by a function
74 // are propagated to the InferenceContext of the function call's op (which is
75 // associated with a Graph of nodes whose shape is being inferred) as the return
76 // values of a function call node are the inputs of its consumer, but the return
77 // values are produced by nodes inside the function whose InferenceContexts
78 // (which owns the values pointed to by ShapeHandle and DimensionHandle) are
79 // reclaimed after inferring function result shapes. Recursive user-defined
80 // function are not supported hence inference of functions are fully nested with
81 // the InferenceContext's of function calls forming a stack.
82 //
83 // For example, consider the following call and function:
84 //
85 // ```python
86 // @tf.function
87 // def g(st):
88 //   d = tf.add(st, st)
89 //   return d
90 //
91 // @tf.function
92 // def f():
93 //   st = tf.A()
94 //   result = g(st)
95 //   return h(result)
96 // ```
97 //
98 // During inference of f, the shape of `A` will be inferred and the results from
99 // its InferenceContext used as inputs to function call `g(st)`. The call node
100 // will have an InferenceContext created (call it outer context) and the graph
101 // corresponding to function `g` will be instantiated. The result shape of the
102 // Arg nodes of the function will be associated with input from outer context.
103 // During inference of `g` (for the callsite `g(st)` in `f`), the
104 // InferenceContext of all nodes inside `g` will remain alive. Thus, when shape
105 // of `tf.add` is computed it may rely on all inputs. Once the RetVal nodes of a
106 // function is reached, we know the shape of its input may correspond to a shape
107 // queried in the outer context and it is explicitly copied to outer context. In
108 // this case that means that the shape of `d` is copied to the InferenceContext
109 // of `g(st)` and so when `h(result)` is executed this shape may be queried.
110 // Furthermore, no shapes computed due to call `g(st)` can be queried post this
111 // point and, as the RetVal shapes have been coppied into outer context, all
112 // InferenceContexts associated with nodes in function `g` instantiated for
113 // `g(st)` may be and are released.
114 
115 // Dimension values are accessed through InferenceContext.
116 class Dimension {
117  private:
118   Dimension();
119   Dimension(int64_t value);
~Dimension()120   ~Dimension() {}
121 
122   const int64_t value_;
123 
124   friend class InferenceContext;
125   friend class ShapeManager;
126   TF_DISALLOW_COPY_AND_ASSIGN(Dimension);
127 };
128 
129 class DimensionHandle {
130  public:
DimensionHandle()131   DimensionHandle() {}
SameHandle(DimensionHandle d)132   bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; }
Handle()133   std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
134 
135  private:
DimensionHandle(const Dimension * dim)136   DimensionHandle(const Dimension* dim) { ptr_ = dim; }
137 
138   const Dimension* operator->() const { return ptr_; }
IsSet()139   bool IsSet() const { return ptr_ != nullptr; }
140 
141   const Dimension* ptr_ = nullptr;
142 
143   friend struct DimensionOrConstant;
144   friend class InferenceContext;
145   friend class ShapeInferenceTest;
146   friend class ShapeInferenceTestutil;
147   friend class ::tensorflow::grappler::GraphProperties;
148   friend class ::tensorflow::grappler::SymbolicShapeManager;
149 
150   // Intentionally copyable.
151 };
152 
153 // Shape rank and dimensions are accessed through InferenceContext.
154 class Shape {
155  private:
156   Shape();
157   Shape(const std::vector<DimensionHandle>& dims);
~Shape()158   ~Shape() {}
159 
160   const int32 rank_;
161   const std::vector<DimensionHandle> dims_;
162 
163   friend class InferenceContext;
164   friend class ::tensorflow::grappler::SymbolicShapeManager;
165 
166   TF_DISALLOW_COPY_AND_ASSIGN(Shape);
167 };
168 
169 class ShapeHandle {
170  public:
ShapeHandle()171   ShapeHandle() {}
SameHandle(ShapeHandle s)172   bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; }
Handle()173   std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
174 
175  private:
ShapeHandle(const Shape * shape)176   ShapeHandle(const Shape* shape) { ptr_ = shape; }
177   const Shape* operator->() const { return ptr_; }
IsSet()178   bool IsSet() const { return ptr_ != nullptr; }
179 
180   const Shape* ptr_ = nullptr;
181 
182   friend class InferenceContext;
183   friend class ShapeInferenceTest;
184   friend class ShapeInferenceTestutil;
185   friend class ::tensorflow::grappler::SymbolicShapeManager;
186 
187   // Intentionally copyable.
188 };
189 
190 // Struct used to allow functions to take DimensionHandle or a dimension value.
191 // Not meant to be constructed directly.
192 struct DimensionOrConstant {
193  public:
194   // Intentionally not explicit.
195   DimensionOrConstant(DimensionHandle dim);
196 
197   // val must be non-negative or InferenceContext::kUnknownDim.
198   DimensionOrConstant(int64_t val);
199 
200   // dim takes precedence. If dim != nullptr, val is ignored.
201   DimensionHandle dim;
202   int64_t val;
203 
204  private:
205   DimensionOrConstant();
206 };
207 
208 struct ShapeAndType {
ShapeAndTypeShapeAndType209   ShapeAndType() {}
ShapeAndTypeShapeAndType210   ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
211   // TODO(mdan): Remove dtype from constructor, and use type_ instead.
212   // dtype is kept here for backward compatibiity. Its information should
213   // be redundant to that in type;
ShapeAndTypeShapeAndType214   ShapeAndType(ShapeHandle s, DataType t, FullTypeDef type_)
215       : shape(s), dtype(t), type(type_) {}
216 
217   ShapeHandle shape;
218   DataType dtype = DT_INVALID;
219   FullTypeDef type;
220 };
221 
222 // Shape inference functions registered on ops in REGISTER_OP implement
223 // their shape functions in terms of this InferenceContext.  An InferenceContext
224 // is created by the framework and passed to a shape inference function.  The
225 // shape inference function calls functions on the context, and should call
226 // set_output() to set the shape on all outputs.
227 //
228 // To infer shapes for user-defined functions see ShapeRefiner.
229 //
230 // All Shape* and Dimension* returned by functions of InferenceContext are owned
231 // by the InferenceContext.
232 class InferenceContext {
233  public:
234   static constexpr int64_t kUnknownDim = -1;
235   static constexpr int32_t kUnknownRank = -1;
236 
237   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
238   //
239   // Elements of <input_tensors_as_shapes> are used for when a shape function
240   // makes a call to MakeShapeFromShapeTensor; in particular, when the
241   // input_tensors[i] is nullptr but the shape represented by it is partially
242   // known from analysis of the graph.
243   // <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
244   // Values of <input_tensors_as_shapes> do not need to outlive the context.
245   InferenceContext(int graph_def_version, const AttrSlice& attrs,
246                    const OpDef& op_def,
247                    const std::vector<ShapeHandle>& input_shapes,
248                    const std::vector<const Tensor*>& input_tensors,
249                    const std::vector<ShapeHandle>& input_tensors_as_shapes,
250                    std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
251                        input_handle_shapes_and_types);
252 
253   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
254   //
255   // Elements of <input_tensors_as_shapes> are used for when a shape
256   // function makes a call to MakeShapeFromShapeTensor; in particular, when
257   // the input_tensors[i] is nullptr but the shape represented by it is
258   // partially known from analysis of the graph. <input_tensors_as_shapes>
259   // can have fewer elements than <input_shapes>. Values of
260   // <input_tensors_as_shapes> do not need to outlive the context.
261   InferenceContext(
262       int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
263       const std::vector<PartialTensorShape>& input_shapes,
264       const std::vector<const Tensor*>& input_tensors,
265       const std::vector<PartialTensorShape>& input_tensors_as_shapes,
266       const std::vector<std::unique_ptr<
267           std::vector<std::pair<PartialTensorShape, DataType>>>>&
268           input_handle_shapes_and_types);
269 
270   ~InferenceContext();
271 
272   // Runs the shape inference function 'fn' with 'this' as the
273   // argument, returns the status of the inference.
274   //
275   // On error, additional context is provided in the error message.
276   Status Run(
277       const std::function<Status(shape_inference::InferenceContext* c)>& fn);
278 
279   // Merge the stored shape of the input in position idx with <shape> according
280   // to the following rules:
281   //
282   // - If the ShapeHandles are the same or <shape> is unknown, there will be no
283   //   change. Otherwise if the stored shape is unknown, the new shape will be
284   //   <shape>.
285   // - If both shapes are known, then they must have the same rank.
286   // - For any one dimension, if the values for that dimension in both shapes
287   //   are known, then the values must match.
288   // - If one shape has equal or more information than the other shape in every
289   //   dimension, the new shape will become the shape with more information.
290   // - Example: merging [2,?] and [?,2] results in [2,2]
291   // - Example: [2,2] cannot be merged with [1,2]
292   //
293   // This requires idx to be in the [0, num_inputs) range. If the merge is
294   // successful, return true. Return false otherwise.
MergeInput(int idx,ShapeHandle shape)295   bool MergeInput(int idx, ShapeHandle shape) {
296     ShapeHandle new_shape;
297     if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
298     inputs_[idx] = new_shape;
299     return true;
300   }
301 
302   // Relax the stored shape of the input in position idx with <shape> according
303   // to the following rules:
304   //
305   // - If the ShapeHandles are the same then the stored shape will be returned.
306   // - If either of the ShapeHandles are unknown, then a new UnknownShape will
307   //   be returned. A new shape must be returned because we cannot claim that
308   //   the resulting shape is necessarily the same as either of the input
309   //   shapes.
310   // - If the shapes both have known ranks but their ranks are different, a new
311   //   UnknownShape will be returned.
312   // - For any one dimension, if the value for that dimension in either of the
313   //   shapes is unknown, a new shape will be returned with a new UnknownDim in
314   //   that dimension.
315   // - For any one dimension, if the values for that dimension in both shapes
316   //   are known but do not match, a new shape will be returned with a new
317   //   UnknownDim in that dimension.
318   // - If both shapes have the same known rank and match in every dimension,
319   //   the stored shape will be returned.
320   // - Example: relaxing [2,?] and [?,2] results in [?,?]
321   // - Example: relaxing [2,2] and [3,2] results in [?,2]
322   // - Example: relaxing [2,2] with [1,2,3] results in ?
323   //
324   // This requires idx to be in the [0, num_inputs) range. If the relax is
325   // successful and the new shape differs from the old one, store the new
326   // shape and return true. Return false otherwise.
RelaxInput(int idx,ShapeHandle shape)327   bool RelaxInput(int idx, ShapeHandle shape) {
328     ShapeHandle new_shape;
329     Relax(inputs_[idx], shape, &new_shape);
330     if (inputs_[idx].SameHandle(new_shape)) {
331       return false;
332     }
333     inputs_[idx] = new_shape;
334     return true;
335   }
336 
SetInput(int idx,ShapeHandle shape)337   void SetInput(int idx, ShapeHandle shape) { inputs_[idx] = shape; }
338 
input(int64_t idx)339   ShapeHandle input(int64_t idx) const { return inputs_[idx]; }
340   Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
num_inputs()341   int num_inputs() const { return inputs_.size(); }
342 
343   // Returns the input tensor at index <idx>, or nullptr if the input tensor is
344   // not available at the time of shape inference.
input_tensor(int idx)345   const Tensor* input_tensor(int idx) {
346     // Mark that this idx was requested.
347     request_input_tensor(idx);
348     return input_tensors_[idx];
349   }
350 
351   // Notifies the shape refiner that the value of the tensor at index <idx>
352   // is needed. The shape refiner tries to statically compute this tensor,
353   // and if successful re-runs the  shape function with this tensor available
354   // in the call to 'input_tensor(idx)'.
request_input_tensor(int idx)355   void request_input_tensor(int idx) { requested_input_tensor_[idx] = true; }
356 
357   // Returns true iff input_tensor(idx) was called by the shape function.
requested_input_tensor(int idx)358   bool requested_input_tensor(int idx) const {
359     return requested_input_tensor_[idx];
360   }
361 
362   // Notifies the shape refiner that the value of the tensor at index <idx>
363   // as a partial shape is needed. The shape refiner tries to statically compute
364   // this, and if successful re-runs the  shape function with the
365   // computed PartialTensorShape available in the call to
366   // 'MakeShapeFromShapeTensor(idx, handle)' or
367   // 'MakeShapeFromShapeTensorTreatScalarAsUnknownShape(idx, handle)'.
request_input_tensor_as_partial_shape(int idx)368   void request_input_tensor_as_partial_shape(int idx) {
369     requested_input_tensor_as_partial_shape_[idx] = true;
370   }
371 
372   // Returns true if MakeShapeFromInputTensor was called but the constant
373   // input_tensor was not present.
requested_input_tensor_as_partial_shape(int idx)374   bool requested_input_tensor_as_partial_shape(int idx) const {
375     return requested_input_tensor_as_partial_shape_[idx];
376   }
377 
set_input_tensors(const std::vector<const Tensor * > & input_tensors)378   void set_input_tensors(const std::vector<const Tensor*>& input_tensors) {
379     input_tensors_ = input_tensors;
380   }
381 
set_input_tensors_as_shapes(const std::vector<ShapeHandle> & input_tensors_as_shapes)382   void set_input_tensors_as_shapes(
383       const std::vector<ShapeHandle>& input_tensors_as_shapes) {
384     input_tensors_as_shapes_ = input_tensors_as_shapes;
385   }
386 
input_tensors_as_shapes()387   const std::vector<ShapeHandle>& input_tensors_as_shapes() const {
388     return input_tensors_as_shapes_;
389   }
390 
output(int64_t idx)391   ShapeHandle output(int64_t idx) const { return outputs_.at(idx); }
set_output(int idx,ShapeHandle shape)392   void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
393   Status set_output(StringPiece output_name,
394                     const std::vector<ShapeHandle>& shapes);
395 
num_outputs()396   int num_outputs() const { return outputs_.size(); }
output(int idx)397   ShapeHandle output(int idx) const { return outputs_.at(idx); }
398   Status output(StringPiece output_name,
399                 std::vector<ShapeHandle>* output) const;
400 
401   // Returns the value for attribute named `attr_name`.
GetAttr(StringPiece attr_name,const AttrValue ** attr_value)402   Status GetAttr(StringPiece attr_name, const AttrValue** attr_value) const {
403     return attrs_.Find(attr_name, attr_value);
404   }
GetAttr(StringPiece attr_name)405   const AttrValue* GetAttr(StringPiece attr_name) const {
406     return attrs_.Find(attr_name);
407   }
408 
ret_types()409   const FullTypeDef& ret_types() const { return ret_types_; }
410 
411   // idx can be negative for an offset from end of dimensions.
412   // idx must be in the range [-1 * s.rank, s.rank).
Dim(ShapeHandle s,int64_t idx)413   DimensionHandle Dim(ShapeHandle s, int64_t idx) {
414     if (!s.Handle() || s->rank_ == kUnknownRank) {
415       return UnknownDim();
416     }
417     return DimKnownRank(s, idx);
418   }
419   // As above, but asserts that the rank of the shape is known.
DimKnownRank(ShapeHandle s,int64_t idx)420   static DimensionHandle DimKnownRank(ShapeHandle s, int64_t idx) {
421     CHECK_NE(s->rank_, kUnknownRank);
422     if (idx < 0) {
423       return s->dims_[s->dims_.size() + idx];
424     }
425     return s->dims_[idx];
426   }
427 
Rank(ShapeHandle s)428   static int32 Rank(ShapeHandle s) {
429     return s.IsSet() ? s->rank_ : kUnknownRank;
430   }
RankKnown(ShapeHandle s)431   static bool RankKnown(ShapeHandle s) {
432     return (s.IsSet() && (Rank(s) != kUnknownRank));
433   }
Value(DimensionOrConstant d)434   static inline int64_t Value(DimensionOrConstant d) {
435     return d.dim.IsSet() ? d.dim->value_ : d.val;
436   }
ValueKnown(DimensionOrConstant d)437   static inline bool ValueKnown(DimensionOrConstant d) {
438     return Value(d) != kUnknownDim;
439   }
440 
441   // Fills the output proto with the shape defined by the handle.
442   // "proto" is expected to be empty prior to the call.
443   void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto);
444   TensorShapeProto ShapeHandleToProto(ShapeHandle handle);
445 
446   // Returns true if the rank and all dimensions of the Shape are known.
447   bool FullyDefined(ShapeHandle s);
448 
449   // Returns the total number of elements, or an unknown dimension for an
450   // incomplete shape.
451   DimensionHandle NumElements(ShapeHandle s);
452 
453   std::string DebugString(ShapeHandle s);
454   std::string DebugString(DimensionHandle d);
455   std::string DebugString(const ShapeAndType& shape_and_type);
456   std::string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types);
457 
458   // Describes the whole context, for debugging purposes.
459   std::string DebugString() const;
460 
461   // If <shape> has rank <rank>, or its rank is unknown, return OK and return
462   // the shape with asserted rank in <*out>. Otherwise return an error.
463   //
464   // Note that <*out> may be set to <shape>.
465   Status WithRank(ShapeHandle shape, int64_t rank,
466                   ShapeHandle* out) TF_MUST_USE_RESULT;
467   Status WithRankAtLeast(ShapeHandle shape, int64_t rank,
468                          ShapeHandle* out) TF_MUST_USE_RESULT;
469   Status WithRankAtMost(ShapeHandle shape, int64_t rank,
470                         ShapeHandle* out) TF_MUST_USE_RESULT;
471 
472   // If <dim> has value <value>, or its value is unknown, returns OK and returns
473   // the dimension with asserted value in <*out>. Otherwise returns an error.
474   //
475   // Note that <*out> may be set to <dim>.
476   Status WithValue(DimensionHandle dim, int64_t value,
477                    DimensionHandle* out) TF_MUST_USE_RESULT;
478 
479   // Merges <s0> and <s1> and returns the merged shape in <*out>. See
480   // 'MergeInput' function for full details and examples.
481   Status Merge(ShapeHandle s0, ShapeHandle s1,
482                ShapeHandle* out) TF_MUST_USE_RESULT;
483 
484   // Asserts that <s>'s rank >= <prefix>'s rank, and the first
485   // <prefix.rank> dimensions of <s> are compatible with the dimensions of
486   // <prefix>.
487   // Returns the merged results in <*s_out> and <*prefix_out>.
488   Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out,
489                      ShapeHandle* prefix_out) TF_MUST_USE_RESULT;
490 
491   // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0>
492   // and <d1> have incompatible values, returns an error.
493   //
494   // Note that <*out> may be set to <d0> or <d1>.
495   Status Merge(DimensionHandle d0, DimensionHandle d1,
496                DimensionHandle* out) TF_MUST_USE_RESULT;
497 
498   // Returns in <*out> a sub-shape of <s> with dimensions [start:].
499   // <start> can be negative to index from the end of the shape. If <start> >
500   // rank of <s>, then an empty subshape is returned.
501   Status Subshape(ShapeHandle s, int64_t start,
502                   ShapeHandle* out) TF_MUST_USE_RESULT;
503 
504   // Returns in <*out> a sub-shape of <s>, with dimensions [start:end].
505   // <start> and <end> can be negative, to index from the end of the shape.
506   // <start> and <end> are set to the rank of <s> if > rank of <s>.
507   Status Subshape(ShapeHandle s, int64_t start, int64_t end,
508                   ShapeHandle* out) TF_MUST_USE_RESULT;
509 
510   // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride].
511   // <start> and <end> can be negative, to index from the end of the shape.
512   // <start> and <end> are set to the rank of <s> if > rank of <s>.
513   // <stride> can be negative, to reverse the <s>.
514   Status Subshape(ShapeHandle s, int64_t start, int64_t end, int64_t stride,
515                   ShapeHandle* out) TF_MUST_USE_RESULT;
516 
517   // Returns in <*out> the result of appending the dimensions of <s2> to those
518   // of <s1>.
519   Status Concatenate(ShapeHandle s1, ShapeHandle s2,
520                      ShapeHandle* out) TF_MUST_USE_RESULT;
521 
522   // Returns in <out> the shape from replacing <s.dim[dim_index]> with
523   // <new_dim>.
524   Status ReplaceDim(ShapeHandle s, int64_t dim_index, DimensionHandle new_dim,
525                     ShapeHandle* out) TF_MUST_USE_RESULT;
526 
527   // Returns a new shape with the given dims. The returned value is owned by
528   // this context.
529   ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
530   ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims);
531 
532   // Returns a new unknown shape.
533   ShapeHandle UnknownShape();
534 
535   // Returns a shape with specified rank but unknown dims.
536   ShapeHandle UnknownShapeOfRank(int64_t rank);
537 
538   // Returns a new shape of zero dimensions.
539   ShapeHandle Scalar();
540 
541   // Returns a new shape of one dimension.
542   ShapeHandle Vector(DimensionOrConstant dim);
543 
544   // Returns a new shape of two dimensions.
545   ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2);
546 
547   // Returns in <out> a new shape whose dimension sizes come from input tensor
548   // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor.  If
549   // the input tensor is NULL, then an unknown shape is returned.
550   Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
551 
552   // Like the function above, but treats scalar values as unknown
553   // shapes.  **NOTE** If the scalar is statically known, its value
554   // must be -1 or an error is returned.
555   Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,
556                                                            ShapeHandle* out);
557 
558   // Returns in <out> a new shape corresponding to <proto>.
559   Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
560                                  ShapeHandle* out);
561 
562   // Returns in <out> a new shape corresponding to <partial_shape>.
563   Status MakeShapeFromPartialTensorShape(
564       const PartialTensorShape& partial_shape, ShapeHandle* out);
565 
566   // Returns in <out> a new shape corresponding to <shape>.
567   Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out);
568   StatusOr<ShapeHandle> MakeShapeFromShapeTensor(const TensorShape& shape);
569 
570   // Returns a new dimension of the given size.  The returned value is owned by
571   // this context.
MakeDim(DimensionOrConstant d)572   inline DimensionHandle MakeDim(DimensionOrConstant d) {
573     return shape_manager_.MakeDim(d);
574   }
575 
UnknownDim()576   inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
577 
578   // Returns in <val> a scalar value from an input tensor <t>.  The input tensor
579   // must be a 0-dimensional int32 or int64 tensor.  Caller must ensure that the
580   // input tensor is not NULL.
581   Status GetScalarFromTensor(const Tensor* t, int64_t* val);
582 
583   // Returns in <val> a scalar value from a 1D input tensor <t> with int32 or
584   // int64 elements. Caller must ensure that the input tensor is not NULL.
585   Status GetScalarFromTensor(const Tensor* t, int64_t idx, int64_t* val);
586 
587   // Returns a new dimension whose value is given by a scalar input tensor.
588   // The input tensor must be in host memory, since it is dereferenced to get
589   // the value.
590   Status MakeDimForScalarInput(int idx, DimensionHandle* out);
591 
592   // Returns a new dimension whose value is given by a scalar input tensor.
593   // This allows for a negative input dimension given the rank of a separate
594   // tensor.  This rank can be negative if unknown.
595   // The input tensor must be in host memory, since it is dereferenced to get
596   // the value.
597   Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
598                                                    DimensionHandle* out);
599 
600   // Look up the attr being evaluated with name attr_name and set *value to its
601   // value. If no attr with attr_name is found in def(), or the attr does not
602   // have a matching type, a non-ok status will be returned.
603   template <class T>
604   Status GetAttr(StringPiece attr_name, T* value) const;
605 
606   // Returns in <out> the result of dividing <dividend> by <divisor>.
607   // Returns an error if <divisor>  is not positive or if <evenly_divisible>
608   // and <divisor> does not evenly divide <dividend>.
609   Status Divide(DimensionHandle dividend, DimensionOrConstant divisor,
610                 bool evenly_divisible, DimensionHandle* out);
611 
612   // Returns in <out> the sum of <first> and <second>.
613   Status Add(DimensionHandle first, DimensionOrConstant second,
614              DimensionHandle* out);
615 
616   // Returns in <out> the dimension that is <first> minus <second>.
617   Status Subtract(DimensionHandle first, DimensionOrConstant second,
618                   DimensionHandle* out);
619 
620   // Returns in <out> the product of <first> and <second>.
621   Status Multiply(DimensionHandle first, DimensionOrConstant second,
622                   DimensionHandle* out);
623 
624   // Returns in <out> the minimum of <first> and <second>. If either <first> or
625   // <second> is zero the results is zero. Otherwise, if either <first> or
626   // <second> is unknown the results is unknown.
627   Status Min(DimensionHandle first, DimensionOrConstant second,
628              DimensionHandle* out);
629 
630   // Returns in <out> the maximum of <first> and <second>. If either <first> or
631   // <second> is unknown the results is unknown.
632   Status Max(DimensionHandle first, DimensionOrConstant second,
633              DimensionHandle* out);
634 
construction_status()635   Status construction_status() const { return construction_status_; }
636 
637   // Methods to propagate shape and dtype on edges of handles. Handles are the
638   // dtype DT_RESOURCE which can be used to access state stored in a
639   // ResourceManager. When ops (such as variables) consume these handles to
640   // produce tensors they might need to know side-information about the shapes
641   // and dtypes of tensors which can be accessed via the handle. These methods
642   // propagate that information. Output handle dtypes and shapes are ignored if
643   // the output tensor is not of type DT_RESOURCE.
644 
645   // Merge the stored shapes and types corresponding to the input handle in
646   // position idx with the specified shapes and types. This requires idx to be
647   // in the [0, num_inputs) range.
648   //
649   // If the merge is successful and any of the new shapes differs from the old
650   // one, or any of the old dtypes was DT_INVALID, store the new shapes and
651   // return true.  Return false otherwise.
652   //
653   // See 'MergeInput' function for full details and examples.
654   bool MergeInputHandleShapesAndTypes(
655       int idx,
656       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
657 
658   // As MergeInputHandleShapesAndTypes, but for an output.
659   bool MergeOutputHandleShapesAndTypes(
660       int idx,
661       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
662 
663   // Relaxes the stored shapes and types corresponding to the input handle in
664   // position idx with the specified shapes and types. This requires idx to be
665   // in the [0, num_inputs) range.
666   //
667   // If the relax is successful (sizes are the same, old dtypes match new ones
668   // or are DT_INVALID), then store the relaxed shapes and return true.
669   // Return false otherwise.
670   //
671   // See 'RelaxInput' function for full details and examples.
672   bool RelaxInputHandleShapesAndMergeTypes(
673       int idx,
674       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
675 
676   // As RelaxInputHandleShapesAndTypes, but for an output.
677   bool RelaxOutputHandleShapesAndMergeTypes(
678       int idx,
679       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
680 
set_input_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)681   void set_input_handle_shapes_and_types(
682       int idx, const std::vector<ShapeAndType>& shapes_and_types) {
683     input_handle_shapes_and_types_[idx] =
684         absl::make_unique<std::vector<ShapeAndType>>(shapes_and_types);
685   }
686 
687   // Returns the output handle shapes and types, for the resource tensor output
688   // at index <idx>. Returns NULL if the shape and types were never set.
output_handle_shapes_and_types(int idx)689   const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
690     return output_handle_shapes_and_types_[idx].get();
691   }
692 
693   // Returns the inputs handle shapes and types, for the resource tensor input
694   // at index <idx>. Returns NULL if the shape and types were not available.
input_handle_shapes_and_types(int idx)695   const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
696     return input_handle_shapes_and_types_[idx].get();
697   }
698 
set_output_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)699   void set_output_handle_shapes_and_types(
700       int idx, const std::vector<ShapeAndType>& shapes_and_types) {
701     output_handle_shapes_and_types_[idx] =
702         absl::make_unique<std::vector<ShapeAndType>>(shapes_and_types);
703   }
704 
705   // Note that shape functions should usually call MakeShapeFromShapeTensor,
706   // as it does more analysis to provide partial shapes.
707   //
708   // Returns in <out> a new shape whose dimension sizes come from tensor <t>.
709   // The tensor must be a 1-dimensional int32 or int64 tensor.  If <t> is NULL,
710   // then an unknown shape is returned.
711   Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
712                              ShapeHandle* out);
713 
graph_def_version()714   int graph_def_version() const { return graph_def_version_; }
715 
MergedShapes()716   const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const {
717     return merged_shapes_;
718   }
MergedDims()719   const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims()
720       const {
721     return merged_dims_;
722   }
723 
724   // Adds new outputs; useful when mutating the graph.
725   Status ExpandOutputs(int new_output_size);
726 
727  private:
728   // Creates and stores shapes for use in InferenceContext.
729   class ShapeManager {
730    public:
731     ShapeManager();
732     ~ShapeManager();
733 
734     // Returns a new shape with the given dims. The returned value is owned by
735     // this class.
736     ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
737 
738     // Returns a new unknown shape.
739     ShapeHandle UnknownShape();
740 
741     // Returns a new dimension of the given size.  The returned value
742     // is owned by this class.
MakeDim(DimensionOrConstant d)743     inline DimensionHandle MakeDim(DimensionOrConstant d) {
744       if (d.dim.IsSet()) {
745         return d.dim;
746       } else {
747         all_dims_.push_back(new Dimension(d.val));
748         return all_dims_.back();
749       }
750     }
751 
752    private:
753     std::vector<Shape*> all_shapes_;    // values are owned.
754     std::vector<Dimension*> all_dims_;  // values are owned.
755   };
756 
757   friend class ::tensorflow::grappler::GraphProperties;
758 
759   friend class ShapeInferenceTest;      // For testing Relax functions.
760   friend class ShapeInferenceTestutil;  // For testing shapes.
761 
762   // Shared initialization across the two constructors.  Remove
763   // once we get rid of one of them.
764   void PreInputInit(const OpDef& op_def,
765                     const std::vector<const Tensor*>& input_tensors,
766                     const std::vector<ShapeHandle>& input_tensors_as_shapes);
767   void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
768                          input_handle_data);
769 
ReturnUnknownShape(ShapeHandle * out)770   Status ReturnUnknownShape(ShapeHandle* out) {
771     *out = UnknownShape();
772     return OkStatus();
773   }
ReturnCreatedShape(const std::vector<DimensionHandle> & dims,ShapeHandle * out)774   Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims,
775                             ShapeHandle* out) {
776     *out = MakeShape(dims);
777     return OkStatus();
778   }
779 
780   // Adds additional context to the given status.
781   Status AttachContext(const Status& status);
782 
783   // Relaxes an existing value <d_old> with a new value <d_new> and returns the
784   // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible
785   // values, returns an error.
786   //
787   // Note that <*out> may be set to <d_old> or <d_new>.
788   void Relax(DimensionHandle d_old, DimensionHandle d_new,
789              DimensionHandle* out);
790   // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the
791   // relaxed shape in <*out>. See 'RelaxInput' function for full details and
792   // examples.
793   void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out);
794 
795   // Used to implement MergeInputHandleShapesAndTypes and
796   // MergeOutputHandleShapesAndTypes.
797   bool MergeHandleShapesAndTypes(
798       const std::vector<ShapeAndType>& shapes_and_types,
799       std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
800   // Used to implement RelaxInputHandleShapesAndMergeTypes and
801   // RelaxOutputHandleShapesAndMergeTypes.
802   bool RelaxHandleShapesAndMergeTypes(
803       const std::vector<ShapeAndType>& shapes_and_types,
804       std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
805 
806   // Forget all the previous merged shapes and dims.
ForgetMerges()807   void ForgetMerges() {
808     merged_shapes_.clear();
809     merged_dims_.clear();
810   }
811 
812   // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor.
813   Status InternalMakeShapeFromTensor(
814       bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
815       ShapeHandle tensor_shape, ShapeHandle* out);
816 
817   ShapeManager shape_manager_;
818 
819   // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
820   // `shape_manager_`.
821   std::vector<ShapeHandle> inputs_;
822   std::vector<const Tensor*> input_tensors_;
823   std::vector<bool> requested_input_tensor_;
824   std::vector<ShapeHandle> outputs_;
825   // Can have fewer elements than inputs_.
826   std::vector<ShapeHandle> input_tensors_as_shapes_;
827   std::vector<bool> requested_input_tensor_as_partial_shape_;
828 
829   // input_handle_shapes_and_types_[i] is the list of shape/type pairs available
830   // through the resource handle passed along input i of the node.
831   //
832   // Values may be NULL.
833   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
834       input_handle_shapes_and_types_;
835 
836   // output_handle_shapes_and_types_[i] is the list of shape/type pairs
837   // available through the resource handle passed along output i of the node.
838   //
839   // Values may be NULL.
840   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
841       output_handle_shapes_and_types_;
842 
843   // Return types for the node this context is associated with. This information
844   // is to eventually consolidate all the dtype and shape info, allowing for
845   // output_handle_shapes_and_types_ to be removed.
846   FullTypeDef ret_types_;
847 
848   const int graph_def_version_;
849   AttrSlice attrs_;
850   NameRangeMap input_name_map_;
851   NameRangeMap output_name_map_;
852 
853   // An error set during construction. TODO(cwhipkey): remove when test
854   // constructor is removed.
855   Status construction_status_;
856 
857   // Pair of shape or dim handles that are equivalent, ie that represent the
858   // same underlying shape of dimension. Note that for each pair at least one of
859   // the handles must contain an unknown shape, since we don't keep track of
860   // known shapes or dims here.
861   std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_;
862   std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_;
863 
864   TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
865 };
866 
867 // -----------------------------------------------------------------------------
868 // Template and inline method implementations, please ignore
869 
Dimension()870 inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
Dimension(int64_t value)871 inline Dimension::Dimension(int64_t value) : value_(value) {
872   DCHECK(value >= 0 || value == InferenceContext::kUnknownDim)
873       << "Dimension must be non-negative or equal to "
874          "InferenceContext::kUnknownDim but got "
875       << value;
876 }
877 
Shape()878 inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
Shape(const std::vector<DimensionHandle> & dims)879 inline Shape::Shape(const std::vector<DimensionHandle>& dims)
880     : rank_(dims.size()), dims_(dims) {}
881 
DimensionOrConstant(DimensionHandle dim)882 inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim)
883     : dim(dim) {
884   DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension.";
885 }
886 
DimensionOrConstant(int64_t val)887 inline DimensionOrConstant::DimensionOrConstant(int64_t val) : val(val) {
888   DCHECK(val >= 0 || val == InferenceContext::kUnknownDim)
889       << "Dimension must be non-negative or equal to "
890          "InferenceContext::kUnknownDim but got "
891       << val;
892 }
893 
894 template <class T>
GetAttr(StringPiece attr_name,T * value)895 Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
896   return GetNodeAttr(attrs_, attr_name, value);
897 }
898 
899 }  // namespace shape_inference
900 }  // namespace tensorflow
901 
902 #endif  // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
903