1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
17
18 #include <vector>
19
20 #include "absl/memory/memory.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/gtl/inlined_vector.h"
26 #include "tensorflow/core/platform/macros.h"
27
28 namespace tensorflow {
29
30 namespace grappler {
31 class GraphProperties;
32 class SymbolicShapeManager;
33 } // namespace grappler
34
35 namespace shape_inference {
36
37 struct DimensionOrConstant;
38 class InferenceContext;
39
40 // Dimension values are accessed through InferenceContext.
41 class Dimension {
42 private:
43 Dimension();
44 Dimension(int64 value);
~Dimension()45 ~Dimension() {}
46
47 const int64 value_;
48
49 friend class InferenceContext;
50 friend class ShapeManager;
51 TF_DISALLOW_COPY_AND_ASSIGN(Dimension);
52 };
53
54 class DimensionHandle {
55 public:
DimensionHandle()56 DimensionHandle() {}
SameHandle(DimensionHandle d)57 bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; }
Handle()58 std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
59
60 private:
DimensionHandle(const Dimension * dim)61 DimensionHandle(const Dimension* dim) { ptr_ = dim; }
62
63 const Dimension* operator->() const { return ptr_; }
IsSet()64 bool IsSet() const { return ptr_ != nullptr; }
65
66 const Dimension* ptr_ = nullptr;
67
68 friend struct DimensionOrConstant;
69 friend class InferenceContext;
70 friend class ShapeInferenceTest;
71 friend class ShapeInferenceTestutil;
72 friend class ::tensorflow::grappler::GraphProperties;
73 friend class ::tensorflow::grappler::SymbolicShapeManager;
74
75 // Intentionally copyable.
76 };
77
78 // Shape rank and dimensions are accessed through InferenceContext.
79 class Shape {
80 private:
81 Shape();
82 Shape(const std::vector<DimensionHandle>& dims);
~Shape()83 ~Shape() {}
84
85 const int32 rank_;
86 const std::vector<DimensionHandle> dims_;
87
88 friend class InferenceContext;
89 friend class ::tensorflow::grappler::SymbolicShapeManager;
90
91 TF_DISALLOW_COPY_AND_ASSIGN(Shape);
92 };
93
94 class ShapeHandle {
95 public:
ShapeHandle()96 ShapeHandle() {}
SameHandle(ShapeHandle s)97 bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; }
Handle()98 std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
99
100 private:
ShapeHandle(const Shape * shape)101 ShapeHandle(const Shape* shape) { ptr_ = shape; }
102 const Shape* operator->() const { return ptr_; }
IsSet()103 bool IsSet() const { return ptr_ != nullptr; }
104
105 const Shape* ptr_ = nullptr;
106
107 friend class InferenceContext;
108 friend class ShapeInferenceTest;
109 friend class ShapeInferenceTestutil;
110 friend class ::tensorflow::grappler::SymbolicShapeManager;
111
112 // Intentionally copyable.
113 };
114
115 // Struct used to allow functions to take DimensionHandle or a dimension value.
116 // Not meant to be constructed directly.
117 struct DimensionOrConstant {
118 public:
119 // Intentionally not explicit.
120 DimensionOrConstant(DimensionHandle dim);
121
122 // val must be non-negative or InferenceContext::kUnknownDim.
123 DimensionOrConstant(int64 val);
124
125 // dim takes precedence. If dim != nullptr, val is ignored.
126 DimensionHandle dim;
127 int64 val;
128
129 private:
130 DimensionOrConstant();
131 };
132
133 struct ShapeAndType {
ShapeAndTypeShapeAndType134 ShapeAndType() {}
ShapeAndTypeShapeAndType135 ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
ShapeAndTypeShapeAndType136 ShapeAndType(ShapeHandle s, DataType t, SpecializedType specialized_t)
137 : shape(s), dtype(t), specialized_type(specialized_t) {}
138
139 ShapeHandle shape;
140 DataType dtype = DT_INVALID;
141 // The type of a variant-dtype tensor sometimes affects graph building
142 // (e.g. for vectorization), and needs to be know statically in such cases.
143 SpecializedType specialized_type = ST_INVALID;
144 };
145
146 // Shape inference functions registered on ops in REGISTER_OP implement
147 // their shape functions in terms of this InferenceContext. An InferenceContext
148 // is created by the framework and passed to a shape inference function. The
149 // shape inference function calls functions on the context, and should call
150 // set_output() to set the shape on all outputs.
151 //
152 // To infer shapes for user-defined functions see ShapeRefiner.
153 //
154 // All Shape* and Dimension* returned by functions of InferenceContext are owned
155 // by the InferenceContext.
156 class InferenceContext {
157 public:
158 static constexpr int64 kUnknownDim = -1;
159 static constexpr int32 kUnknownRank = -1;
160
161 // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
162 //
163 // Elements of <input_tensors_as_shapes> are used for when a shape function
164 // makes a call to MakeShapeFromShapeTensor; in particular, when the
165 // input_tensors[i] is nullptr but the shape represented by it is partially
166 // known from analysis of the graph.
167 // <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
168 // Values of <input_tensors_as_shapes> do not need to outlive the context.
169 InferenceContext(int graph_def_version, const AttrSlice& attrs,
170 const OpDef& op_def,
171 const std::vector<ShapeHandle>& input_shapes,
172 const std::vector<const Tensor*>& input_tensors,
173 const std::vector<ShapeHandle>& input_tensors_as_shapes,
174 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
175 input_handle_shapes_and_types);
176
177 // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
178 //
179 // Elements of <input_tensors_as_shapes> are used for when a shape
180 // function makes a call to MakeShapeFromShapeTensor; in particular, when
181 // the input_tensors[i] is nullptr but the shape represented by it is
182 // partially known from analysis of the graph. <input_tensors_as_shapes>
183 // can have fewer elements than <input_shapes>. Values of
184 // <input_tensors_as_shapes> do not need to outlive the context.
185 InferenceContext(
186 int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
187 const std::vector<PartialTensorShape>& input_shapes,
188 const std::vector<const Tensor*>& input_tensors,
189 const std::vector<PartialTensorShape>& input_tensors_as_shapes,
190 const std::vector<std::unique_ptr<
191 std::vector<std::pair<PartialTensorShape, DataType>>>>&
192 input_handle_shapes_and_types);
193
194 ~InferenceContext();
195
196 // Runs the shape inference function 'fn' with 'this' as the
197 // argument, returns the status of the inference.
198 //
199 // On error, additional context is provided in the error message.
200 Status Run(
201 const std::function<Status(shape_inference::InferenceContext* c)>& fn);
202
203 // Merge the stored shape of the input in position idx with <shape> according
204 // to the following rules:
205 //
206 // - If the ShapeHandles are the same or <shape> is unknown, there will be no
207 // change. Otherwise if the stored shape is unknown, the new shape will be
208 // <shape>.
209 // - If both shapes are known, then they must have the same rank.
210 // - For any one dimension, if the values for that dimension in both shapes
211 // are known, then the values must match.
212 // - If one shape has equal or more information than the other shape in every
213 // dimension, the new shape will become the shape with more information.
214 // - Example: merging [2,?] and [?,2] results in [2,2]
215 // - Example: [2,2] cannot be merged with [1,2]
216 //
217 // This requires idx to be in the [0, num_inputs) range. If the merge is
218 // successful, return true. Return false otherwise.
MergeInput(int idx,ShapeHandle shape)219 bool MergeInput(int idx, ShapeHandle shape) {
220 ShapeHandle new_shape;
221 if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
222 inputs_[idx] = new_shape;
223 return true;
224 }
225
226 // Relax the stored shape of the input in position idx with <shape> according
227 // to the following rules:
228 //
229 // - If the ShapeHandles are the same then the stored shape will be returned.
230 // - If either of the ShapeHandles are unknown, then a new UnknownShape will
231 // be returned. A new shape must be returned because we cannot claim that
232 // the resulting shape is necessarily the same as either of the input
233 // shapes.
234 // - If the shapes both have known ranks but their ranks are different, a new
235 // UnknownShape will be returned.
236 // - For any one dimension, if the value for that dimension in either of the
237 // shapes is unknown, a new shape will be returned with a new UnknownDim in
238 // that dimension.
239 // - For any one dimension, if the values for that dimension in both shapes
240 // are known but do not match, a new shape will be returned with a new
241 // UnknownDim in that dimension.
242 // - If both shapes have the same known rank and match in every dimension,
243 // the stored shape will be returned.
244 // - Example: relaxing [2,?] and [?,2] results in [?,?]
245 // - Example: relaxing [2,2] and [3,2] results in [?,2]
246 // - Example: relaxing [2,2] with [1,2,3] results in ?
247 //
248 // This requires idx to be in the [0, num_inputs) range. If the relax is
249 // successful and the new shape differs from the old one, store the new
250 // shape and return true. Return false otherwise.
RelaxInput(int idx,ShapeHandle shape)251 bool RelaxInput(int idx, ShapeHandle shape) {
252 ShapeHandle new_shape;
253 Relax(inputs_[idx], shape, &new_shape);
254 if (inputs_[idx].SameHandle(new_shape)) {
255 return false;
256 }
257 inputs_[idx] = new_shape;
258 return true;
259 }
260
SetInput(int idx,ShapeHandle shape)261 void SetInput(int idx, ShapeHandle shape) { inputs_[idx] = shape; }
262
input(int64 idx)263 ShapeHandle input(int64 idx) const { return inputs_[idx]; }
264 Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
num_inputs()265 int num_inputs() const { return inputs_.size(); }
266
267 // Returns the input tensor at index <idx>, or nullptr if the input tensor is
268 // not available at the time of shape inference.
input_tensor(int idx)269 const Tensor* input_tensor(int idx) {
270 // Mark that this idx was requested.
271 request_input_tensor(idx);
272 return input_tensors_[idx];
273 }
274
275 // Notifies the shape refiner that the value of the tensor at index <idx>
276 // is needed. The shape refiner tries to statically compute this tensor,
277 // and if successful re-runs the shape function with this tensor available
278 // in the call to 'input_tensor(idx)'.
request_input_tensor(int idx)279 void request_input_tensor(int idx) { requested_input_tensor_[idx] = true; }
280
281 // Returns true iff input_tensor(idx) was called by the shape function.
requested_input_tensor(int idx)282 bool requested_input_tensor(int idx) const {
283 return requested_input_tensor_[idx];
284 }
285
286 // Notifies the shape refiner that the value of the tensor at index <idx>
287 // as a partial shape is needed. The shape refiner tries to statically compute
288 // this, and if successful re-runs the shape function with the
289 // computed PartialTensorShape available in the call to
290 // 'MakeShapeFromShapeTensor(idx, handle)' or
291 // 'MakeShapeFromShapeTensorTreatScalarAsUnknownShape(idx, handle)'.
request_input_tensor_as_partial_shape(int idx)292 void request_input_tensor_as_partial_shape(int idx) {
293 requested_input_tensor_as_partial_shape_[idx] = true;
294 }
295
296 // Returns true if MakeShapeFromInputTensor was called but the constant
297 // input_tensor was not present.
requested_input_tensor_as_partial_shape(int idx)298 bool requested_input_tensor_as_partial_shape(int idx) const {
299 return requested_input_tensor_as_partial_shape_[idx];
300 }
301
set_input_tensors(const std::vector<const Tensor * > & input_tensors)302 void set_input_tensors(const std::vector<const Tensor*>& input_tensors) {
303 input_tensors_ = input_tensors;
304 }
305
set_input_tensors_as_shapes(const std::vector<ShapeHandle> & input_tensors_as_shapes)306 void set_input_tensors_as_shapes(
307 const std::vector<ShapeHandle>& input_tensors_as_shapes) {
308 input_tensors_as_shapes_ = input_tensors_as_shapes;
309 }
310
input_tensors_as_shapes()311 const std::vector<ShapeHandle>& input_tensors_as_shapes() const {
312 return input_tensors_as_shapes_;
313 }
314
output(int64 idx)315 ShapeHandle output(int64 idx) const { return outputs_.at(idx); }
set_output(int idx,ShapeHandle shape)316 void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
317 Status set_output(StringPiece output_name,
318 const std::vector<ShapeHandle>& shapes);
319
num_outputs()320 int num_outputs() const { return outputs_.size(); }
output(int idx)321 ShapeHandle output(int idx) const { return outputs_.at(idx); }
322 Status output(StringPiece output_name,
323 std::vector<ShapeHandle>* output) const;
324
attrs()325 const AttrSlice& attrs() const { return attrs_; }
326
327 // idx can be negative for an offset from end of dimensions.
328 // idx must be in the range [-1 * s.rank, s.rank).
Dim(ShapeHandle s,int64 idx)329 DimensionHandle Dim(ShapeHandle s, int64 idx) {
330 if (!s.Handle() || s->rank_ == kUnknownRank) {
331 return UnknownDim();
332 }
333 return DimKnownRank(s, idx);
334 }
335 // As above, but asserts that the rank of the shape is known.
DimKnownRank(ShapeHandle s,int64 idx)336 static DimensionHandle DimKnownRank(ShapeHandle s, int64 idx) {
337 CHECK_NE(s->rank_, kUnknownRank);
338 if (idx < 0) {
339 return s->dims_[s->dims_.size() + idx];
340 }
341 return s->dims_[idx];
342 }
343
Rank(ShapeHandle s)344 static int32 Rank(ShapeHandle s) {
345 return s.IsSet() ? s->rank_ : kUnknownRank;
346 }
RankKnown(ShapeHandle s)347 static bool RankKnown(ShapeHandle s) {
348 return (s.IsSet() && (Rank(s) != kUnknownRank));
349 }
Value(DimensionOrConstant d)350 static inline int64 Value(DimensionOrConstant d) {
351 return d.dim.IsSet() ? d.dim->value_ : d.val;
352 }
ValueKnown(DimensionOrConstant d)353 static inline bool ValueKnown(DimensionOrConstant d) {
354 return Value(d) != kUnknownDim;
355 }
356
357 // Fills the output proto with the shape defined by the handle.
358 // "proto" is expected to be empty prior to the call.
359 void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto);
360
361 // Returns true if the rank and all dimensions of the Shape are known.
362 bool FullyDefined(ShapeHandle s);
363
364 // Returns the total number of elements, or an unknown dimension for an
365 // incomplete shape.
366 DimensionHandle NumElements(ShapeHandle s);
367
368 std::string DebugString(ShapeHandle s);
369 std::string DebugString(DimensionHandle d);
370 std::string DebugString(const ShapeAndType& shape_and_type);
371 std::string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types);
372
373 // Describes the whole context, for debugging purposes.
374 std::string DebugString() const;
375
376 // If <shape> has rank <rank>, or its rank is unknown, return OK and return
377 // the shape with asserted rank in <*out>. Otherwise return an error.
378 //
379 // Note that <*out> may be set to <shape>.
380 Status WithRank(ShapeHandle shape, int64 rank,
381 ShapeHandle* out) TF_MUST_USE_RESULT;
382 Status WithRankAtLeast(ShapeHandle shape, int64 rank,
383 ShapeHandle* out) TF_MUST_USE_RESULT;
384 Status WithRankAtMost(ShapeHandle shape, int64 rank,
385 ShapeHandle* out) TF_MUST_USE_RESULT;
386
387 // If <dim> has value <value>, or its value is unknown, returns OK and returns
388 // the dimension with asserted value in <*out>. Otherwise returns an error.
389 //
390 // Note that <*out> may be set to <dim>.
391 Status WithValue(DimensionHandle dim, int64 value,
392 DimensionHandle* out) TF_MUST_USE_RESULT;
393
394 // Merges <s0> and <s1> and returns the merged shape in <*out>. See
395 // 'MergeInput' function for full details and examples.
396 Status Merge(ShapeHandle s0, ShapeHandle s1,
397 ShapeHandle* out) TF_MUST_USE_RESULT;
398
399 // Asserts that <s>'s rank >= <prefix>'s rank, and the first
400 // <prefix.rank> dimensions of <s> are compatible with the dimensions of
401 // <prefix>.
402 // Returns the merged results in <*s_out> and <*prefix_out>.
403 Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out,
404 ShapeHandle* prefix_out) TF_MUST_USE_RESULT;
405
406 // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0>
407 // and <d1> have incompatible values, returns an error.
408 //
409 // Note that <*out> may be set to <d0> or <d1>.
410 Status Merge(DimensionHandle d0, DimensionHandle d1,
411 DimensionHandle* out) TF_MUST_USE_RESULT;
412
413 // Returns in <*out> a sub-shape of <s> with dimensions [start:].
414 // <start> can be negative to index from the end of the shape. If <start> >
415 // rank of <s>, then an empty subshape is returned.
416 Status Subshape(ShapeHandle s, int64 start,
417 ShapeHandle* out) TF_MUST_USE_RESULT;
418
419 // Returns in <*out> a sub-shape of <s>, with dimensions [start:end].
420 // <start> and <end> can be negative, to index from the end of the shape.
421 // <start> and <end> are set to the rank of <s> if > rank of <s>.
422 Status Subshape(ShapeHandle s, int64 start, int64 end,
423 ShapeHandle* out) TF_MUST_USE_RESULT;
424
425 // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride].
426 // <start> and <end> can be negative, to index from the end of the shape.
427 // <start> and <end> are set to the rank of <s> if > rank of <s>.
428 // <stride> can be negative, to reverse the <s>.
429 Status Subshape(ShapeHandle s, int64 start, int64 end, int64 stride,
430 ShapeHandle* out) TF_MUST_USE_RESULT;
431
432 // Returns in <*out> the result of appending the dimensions of <s2> to those
433 // of <s1>.
434 Status Concatenate(ShapeHandle s1, ShapeHandle s2,
435 ShapeHandle* out) TF_MUST_USE_RESULT;
436
437 // Returns in <out> the shape from replacing <s.dim[dim_index]> with
438 // <new_dim>.
439 Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim,
440 ShapeHandle* out) TF_MUST_USE_RESULT;
441
442 // Returns a new shape with the given dims. The returned value is owned by
443 // this context.
444 ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
445 ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims);
446
447 // Returns a new unknown shape.
448 ShapeHandle UnknownShape();
449
450 // Returns a shape with specified rank but unknown dims.
451 ShapeHandle UnknownShapeOfRank(int64 rank);
452
453 // Returns a new shape of zero dimensions.
454 ShapeHandle Scalar();
455
456 // Returns a new shape of one dimension.
457 ShapeHandle Vector(DimensionOrConstant dim);
458
459 // Returns a new shape of two dimensions.
460 ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2);
461
462 // Returns in <out> a new shape whose dimension sizes come from input tensor
463 // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor. If
464 // the input tensor is NULL, then an unknown shape is returned.
465 Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
466
467 // Like the function above, but treats scalar values as unknown
468 // shapes. **NOTE** If the scalar is statically known, its value
469 // must be -1 or an error is returned.
470 Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,
471 ShapeHandle* out);
472
473 // Returns in <out> a new shape corresponding to <proto>.
474 Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
475 ShapeHandle* out);
476
477 // Returns in <out> a new shape corresponding to <partial_shape>.
478 Status MakeShapeFromPartialTensorShape(
479 const PartialTensorShape& partial_shape, ShapeHandle* out);
480
481 // Returns in <out> a new shape corresponding to <shape>.
482 Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out);
483
484 // Returns a new dimension of the given size. The returned value is owned by
485 // this context.
MakeDim(DimensionOrConstant d)486 inline DimensionHandle MakeDim(DimensionOrConstant d) {
487 return shape_manager_.MakeDim(d);
488 }
489
UnknownDim()490 inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
491
492 // Returns in <val> a scalar value from an input tensor <t>. The input tensor
493 // must be a 0-dimensional int32 or int64 tensor. Caller must ensure that the
494 // input tensor is not NULL.
495 Status GetScalarFromTensor(const Tensor* t, int64* val);
496
497 // Returns in <val> a scalar value from a 1D input tensor <t> with int32 or
498 // int64 elements. Caller must ensure that the input tensor is not NULL.
499 Status GetScalarFromTensor(const Tensor* t, int64 idx, int64* val);
500
501 // Returns a new dimension whose value is given by a scalar input tensor.
502 // The input tensor must be in host memory, since it is dereferenced to get
503 // the value.
504 Status MakeDimForScalarInput(int idx, DimensionHandle* out);
505
506 // Returns a new dimension whose value is given by a scalar input tensor.
507 // This allows for a negative input dimension given the rank of a separate
508 // tensor. This rank can be negative if unknown.
509 // The input tensor must be in host memory, since it is dereferenced to get
510 // the value.
511 Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
512 DimensionHandle* out);
513
514 // Look up the attr being evaluated with name attr_name and set *value to its
515 // value. If no attr with attr_name is found in def(), or the attr does not
516 // have a matching type, a non-ok status will be returned.
517 template <class T>
518 Status GetAttr(StringPiece attr_name, T* value) const;
519
520 // Returns in <out> the result of dividing <dividend> by <divisor>.
521 // Returns an error if <divisor> is not positive or if <evenly_divisible>
522 // and <divisor> does not evenly divide <dividend>.
523 Status Divide(DimensionHandle dividend, DimensionOrConstant divisor,
524 bool evenly_divisible, DimensionHandle* out);
525
526 // Returns in <out> the sum of <first> and <second>.
527 Status Add(DimensionHandle first, DimensionOrConstant second,
528 DimensionHandle* out);
529
530 // Returns in <out> the dimension that is <first> minus <second>.
531 Status Subtract(DimensionHandle first, DimensionOrConstant second,
532 DimensionHandle* out);
533
534 // Returns in <out> the product of <first> and <second>.
535 Status Multiply(DimensionHandle first, DimensionOrConstant second,
536 DimensionHandle* out);
537
538 // Returns in <out> the minimum of <first> and <second>. If either <first> or
539 // <second> is zero the results is zero. Otherwise, if either <first> or
540 // <second> is unknown the results is unknown.
541 Status Min(DimensionHandle first, DimensionOrConstant second,
542 DimensionHandle* out);
543
544 // Returns in <out> the maximum of <first> and <second>. If either <first> or
545 // <second> is unknown the results is unknown.
546 Status Max(DimensionHandle first, DimensionOrConstant second,
547 DimensionHandle* out);
548
construction_status()549 Status construction_status() const { return construction_status_; }
550
551 // Methods to propagate shape and dtype on edges of handles. Handles are the
552 // dtype DT_RESOURCE which can be used to access state stored in a
553 // ResourceManager. When ops (such as variables) consume these handles to
554 // produce tensors they might need to know side-information about the shapes
555 // and dtypes of tensors which can be accessed via the handle. These methods
556 // propagate that information. Output handle dtypes and shapes are ignored if
557 // the output tensor is not of type DT_RESOURCE.
558
559 // Merge the stored shapes and types corresponding to the input handle in
560 // position idx with the specified shapes and types. This requires idx to be
561 // in the [0, num_inputs) range.
562 //
563 // If the merge is successful and any of the new shapes differs from the old
564 // one, or any of the old dtypes was DT_INVALID, store the new shapes and
565 // return true. Return false otherwise.
566 //
567 // See 'MergeInput' function for full details and examples.
568 bool MergeInputHandleShapesAndTypes(
569 int idx,
570 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
571
572 // As MergeInputHandleShapesAndTypes, but for an output.
573 bool MergeOutputHandleShapesAndTypes(
574 int idx,
575 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
576
577 // Relaxes the stored shapes and types corresponding to the input handle in
578 // position idx with the specified shapes and types. This requires idx to be
579 // in the [0, num_inputs) range.
580 //
581 // If the relax is successful (sizes are the same, old dtypes match new ones
582 // or are DT_INVALID), then store the relaxed shapes and return true.
583 // Return false otherwise.
584 //
585 // See 'RelaxInput' function for full details and examples.
586 bool RelaxInputHandleShapesAndMergeTypes(
587 int idx,
588 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
589
590 // As RelaxInputHandleShapesAndTypes, but for an output.
591 bool RelaxOutputHandleShapesAndMergeTypes(
592 int idx,
593 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
594
set_input_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)595 void set_input_handle_shapes_and_types(
596 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
597 input_handle_shapes_and_types_[idx] =
598 absl::make_unique<std::vector<ShapeAndType>>(shapes_and_types);
599 }
600
601 // Returns the output handle shapes and types, for the resource tensor output
602 // at index <idx>. Returns NULL if the shape and types were never set.
output_handle_shapes_and_types(int idx)603 const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
604 return output_handle_shapes_and_types_[idx].get();
605 }
606
607 // Returns the inputs handle shapes and types, for the resource tensor output
608 // at index <idx>. Returns NULL if the shape and types were not available.
input_handle_shapes_and_types(int idx)609 const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
610 return input_handle_shapes_and_types_[idx].get();
611 }
612
set_output_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)613 void set_output_handle_shapes_and_types(
614 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
615 output_handle_shapes_and_types_[idx].reset(
616 new std::vector<ShapeAndType>(shapes_and_types));
617 }
618
619 // Note that shape functions should usually call MakeShapeFromShapeTensor,
620 // as it does more analysis to provide partial shapes.
621 //
622 // Returns in <out> a new shape whose dimension sizes come from tensor <t>.
623 // The tensor must be a 1-dimensional int32 or int64 tensor. If <t> is NULL,
624 // then an unknown shape is returned.
625 Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
626 ShapeHandle* out);
627
graph_def_version()628 int graph_def_version() const { return graph_def_version_; }
629
MergedShapes()630 const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const {
631 return merged_shapes_;
632 }
MergedDims()633 const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims()
634 const {
635 return merged_dims_;
636 }
637
638 // Adds new outputs; useful when mutating the graph.
639 Status ExpandOutputs(int new_output_size);
640
641 private:
642 // Creates and stores shapes for use in InferenceContext.
643 class ShapeManager {
644 public:
645 ShapeManager();
646 ~ShapeManager();
647
648 // Returns a new shape with the given dims. The returned value is owned by
649 // this class.
650 ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
651
652 // Returns a new unknown shape.
653 ShapeHandle UnknownShape();
654
655 // Returns a new dimension of the given size. The returned value
656 // is owned by this class.
MakeDim(DimensionOrConstant d)657 inline DimensionHandle MakeDim(DimensionOrConstant d) {
658 if (d.dim.IsSet()) {
659 return d.dim;
660 } else {
661 all_dims_.push_back(new Dimension(d.val));
662 return all_dims_.back();
663 }
664 }
665
666 private:
667 std::vector<Shape*> all_shapes_; // values are owned.
668 std::vector<Dimension*> all_dims_; // values are owned.
669 };
670
671 friend class ::tensorflow::grappler::GraphProperties;
672
673 friend class ShapeInferenceTest; // For testing Relax functions.
674 friend class ShapeInferenceTestutil; // For testing shapes.
675
676 // Shared initialization across the two constructors. Remove
677 // once we get rid of one of them.
678 void PreInputInit(const OpDef& op_def,
679 const std::vector<const Tensor*>& input_tensors,
680 const std::vector<ShapeHandle>& input_tensors_as_shapes);
681 void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
682 input_handle_data);
683
ReturnUnknownShape(ShapeHandle * out)684 Status ReturnUnknownShape(ShapeHandle* out) {
685 *out = UnknownShape();
686 return Status::OK();
687 }
ReturnCreatedShape(const std::vector<DimensionHandle> & dims,ShapeHandle * out)688 Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims,
689 ShapeHandle* out) {
690 *out = MakeShape(dims);
691 return Status::OK();
692 }
693
694 // Adds additional context to the given status.
695 Status AttachContext(const Status& status);
696
697 // Relaxes an existing value <d_old> with a new value <d_new> and returns the
698 // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible
699 // values, returns an error.
700 //
701 // Note that <*out> may be set to <d_old> or <d_new>.
702 void Relax(DimensionHandle d_old, DimensionHandle d_new,
703 DimensionHandle* out);
704 // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the
705 // relaxed shape in <*out>. See 'RelaxInput' function for full details and
706 // examples.
707 void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out);
708
709 // Used to implement MergeInputHandleShapesAndTypes and
710 // MergeOutputHandleShapesAndTypes.
711 bool MergeHandleShapesAndTypes(
712 const std::vector<ShapeAndType>& shapes_and_types,
713 std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
714 // Used to implement RelaxInputHandleShapesAndMergeTypes and
715 // RelaxOutputHandleShapesAndMergeTypes.
716 bool RelaxHandleShapesAndMergeTypes(
717 const std::vector<ShapeAndType>& shapes_and_types,
718 std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
719
720 // Forget all the previous merged shapes and dims.
ForgetMerges()721 void ForgetMerges() {
722 merged_shapes_.clear();
723 merged_dims_.clear();
724 }
725
726 // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor.
727 Status InternalMakeShapeFromTensor(
728 bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
729 ShapeHandle tensor_shape, ShapeHandle* out);
730
731 ShapeManager shape_manager_;
732
733 // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
734 // `shape_manager_`.
735 std::vector<ShapeHandle> inputs_;
736 std::vector<const Tensor*> input_tensors_;
737 std::vector<bool> requested_input_tensor_;
738 std::vector<ShapeHandle> outputs_;
739 // Can have fewer elements than inputs_.
740 std::vector<ShapeHandle> input_tensors_as_shapes_;
741 std::vector<bool> requested_input_tensor_as_partial_shape_;
742
743 // input_handle_shapes_and_types_[i] is the list of shape/type pairs available
744 // through the resource handle passed along input i of the node.
745 //
746 // Values may be NULL.
747 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
748 input_handle_shapes_and_types_;
749
750 // output_handle_shapes_and_types_[i] is the list of shape/type pairs
751 // available through the resource handle passed along output i of the node.
752 //
753 // Values may be NULL.
754 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
755 output_handle_shapes_and_types_;
756
757 const int graph_def_version_;
758 AttrSlice attrs_;
759 NameRangeMap input_name_map_;
760 NameRangeMap output_name_map_;
761
762 // An error set during construction. TODO(cwhipkey): remove when test
763 // constructor is removed.
764 Status construction_status_;
765
766 // Pair of shape or dim handles that are equivalent, ie that represent the
767 // same underlying shape of dimension. Note that for each pair at least one of
768 // the handles must contain an unknown shape, since we don't keep track of
769 // known shapes or dims here.
770 std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_;
771 std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_;
772
773 TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
774 };
775
776 // -----------------------------------------------------------------------------
777 // Template and inline method implementations, please ignore
778
Dimension()779 inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
Dimension(int64 value)780 inline Dimension::Dimension(int64 value) : value_(value) {
781 DCHECK(value >= 0 || value == InferenceContext::kUnknownDim)
782 << "Dimension must be non-negative or equal to "
783 "InferenceContext::kUnknownDim but got "
784 << value;
785 }
786
Shape()787 inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
Shape(const std::vector<DimensionHandle> & dims)788 inline Shape::Shape(const std::vector<DimensionHandle>& dims)
789 : rank_(dims.size()), dims_(dims) {}
790
DimensionOrConstant(DimensionHandle dim)791 inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim)
792 : dim(dim) {
793 DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension.";
794 }
795
DimensionOrConstant(int64 val)796 inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) {
797 DCHECK(val >= 0 || val == InferenceContext::kUnknownDim)
798 << "Dimension must be non-negative or equal to "
799 "InferenceContext::kUnknownDim but got "
800 << val;
801 }
802
803 template <class T>
GetAttr(StringPiece attr_name,T * value)804 Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
805 return GetNodeAttr(attrs_, attr_name, value);
806 }
807
808 } // namespace shape_inference
809 } // namespace tensorflow
810
811 #endif // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
812