• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
17 
18 #include <vector>
19 
20 #include "tensorflow/core/framework/node_def_util.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/lib/gtl/inlined_vector.h"
25 #include "tensorflow/core/platform/macros.h"
26 
27 namespace tensorflow {
28 
29 class ShapeRefiner;
30 class ShapeRefinerTest;
31 
32 namespace grappler {
33 class GraphProperties;
34 class SymbolicShapeManager;
35 }  // namespace grappler
36 
37 namespace shape_inference {
38 
39 struct DimensionOrConstant;
40 class InferenceContext;
41 
42 // Dimension values are accessed through InferenceContext.
43 class Dimension {
44  private:
45   Dimension();
46   Dimension(int64 value);
~Dimension()47   ~Dimension() {}
48 
49   const int64 value_;
50 
51   friend class InferenceContext;
52   friend class ShapeManager;
53   TF_DISALLOW_COPY_AND_ASSIGN(Dimension);
54 };
55 
56 class DimensionHandle {
57  public:
DimensionHandle()58   DimensionHandle() {}
SameHandle(DimensionHandle d)59   bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; }
Handle()60   std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
61 
62  private:
DimensionHandle(const Dimension * dim)63   DimensionHandle(const Dimension* dim) { ptr_ = dim; }
64 
65   const Dimension* operator->() const { return ptr_; }
IsSet()66   bool IsSet() const { return ptr_ != nullptr; }
67 
68   const Dimension* ptr_ = nullptr;
69 
70   friend struct DimensionOrConstant;
71   friend class InferenceContext;
72   friend class ShapeInferenceTest;
73   friend class ShapeInferenceTestutil;
74   friend class ::tensorflow::ShapeRefinerTest;
75   friend class ShapeManager;
76   friend class ::tensorflow::grappler::GraphProperties;
77   friend class ::tensorflow::grappler::SymbolicShapeManager;
78 
79   // Intentionally copyable.
80 };
81 
82 // Shape rank and dimensions are accessed through InferenceContext.
83 class Shape {
84  private:
85   Shape();
86   Shape(const std::vector<DimensionHandle>& dims);
~Shape()87   ~Shape() {}
88 
89   const int32 rank_;
90   const std::vector<DimensionHandle> dims_;
91 
92   friend class InferenceContext;
93   friend class ShapeManager;
94   friend class ::tensorflow::grappler::SymbolicShapeManager;
95 
96   TF_DISALLOW_COPY_AND_ASSIGN(Shape);
97 };
98 
99 class ShapeHandle {
100  public:
ShapeHandle()101   ShapeHandle() {}
SameHandle(ShapeHandle s)102   bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; }
Handle()103   std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
104 
105  private:
ShapeHandle(const Shape * shape)106   ShapeHandle(const Shape* shape) { ptr_ = shape; }
107   const Shape* operator->() const { return ptr_; }
IsSet()108   bool IsSet() const { return ptr_ != nullptr; }
109 
110   const Shape* ptr_ = nullptr;
111 
112   friend class InferenceContext;
113   friend class ShapeInferenceTest;
114   friend class ShapeInferenceTestutil;
115   friend class ::tensorflow::ShapeRefinerTest;
116   friend class ShapeManager;
117   friend class ::tensorflow::grappler::SymbolicShapeManager;
118 
119   // Intentionally copyable.
120 };
121 
122 // Struct used to allow functions to take DimensionHandle or a dimension value.
123 // Not meant to be constructed directly.
124 struct DimensionOrConstant {
125  public:
126   // Intentionally not explicit.
127   DimensionOrConstant(DimensionHandle dim);
128 
129   // val must be non-negative or InferenceContext::kUnknownDim.
130   DimensionOrConstant(int64 val);
131 
132   // dim takes precedence. If dim != nullptr, val is ignored.
133   DimensionHandle dim;
134   int64 val;
135 
136  private:
137   DimensionOrConstant();
138 };
139 
140 struct ShapeAndType {
ShapeAndTypeShapeAndType141   ShapeAndType() {}
ShapeAndTypeShapeAndType142   ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
143 
144   ShapeHandle shape;
145   DataType dtype = DT_INVALID;
146 };
147 
148 // Shape inference functions registered on ops in REGISTER_OP implement
149 // their shape functions in terms of this InferenceContext.  An InferenceContext
150 // is created by the framework and passed to a shape inference function.  The
151 // shape inference function calls functions on the context, and should call
152 // set_output() to set the shape on all outputs.
153 //
154 // To infer shapes for user-defined functions see ShapeRefiner.
155 //
156 // All Shape* and Dimension* returned by functions of InferenceContext are owned
157 // by the InferenceContext.
158 class InferenceContext {
159  public:
160   static constexpr int64 kUnknownDim = -1;
161   static constexpr int32 kUnknownRank = -1;
162 
163   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
164   //
165   // Elements of <input_tensors_as_shapes> are used for when a shape function
166   // makes a call to MakeShapeFromShapeTensor; in particular, when the
167   // input_tensors[i] is nullptr but the shape represented by it is partially
168   // known from analysis of the graph.
169   // <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
170   // Values of <input_tensors_as_shapes> do not need to outlive the context.
171   //
172   // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
173   InferenceContext(int graph_def_version, const NodeDef* node_def,
174                    const OpDef& op_def,
175                    const std::vector<ShapeHandle>& input_shapes,
176                    const std::vector<const Tensor*>& input_tensors,
177                    const std::vector<ShapeHandle>& input_tensors_as_shapes,
178                    std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
179                        input_handle_shapes_and_types);
180 
181   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
182   //
183   // Elements of <input_tensors_as_shapes> are used for when a shape
184   // function makes a call to MakeShapeFromShapeTensor; in particular, when
185   // the input_tensors[i] is nullptr but the shape represented by it is
186   // partially known from analysis of the graph. <input_tensors_as_shapes>
187   // can have fewer elements than <input_shapes>. Values of
188   // <input_tensors_as_shapes> do not need to outlive the context.
189   //
190   // REQUIRES: <node_def> is not NULL, and must outlive the
191   // InferenceContext.
192   InferenceContext(
193       int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
194       const std::vector<TensorShapeProto>& input_shapes,
195       const std::vector<const Tensor*>& input_tensors,
196       const std::vector<TensorShapeProto>& input_tensors_as_shapes,
197       const std::vector<
198           std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
199           input_handle_shapes_and_types);
200 
201   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
202   //
203   // Elements of <input_tensors_as_shapes> are used for when a shape
204   // function makes a call to MakeShapeFromShapeTensor; in particular, when
205   // the input_tensors[i] is nullptr but the shape represented by it is
206   // partially known from analysis of the graph. <input_tensors_as_shapes>
207   // can have fewer elements than <input_shapes>. Values of
208   // <input_tensors_as_shapes> do not need to outlive the context.
209   //
210   // REQUIRES: <node_def> is not NULL, and must outlive the
211   // InferenceContext.
212   InferenceContext(
213       int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
214       const std::vector<PartialTensorShape>& input_shapes,
215       const std::vector<const Tensor*>& input_tensors,
216       const std::vector<PartialTensorShape>& input_tensors_as_shapes,
217       const std::vector<std::unique_ptr<
218           std::vector<std::pair<PartialTensorShape, DataType>>>>&
219           input_handle_shapes_and_types);
220 
221   ~InferenceContext();
222 
223   // Runs the shape inference function 'fn' with 'this' as the
224   // argument, returns the status of the inference.
225   //
226   // On error, additional context is provided in the error message.
227   Status Run(
228       const std::function<Status(shape_inference::InferenceContext* c)>& fn);
229 
230   // Merge the stored shape of the input in position idx with <shape> according
231   // to the following rules:
232   //
233   // - If the ShapeHandles are the same or <shape> is unknown, there will be no
234   //   change. Otherwise if the stored shape is unknown, the new shape will be
235   //   <shape>.
236   // - If both shapes are known, then they must have the same rank.
237   // - For any one dimension, if the values for that dimension in both shapes
238   //   are known, then the values must match.
239   // - If one shape has equal or more information than the other shape in every
240   //   dimension, the new shape will become the shape with more information.
241   // - Example: merging [2,?] and [?,2] results in [2,2]
242   // - Example: [2,2] cannot be merged with [1,2]
243   //
244   // This requires idx to be in the [0, num_inputs) range. If the merge is
245   // successful, return true. Return false otherwise.
MergeInput(int idx,ShapeHandle shape)246   bool MergeInput(int idx, ShapeHandle shape) {
247     ShapeHandle new_shape;
248     if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
249     inputs_[idx] = new_shape;
250     return true;
251   }
252 
253   // Relax the stored shape of the input in position idx with <shape> according
254   // to the following rules:
255   //
256   // - If the ShapeHandles are the same then the stored shape will be returned.
257   // - If either of the ShapeHandles are unknown, then a new UnknownShape will
258   //   be returned. A new shape must be returned because we cannot claim that
259   //   the resulting shape is necessarily the same as either of the input
260   //   shapes.
261   // - If the shapes both have known ranks but their ranks are different, a new
262   //   UnknownShape will be returned.
263   // - For any one dimension, if the value for that dimension in either of the
264   //   shapes is unknown, a new shape will be returned with a new UnknownDim in
265   //   that dimension.
266   // - For any one dimension, if the values for that dimension in both shapes
267   //   are known but do not match, a new shape will be returned with a new
268   //   UnknownDim in that dimension.
269   // - If both shapes have the same known rank and match in every dimension,
270   //   the stored shape will be returned.
271   // - Example: relaxing [2,?] and [?,2] results in [?,?]
272   // - Example: relaxing [2,2] and [3,2] results in [?,2]
273   // - Example: relaxing [2,2] with [1,2,3] results in ?
274   //
275   // This requires idx to be in the [0, num_inputs) range. If the relax is
276   // successful and the new shape differs from the old one, store the new
277   // shape and return true. Return false otherwise.
RelaxInput(int idx,ShapeHandle shape)278   bool RelaxInput(int idx, ShapeHandle shape) {
279     ShapeHandle new_shape;
280     Relax(inputs_[idx], shape, &new_shape);
281     if (inputs_[idx].SameHandle(new_shape)) {
282       return false;
283     }
284     inputs_[idx] = new_shape;
285     return true;
286   }
287 
SetInput(int idx,ShapeHandle shape)288   void SetInput(int idx, ShapeHandle shape) { inputs_[idx] = shape; }
289 
input(int64 idx)290   ShapeHandle input(int64 idx) const { return inputs_[idx]; }
291   Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
num_inputs()292   int num_inputs() const { return inputs_.size(); }
293 
294   // Returns the input tensor at index <idx>, or nullptr if the input tensor is
295   // not available at the time of shape inference.
input_tensor(int idx)296   const Tensor* input_tensor(int idx) {
297     // Mark that this idx was requested.
298     requested_input_tensor_[idx] = true;
299     return input_tensors_[idx];
300   }
301 
302   // Returns true iff input_tensor(idx) was called by the shape function.
requested_input_tensor(int idx)303   bool requested_input_tensor(int idx) const {
304     return requested_input_tensor_[idx];
305   }
306 
307   // Returns true if MakeShapeFromInputTensor was called but the constant
308   // input_tensor was not present.
requested_input_tensor_as_partial_shape(int idx)309   bool requested_input_tensor_as_partial_shape(int idx) const {
310     return requested_input_tensor_as_partial_shape_[idx];
311   }
312 
set_input_tensors(const std::vector<const Tensor * > & input_tensors)313   void set_input_tensors(const std::vector<const Tensor*>& input_tensors) {
314     input_tensors_ = input_tensors;
315   }
316 
set_input_tensors_as_shapes(const std::vector<ShapeHandle> & input_tensors_as_shapes)317   void set_input_tensors_as_shapes(
318       const std::vector<ShapeHandle>& input_tensors_as_shapes) {
319     input_tensors_as_shapes_ = input_tensors_as_shapes;
320   }
321 
input_tensors_as_shapes()322   const std::vector<ShapeHandle>& input_tensors_as_shapes() const {
323     return input_tensors_as_shapes_;
324   }
325 
output(int64 idx)326   ShapeHandle output(int64 idx) const { return outputs_.at(idx); }
set_output(int idx,ShapeHandle shape)327   void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
328   Status set_output(StringPiece output_name,
329                     const std::vector<ShapeHandle>& shapes);
330 
num_outputs()331   int num_outputs() const { return outputs_.size(); }
output(int idx)332   ShapeHandle output(int idx) const { return outputs_.at(idx); }
333   Status output(StringPiece output_name,
334                 std::vector<ShapeHandle>* output) const;
335 
attrs()336   AttrSlice attrs() const { return AttrSlice(*node_def_); }
337 
338   string op() const;
339 
340   // idx can be negative for an offset from end of dimensions.
341   // idx must be in the range [-1 * s.rank, s.rank).
Dim(ShapeHandle s,int64 idx)342   DimensionHandle Dim(ShapeHandle s, int64 idx) {
343     if (s->rank_ == kUnknownRank) {
344       return UnknownDim();
345     }
346     return DimKnownRank(s, idx);
347   }
348   // As above, but asserts that the rank of the shape is known.
DimKnownRank(ShapeHandle s,int64 idx)349   static DimensionHandle DimKnownRank(ShapeHandle s, int64 idx) {
350     CHECK_NE(s->rank_, kUnknownRank);
351     if (idx < 0) {
352       return s->dims_[s->dims_.size() + idx];
353     }
354     return s->dims_[idx];
355   }
356 
Rank(ShapeHandle s)357   static int32 Rank(ShapeHandle s) {
358     DCHECK(s.IsSet());
359     return s.IsSet() ? s->rank_ : kUnknownRank;
360   }
RankKnown(ShapeHandle s)361   static bool RankKnown(ShapeHandle s) {
362     return (s.IsSet() && (Rank(s) != kUnknownRank));
363   }
Value(DimensionOrConstant d)364   static inline int64 Value(DimensionOrConstant d) {
365     return d.dim.IsSet() ? d.dim->value_ : d.val;
366   }
ValueKnown(DimensionOrConstant d)367   static inline bool ValueKnown(DimensionOrConstant d) {
368     return Value(d) != kUnknownDim;
369   }
370 
371   // Fills the output proto with the shape defined by the handle.
372   // "proto" is expected to be empty prior to the call.
373   void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto);
374 
375   // Returns true if the rank and all dimensions of the Shape are known.
376   bool FullyDefined(ShapeHandle s);
377 
378   // Returns the total number of elements, or an unknown dimension for an
379   // incomplete shape.
380   DimensionHandle NumElements(ShapeHandle s);
381 
382   string DebugString(ShapeHandle s);
383   string DebugString(DimensionHandle d);
384   string DebugString(const ShapeAndType& shape_and_type);
385   string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types);
386 
387   // Describes the whole context, for debugging purposes.
388   string DebugString() const;
389 
390   // If <shape> has rank <rank>, or its rank is unknown, return OK and return
391   // the shape with asserted rank in <*out>. Otherwise return an error.
392   //
393   // Note that <*out> may be set to <shape>.
394   Status WithRank(ShapeHandle shape, int64 rank,
395                   ShapeHandle* out) TF_MUST_USE_RESULT;
396   Status WithRankAtLeast(ShapeHandle shape, int64 rank,
397                          ShapeHandle* out) TF_MUST_USE_RESULT;
398   Status WithRankAtMost(ShapeHandle shape, int64 rank,
399                         ShapeHandle* out) TF_MUST_USE_RESULT;
400 
401   // If <dim> has value <value>, or its value is unknown, returns OK and returns
402   // the dimension with asserted value in <*out>. Otherwise returns an error.
403   //
404   // Note that <*out> may be set to <dim>.
405   Status WithValue(DimensionHandle dim, int64 value,
406                    DimensionHandle* out) TF_MUST_USE_RESULT;
407 
408   // Merges <s0> and <s1> and returns the merged shape in <*out>. See
409   // 'MergeInput' function for full details and examples.
410   Status Merge(ShapeHandle s0, ShapeHandle s1,
411                ShapeHandle* out) TF_MUST_USE_RESULT;
412 
413   // Asserts that <s>'s rank >= <prefix>'s rank, and the first
414   // <prefix.rank> dimensions of <s> are compatible with the dimensions of
415   // <prefix>.
416   // Returns the merged results in <*s_out> and <*prefix_out>.
417   Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out,
418                      ShapeHandle* prefix_out) TF_MUST_USE_RESULT;
419 
420   // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0>
421   // and <d1> have incompatible values, returns an error.
422   //
423   // Note that <*out> may be set to <d0> or <d1>.
424   Status Merge(DimensionHandle d0, DimensionHandle d1,
425                DimensionHandle* out) TF_MUST_USE_RESULT;
426 
427   // Returns in <*out> a sub-shape of <s> with dimensions [start:].
428   // <start> can be negative to index from the end of the shape. If <start> >
429   // rank of <s>, then an empty subshape is returned.
430   Status Subshape(ShapeHandle s, int64 start,
431                   ShapeHandle* out) TF_MUST_USE_RESULT;
432 
433   // Returns in <*out> a sub-shape of <s>, with dimensions [start:end].
434   // <start> and <end> can be negative, to index from the end of the shape.
435   // <start> and <end> are set to the rank of <s> if > rank of <s>.
436   Status Subshape(ShapeHandle s, int64 start, int64 end,
437                   ShapeHandle* out) TF_MUST_USE_RESULT;
438 
439   // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride].
440   // <start> and <end> can be negative, to index from the end of the shape.
441   // <start> and <end> are set to the rank of <s> if > rank of <s>.
442   // <stride> can be negative, to reverse the <s>.
443   Status Subshape(ShapeHandle s, int64 start, int64 end, int64 stride,
444                   ShapeHandle* out) TF_MUST_USE_RESULT;
445 
446   // Returns in <*out> the result of appending the dimensions of <s2> to those
447   // of <s1>.
448   Status Concatenate(ShapeHandle s1, ShapeHandle s2,
449                      ShapeHandle* out) TF_MUST_USE_RESULT;
450 
451   // Returns in <out> the shape from replacing <s.dim[dim_index]> with
452   // <new_dim>.
453   Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim,
454                     ShapeHandle* out) TF_MUST_USE_RESULT;
455 
456   // Returns a new shape with the given dims. The returned value is owned by
457   // this context.
458   ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
459   ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims);
460 
461   // Returns a new unknown shape.
462   ShapeHandle UnknownShape();
463 
464   // Returns a shape with specified rank but unknown dims.
465   ShapeHandle UnknownShapeOfRank(int64 rank);
466 
467   // Returns a new shape of zero dimensions.
468   ShapeHandle Scalar();
469 
470   // Returns a new shape of one dimension.
471   ShapeHandle Vector(DimensionOrConstant dim);
472 
473   // Returns a new shape of two dimensions.
474   ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2);
475 
476   // Returns in <out> a new shape whose dimension sizes come from input tensor
477   // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor.  If
478   // the input tensor is NULL, then an unknown shape is returned.
479   Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
480 
481   // Like the function above, but treats scalar values as unknown
482   // shapes.  **NOTE** If the scalar is statically known, its value
483   // must be -1 or an error is returned.
484   Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,
485                                                            ShapeHandle* out);
486 
487   // Returns in <out> a new shape corresponding to <proto>.
488   Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
489                                  ShapeHandle* out);
490 
491   // Returns in <out> a new shape corresponding to <partial_shape>.
492   Status MakeShapeFromPartialTensorShape(
493       const PartialTensorShape& partial_shape, ShapeHandle* out);
494 
495   // Returns in <out> a new shape corresponding to <shape>.
496   Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out);
497 
498   // Returns a new dimension of the given size.  The returned value is owned by
499   // this context.
MakeDim(DimensionOrConstant d)500   inline DimensionHandle MakeDim(DimensionOrConstant d) {
501     return shape_manager_.MakeDim(d);
502   }
503 
UnknownDim()504   inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
505 
506   // Returns in <val> a scalar value from an input tensor <t>.  The input tensor
507   // must be a 1-dimensional int32 or int64 tensor.  Caller must ensure that the
508   // input tensor is not NULL.
509   Status GetScalarFromTensor(const Tensor* t, int64* val);
510 
511   // Returns a new dimension whose value is given by a scalar input tensor.
512   // The input tensor must be in host memory, since it is dereferenced to get
513   // the value.
514   Status MakeDimForScalarInput(int idx, DimensionHandle* out);
515 
516   // Returns a new dimension whose value is given by a scalar input tensor.
517   // This allows for a negative input dimension given the rank of a separate
518   // tensor.  This rank can be negative if unknown.
519   // The input tensor must be in host memory, since it is dereferenced to get
520   // the value.
521   Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
522                                                    DimensionHandle* out);
523 
524   // Look up the attr for the NodeDef being evaluated with name attr_name and
525   // set *value to its value.  If no attr with attr_name is found in def(), or
526   // the attr does not have a matching type, a non-ok status will be returned.
527   template <class T>
528   Status GetAttr(StringPiece attr_name, T* value) const;
529 
530   // Returns in <out> the result of dividing <dividend> by <divisor>.
531   // Returns an error if <divisor>  is not positive or if <evenly_divisible>
532   // and <divisor> does not evenly divide <dividend>.
533   Status Divide(DimensionHandle dividend, DimensionOrConstant divisor,
534                 bool evenly_divisible, DimensionHandle* out);
535 
536   // Returns in <out> the sum of <first> and <second>.
537   Status Add(DimensionHandle first, DimensionOrConstant second,
538              DimensionHandle* out);
539 
540   // Returns in <out> the dimension that is <first> minus <second>.
541   Status Subtract(DimensionHandle first, DimensionOrConstant second,
542                   DimensionHandle* out);
543 
544   // Returns in <out> the product of <first> and <second>.
545   Status Multiply(DimensionHandle first, DimensionOrConstant second,
546                   DimensionHandle* out);
547 
548   // Returns in <out> the minimum of <first> and <second>. If either <first> or
549   // <second> is zero the results is zero. Otherwise, if either <first> or
550   // <second> is unknown the results is unknown.
551   Status Min(DimensionHandle first, DimensionOrConstant second,
552              DimensionHandle* out);
553 
554   // Returns in <out> the maximum of <first> and <second>. If either <first> or
555   // <second> is unknown the results is unknown.
556   Status Max(DimensionHandle first, DimensionOrConstant second,
557              DimensionHandle* out);
558 
construction_status()559   Status construction_status() const { return construction_status_; }
560 
561   // Methods to propagate shape and dtype on edges of handles. Handles are the
562   // dtype DT_RESOURCE which can be used to access state stored in a
563   // ResourceManager. When ops (such as variables) consume these handles to
564   // produce tensors they might need to know side-information about the shapes
565   // and dtypes of tensors which can be accessed via the handle. These methods
566   // propagate that information. Output handle dtypes and shapes are ignored if
567   // the output tensor is not of type DT_RESOURCE.
568 
569   // Merge the stored shapes and types corresponding to the input handle in
570   // position idx with the specified shapes and types. This requires idx to be
571   // in the [0, num_inputs) range.
572   //
573   // If the merge is successful and any of the new shapes differs from the old
574   // one, or any of the old dtypes was DT_INVALID, store the new shapes and
575   // return true.  Return false otherwise.
576   //
577   // See 'MergeInput' function for full details and examples.
578   bool MergeInputHandleShapesAndTypes(
579       int idx,
580       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
581 
582   // As MergeInputHandleShapesAndTypes, but for an output.
583   bool MergeOutputHandleShapesAndTypes(
584       int idx,
585       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
586 
587   // Relaxes the stored shapes and types corresponding to the input handle in
588   // position idx with the specified shapes and types. This requires idx to be
589   // in the [0, num_inputs) range.
590   //
591   // If the relax is successful (sizes are the same, old dtypes match new ones
592   // or are DT_INVALID), then store the relaxed shapes and return true.
593   // Return false otherwise.
594   //
595   // See 'RelaxInput' function for full details and examples.
596   bool RelaxInputHandleShapesAndMergeTypes(
597       int idx,
598       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
599 
600   // As RelaxInputHandleShapesAndTypes, but for an output.
601   bool RelaxOutputHandleShapesAndMergeTypes(
602       int idx,
603       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
604 
set_input_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)605   void set_input_handle_shapes_and_types(
606       int idx, const std::vector<ShapeAndType>& shapes_and_types) {
607     input_handle_shapes_and_types_[idx].reset(
608         new std::vector<ShapeAndType>(shapes_and_types));
609   }
610 
611   // Returns the output handle shapes and types, for the resource tensor output
612   // at index <idx>. Returns NULL if the shape and types were never set.
output_handle_shapes_and_types(int idx)613   const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
614     return output_handle_shapes_and_types_[idx].get();
615   }
616 
617   // Returns the inputs handle shapes and types, for the resource tensor output
618   // at index <idx>. Returns NULL if the shape and types were not available.
input_handle_shapes_and_types(int idx)619   const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
620     return input_handle_shapes_and_types_[idx].get();
621   }
622 
set_output_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)623   void set_output_handle_shapes_and_types(
624       int idx, const std::vector<ShapeAndType>& shapes_and_types) {
625     output_handle_shapes_and_types_[idx].reset(
626         new std::vector<ShapeAndType>(shapes_and_types));
627   }
628 
629   // Note that shape functions should usually call MakeShapeFromShapeTensor,
630   // as it does more analysis to provide partial shapes.
631   //
632   // Returns in <out> a new shape whose dimension sizes come from tensor <t>.
633   // The tensor must be a 1-dimensional int32 or int64 tensor.  If <t> is NULL,
634   // then an unknown shape is returned.
635   Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
636                              ShapeHandle* out);
637 
graph_def_version()638   int graph_def_version() const { return graph_def_version_; }
639 
MergedShapes()640   const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const {
641     return merged_shapes_;
642   }
MergedDims()643   const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims()
644       const {
645     return merged_dims_;
646   }
647 
648   // Adds new outputs; useful when mutating the graph.
649   Status ExpandOutputs(int new_output_size);
650 
651  private:
652   // Creates and stores shapes for use in InferenceContext.
653   class ShapeManager {
654    public:
655     ShapeManager();
656     ~ShapeManager();
657 
658     // Returns a new shape with the given dims. The returned value is owned by
659     // this class.
660     ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
661 
662     // Returns a new unknown shape.
663     ShapeHandle UnknownShape();
664 
665     // Returns a new dimension of the given size.  The returned value
666     // is owned by this class.
MakeDim(DimensionOrConstant d)667     inline DimensionHandle MakeDim(DimensionOrConstant d) {
668       if (d.dim.IsSet()) {
669         return d.dim;
670       } else {
671         all_dims_.push_back(new Dimension(d.val));
672         return all_dims_.back();
673       }
674     }
675 
676    private:
677     std::vector<Shape*> all_shapes_;    // values are owned.
678     std::vector<Dimension*> all_dims_;  // values are owned.
679   };
680 
681   friend class ::tensorflow::grappler::GraphProperties;
682 
683   // Friend for user-defined function shape inference purposes.
684   friend class ::tensorflow::ShapeRefiner;
685 
686   friend class ShapeInferenceTest;      // For testing Relax functions.
687   friend class ShapeInferenceTestutil;  // For testing shapes.
688 
689   // Shared initialization across the two constructors.  Remove
690   // once we get rid of one of them.
691   void PreInputInit(const OpDef& op_def,
692                     const std::vector<const Tensor*>& input_tensors,
693                     const std::vector<ShapeHandle>& input_tensors_as_shapes);
694   void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
695                          input_handle_data);
696 
697   DimensionHandle GetDimension(const DimensionOrConstant& d);
698 
ReturnUnknownShape(ShapeHandle * out)699   Status ReturnUnknownShape(ShapeHandle* out) {
700     *out = UnknownShape();
701     return Status::OK();
702   }
ReturnCreatedShape(const std::vector<DimensionHandle> & dims,ShapeHandle * out)703   Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims,
704                             ShapeHandle* out) {
705     *out = MakeShape(dims);
706     return Status::OK();
707   }
708 
709   // Adds additional context to the given status.
710   Status AttachContext(const Status& status);
711 
712   // Relaxes an existing value <d_old> with a new value <d_new> and returns the
713   // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible
714   // values, returns an error.
715   //
716   // Note that <*out> may be set to <d_old> or <d_new>.
717   void Relax(DimensionHandle d_old, DimensionHandle d_new,
718              DimensionHandle* out);
719   // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the
720   // relaxed shape in <*out>. See 'RelaxInput' function for full details and
721   // examples.
722   void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out);
723 
724   // Used to implement MergeInputHandleShapesAndTypes and
725   // MergeOutputHandleShapesAndTypes.
726   bool MergeHandleShapesAndTypes(
727       const std::vector<ShapeAndType>& shapes_and_types,
728       std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
729   // Used to implement RelaxInputHandleShapesAndMergeTypes and
730   // RelaxOutputHandleShapesAndMergeTypes.
731   bool RelaxHandleShapesAndMergeTypes(
732       const std::vector<ShapeAndType>& shapes_and_types,
733       std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
734 
735   // Forget all the previous merged shapes and dims.
ForgetMerges()736   void ForgetMerges() {
737     merged_shapes_.clear();
738     merged_dims_.clear();
739   }
740 
741   // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor.
742   Status InternalMakeShapeFromTensor(
743       bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
744       ShapeHandle tensor_shape, ShapeHandle* out);
745 
746   ShapeManager shape_manager_;
747 
748   // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
749   // `shape_manager_`.
750   std::vector<ShapeHandle> inputs_;
751   std::vector<const Tensor*> input_tensors_;
752   std::vector<bool> requested_input_tensor_;
753   std::vector<ShapeHandle> outputs_;
754   // Can have fewer elements than inputs_.
755   std::vector<ShapeHandle> input_tensors_as_shapes_;
756   std::vector<bool> requested_input_tensor_as_partial_shape_;
757 
758   // input_handle_shapes_and_types_[i] is the list of shape/type pairs available
759   // through the resource handle passed along input i of the node.
760   //
761   // Values may be NULL.
762   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
763       input_handle_shapes_and_types_;
764 
765   // output_handle_shapes_and_types_[i] is the list of shape/type pairs
766   // available through the resource handle passed along output i of the node.
767   //
768   // Values may be NULL.
769   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
770       output_handle_shapes_and_types_;
771 
772   const int graph_def_version_;
773   const NodeDef* node_def_;
774   NameRangeMap input_name_map_;
775   NameRangeMap output_name_map_;
776 
777   // An error set during construction. TODO(cwhipkey): remove when test
778   // constructor is removed.
779   Status construction_status_;
780 
781   // Pair of shape or dim handles that are equivalent, ie that represent the
782   // same underlying shape of dimension. Note that for each pair at least one of
783   // the handles must contain an unknown shape, since we don't keep track of
784   // known shapes or dims here.
785   std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_;
786   std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_;
787 
788   TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
789 };
790 
791 // -----------------------------------------------------------------------------
792 // Template and inline method implementations, please ignore
793 
Dimension()794 inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
Dimension(int64 value)795 inline Dimension::Dimension(int64 value) : value_(value) {
796   DCHECK(value >= 0 || value == InferenceContext::kUnknownDim)
797       << "Dimension must be non-negative or equal to "
798          "InferenceContext::kUnknownDim but got "
799       << value;
800 }
801 
Shape()802 inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
Shape(const std::vector<DimensionHandle> & dims)803 inline Shape::Shape(const std::vector<DimensionHandle>& dims)
804     : rank_(dims.size()), dims_(dims) {}
805 
DimensionOrConstant(DimensionHandle dim)806 inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim)
807     : dim(dim) {
808   DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension.";
809 }
810 
DimensionOrConstant(int64 val)811 inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) {
812   DCHECK(val >= 0 || val == InferenceContext::kUnknownDim)
813       << "Dimension must be non-negative or equal to "
814          "InferenceContext::kUnknownDim but got "
815       << val;
816 }
817 
818 template <class T>
GetAttr(StringPiece attr_name,T * value)819 Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
820   return GetNodeAttr(*node_def_, attr_name, value);
821 }
822 
823 }  // namespace shape_inference
824 }  // namespace tensorflow
825 
826 #endif  // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
827