• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
17 
18 #include <vector>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/gtl/inlined_vector.h"
26 #include "tensorflow/core/platform/macros.h"
27 
28 namespace tensorflow {
29 
30 namespace grappler {
31 class GraphProperties;
32 class SymbolicShapeManager;
33 }  // namespace grappler
34 
35 namespace shape_inference {
36 
37 struct DimensionOrConstant;
38 class InferenceContext;
39 
40 // Dimension values are accessed through InferenceContext.
41 class Dimension {
42  private:
43   Dimension();
44   Dimension(int64 value);
~Dimension()45   ~Dimension() {}
46 
47   const int64 value_;
48 
49   friend class InferenceContext;
50   friend class ShapeManager;
51   TF_DISALLOW_COPY_AND_ASSIGN(Dimension);
52 };
53 
54 class DimensionHandle {
55  public:
DimensionHandle()56   DimensionHandle() {}
SameHandle(DimensionHandle d)57   bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; }
Handle()58   std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
59 
60  private:
DimensionHandle(const Dimension * dim)61   DimensionHandle(const Dimension* dim) { ptr_ = dim; }
62 
63   const Dimension* operator->() const { return ptr_; }
IsSet()64   bool IsSet() const { return ptr_ != nullptr; }
65 
66   const Dimension* ptr_ = nullptr;
67 
68   friend struct DimensionOrConstant;
69   friend class InferenceContext;
70   friend class ShapeInferenceTest;
71   friend class ShapeInferenceTestutil;
72   friend class ::tensorflow::grappler::GraphProperties;
73   friend class ::tensorflow::grappler::SymbolicShapeManager;
74 
75   // Intentionally copyable.
76 };
77 
78 // Shape rank and dimensions are accessed through InferenceContext.
79 class Shape {
80  private:
81   Shape();
82   Shape(const std::vector<DimensionHandle>& dims);
~Shape()83   ~Shape() {}
84 
85   const int32 rank_;
86   const std::vector<DimensionHandle> dims_;
87 
88   friend class InferenceContext;
89   friend class ::tensorflow::grappler::SymbolicShapeManager;
90 
91   TF_DISALLOW_COPY_AND_ASSIGN(Shape);
92 };
93 
94 class ShapeHandle {
95  public:
ShapeHandle()96   ShapeHandle() {}
SameHandle(ShapeHandle s)97   bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; }
Handle()98   std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
99 
100  private:
ShapeHandle(const Shape * shape)101   ShapeHandle(const Shape* shape) { ptr_ = shape; }
102   const Shape* operator->() const { return ptr_; }
IsSet()103   bool IsSet() const { return ptr_ != nullptr; }
104 
105   const Shape* ptr_ = nullptr;
106 
107   friend class InferenceContext;
108   friend class ShapeInferenceTest;
109   friend class ShapeInferenceTestutil;
110   friend class ::tensorflow::grappler::SymbolicShapeManager;
111 
112   // Intentionally copyable.
113 };
114 
115 // Struct used to allow functions to take DimensionHandle or a dimension value.
116 // Not meant to be constructed directly.
117 struct DimensionOrConstant {
118  public:
119   // Intentionally not explicit.
120   DimensionOrConstant(DimensionHandle dim);
121 
122   // val must be non-negative or InferenceContext::kUnknownDim.
123   DimensionOrConstant(int64 val);
124 
125   // dim takes precedence. If dim != nullptr, val is ignored.
126   DimensionHandle dim;
127   int64 val;
128 
129  private:
130   DimensionOrConstant();
131 };
132 
133 struct ShapeAndType {
ShapeAndTypeShapeAndType134   ShapeAndType() {}
ShapeAndTypeShapeAndType135   ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
136 
137   ShapeHandle shape;
138   DataType dtype = DT_INVALID;
139 };
140 
141 // Shape inference functions registered on ops in REGISTER_OP implement
142 // their shape functions in terms of this InferenceContext.  An InferenceContext
143 // is created by the framework and passed to a shape inference function.  The
144 // shape inference function calls functions on the context, and should call
145 // set_output() to set the shape on all outputs.
146 //
147 // To infer shapes for user-defined functions see ShapeRefiner.
148 //
149 // All Shape* and Dimension* returned by functions of InferenceContext are owned
150 // by the InferenceContext.
151 class InferenceContext {
152  public:
153   static constexpr int64 kUnknownDim = -1;
154   static constexpr int32 kUnknownRank = -1;
155 
156   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
157   //
158   // Elements of <input_tensors_as_shapes> are used for when a shape function
159   // makes a call to MakeShapeFromShapeTensor; in particular, when the
160   // input_tensors[i] is nullptr but the shape represented by it is partially
161   // known from analysis of the graph.
162   // <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
163   // Values of <input_tensors_as_shapes> do not need to outlive the context.
164   InferenceContext(int graph_def_version, const AttrSlice& attrs,
165                    const OpDef& op_def,
166                    const std::vector<ShapeHandle>& input_shapes,
167                    const std::vector<const Tensor*>& input_tensors,
168                    const std::vector<ShapeHandle>& input_tensors_as_shapes,
169                    std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
170                        input_handle_shapes_and_types);
171 
172   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
173   //
174   // Elements of <input_tensors_as_shapes> are used for when a shape
175   // function makes a call to MakeShapeFromShapeTensor; in particular, when
176   // the input_tensors[i] is nullptr but the shape represented by it is
177   // partially known from analysis of the graph. <input_tensors_as_shapes>
178   // can have fewer elements than <input_shapes>. Values of
179   // <input_tensors_as_shapes> do not need to outlive the context.
180   InferenceContext(
181       int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
182       const std::vector<PartialTensorShape>& input_shapes,
183       const std::vector<const Tensor*>& input_tensors,
184       const std::vector<PartialTensorShape>& input_tensors_as_shapes,
185       const std::vector<std::unique_ptr<
186           std::vector<std::pair<PartialTensorShape, DataType>>>>&
187           input_handle_shapes_and_types);
188 
189   ~InferenceContext();
190 
191   // Runs the shape inference function 'fn' with 'this' as the
192   // argument, returns the status of the inference.
193   //
194   // On error, additional context is provided in the error message.
195   Status Run(
196       const std::function<Status(shape_inference::InferenceContext* c)>& fn);
197 
198   // Merge the stored shape of the input in position idx with <shape> according
199   // to the following rules:
200   //
201   // - If the ShapeHandles are the same or <shape> is unknown, there will be no
202   //   change. Otherwise if the stored shape is unknown, the new shape will be
203   //   <shape>.
204   // - If both shapes are known, then they must have the same rank.
205   // - For any one dimension, if the values for that dimension in both shapes
206   //   are known, then the values must match.
207   // - If one shape has equal or more information than the other shape in every
208   //   dimension, the new shape will become the shape with more information.
209   // - Example: merging [2,?] and [?,2] results in [2,2]
210   // - Example: [2,2] cannot be merged with [1,2]
211   //
212   // This requires idx to be in the [0, num_inputs) range. If the merge is
213   // successful, return true. Return false otherwise.
MergeInput(int idx,ShapeHandle shape)214   bool MergeInput(int idx, ShapeHandle shape) {
215     ShapeHandle new_shape;
216     if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
217     inputs_[idx] = new_shape;
218     return true;
219   }
220 
221   // Relax the stored shape of the input in position idx with <shape> according
222   // to the following rules:
223   //
224   // - If the ShapeHandles are the same then the stored shape will be returned.
225   // - If either of the ShapeHandles are unknown, then a new UnknownShape will
226   //   be returned. A new shape must be returned because we cannot claim that
227   //   the resulting shape is necessarily the same as either of the input
228   //   shapes.
229   // - If the shapes both have known ranks but their ranks are different, a new
230   //   UnknownShape will be returned.
231   // - For any one dimension, if the value for that dimension in either of the
232   //   shapes is unknown, a new shape will be returned with a new UnknownDim in
233   //   that dimension.
234   // - For any one dimension, if the values for that dimension in both shapes
235   //   are known but do not match, a new shape will be returned with a new
236   //   UnknownDim in that dimension.
237   // - If both shapes have the same known rank and match in every dimension,
238   //   the stored shape will be returned.
239   // - Example: relaxing [2,?] and [?,2] results in [?,?]
240   // - Example: relaxing [2,2] and [3,2] results in [?,2]
241   // - Example: relaxing [2,2] with [1,2,3] results in ?
242   //
243   // This requires idx to be in the [0, num_inputs) range. If the relax is
244   // successful and the new shape differs from the old one, store the new
245   // shape and return true. Return false otherwise.
RelaxInput(int idx,ShapeHandle shape)246   bool RelaxInput(int idx, ShapeHandle shape) {
247     ShapeHandle new_shape;
248     Relax(inputs_[idx], shape, &new_shape);
249     if (inputs_[idx].SameHandle(new_shape)) {
250       return false;
251     }
252     inputs_[idx] = new_shape;
253     return true;
254   }
255 
SetInput(int idx,ShapeHandle shape)256   void SetInput(int idx, ShapeHandle shape) { inputs_[idx] = shape; }
257 
input(int64 idx)258   ShapeHandle input(int64 idx) const { return inputs_[idx]; }
259   Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
num_inputs()260   int num_inputs() const { return inputs_.size(); }
261 
262   // Returns the input tensor at index <idx>, or nullptr if the input tensor is
263   // not available at the time of shape inference.
input_tensor(int idx)264   const Tensor* input_tensor(int idx) {
265     // Mark that this idx was requested.
266     requested_input_tensor_[idx] = true;
267     return input_tensors_[idx];
268   }
269 
270   // Returns true iff input_tensor(idx) was called by the shape function.
requested_input_tensor(int idx)271   bool requested_input_tensor(int idx) const {
272     return requested_input_tensor_[idx];
273   }
274 
275   // Returns true if MakeShapeFromInputTensor was called but the constant
276   // input_tensor was not present.
requested_input_tensor_as_partial_shape(int idx)277   bool requested_input_tensor_as_partial_shape(int idx) const {
278     return requested_input_tensor_as_partial_shape_[idx];
279   }
280 
set_input_tensors(const std::vector<const Tensor * > & input_tensors)281   void set_input_tensors(const std::vector<const Tensor*>& input_tensors) {
282     input_tensors_ = input_tensors;
283   }
284 
set_input_tensors_as_shapes(const std::vector<ShapeHandle> & input_tensors_as_shapes)285   void set_input_tensors_as_shapes(
286       const std::vector<ShapeHandle>& input_tensors_as_shapes) {
287     input_tensors_as_shapes_ = input_tensors_as_shapes;
288   }
289 
input_tensors_as_shapes()290   const std::vector<ShapeHandle>& input_tensors_as_shapes() const {
291     return input_tensors_as_shapes_;
292   }
293 
output(int64 idx)294   ShapeHandle output(int64 idx) const { return outputs_.at(idx); }
set_output(int idx,ShapeHandle shape)295   void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
296   Status set_output(StringPiece output_name,
297                     const std::vector<ShapeHandle>& shapes);
298 
num_outputs()299   int num_outputs() const { return outputs_.size(); }
output(int idx)300   ShapeHandle output(int idx) const { return outputs_.at(idx); }
301   Status output(StringPiece output_name,
302                 std::vector<ShapeHandle>* output) const;
303 
attrs()304   const AttrSlice& attrs() const { return attrs_; }
305 
306   // idx can be negative for an offset from end of dimensions.
307   // idx must be in the range [-1 * s.rank, s.rank).
Dim(ShapeHandle s,int64 idx)308   DimensionHandle Dim(ShapeHandle s, int64 idx) {
309     if (!s.Handle() || s->rank_ == kUnknownRank) {
310       return UnknownDim();
311     }
312     return DimKnownRank(s, idx);
313   }
314   // As above, but asserts that the rank of the shape is known.
DimKnownRank(ShapeHandle s,int64 idx)315   static DimensionHandle DimKnownRank(ShapeHandle s, int64 idx) {
316     CHECK_NE(s->rank_, kUnknownRank);
317     if (idx < 0) {
318       return s->dims_[s->dims_.size() + idx];
319     }
320     return s->dims_[idx];
321   }
322 
Rank(ShapeHandle s)323   static int32 Rank(ShapeHandle s) {
324     return s.IsSet() ? s->rank_ : kUnknownRank;
325   }
RankKnown(ShapeHandle s)326   static bool RankKnown(ShapeHandle s) {
327     return (s.IsSet() && (Rank(s) != kUnknownRank));
328   }
Value(DimensionOrConstant d)329   static inline int64 Value(DimensionOrConstant d) {
330     return d.dim.IsSet() ? d.dim->value_ : d.val;
331   }
ValueKnown(DimensionOrConstant d)332   static inline bool ValueKnown(DimensionOrConstant d) {
333     return Value(d) != kUnknownDim;
334   }
335 
336   // Fills the output proto with the shape defined by the handle.
337   // "proto" is expected to be empty prior to the call.
338   void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto);
339 
340   // Returns true if the rank and all dimensions of the Shape are known.
341   bool FullyDefined(ShapeHandle s);
342 
343   // Returns the total number of elements, or an unknown dimension for an
344   // incomplete shape.
345   DimensionHandle NumElements(ShapeHandle s);
346 
347   string DebugString(ShapeHandle s);
348   string DebugString(DimensionHandle d);
349   string DebugString(const ShapeAndType& shape_and_type);
350   string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types);
351 
352   // Describes the whole context, for debugging purposes.
353   string DebugString() const;
354 
355   // If <shape> has rank <rank>, or its rank is unknown, return OK and return
356   // the shape with asserted rank in <*out>. Otherwise return an error.
357   //
358   // Note that <*out> may be set to <shape>.
359   Status WithRank(ShapeHandle shape, int64 rank,
360                   ShapeHandle* out) TF_MUST_USE_RESULT;
361   Status WithRankAtLeast(ShapeHandle shape, int64 rank,
362                          ShapeHandle* out) TF_MUST_USE_RESULT;
363   Status WithRankAtMost(ShapeHandle shape, int64 rank,
364                         ShapeHandle* out) TF_MUST_USE_RESULT;
365 
366   // If <dim> has value <value>, or its value is unknown, returns OK and returns
367   // the dimension with asserted value in <*out>. Otherwise returns an error.
368   //
369   // Note that <*out> may be set to <dim>.
370   Status WithValue(DimensionHandle dim, int64 value,
371                    DimensionHandle* out) TF_MUST_USE_RESULT;
372 
373   // Merges <s0> and <s1> and returns the merged shape in <*out>. See
374   // 'MergeInput' function for full details and examples.
375   Status Merge(ShapeHandle s0, ShapeHandle s1,
376                ShapeHandle* out) TF_MUST_USE_RESULT;
377 
378   // Asserts that <s>'s rank >= <prefix>'s rank, and the first
379   // <prefix.rank> dimensions of <s> are compatible with the dimensions of
380   // <prefix>.
381   // Returns the merged results in <*s_out> and <*prefix_out>.
382   Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out,
383                      ShapeHandle* prefix_out) TF_MUST_USE_RESULT;
384 
385   // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0>
386   // and <d1> have incompatible values, returns an error.
387   //
388   // Note that <*out> may be set to <d0> or <d1>.
389   Status Merge(DimensionHandle d0, DimensionHandle d1,
390                DimensionHandle* out) TF_MUST_USE_RESULT;
391 
392   // Returns in <*out> a sub-shape of <s> with dimensions [start:].
393   // <start> can be negative to index from the end of the shape. If <start> >
394   // rank of <s>, then an empty subshape is returned.
395   Status Subshape(ShapeHandle s, int64 start,
396                   ShapeHandle* out) TF_MUST_USE_RESULT;
397 
398   // Returns in <*out> a sub-shape of <s>, with dimensions [start:end].
399   // <start> and <end> can be negative, to index from the end of the shape.
400   // <start> and <end> are set to the rank of <s> if > rank of <s>.
401   Status Subshape(ShapeHandle s, int64 start, int64 end,
402                   ShapeHandle* out) TF_MUST_USE_RESULT;
403 
404   // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride].
405   // <start> and <end> can be negative, to index from the end of the shape.
406   // <start> and <end> are set to the rank of <s> if > rank of <s>.
407   // <stride> can be negative, to reverse the <s>.
408   Status Subshape(ShapeHandle s, int64 start, int64 end, int64 stride,
409                   ShapeHandle* out) TF_MUST_USE_RESULT;
410 
411   // Returns in <*out> the result of appending the dimensions of <s2> to those
412   // of <s1>.
413   Status Concatenate(ShapeHandle s1, ShapeHandle s2,
414                      ShapeHandle* out) TF_MUST_USE_RESULT;
415 
416   // Returns in <out> the shape from replacing <s.dim[dim_index]> with
417   // <new_dim>.
418   Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim,
419                     ShapeHandle* out) TF_MUST_USE_RESULT;
420 
421   // Returns a new shape with the given dims. The returned value is owned by
422   // this context.
423   ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
424   ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims);
425 
426   // Returns a new unknown shape.
427   ShapeHandle UnknownShape();
428 
429   // Returns a shape with specified rank but unknown dims.
430   ShapeHandle UnknownShapeOfRank(int64 rank);
431 
432   // Returns a new shape of zero dimensions.
433   ShapeHandle Scalar();
434 
435   // Returns a new shape of one dimension.
436   ShapeHandle Vector(DimensionOrConstant dim);
437 
438   // Returns a new shape of two dimensions.
439   ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2);
440 
441   // Returns in <out> a new shape whose dimension sizes come from input tensor
442   // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor.  If
443   // the input tensor is NULL, then an unknown shape is returned.
444   Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
445 
446   // Like the function above, but treats scalar values as unknown
447   // shapes.  **NOTE** If the scalar is statically known, its value
448   // must be -1 or an error is returned.
449   Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,
450                                                            ShapeHandle* out);
451 
452   // Returns in <out> a new shape corresponding to <proto>.
453   Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
454                                  ShapeHandle* out);
455 
456   // Returns in <out> a new shape corresponding to <partial_shape>.
457   Status MakeShapeFromPartialTensorShape(
458       const PartialTensorShape& partial_shape, ShapeHandle* out);
459 
460   // Returns in <out> a new shape corresponding to <shape>.
461   Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out);
462 
463   // Returns a new dimension of the given size.  The returned value is owned by
464   // this context.
MakeDim(DimensionOrConstant d)465   inline DimensionHandle MakeDim(DimensionOrConstant d) {
466     return shape_manager_.MakeDim(d);
467   }
468 
UnknownDim()469   inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
470 
471   // Returns in <val> a scalar value from an input tensor <t>.  The input tensor
472   // must be a 1-dimensional int32 or int64 tensor.  Caller must ensure that the
473   // input tensor is not NULL.
474   Status GetScalarFromTensor(const Tensor* t, int64* val);
475 
476   // Returns a new dimension whose value is given by a scalar input tensor.
477   // The input tensor must be in host memory, since it is dereferenced to get
478   // the value.
479   Status MakeDimForScalarInput(int idx, DimensionHandle* out);
480 
481   // Returns a new dimension whose value is given by a scalar input tensor.
482   // This allows for a negative input dimension given the rank of a separate
483   // tensor.  This rank can be negative if unknown.
484   // The input tensor must be in host memory, since it is dereferenced to get
485   // the value.
486   Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
487                                                    DimensionHandle* out);
488 
489   // Look up the attr being evaluated with name attr_name and set *value to its
490   // value. If no attr with attr_name is found in def(), or the attr does not
491   // have a matching type, a non-ok status will be returned.
492   template <class T>
493   Status GetAttr(StringPiece attr_name, T* value) const;
494 
495   // Returns in <out> the result of dividing <dividend> by <divisor>.
496   // Returns an error if <divisor>  is not positive or if <evenly_divisible>
497   // and <divisor> does not evenly divide <dividend>.
498   Status Divide(DimensionHandle dividend, DimensionOrConstant divisor,
499                 bool evenly_divisible, DimensionHandle* out);
500 
501   // Returns in <out> the sum of <first> and <second>.
502   Status Add(DimensionHandle first, DimensionOrConstant second,
503              DimensionHandle* out);
504 
505   // Returns in <out> the dimension that is <first> minus <second>.
506   Status Subtract(DimensionHandle first, DimensionOrConstant second,
507                   DimensionHandle* out);
508 
509   // Returns in <out> the product of <first> and <second>.
510   Status Multiply(DimensionHandle first, DimensionOrConstant second,
511                   DimensionHandle* out);
512 
513   // Returns in <out> the minimum of <first> and <second>. If either <first> or
514   // <second> is zero the results is zero. Otherwise, if either <first> or
515   // <second> is unknown the results is unknown.
516   Status Min(DimensionHandle first, DimensionOrConstant second,
517              DimensionHandle* out);
518 
519   // Returns in <out> the maximum of <first> and <second>. If either <first> or
520   // <second> is unknown the results is unknown.
521   Status Max(DimensionHandle first, DimensionOrConstant second,
522              DimensionHandle* out);
523 
construction_status()524   Status construction_status() const { return construction_status_; }
525 
526   // Methods to propagate shape and dtype on edges of handles. Handles are the
527   // dtype DT_RESOURCE which can be used to access state stored in a
528   // ResourceManager. When ops (such as variables) consume these handles to
529   // produce tensors they might need to know side-information about the shapes
530   // and dtypes of tensors which can be accessed via the handle. These methods
531   // propagate that information. Output handle dtypes and shapes are ignored if
532   // the output tensor is not of type DT_RESOURCE.
533 
534   // Merge the stored shapes and types corresponding to the input handle in
535   // position idx with the specified shapes and types. This requires idx to be
536   // in the [0, num_inputs) range.
537   //
538   // If the merge is successful and any of the new shapes differs from the old
539   // one, or any of the old dtypes was DT_INVALID, store the new shapes and
540   // return true.  Return false otherwise.
541   //
542   // See 'MergeInput' function for full details and examples.
543   bool MergeInputHandleShapesAndTypes(
544       int idx,
545       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
546 
547   // As MergeInputHandleShapesAndTypes, but for an output.
548   bool MergeOutputHandleShapesAndTypes(
549       int idx,
550       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
551 
552   // Relaxes the stored shapes and types corresponding to the input handle in
553   // position idx with the specified shapes and types. This requires idx to be
554   // in the [0, num_inputs) range.
555   //
556   // If the relax is successful (sizes are the same, old dtypes match new ones
557   // or are DT_INVALID), then store the relaxed shapes and return true.
558   // Return false otherwise.
559   //
560   // See 'RelaxInput' function for full details and examples.
561   bool RelaxInputHandleShapesAndMergeTypes(
562       int idx,
563       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
564 
565   // As RelaxInputHandleShapesAndTypes, but for an output.
566   bool RelaxOutputHandleShapesAndMergeTypes(
567       int idx,
568       const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
569 
set_input_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)570   void set_input_handle_shapes_and_types(
571       int idx, const std::vector<ShapeAndType>& shapes_and_types) {
572     input_handle_shapes_and_types_[idx] =
573         absl::make_unique<std::vector<ShapeAndType>>(shapes_and_types);
574   }
575 
576   // Returns the output handle shapes and types, for the resource tensor output
577   // at index <idx>. Returns NULL if the shape and types were never set.
output_handle_shapes_and_types(int idx)578   const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
579     return output_handle_shapes_and_types_[idx].get();
580   }
581 
582   // Returns the inputs handle shapes and types, for the resource tensor output
583   // at index <idx>. Returns NULL if the shape and types were not available.
input_handle_shapes_and_types(int idx)584   const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
585     return input_handle_shapes_and_types_[idx].get();
586   }
587 
set_output_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)588   void set_output_handle_shapes_and_types(
589       int idx, const std::vector<ShapeAndType>& shapes_and_types) {
590     output_handle_shapes_and_types_[idx].reset(
591         new std::vector<ShapeAndType>(shapes_and_types));
592   }
593 
594   // Note that shape functions should usually call MakeShapeFromShapeTensor,
595   // as it does more analysis to provide partial shapes.
596   //
597   // Returns in <out> a new shape whose dimension sizes come from tensor <t>.
598   // The tensor must be a 1-dimensional int32 or int64 tensor.  If <t> is NULL,
599   // then an unknown shape is returned.
600   Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
601                              ShapeHandle* out);
602 
graph_def_version()603   int graph_def_version() const { return graph_def_version_; }
604 
MergedShapes()605   const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const {
606     return merged_shapes_;
607   }
MergedDims()608   const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims()
609       const {
610     return merged_dims_;
611   }
612 
613   // Adds new outputs; useful when mutating the graph.
614   Status ExpandOutputs(int new_output_size);
615 
616  private:
617   // Creates and stores shapes for use in InferenceContext.
618   class ShapeManager {
619    public:
620     ShapeManager();
621     ~ShapeManager();
622 
623     // Returns a new shape with the given dims. The returned value is owned by
624     // this class.
625     ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
626 
627     // Returns a new unknown shape.
628     ShapeHandle UnknownShape();
629 
630     // Returns a new dimension of the given size.  The returned value
631     // is owned by this class.
MakeDim(DimensionOrConstant d)632     inline DimensionHandle MakeDim(DimensionOrConstant d) {
633       if (d.dim.IsSet()) {
634         return d.dim;
635       } else {
636         all_dims_.push_back(new Dimension(d.val));
637         return all_dims_.back();
638       }
639     }
640 
641    private:
642     std::vector<Shape*> all_shapes_;    // values are owned.
643     std::vector<Dimension*> all_dims_;  // values are owned.
644   };
645 
646   friend class ::tensorflow::grappler::GraphProperties;
647 
648   friend class ShapeInferenceTest;      // For testing Relax functions.
649   friend class ShapeInferenceTestutil;  // For testing shapes.
650 
651   // Shared initialization across the two constructors.  Remove
652   // once we get rid of one of them.
653   void PreInputInit(const OpDef& op_def,
654                     const std::vector<const Tensor*>& input_tensors,
655                     const std::vector<ShapeHandle>& input_tensors_as_shapes);
656   void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
657                          input_handle_data);
658 
ReturnUnknownShape(ShapeHandle * out)659   Status ReturnUnknownShape(ShapeHandle* out) {
660     *out = UnknownShape();
661     return Status::OK();
662   }
ReturnCreatedShape(const std::vector<DimensionHandle> & dims,ShapeHandle * out)663   Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims,
664                             ShapeHandle* out) {
665     *out = MakeShape(dims);
666     return Status::OK();
667   }
668 
669   // Adds additional context to the given status.
670   Status AttachContext(const Status& status);
671 
672   // Relaxes an existing value <d_old> with a new value <d_new> and returns the
673   // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible
674   // values, returns an error.
675   //
676   // Note that <*out> may be set to <d_old> or <d_new>.
677   void Relax(DimensionHandle d_old, DimensionHandle d_new,
678              DimensionHandle* out);
679   // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the
680   // relaxed shape in <*out>. See 'RelaxInput' function for full details and
681   // examples.
682   void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out);
683 
684   // Used to implement MergeInputHandleShapesAndTypes and
685   // MergeOutputHandleShapesAndTypes.
686   bool MergeHandleShapesAndTypes(
687       const std::vector<ShapeAndType>& shapes_and_types,
688       std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
689   // Used to implement RelaxInputHandleShapesAndMergeTypes and
690   // RelaxOutputHandleShapesAndMergeTypes.
691   bool RelaxHandleShapesAndMergeTypes(
692       const std::vector<ShapeAndType>& shapes_and_types,
693       std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
694 
695   // Forget all the previous merged shapes and dims.
ForgetMerges()696   void ForgetMerges() {
697     merged_shapes_.clear();
698     merged_dims_.clear();
699   }
700 
701   // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor.
702   Status InternalMakeShapeFromTensor(
703       bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
704       ShapeHandle tensor_shape, ShapeHandle* out);
705 
706   ShapeManager shape_manager_;
707 
708   // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
709   // `shape_manager_`.
710   std::vector<ShapeHandle> inputs_;
711   std::vector<const Tensor*> input_tensors_;
712   std::vector<bool> requested_input_tensor_;
713   std::vector<ShapeHandle> outputs_;
714   // Can have fewer elements than inputs_.
715   std::vector<ShapeHandle> input_tensors_as_shapes_;
716   std::vector<bool> requested_input_tensor_as_partial_shape_;
717 
718   // input_handle_shapes_and_types_[i] is the list of shape/type pairs available
719   // through the resource handle passed along input i of the node.
720   //
721   // Values may be NULL.
722   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
723       input_handle_shapes_and_types_;
724 
725   // output_handle_shapes_and_types_[i] is the list of shape/type pairs
726   // available through the resource handle passed along output i of the node.
727   //
728   // Values may be NULL.
729   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
730       output_handle_shapes_and_types_;
731 
732   const int graph_def_version_;
733   AttrSlice attrs_;
734   NameRangeMap input_name_map_;
735   NameRangeMap output_name_map_;
736 
737   // An error set during construction. TODO(cwhipkey): remove when test
738   // constructor is removed.
739   Status construction_status_;
740 
741   // Pair of shape or dim handles that are equivalent, ie that represent the
742   // same underlying shape of dimension. Note that for each pair at least one of
743   // the handles must contain an unknown shape, since we don't keep track of
744   // known shapes or dims here.
745   std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_;
746   std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_;
747 
748   TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
749 };
750 
751 // -----------------------------------------------------------------------------
752 // Template and inline method implementations, please ignore
753 
Dimension()754 inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
Dimension(int64 value)755 inline Dimension::Dimension(int64 value) : value_(value) {
756   DCHECK(value >= 0 || value == InferenceContext::kUnknownDim)
757       << "Dimension must be non-negative or equal to "
758          "InferenceContext::kUnknownDim but got "
759       << value;
760 }
761 
Shape()762 inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
Shape(const std::vector<DimensionHandle> & dims)763 inline Shape::Shape(const std::vector<DimensionHandle>& dims)
764     : rank_(dims.size()), dims_(dims) {}
765 
DimensionOrConstant(DimensionHandle dim)766 inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim)
767     : dim(dim) {
768   DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension.";
769 }
770 
DimensionOrConstant(int64 val)771 inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) {
772   DCHECK(val >= 0 || val == InferenceContext::kUnknownDim)
773       << "Dimension must be non-negative or equal to "
774          "InferenceContext::kUnknownDim but got "
775       << val;
776 }
777 
778 template <class T>
GetAttr(StringPiece attr_name,T * value)779 Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
780   return GetNodeAttr(attrs_, attr_name, value);
781 }
782 
783 }  // namespace shape_inference
784 }  // namespace tensorflow
785 
786 #endif  // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
787