1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
17
18 #include <vector>
19
20 #include "absl/memory/memory.h"
21 #include "tensorflow/core/framework/full_type.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/macros.h"
26
27 namespace tensorflow {
28
29 namespace grappler {
30 class GraphProperties;
31 class SymbolicShapeManager;
32 } // namespace grappler
33
34 namespace shape_inference {
35
36 struct DimensionOrConstant;
37 class InferenceContext;
38
39 // This header contains the InferenceContext that is used to infer the shape of
40 // the results of an operation or flag an operation with invalid inputs (e.g.,
41 // mismatched shapes for elementwise operation) by ShapeRefiner. The shape of an
42 // operation is computed using the OpShapeInferenceFn set via SetShapeFn in op
43 // registration. The OpShapeInferenceFn uses a per op InferenceContext populated
44 // with input shapes to compute resultant shape (including resource shapes).
45 //
46 // The shapes created in the InferenceContext are bound to the lifetime of the
47 // InferenceContext in which it was created. E.g., in
48 //
49 // ```c++
50 // InferenceContext c;
51 // // Below a ShapeHandle is returned by MakeShape, while UnknownDim returns a
52 // // DimensionHandle.
53 // ShapeHandle in0 = c.MakeShape({10, c.UnknownDim()});
54 // ```
55 //
56 // the ShapeHandle `in0` (and the nested unknown dim inside) is only valid while
57 // `c` is in scope, as ShapeHandle and DimensionHandle are effectively
58 // wrappers around pointers stored inside the context with the lifetime of the
59 // value pointed to managed by the context. The result from one operation's
60 // inference context will be passed as input to the inference of consumer
61 // operations. Hence it is possible for ShapeHandles produced by inference on a
62 // node to consist of ShapeHandles owned by different InferenceContexts. While
63 // inferring the shapes of a Graph, the InferenceContext of all nodes/operations
64 // in the Graph remain resident for the lifetime of the Graph (e.g, there is a
65 // map from each node to its InferenceContext, technically its
66 // ExtendedInferencContext which additionally stores the element types of inputs
67 // & outputs, which remains resident).
68 //
69 // For functions, the body of the function is instantiated as a Graph while
70 // inferring the result shapes of a function call node. The rules above apply
71 // while the function's shape is being inferred, but the contexts associated
72 // with nodes in the function body are released once the function call's
73 // resultant shapes are inferred. The shapes of results returned by a function
74 // are propagated to the InferenceContext of the function call's op (which is
75 // associated with a Graph of nodes whose shape is being inferred) as the return
76 // values of a function call node are the inputs of its consumer, but the return
77 // values are produced by nodes inside the function whose InferenceContexts
78 // (which owns the values pointed to by ShapeHandle and DimensionHandle) are
79 // reclaimed after inferring function result shapes. Recursive user-defined
80 // function are not supported hence inference of functions are fully nested with
81 // the InferenceContext's of function calls forming a stack.
82 //
83 // For example, consider the following call and function:
84 //
85 // ```python
86 // @tf.function
87 // def g(st):
88 // d = tf.add(st, st)
89 // return d
90 //
91 // @tf.function
92 // def f():
93 // st = tf.A()
94 // result = g(st)
95 // return h(result)
96 // ```
97 //
98 // During inference of f, the shape of `A` will be inferred and the results from
99 // its InferenceContext used as inputs to function call `g(st)`. The call node
100 // will have an InferenceContext created (call it outer context) and the graph
101 // corresponding to function `g` will be instantiated. The result shape of the
102 // Arg nodes of the function will be associated with input from outer context.
103 // During inference of `g` (for the callsite `g(st)` in `f`), the
104 // InferenceContext of all nodes inside `g` will remain alive. Thus, when shape
105 // of `tf.add` is computed it may rely on all inputs. Once the RetVal nodes of a
106 // function is reached, we know the shape of its input may correspond to a shape
107 // queried in the outer context and it is explicitly copied to outer context. In
108 // this case that means that the shape of `d` is copied to the InferenceContext
109 // of `g(st)` and so when `h(result)` is executed this shape may be queried.
110 // Furthermore, no shapes computed due to call `g(st)` can be queried post this
111 // point and, as the RetVal shapes have been coppied into outer context, all
112 // InferenceContexts associated with nodes in function `g` instantiated for
113 // `g(st)` may be and are released.
114
115 // Dimension values are accessed through InferenceContext.
116 class Dimension {
117 private:
118 Dimension();
119 Dimension(int64_t value);
~Dimension()120 ~Dimension() {}
121
122 const int64_t value_;
123
124 friend class InferenceContext;
125 friend class ShapeManager;
126 TF_DISALLOW_COPY_AND_ASSIGN(Dimension);
127 };
128
129 class DimensionHandle {
130 public:
DimensionHandle()131 DimensionHandle() {}
SameHandle(DimensionHandle d)132 bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; }
Handle()133 std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
134
135 private:
DimensionHandle(const Dimension * dim)136 DimensionHandle(const Dimension* dim) { ptr_ = dim; }
137
138 const Dimension* operator->() const { return ptr_; }
IsSet()139 bool IsSet() const { return ptr_ != nullptr; }
140
141 const Dimension* ptr_ = nullptr;
142
143 friend struct DimensionOrConstant;
144 friend class InferenceContext;
145 friend class ShapeInferenceTest;
146 friend class ShapeInferenceTestutil;
147 friend class ::tensorflow::grappler::GraphProperties;
148 friend class ::tensorflow::grappler::SymbolicShapeManager;
149
150 // Intentionally copyable.
151 };
152
153 // Shape rank and dimensions are accessed through InferenceContext.
154 class Shape {
155 private:
156 Shape();
157 Shape(const std::vector<DimensionHandle>& dims);
~Shape()158 ~Shape() {}
159
160 const int32 rank_;
161 const std::vector<DimensionHandle> dims_;
162
163 friend class InferenceContext;
164 friend class ::tensorflow::grappler::SymbolicShapeManager;
165
166 TF_DISALLOW_COPY_AND_ASSIGN(Shape);
167 };
168
169 class ShapeHandle {
170 public:
ShapeHandle()171 ShapeHandle() {}
SameHandle(ShapeHandle s)172 bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; }
Handle()173 std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
174
175 private:
ShapeHandle(const Shape * shape)176 ShapeHandle(const Shape* shape) { ptr_ = shape; }
177 const Shape* operator->() const { return ptr_; }
IsSet()178 bool IsSet() const { return ptr_ != nullptr; }
179
180 const Shape* ptr_ = nullptr;
181
182 friend class InferenceContext;
183 friend class ShapeInferenceTest;
184 friend class ShapeInferenceTestutil;
185 friend class ::tensorflow::grappler::SymbolicShapeManager;
186
187 // Intentionally copyable.
188 };
189
190 // Struct used to allow functions to take DimensionHandle or a dimension value.
191 // Not meant to be constructed directly.
192 struct DimensionOrConstant {
193 public:
194 // Intentionally not explicit.
195 DimensionOrConstant(DimensionHandle dim);
196
197 // val must be non-negative or InferenceContext::kUnknownDim.
198 DimensionOrConstant(int64_t val);
199
200 // dim takes precedence. If dim != nullptr, val is ignored.
201 DimensionHandle dim;
202 int64_t val;
203
204 private:
205 DimensionOrConstant();
206 };
207
208 struct ShapeAndType {
ShapeAndTypeShapeAndType209 ShapeAndType() {}
ShapeAndTypeShapeAndType210 ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
211 // TODO(mdan): Remove dtype from constructor, and use type_ instead.
212 // dtype is kept here for backward compatibiity. Its information should
213 // be redundant to that in type;
ShapeAndTypeShapeAndType214 ShapeAndType(ShapeHandle s, DataType t, FullTypeDef type_)
215 : shape(s), dtype(t), type(type_) {}
216
217 ShapeHandle shape;
218 DataType dtype = DT_INVALID;
219 FullTypeDef type;
220 };
221
222 // Shape inference functions registered on ops in REGISTER_OP implement
223 // their shape functions in terms of this InferenceContext. An InferenceContext
224 // is created by the framework and passed to a shape inference function. The
225 // shape inference function calls functions on the context, and should call
226 // set_output() to set the shape on all outputs.
227 //
228 // To infer shapes for user-defined functions see ShapeRefiner.
229 //
230 // All Shape* and Dimension* returned by functions of InferenceContext are owned
231 // by the InferenceContext.
232 class InferenceContext {
233 public:
234 static constexpr int64_t kUnknownDim = -1;
235 static constexpr int32_t kUnknownRank = -1;
236
237 // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
238 //
239 // Elements of <input_tensors_as_shapes> are used for when a shape function
240 // makes a call to MakeShapeFromShapeTensor; in particular, when the
241 // input_tensors[i] is nullptr but the shape represented by it is partially
242 // known from analysis of the graph.
243 // <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
244 // Values of <input_tensors_as_shapes> do not need to outlive the context.
245 InferenceContext(int graph_def_version, const AttrSlice& attrs,
246 const OpDef& op_def,
247 const std::vector<ShapeHandle>& input_shapes,
248 const std::vector<const Tensor*>& input_tensors,
249 const std::vector<ShapeHandle>& input_tensors_as_shapes,
250 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
251 input_handle_shapes_and_types);
252
253 // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
254 //
255 // Elements of <input_tensors_as_shapes> are used for when a shape
256 // function makes a call to MakeShapeFromShapeTensor; in particular, when
257 // the input_tensors[i] is nullptr but the shape represented by it is
258 // partially known from analysis of the graph. <input_tensors_as_shapes>
259 // can have fewer elements than <input_shapes>. Values of
260 // <input_tensors_as_shapes> do not need to outlive the context.
261 InferenceContext(
262 int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
263 const std::vector<PartialTensorShape>& input_shapes,
264 const std::vector<const Tensor*>& input_tensors,
265 const std::vector<PartialTensorShape>& input_tensors_as_shapes,
266 const std::vector<std::unique_ptr<
267 std::vector<std::pair<PartialTensorShape, DataType>>>>&
268 input_handle_shapes_and_types);
269
270 ~InferenceContext();
271
272 // Runs the shape inference function 'fn' with 'this' as the
273 // argument, returns the status of the inference.
274 //
275 // On error, additional context is provided in the error message.
276 Status Run(
277 const std::function<Status(shape_inference::InferenceContext* c)>& fn);
278
279 // Merge the stored shape of the input in position idx with <shape> according
280 // to the following rules:
281 //
282 // - If the ShapeHandles are the same or <shape> is unknown, there will be no
283 // change. Otherwise if the stored shape is unknown, the new shape will be
284 // <shape>.
285 // - If both shapes are known, then they must have the same rank.
286 // - For any one dimension, if the values for that dimension in both shapes
287 // are known, then the values must match.
288 // - If one shape has equal or more information than the other shape in every
289 // dimension, the new shape will become the shape with more information.
290 // - Example: merging [2,?] and [?,2] results in [2,2]
291 // - Example: [2,2] cannot be merged with [1,2]
292 //
293 // This requires idx to be in the [0, num_inputs) range. If the merge is
294 // successful, return true. Return false otherwise.
MergeInput(int idx,ShapeHandle shape)295 bool MergeInput(int idx, ShapeHandle shape) {
296 ShapeHandle new_shape;
297 if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
298 inputs_[idx] = new_shape;
299 return true;
300 }
301
302 // Relax the stored shape of the input in position idx with <shape> according
303 // to the following rules:
304 //
305 // - If the ShapeHandles are the same then the stored shape will be returned.
306 // - If either of the ShapeHandles are unknown, then a new UnknownShape will
307 // be returned. A new shape must be returned because we cannot claim that
308 // the resulting shape is necessarily the same as either of the input
309 // shapes.
310 // - If the shapes both have known ranks but their ranks are different, a new
311 // UnknownShape will be returned.
312 // - For any one dimension, if the value for that dimension in either of the
313 // shapes is unknown, a new shape will be returned with a new UnknownDim in
314 // that dimension.
315 // - For any one dimension, if the values for that dimension in both shapes
316 // are known but do not match, a new shape will be returned with a new
317 // UnknownDim in that dimension.
318 // - If both shapes have the same known rank and match in every dimension,
319 // the stored shape will be returned.
320 // - Example: relaxing [2,?] and [?,2] results in [?,?]
321 // - Example: relaxing [2,2] and [3,2] results in [?,2]
322 // - Example: relaxing [2,2] with [1,2,3] results in ?
323 //
324 // This requires idx to be in the [0, num_inputs) range. If the relax is
325 // successful and the new shape differs from the old one, store the new
326 // shape and return true. Return false otherwise.
RelaxInput(int idx,ShapeHandle shape)327 bool RelaxInput(int idx, ShapeHandle shape) {
328 ShapeHandle new_shape;
329 Relax(inputs_[idx], shape, &new_shape);
330 if (inputs_[idx].SameHandle(new_shape)) {
331 return false;
332 }
333 inputs_[idx] = new_shape;
334 return true;
335 }
336
SetInput(int idx,ShapeHandle shape)337 void SetInput(int idx, ShapeHandle shape) { inputs_[idx] = shape; }
338
input(int64_t idx)339 ShapeHandle input(int64_t idx) const { return inputs_[idx]; }
340 Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
num_inputs()341 int num_inputs() const { return inputs_.size(); }
342
343 // Returns the input tensor at index <idx>, or nullptr if the input tensor is
344 // not available at the time of shape inference.
input_tensor(int idx)345 const Tensor* input_tensor(int idx) {
346 // Mark that this idx was requested.
347 request_input_tensor(idx);
348 return input_tensors_[idx];
349 }
350
351 // Notifies the shape refiner that the value of the tensor at index <idx>
352 // is needed. The shape refiner tries to statically compute this tensor,
353 // and if successful re-runs the shape function with this tensor available
354 // in the call to 'input_tensor(idx)'.
request_input_tensor(int idx)355 void request_input_tensor(int idx) { requested_input_tensor_[idx] = true; }
356
357 // Returns true iff input_tensor(idx) was called by the shape function.
requested_input_tensor(int idx)358 bool requested_input_tensor(int idx) const {
359 return requested_input_tensor_[idx];
360 }
361
362 // Notifies the shape refiner that the value of the tensor at index <idx>
363 // as a partial shape is needed. The shape refiner tries to statically compute
364 // this, and if successful re-runs the shape function with the
365 // computed PartialTensorShape available in the call to
366 // 'MakeShapeFromShapeTensor(idx, handle)' or
367 // 'MakeShapeFromShapeTensorTreatScalarAsUnknownShape(idx, handle)'.
request_input_tensor_as_partial_shape(int idx)368 void request_input_tensor_as_partial_shape(int idx) {
369 requested_input_tensor_as_partial_shape_[idx] = true;
370 }
371
372 // Returns true if MakeShapeFromInputTensor was called but the constant
373 // input_tensor was not present.
requested_input_tensor_as_partial_shape(int idx)374 bool requested_input_tensor_as_partial_shape(int idx) const {
375 return requested_input_tensor_as_partial_shape_[idx];
376 }
377
set_input_tensors(const std::vector<const Tensor * > & input_tensors)378 void set_input_tensors(const std::vector<const Tensor*>& input_tensors) {
379 input_tensors_ = input_tensors;
380 }
381
set_input_tensors_as_shapes(const std::vector<ShapeHandle> & input_tensors_as_shapes)382 void set_input_tensors_as_shapes(
383 const std::vector<ShapeHandle>& input_tensors_as_shapes) {
384 input_tensors_as_shapes_ = input_tensors_as_shapes;
385 }
386
input_tensors_as_shapes()387 const std::vector<ShapeHandle>& input_tensors_as_shapes() const {
388 return input_tensors_as_shapes_;
389 }
390
output(int64_t idx)391 ShapeHandle output(int64_t idx) const { return outputs_.at(idx); }
set_output(int idx,ShapeHandle shape)392 void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
393 Status set_output(StringPiece output_name,
394 const std::vector<ShapeHandle>& shapes);
395
num_outputs()396 int num_outputs() const { return outputs_.size(); }
output(int idx)397 ShapeHandle output(int idx) const { return outputs_.at(idx); }
398 Status output(StringPiece output_name,
399 std::vector<ShapeHandle>* output) const;
400
401 // Returns the value for attribute named `attr_name`.
GetAttr(StringPiece attr_name,const AttrValue ** attr_value)402 Status GetAttr(StringPiece attr_name, const AttrValue** attr_value) const {
403 return attrs_.Find(attr_name, attr_value);
404 }
GetAttr(StringPiece attr_name)405 const AttrValue* GetAttr(StringPiece attr_name) const {
406 return attrs_.Find(attr_name);
407 }
408
ret_types()409 const FullTypeDef& ret_types() const { return ret_types_; }
410
411 // idx can be negative for an offset from end of dimensions.
412 // idx must be in the range [-1 * s.rank, s.rank).
Dim(ShapeHandle s,int64_t idx)413 DimensionHandle Dim(ShapeHandle s, int64_t idx) {
414 if (!s.Handle() || s->rank_ == kUnknownRank) {
415 return UnknownDim();
416 }
417 return DimKnownRank(s, idx);
418 }
419 // As above, but asserts that the rank of the shape is known.
DimKnownRank(ShapeHandle s,int64_t idx)420 static DimensionHandle DimKnownRank(ShapeHandle s, int64_t idx) {
421 CHECK_NE(s->rank_, kUnknownRank);
422 if (idx < 0) {
423 return s->dims_[s->dims_.size() + idx];
424 }
425 return s->dims_[idx];
426 }
427
Rank(ShapeHandle s)428 static int32 Rank(ShapeHandle s) {
429 return s.IsSet() ? s->rank_ : kUnknownRank;
430 }
RankKnown(ShapeHandle s)431 static bool RankKnown(ShapeHandle s) {
432 return (s.IsSet() && (Rank(s) != kUnknownRank));
433 }
Value(DimensionOrConstant d)434 static inline int64_t Value(DimensionOrConstant d) {
435 return d.dim.IsSet() ? d.dim->value_ : d.val;
436 }
ValueKnown(DimensionOrConstant d)437 static inline bool ValueKnown(DimensionOrConstant d) {
438 return Value(d) != kUnknownDim;
439 }
440
441 // Fills the output proto with the shape defined by the handle.
442 // "proto" is expected to be empty prior to the call.
443 void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto);
444 TensorShapeProto ShapeHandleToProto(ShapeHandle handle);
445
446 // Returns true if the rank and all dimensions of the Shape are known.
447 bool FullyDefined(ShapeHandle s);
448
449 // Returns the total number of elements, or an unknown dimension for an
450 // incomplete shape.
451 DimensionHandle NumElements(ShapeHandle s);
452
453 std::string DebugString(ShapeHandle s);
454 std::string DebugString(DimensionHandle d);
455 std::string DebugString(const ShapeAndType& shape_and_type);
456 std::string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types);
457
458 // Describes the whole context, for debugging purposes.
459 std::string DebugString() const;
460
461 // If <shape> has rank <rank>, or its rank is unknown, return OK and return
462 // the shape with asserted rank in <*out>. Otherwise return an error.
463 //
464 // Note that <*out> may be set to <shape>.
465 Status WithRank(ShapeHandle shape, int64_t rank,
466 ShapeHandle* out) TF_MUST_USE_RESULT;
467 Status WithRankAtLeast(ShapeHandle shape, int64_t rank,
468 ShapeHandle* out) TF_MUST_USE_RESULT;
469 Status WithRankAtMost(ShapeHandle shape, int64_t rank,
470 ShapeHandle* out) TF_MUST_USE_RESULT;
471
472 // If <dim> has value <value>, or its value is unknown, returns OK and returns
473 // the dimension with asserted value in <*out>. Otherwise returns an error.
474 //
475 // Note that <*out> may be set to <dim>.
476 Status WithValue(DimensionHandle dim, int64_t value,
477 DimensionHandle* out) TF_MUST_USE_RESULT;
478
479 // Merges <s0> and <s1> and returns the merged shape in <*out>. See
480 // 'MergeInput' function for full details and examples.
481 Status Merge(ShapeHandle s0, ShapeHandle s1,
482 ShapeHandle* out) TF_MUST_USE_RESULT;
483
484 // Asserts that <s>'s rank >= <prefix>'s rank, and the first
485 // <prefix.rank> dimensions of <s> are compatible with the dimensions of
486 // <prefix>.
487 // Returns the merged results in <*s_out> and <*prefix_out>.
488 Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out,
489 ShapeHandle* prefix_out) TF_MUST_USE_RESULT;
490
491 // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0>
492 // and <d1> have incompatible values, returns an error.
493 //
494 // Note that <*out> may be set to <d0> or <d1>.
495 Status Merge(DimensionHandle d0, DimensionHandle d1,
496 DimensionHandle* out) TF_MUST_USE_RESULT;
497
498 // Returns in <*out> a sub-shape of <s> with dimensions [start:].
499 // <start> can be negative to index from the end of the shape. If <start> >
500 // rank of <s>, then an empty subshape is returned.
501 Status Subshape(ShapeHandle s, int64_t start,
502 ShapeHandle* out) TF_MUST_USE_RESULT;
503
504 // Returns in <*out> a sub-shape of <s>, with dimensions [start:end].
505 // <start> and <end> can be negative, to index from the end of the shape.
506 // <start> and <end> are set to the rank of <s> if > rank of <s>.
507 Status Subshape(ShapeHandle s, int64_t start, int64_t end,
508 ShapeHandle* out) TF_MUST_USE_RESULT;
509
510 // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride].
511 // <start> and <end> can be negative, to index from the end of the shape.
512 // <start> and <end> are set to the rank of <s> if > rank of <s>.
513 // <stride> can be negative, to reverse the <s>.
514 Status Subshape(ShapeHandle s, int64_t start, int64_t end, int64_t stride,
515 ShapeHandle* out) TF_MUST_USE_RESULT;
516
517 // Returns in <*out> the result of appending the dimensions of <s2> to those
518 // of <s1>.
519 Status Concatenate(ShapeHandle s1, ShapeHandle s2,
520 ShapeHandle* out) TF_MUST_USE_RESULT;
521
522 // Returns in <out> the shape from replacing <s.dim[dim_index]> with
523 // <new_dim>.
524 Status ReplaceDim(ShapeHandle s, int64_t dim_index, DimensionHandle new_dim,
525 ShapeHandle* out) TF_MUST_USE_RESULT;
526
527 // Returns a new shape with the given dims. The returned value is owned by
528 // this context.
529 ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
530 ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims);
531
532 // Returns a new unknown shape.
533 ShapeHandle UnknownShape();
534
535 // Returns a shape with specified rank but unknown dims.
536 ShapeHandle UnknownShapeOfRank(int64_t rank);
537
538 // Returns a new shape of zero dimensions.
539 ShapeHandle Scalar();
540
541 // Returns a new shape of one dimension.
542 ShapeHandle Vector(DimensionOrConstant dim);
543
544 // Returns a new shape of two dimensions.
545 ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2);
546
547 // Returns in <out> a new shape whose dimension sizes come from input tensor
548 // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor. If
549 // the input tensor is NULL, then an unknown shape is returned.
550 Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
551
552 // Like the function above, but treats scalar values as unknown
553 // shapes. **NOTE** If the scalar is statically known, its value
554 // must be -1 or an error is returned.
555 Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,
556 ShapeHandle* out);
557
558 // Returns in <out> a new shape corresponding to <proto>.
559 Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
560 ShapeHandle* out);
561
562 // Returns in <out> a new shape corresponding to <partial_shape>.
563 Status MakeShapeFromPartialTensorShape(
564 const PartialTensorShape& partial_shape, ShapeHandle* out);
565
566 // Returns in <out> a new shape corresponding to <shape>.
567 Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out);
568 StatusOr<ShapeHandle> MakeShapeFromShapeTensor(const TensorShape& shape);
569
570 // Returns a new dimension of the given size. The returned value is owned by
571 // this context.
MakeDim(DimensionOrConstant d)572 inline DimensionHandle MakeDim(DimensionOrConstant d) {
573 return shape_manager_.MakeDim(d);
574 }
575
UnknownDim()576 inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
577
578 // Returns in <val> a scalar value from an input tensor <t>. The input tensor
579 // must be a 0-dimensional int32 or int64 tensor. Caller must ensure that the
580 // input tensor is not NULL.
581 Status GetScalarFromTensor(const Tensor* t, int64_t* val);
582
583 // Returns in <val> a scalar value from a 1D input tensor <t> with int32 or
584 // int64 elements. Caller must ensure that the input tensor is not NULL.
585 Status GetScalarFromTensor(const Tensor* t, int64_t idx, int64_t* val);
586
587 // Returns a new dimension whose value is given by a scalar input tensor.
588 // The input tensor must be in host memory, since it is dereferenced to get
589 // the value.
590 Status MakeDimForScalarInput(int idx, DimensionHandle* out);
591
592 // Returns a new dimension whose value is given by a scalar input tensor.
593 // This allows for a negative input dimension given the rank of a separate
594 // tensor. This rank can be negative if unknown.
595 // The input tensor must be in host memory, since it is dereferenced to get
596 // the value.
597 Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
598 DimensionHandle* out);
599
600 // Look up the attr being evaluated with name attr_name and set *value to its
601 // value. If no attr with attr_name is found in def(), or the attr does not
602 // have a matching type, a non-ok status will be returned.
603 template <class T>
604 Status GetAttr(StringPiece attr_name, T* value) const;
605
606 // Returns in <out> the result of dividing <dividend> by <divisor>.
607 // Returns an error if <divisor> is not positive or if <evenly_divisible>
608 // and <divisor> does not evenly divide <dividend>.
609 Status Divide(DimensionHandle dividend, DimensionOrConstant divisor,
610 bool evenly_divisible, DimensionHandle* out);
611
612 // Returns in <out> the sum of <first> and <second>.
613 Status Add(DimensionHandle first, DimensionOrConstant second,
614 DimensionHandle* out);
615
616 // Returns in <out> the dimension that is <first> minus <second>.
617 Status Subtract(DimensionHandle first, DimensionOrConstant second,
618 DimensionHandle* out);
619
620 // Returns in <out> the product of <first> and <second>.
621 Status Multiply(DimensionHandle first, DimensionOrConstant second,
622 DimensionHandle* out);
623
624 // Returns in <out> the minimum of <first> and <second>. If either <first> or
625 // <second> is zero the results is zero. Otherwise, if either <first> or
626 // <second> is unknown the results is unknown.
627 Status Min(DimensionHandle first, DimensionOrConstant second,
628 DimensionHandle* out);
629
630 // Returns in <out> the maximum of <first> and <second>. If either <first> or
631 // <second> is unknown the results is unknown.
632 Status Max(DimensionHandle first, DimensionOrConstant second,
633 DimensionHandle* out);
634
construction_status()635 Status construction_status() const { return construction_status_; }
636
637 // Methods to propagate shape and dtype on edges of handles. Handles are the
638 // dtype DT_RESOURCE which can be used to access state stored in a
639 // ResourceManager. When ops (such as variables) consume these handles to
640 // produce tensors they might need to know side-information about the shapes
641 // and dtypes of tensors which can be accessed via the handle. These methods
642 // propagate that information. Output handle dtypes and shapes are ignored if
643 // the output tensor is not of type DT_RESOURCE.
644
645 // Merge the stored shapes and types corresponding to the input handle in
646 // position idx with the specified shapes and types. This requires idx to be
647 // in the [0, num_inputs) range.
648 //
649 // If the merge is successful and any of the new shapes differs from the old
650 // one, or any of the old dtypes was DT_INVALID, store the new shapes and
651 // return true. Return false otherwise.
652 //
653 // See 'MergeInput' function for full details and examples.
654 bool MergeInputHandleShapesAndTypes(
655 int idx,
656 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
657
658 // As MergeInputHandleShapesAndTypes, but for an output.
659 bool MergeOutputHandleShapesAndTypes(
660 int idx,
661 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
662
663 // Relaxes the stored shapes and types corresponding to the input handle in
664 // position idx with the specified shapes and types. This requires idx to be
665 // in the [0, num_inputs) range.
666 //
667 // If the relax is successful (sizes are the same, old dtypes match new ones
668 // or are DT_INVALID), then store the relaxed shapes and return true.
669 // Return false otherwise.
670 //
671 // See 'RelaxInput' function for full details and examples.
672 bool RelaxInputHandleShapesAndMergeTypes(
673 int idx,
674 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
675
676 // As RelaxInputHandleShapesAndTypes, but for an output.
677 bool RelaxOutputHandleShapesAndMergeTypes(
678 int idx,
679 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
680
set_input_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)681 void set_input_handle_shapes_and_types(
682 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
683 input_handle_shapes_and_types_[idx] =
684 absl::make_unique<std::vector<ShapeAndType>>(shapes_and_types);
685 }
686
687 // Returns the output handle shapes and types, for the resource tensor output
688 // at index <idx>. Returns NULL if the shape and types were never set.
output_handle_shapes_and_types(int idx)689 const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
690 return output_handle_shapes_and_types_[idx].get();
691 }
692
693 // Returns the inputs handle shapes and types, for the resource tensor input
694 // at index <idx>. Returns NULL if the shape and types were not available.
input_handle_shapes_and_types(int idx)695 const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
696 return input_handle_shapes_and_types_[idx].get();
697 }
698
set_output_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)699 void set_output_handle_shapes_and_types(
700 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
701 output_handle_shapes_and_types_[idx] =
702 absl::make_unique<std::vector<ShapeAndType>>(shapes_and_types);
703 }
704
705 // Note that shape functions should usually call MakeShapeFromShapeTensor,
706 // as it does more analysis to provide partial shapes.
707 //
708 // Returns in <out> a new shape whose dimension sizes come from tensor <t>.
709 // The tensor must be a 1-dimensional int32 or int64 tensor. If <t> is NULL,
710 // then an unknown shape is returned.
711 Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
712 ShapeHandle* out);
713
graph_def_version()714 int graph_def_version() const { return graph_def_version_; }
715
MergedShapes()716 const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const {
717 return merged_shapes_;
718 }
MergedDims()719 const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims()
720 const {
721 return merged_dims_;
722 }
723
724 // Adds new outputs; useful when mutating the graph.
725 Status ExpandOutputs(int new_output_size);
726
727 private:
728 // Creates and stores shapes for use in InferenceContext.
729 class ShapeManager {
730 public:
731 ShapeManager();
732 ~ShapeManager();
733
734 // Returns a new shape with the given dims. The returned value is owned by
735 // this class.
736 ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
737
738 // Returns a new unknown shape.
739 ShapeHandle UnknownShape();
740
741 // Returns a new dimension of the given size. The returned value
742 // is owned by this class.
MakeDim(DimensionOrConstant d)743 inline DimensionHandle MakeDim(DimensionOrConstant d) {
744 if (d.dim.IsSet()) {
745 return d.dim;
746 } else {
747 all_dims_.push_back(new Dimension(d.val));
748 return all_dims_.back();
749 }
750 }
751
752 private:
753 std::vector<Shape*> all_shapes_; // values are owned.
754 std::vector<Dimension*> all_dims_; // values are owned.
755 };
756
757 friend class ::tensorflow::grappler::GraphProperties;
758
759 friend class ShapeInferenceTest; // For testing Relax functions.
760 friend class ShapeInferenceTestutil; // For testing shapes.
761
762 // Shared initialization across the two constructors. Remove
763 // once we get rid of one of them.
764 void PreInputInit(const OpDef& op_def,
765 const std::vector<const Tensor*>& input_tensors,
766 const std::vector<ShapeHandle>& input_tensors_as_shapes);
767 void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
768 input_handle_data);
769
ReturnUnknownShape(ShapeHandle * out)770 Status ReturnUnknownShape(ShapeHandle* out) {
771 *out = UnknownShape();
772 return OkStatus();
773 }
ReturnCreatedShape(const std::vector<DimensionHandle> & dims,ShapeHandle * out)774 Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims,
775 ShapeHandle* out) {
776 *out = MakeShape(dims);
777 return OkStatus();
778 }
779
780 // Adds additional context to the given status.
781 Status AttachContext(const Status& status);
782
783 // Relaxes an existing value <d_old> with a new value <d_new> and returns the
784 // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible
785 // values, returns an error.
786 //
787 // Note that <*out> may be set to <d_old> or <d_new>.
788 void Relax(DimensionHandle d_old, DimensionHandle d_new,
789 DimensionHandle* out);
790 // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the
791 // relaxed shape in <*out>. See 'RelaxInput' function for full details and
792 // examples.
793 void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out);
794
795 // Used to implement MergeInputHandleShapesAndTypes and
796 // MergeOutputHandleShapesAndTypes.
797 bool MergeHandleShapesAndTypes(
798 const std::vector<ShapeAndType>& shapes_and_types,
799 std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
800 // Used to implement RelaxInputHandleShapesAndMergeTypes and
801 // RelaxOutputHandleShapesAndMergeTypes.
802 bool RelaxHandleShapesAndMergeTypes(
803 const std::vector<ShapeAndType>& shapes_and_types,
804 std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
805
806 // Forget all the previous merged shapes and dims.
ForgetMerges()807 void ForgetMerges() {
808 merged_shapes_.clear();
809 merged_dims_.clear();
810 }
811
812 // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor.
813 Status InternalMakeShapeFromTensor(
814 bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
815 ShapeHandle tensor_shape, ShapeHandle* out);
816
817 ShapeManager shape_manager_;
818
819 // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
820 // `shape_manager_`.
821 std::vector<ShapeHandle> inputs_;
822 std::vector<const Tensor*> input_tensors_;
823 std::vector<bool> requested_input_tensor_;
824 std::vector<ShapeHandle> outputs_;
825 // Can have fewer elements than inputs_.
826 std::vector<ShapeHandle> input_tensors_as_shapes_;
827 std::vector<bool> requested_input_tensor_as_partial_shape_;
828
829 // input_handle_shapes_and_types_[i] is the list of shape/type pairs available
830 // through the resource handle passed along input i of the node.
831 //
832 // Values may be NULL.
833 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
834 input_handle_shapes_and_types_;
835
836 // output_handle_shapes_and_types_[i] is the list of shape/type pairs
837 // available through the resource handle passed along output i of the node.
838 //
839 // Values may be NULL.
840 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
841 output_handle_shapes_and_types_;
842
843 // Return types for the node this context is associated with. This information
844 // is to eventually consolidate all the dtype and shape info, allowing for
845 // output_handle_shapes_and_types_ to be removed.
846 FullTypeDef ret_types_;
847
848 const int graph_def_version_;
849 AttrSlice attrs_;
850 NameRangeMap input_name_map_;
851 NameRangeMap output_name_map_;
852
853 // An error set during construction. TODO(cwhipkey): remove when test
854 // constructor is removed.
855 Status construction_status_;
856
857 // Pair of shape or dim handles that are equivalent, ie that represent the
858 // same underlying shape of dimension. Note that for each pair at least one of
859 // the handles must contain an unknown shape, since we don't keep track of
860 // known shapes or dims here.
861 std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_;
862 std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_;
863
864 TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
865 };
866
867 // -----------------------------------------------------------------------------
868 // Template and inline method implementations, please ignore
869
Dimension()870 inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
Dimension(int64_t value)871 inline Dimension::Dimension(int64_t value) : value_(value) {
872 DCHECK(value >= 0 || value == InferenceContext::kUnknownDim)
873 << "Dimension must be non-negative or equal to "
874 "InferenceContext::kUnknownDim but got "
875 << value;
876 }
877
Shape()878 inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
Shape(const std::vector<DimensionHandle> & dims)879 inline Shape::Shape(const std::vector<DimensionHandle>& dims)
880 : rank_(dims.size()), dims_(dims) {}
881
DimensionOrConstant(DimensionHandle dim)882 inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim)
883 : dim(dim) {
884 DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension.";
885 }
886
DimensionOrConstant(int64_t val)887 inline DimensionOrConstant::DimensionOrConstant(int64_t val) : val(val) {
888 DCHECK(val >= 0 || val == InferenceContext::kUnknownDim)
889 << "Dimension must be non-negative or equal to "
890 "InferenceContext::kUnknownDim but got "
891 << val;
892 }
893
894 template <class T>
GetAttr(StringPiece attr_name,T * value)895 Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
896 return GetNodeAttr(attrs_, attr_name, value);
897 }
898
899 } // namespace shape_inference
900 } // namespace tensorflow
901
902 #endif // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
903