1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
17
18 #include <vector>
19
20 #include "tensorflow/core/framework/node_def_util.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/lib/gtl/inlined_vector.h"
25 #include "tensorflow/core/platform/macros.h"
26
27 namespace tensorflow {
28
29 class ShapeRefiner;
30 class ShapeRefinerTest;
31
32 namespace grappler {
33 class GraphProperties;
34 class SymbolicShapeManager;
35 } // namespace grappler
36
37 namespace shape_inference {
38
39 struct DimensionOrConstant;
40 class InferenceContext;
41
42 // Dimension values are accessed through InferenceContext.
43 class Dimension {
44 private:
45 Dimension();
46 Dimension(int64 value);
~Dimension()47 ~Dimension() {}
48
49 const int64 value_;
50
51 friend class InferenceContext;
52 friend class ShapeManager;
53 TF_DISALLOW_COPY_AND_ASSIGN(Dimension);
54 };
55
56 class DimensionHandle {
57 public:
DimensionHandle()58 DimensionHandle() {}
SameHandle(DimensionHandle d)59 bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; }
Handle()60 std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
61
62 private:
DimensionHandle(const Dimension * dim)63 DimensionHandle(const Dimension* dim) { ptr_ = dim; }
64
65 const Dimension* operator->() const { return ptr_; }
IsSet()66 bool IsSet() const { return ptr_ != nullptr; }
67
68 const Dimension* ptr_ = nullptr;
69
70 friend struct DimensionOrConstant;
71 friend class InferenceContext;
72 friend class ShapeInferenceTest;
73 friend class ShapeInferenceTestutil;
74 friend class ::tensorflow::ShapeRefinerTest;
75 friend class ShapeManager;
76 friend class ::tensorflow::grappler::GraphProperties;
77 friend class ::tensorflow::grappler::SymbolicShapeManager;
78
79 // Intentionally copyable.
80 };
81
82 // Shape rank and dimensions are accessed through InferenceContext.
83 class Shape {
84 private:
85 Shape();
86 Shape(const std::vector<DimensionHandle>& dims);
~Shape()87 ~Shape() {}
88
89 const int32 rank_;
90 const std::vector<DimensionHandle> dims_;
91
92 friend class InferenceContext;
93 friend class ShapeManager;
94 friend class ::tensorflow::grappler::SymbolicShapeManager;
95
96 TF_DISALLOW_COPY_AND_ASSIGN(Shape);
97 };
98
99 class ShapeHandle {
100 public:
ShapeHandle()101 ShapeHandle() {}
SameHandle(ShapeHandle s)102 bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; }
Handle()103 std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
104
105 private:
ShapeHandle(const Shape * shape)106 ShapeHandle(const Shape* shape) { ptr_ = shape; }
107 const Shape* operator->() const { return ptr_; }
IsSet()108 bool IsSet() const { return ptr_ != nullptr; }
109
110 const Shape* ptr_ = nullptr;
111
112 friend class InferenceContext;
113 friend class ShapeInferenceTest;
114 friend class ShapeInferenceTestutil;
115 friend class ::tensorflow::ShapeRefinerTest;
116 friend class ShapeManager;
117 friend class ::tensorflow::grappler::SymbolicShapeManager;
118
119 // Intentionally copyable.
120 };
121
122 // Struct used to allow functions to take DimensionHandle or a dimension value.
123 // Not meant to be constructed directly.
124 struct DimensionOrConstant {
125 public:
126 // Intentionally not explicit.
127 DimensionOrConstant(DimensionHandle dim);
128
129 // val must be non-negative or InferenceContext::kUnknownDim.
130 DimensionOrConstant(int64 val);
131
132 // dim takes precedence. If dim != nullptr, val is ignored.
133 DimensionHandle dim;
134 int64 val;
135
136 private:
137 DimensionOrConstant();
138 };
139
140 struct ShapeAndType {
ShapeAndTypeShapeAndType141 ShapeAndType() {}
ShapeAndTypeShapeAndType142 ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
143
144 ShapeHandle shape;
145 DataType dtype = DT_INVALID;
146 };
147
148 // Shape inference functions registered on ops in REGISTER_OP implement
149 // their shape functions in terms of this InferenceContext. An InferenceContext
150 // is created by the framework and passed to a shape inference function. The
151 // shape inference function calls functions on the context, and should call
152 // set_output() to set the shape on all outputs.
153 //
154 // To infer shapes for user-defined functions see ShapeRefiner.
155 //
156 // All Shape* and Dimension* returned by functions of InferenceContext are owned
157 // by the InferenceContext.
158 class InferenceContext {
159 public:
160 static constexpr int64 kUnknownDim = -1;
161 static constexpr int32 kUnknownRank = -1;
162
163 // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
164 //
165 // Elements of <input_tensors_as_shapes> are used for when a shape function
166 // makes a call to MakeShapeFromShapeTensor; in particular, when the
167 // input_tensors[i] is nullptr but the shape represented by it is partially
168 // known from analysis of the graph.
169 // <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
170 // Values of <input_tensors_as_shapes> do not need to outlive the context.
171 //
172 // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
173 InferenceContext(int graph_def_version, const NodeDef* node_def,
174 const OpDef& op_def,
175 const std::vector<ShapeHandle>& input_shapes,
176 const std::vector<const Tensor*>& input_tensors,
177 const std::vector<ShapeHandle>& input_tensors_as_shapes,
178 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
179 input_handle_shapes_and_types);
180
181 // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
182 //
183 // Elements of <input_tensors_as_shapes> are used for when a shape
184 // function makes a call to MakeShapeFromShapeTensor; in particular, when
185 // the input_tensors[i] is nullptr but the shape represented by it is
186 // partially known from analysis of the graph. <input_tensors_as_shapes>
187 // can have fewer elements than <input_shapes>. Values of
188 // <input_tensors_as_shapes> do not need to outlive the context.
189 //
190 // REQUIRES: <node_def> is not NULL, and must outlive the
191 // InferenceContext.
192 InferenceContext(
193 int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
194 const std::vector<TensorShapeProto>& input_shapes,
195 const std::vector<const Tensor*>& input_tensors,
196 const std::vector<TensorShapeProto>& input_tensors_as_shapes,
197 const std::vector<
198 std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
199 input_handle_shapes_and_types);
200
201 // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
202 //
203 // Elements of <input_tensors_as_shapes> are used for when a shape
204 // function makes a call to MakeShapeFromShapeTensor; in particular, when
205 // the input_tensors[i] is nullptr but the shape represented by it is
206 // partially known from analysis of the graph. <input_tensors_as_shapes>
207 // can have fewer elements than <input_shapes>. Values of
208 // <input_tensors_as_shapes> do not need to outlive the context.
209 //
210 // REQUIRES: <node_def> is not NULL, and must outlive the
211 // InferenceContext.
212 InferenceContext(
213 int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
214 const std::vector<PartialTensorShape>& input_shapes,
215 const std::vector<const Tensor*>& input_tensors,
216 const std::vector<PartialTensorShape>& input_tensors_as_shapes,
217 const std::vector<std::unique_ptr<
218 std::vector<std::pair<PartialTensorShape, DataType>>>>&
219 input_handle_shapes_and_types);
220
221 ~InferenceContext();
222
223 // Runs the shape inference function 'fn' with 'this' as the
224 // argument, returns the status of the inference.
225 //
226 // On error, additional context is provided in the error message.
227 Status Run(
228 const std::function<Status(shape_inference::InferenceContext* c)>& fn);
229
230 // Merge the stored shape of the input in position idx with <shape> according
231 // to the following rules:
232 //
233 // - If the ShapeHandles are the same or <shape> is unknown, there will be no
234 // change. Otherwise if the stored shape is unknown, the new shape will be
235 // <shape>.
236 // - If both shapes are known, then they must have the same rank.
237 // - For any one dimension, if the values for that dimension in both shapes
238 // are known, then the values must match.
239 // - If one shape has equal or more information than the other shape in every
240 // dimension, the new shape will become the shape with more information.
241 // - Example: merging [2,?] and [?,2] results in [2,2]
242 // - Example: [2,2] cannot be merged with [1,2]
243 //
244 // This requires idx to be in the [0, num_inputs) range. If the merge is
245 // successful, return true. Return false otherwise.
MergeInput(int idx,ShapeHandle shape)246 bool MergeInput(int idx, ShapeHandle shape) {
247 ShapeHandle new_shape;
248 if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
249 inputs_[idx] = new_shape;
250 return true;
251 }
252
253 // Relax the stored shape of the input in position idx with <shape> according
254 // to the following rules:
255 //
256 // - If the ShapeHandles are the same then the stored shape will be returned.
257 // - If either of the ShapeHandles are unknown, then a new UnknownShape will
258 // be returned. A new shape must be returned because we cannot claim that
259 // the resulting shape is necessarily the same as either of the input
260 // shapes.
261 // - If the shapes both have known ranks but their ranks are different, a new
262 // UnknownShape will be returned.
263 // - For any one dimension, if the value for that dimension in either of the
264 // shapes is unknown, a new shape will be returned with a new UnknownDim in
265 // that dimension.
266 // - For any one dimension, if the values for that dimension in both shapes
267 // are known but do not match, a new shape will be returned with a new
268 // UnknownDim in that dimension.
269 // - If both shapes have the same known rank and match in every dimension,
270 // the stored shape will be returned.
271 // - Example: relaxing [2,?] and [?,2] results in [?,?]
272 // - Example: relaxing [2,2] and [3,2] results in [?,2]
273 // - Example: relaxing [2,2] with [1,2,3] results in ?
274 //
275 // This requires idx to be in the [0, num_inputs) range. If the relax is
276 // successful and the new shape differs from the old one, store the new
277 // shape and return true. Return false otherwise.
RelaxInput(int idx,ShapeHandle shape)278 bool RelaxInput(int idx, ShapeHandle shape) {
279 ShapeHandle new_shape;
280 Relax(inputs_[idx], shape, &new_shape);
281 if (inputs_[idx].SameHandle(new_shape)) {
282 return false;
283 }
284 inputs_[idx] = new_shape;
285 return true;
286 }
287
SetInput(int idx,ShapeHandle shape)288 void SetInput(int idx, ShapeHandle shape) { inputs_[idx] = shape; }
289
input(int64 idx)290 ShapeHandle input(int64 idx) const { return inputs_[idx]; }
291 Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
num_inputs()292 int num_inputs() const { return inputs_.size(); }
293
294 // Returns the input tensor at index <idx>, or nullptr if the input tensor is
295 // not available at the time of shape inference.
input_tensor(int idx)296 const Tensor* input_tensor(int idx) {
297 // Mark that this idx was requested.
298 requested_input_tensor_[idx] = true;
299 return input_tensors_[idx];
300 }
301
302 // Returns true iff input_tensor(idx) was called by the shape function.
requested_input_tensor(int idx)303 bool requested_input_tensor(int idx) const {
304 return requested_input_tensor_[idx];
305 }
306
307 // Returns true if MakeShapeFromInputTensor was called but the constant
308 // input_tensor was not present.
requested_input_tensor_as_partial_shape(int idx)309 bool requested_input_tensor_as_partial_shape(int idx) const {
310 return requested_input_tensor_as_partial_shape_[idx];
311 }
312
set_input_tensors(const std::vector<const Tensor * > & input_tensors)313 void set_input_tensors(const std::vector<const Tensor*>& input_tensors) {
314 input_tensors_ = input_tensors;
315 }
316
set_input_tensors_as_shapes(const std::vector<ShapeHandle> & input_tensors_as_shapes)317 void set_input_tensors_as_shapes(
318 const std::vector<ShapeHandle>& input_tensors_as_shapes) {
319 input_tensors_as_shapes_ = input_tensors_as_shapes;
320 }
321
input_tensors_as_shapes()322 const std::vector<ShapeHandle>& input_tensors_as_shapes() const {
323 return input_tensors_as_shapes_;
324 }
325
output(int64 idx)326 ShapeHandle output(int64 idx) const { return outputs_.at(idx); }
set_output(int idx,ShapeHandle shape)327 void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
328 Status set_output(StringPiece output_name,
329 const std::vector<ShapeHandle>& shapes);
330
num_outputs()331 int num_outputs() const { return outputs_.size(); }
output(int idx)332 ShapeHandle output(int idx) const { return outputs_.at(idx); }
333 Status output(StringPiece output_name,
334 std::vector<ShapeHandle>* output) const;
335
attrs()336 AttrSlice attrs() const { return AttrSlice(*node_def_); }
337
338 string op() const;
339
340 // idx can be negative for an offset from end of dimensions.
341 // idx must be in the range [-1 * s.rank, s.rank).
Dim(ShapeHandle s,int64 idx)342 DimensionHandle Dim(ShapeHandle s, int64 idx) {
343 if (s->rank_ == kUnknownRank) {
344 return UnknownDim();
345 }
346 return DimKnownRank(s, idx);
347 }
348 // As above, but asserts that the rank of the shape is known.
DimKnownRank(ShapeHandle s,int64 idx)349 static DimensionHandle DimKnownRank(ShapeHandle s, int64 idx) {
350 CHECK_NE(s->rank_, kUnknownRank);
351 if (idx < 0) {
352 return s->dims_[s->dims_.size() + idx];
353 }
354 return s->dims_[idx];
355 }
356
Rank(ShapeHandle s)357 static int32 Rank(ShapeHandle s) {
358 DCHECK(s.IsSet());
359 return s.IsSet() ? s->rank_ : kUnknownRank;
360 }
RankKnown(ShapeHandle s)361 static bool RankKnown(ShapeHandle s) {
362 return (s.IsSet() && (Rank(s) != kUnknownRank));
363 }
Value(DimensionOrConstant d)364 static inline int64 Value(DimensionOrConstant d) {
365 return d.dim.IsSet() ? d.dim->value_ : d.val;
366 }
ValueKnown(DimensionOrConstant d)367 static inline bool ValueKnown(DimensionOrConstant d) {
368 return Value(d) != kUnknownDim;
369 }
370
371 // Fills the output proto with the shape defined by the handle.
372 // "proto" is expected to be empty prior to the call.
373 void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto);
374
375 // Returns true if the rank and all dimensions of the Shape are known.
376 bool FullyDefined(ShapeHandle s);
377
378 // Returns the total number of elements, or an unknown dimension for an
379 // incomplete shape.
380 DimensionHandle NumElements(ShapeHandle s);
381
382 string DebugString(ShapeHandle s);
383 string DebugString(DimensionHandle d);
384 string DebugString(const ShapeAndType& shape_and_type);
385 string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types);
386
387 // Describes the whole context, for debugging purposes.
388 string DebugString() const;
389
390 // If <shape> has rank <rank>, or its rank is unknown, return OK and return
391 // the shape with asserted rank in <*out>. Otherwise return an error.
392 //
393 // Note that <*out> may be set to <shape>.
394 Status WithRank(ShapeHandle shape, int64 rank,
395 ShapeHandle* out) TF_MUST_USE_RESULT;
396 Status WithRankAtLeast(ShapeHandle shape, int64 rank,
397 ShapeHandle* out) TF_MUST_USE_RESULT;
398 Status WithRankAtMost(ShapeHandle shape, int64 rank,
399 ShapeHandle* out) TF_MUST_USE_RESULT;
400
401 // If <dim> has value <value>, or its value is unknown, returns OK and returns
402 // the dimension with asserted value in <*out>. Otherwise returns an error.
403 //
404 // Note that <*out> may be set to <dim>.
405 Status WithValue(DimensionHandle dim, int64 value,
406 DimensionHandle* out) TF_MUST_USE_RESULT;
407
408 // Merges <s0> and <s1> and returns the merged shape in <*out>. See
409 // 'MergeInput' function for full details and examples.
410 Status Merge(ShapeHandle s0, ShapeHandle s1,
411 ShapeHandle* out) TF_MUST_USE_RESULT;
412
413 // Asserts that <s>'s rank >= <prefix>'s rank, and the first
414 // <prefix.rank> dimensions of <s> are compatible with the dimensions of
415 // <prefix>.
416 // Returns the merged results in <*s_out> and <*prefix_out>.
417 Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out,
418 ShapeHandle* prefix_out) TF_MUST_USE_RESULT;
419
420 // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0>
421 // and <d1> have incompatible values, returns an error.
422 //
423 // Note that <*out> may be set to <d0> or <d1>.
424 Status Merge(DimensionHandle d0, DimensionHandle d1,
425 DimensionHandle* out) TF_MUST_USE_RESULT;
426
427 // Returns in <*out> a sub-shape of <s> with dimensions [start:].
428 // <start> can be negative to index from the end of the shape. If <start> >
429 // rank of <s>, then an empty subshape is returned.
430 Status Subshape(ShapeHandle s, int64 start,
431 ShapeHandle* out) TF_MUST_USE_RESULT;
432
433 // Returns in <*out> a sub-shape of <s>, with dimensions [start:end].
434 // <start> and <end> can be negative, to index from the end of the shape.
435 // <start> and <end> are set to the rank of <s> if > rank of <s>.
436 Status Subshape(ShapeHandle s, int64 start, int64 end,
437 ShapeHandle* out) TF_MUST_USE_RESULT;
438
439 // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride].
440 // <start> and <end> can be negative, to index from the end of the shape.
441 // <start> and <end> are set to the rank of <s> if > rank of <s>.
442 // <stride> can be negative, to reverse the <s>.
443 Status Subshape(ShapeHandle s, int64 start, int64 end, int64 stride,
444 ShapeHandle* out) TF_MUST_USE_RESULT;
445
446 // Returns in <*out> the result of appending the dimensions of <s2> to those
447 // of <s1>.
448 Status Concatenate(ShapeHandle s1, ShapeHandle s2,
449 ShapeHandle* out) TF_MUST_USE_RESULT;
450
451 // Returns in <out> the shape from replacing <s.dim[dim_index]> with
452 // <new_dim>.
453 Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim,
454 ShapeHandle* out) TF_MUST_USE_RESULT;
455
456 // Returns a new shape with the given dims. The returned value is owned by
457 // this context.
458 ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
459 ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims);
460
461 // Returns a new unknown shape.
462 ShapeHandle UnknownShape();
463
464 // Returns a shape with specified rank but unknown dims.
465 ShapeHandle UnknownShapeOfRank(int64 rank);
466
467 // Returns a new shape of zero dimensions.
468 ShapeHandle Scalar();
469
470 // Returns a new shape of one dimension.
471 ShapeHandle Vector(DimensionOrConstant dim);
472
473 // Returns a new shape of two dimensions.
474 ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2);
475
476 // Returns in <out> a new shape whose dimension sizes come from input tensor
477 // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor. If
478 // the input tensor is NULL, then an unknown shape is returned.
479 Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
480
481 // Like the function above, but treats scalar values as unknown
482 // shapes. **NOTE** If the scalar is statically known, its value
483 // must be -1 or an error is returned.
484 Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,
485 ShapeHandle* out);
486
487 // Returns in <out> a new shape corresponding to <proto>.
488 Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
489 ShapeHandle* out);
490
491 // Returns in <out> a new shape corresponding to <partial_shape>.
492 Status MakeShapeFromPartialTensorShape(
493 const PartialTensorShape& partial_shape, ShapeHandle* out);
494
495 // Returns in <out> a new shape corresponding to <shape>.
496 Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out);
497
498 // Returns a new dimension of the given size. The returned value is owned by
499 // this context.
MakeDim(DimensionOrConstant d)500 inline DimensionHandle MakeDim(DimensionOrConstant d) {
501 return shape_manager_.MakeDim(d);
502 }
503
UnknownDim()504 inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
505
506 // Returns in <val> a scalar value from an input tensor <t>. The input tensor
507 // must be a 1-dimensional int32 or int64 tensor. Caller must ensure that the
508 // input tensor is not NULL.
509 Status GetScalarFromTensor(const Tensor* t, int64* val);
510
511 // Returns a new dimension whose value is given by a scalar input tensor.
512 // The input tensor must be in host memory, since it is dereferenced to get
513 // the value.
514 Status MakeDimForScalarInput(int idx, DimensionHandle* out);
515
516 // Returns a new dimension whose value is given by a scalar input tensor.
517 // This allows for a negative input dimension given the rank of a separate
518 // tensor. This rank can be negative if unknown.
519 // The input tensor must be in host memory, since it is dereferenced to get
520 // the value.
521 Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
522 DimensionHandle* out);
523
524 // Look up the attr for the NodeDef being evaluated with name attr_name and
525 // set *value to its value. If no attr with attr_name is found in def(), or
526 // the attr does not have a matching type, a non-ok status will be returned.
527 template <class T>
528 Status GetAttr(StringPiece attr_name, T* value) const;
529
530 // Returns in <out> the result of dividing <dividend> by <divisor>.
531 // Returns an error if <divisor> is not positive or if <evenly_divisible>
532 // and <divisor> does not evenly divide <dividend>.
533 Status Divide(DimensionHandle dividend, DimensionOrConstant divisor,
534 bool evenly_divisible, DimensionHandle* out);
535
536 // Returns in <out> the sum of <first> and <second>.
537 Status Add(DimensionHandle first, DimensionOrConstant second,
538 DimensionHandle* out);
539
540 // Returns in <out> the dimension that is <first> minus <second>.
541 Status Subtract(DimensionHandle first, DimensionOrConstant second,
542 DimensionHandle* out);
543
544 // Returns in <out> the product of <first> and <second>.
545 Status Multiply(DimensionHandle first, DimensionOrConstant second,
546 DimensionHandle* out);
547
548 // Returns in <out> the minimum of <first> and <second>. If either <first> or
549 // <second> is zero the results is zero. Otherwise, if either <first> or
550 // <second> is unknown the results is unknown.
551 Status Min(DimensionHandle first, DimensionOrConstant second,
552 DimensionHandle* out);
553
554 // Returns in <out> the maximum of <first> and <second>. If either <first> or
555 // <second> is unknown the results is unknown.
556 Status Max(DimensionHandle first, DimensionOrConstant second,
557 DimensionHandle* out);
558
construction_status()559 Status construction_status() const { return construction_status_; }
560
561 // Methods to propagate shape and dtype on edges of handles. Handles are the
562 // dtype DT_RESOURCE which can be used to access state stored in a
563 // ResourceManager. When ops (such as variables) consume these handles to
564 // produce tensors they might need to know side-information about the shapes
565 // and dtypes of tensors which can be accessed via the handle. These methods
566 // propagate that information. Output handle dtypes and shapes are ignored if
567 // the output tensor is not of type DT_RESOURCE.
568
569 // Merge the stored shapes and types corresponding to the input handle in
570 // position idx with the specified shapes and types. This requires idx to be
571 // in the [0, num_inputs) range.
572 //
573 // If the merge is successful and any of the new shapes differs from the old
574 // one, or any of the old dtypes was DT_INVALID, store the new shapes and
575 // return true. Return false otherwise.
576 //
577 // See 'MergeInput' function for full details and examples.
578 bool MergeInputHandleShapesAndTypes(
579 int idx,
580 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
581
582 // As MergeInputHandleShapesAndTypes, but for an output.
583 bool MergeOutputHandleShapesAndTypes(
584 int idx,
585 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
586
587 // Relaxes the stored shapes and types corresponding to the input handle in
588 // position idx with the specified shapes and types. This requires idx to be
589 // in the [0, num_inputs) range.
590 //
591 // If the relax is successful (sizes are the same, old dtypes match new ones
592 // or are DT_INVALID), then store the relaxed shapes and return true.
593 // Return false otherwise.
594 //
595 // See 'RelaxInput' function for full details and examples.
596 bool RelaxInputHandleShapesAndMergeTypes(
597 int idx,
598 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
599
600 // As RelaxInputHandleShapesAndTypes, but for an output.
601 bool RelaxOutputHandleShapesAndMergeTypes(
602 int idx,
603 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
604
set_input_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)605 void set_input_handle_shapes_and_types(
606 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
607 input_handle_shapes_and_types_[idx].reset(
608 new std::vector<ShapeAndType>(shapes_and_types));
609 }
610
611 // Returns the output handle shapes and types, for the resource tensor output
612 // at index <idx>. Returns NULL if the shape and types were never set.
output_handle_shapes_and_types(int idx)613 const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
614 return output_handle_shapes_and_types_[idx].get();
615 }
616
617 // Returns the inputs handle shapes and types, for the resource tensor output
618 // at index <idx>. Returns NULL if the shape and types were not available.
input_handle_shapes_and_types(int idx)619 const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
620 return input_handle_shapes_and_types_[idx].get();
621 }
622
set_output_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)623 void set_output_handle_shapes_and_types(
624 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
625 output_handle_shapes_and_types_[idx].reset(
626 new std::vector<ShapeAndType>(shapes_and_types));
627 }
628
629 // Note that shape functions should usually call MakeShapeFromShapeTensor,
630 // as it does more analysis to provide partial shapes.
631 //
632 // Returns in <out> a new shape whose dimension sizes come from tensor <t>.
633 // The tensor must be a 1-dimensional int32 or int64 tensor. If <t> is NULL,
634 // then an unknown shape is returned.
635 Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
636 ShapeHandle* out);
637
graph_def_version()638 int graph_def_version() const { return graph_def_version_; }
639
MergedShapes()640 const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const {
641 return merged_shapes_;
642 }
MergedDims()643 const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims()
644 const {
645 return merged_dims_;
646 }
647
648 // Adds new outputs; useful when mutating the graph.
649 Status ExpandOutputs(int new_output_size);
650
651 private:
652 // Creates and stores shapes for use in InferenceContext.
653 class ShapeManager {
654 public:
655 ShapeManager();
656 ~ShapeManager();
657
658 // Returns a new shape with the given dims. The returned value is owned by
659 // this class.
660 ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
661
662 // Returns a new unknown shape.
663 ShapeHandle UnknownShape();
664
665 // Returns a new dimension of the given size. The returned value
666 // is owned by this class.
MakeDim(DimensionOrConstant d)667 inline DimensionHandle MakeDim(DimensionOrConstant d) {
668 if (d.dim.IsSet()) {
669 return d.dim;
670 } else {
671 all_dims_.push_back(new Dimension(d.val));
672 return all_dims_.back();
673 }
674 }
675
676 private:
677 std::vector<Shape*> all_shapes_; // values are owned.
678 std::vector<Dimension*> all_dims_; // values are owned.
679 };
680
681 friend class ::tensorflow::grappler::GraphProperties;
682
683 // Friend for user-defined function shape inference purposes.
684 friend class ::tensorflow::ShapeRefiner;
685
686 friend class ShapeInferenceTest; // For testing Relax functions.
687 friend class ShapeInferenceTestutil; // For testing shapes.
688
689 // Shared initialization across the two constructors. Remove
690 // once we get rid of one of them.
691 void PreInputInit(const OpDef& op_def,
692 const std::vector<const Tensor*>& input_tensors,
693 const std::vector<ShapeHandle>& input_tensors_as_shapes);
694 void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
695 input_handle_data);
696
697 DimensionHandle GetDimension(const DimensionOrConstant& d);
698
ReturnUnknownShape(ShapeHandle * out)699 Status ReturnUnknownShape(ShapeHandle* out) {
700 *out = UnknownShape();
701 return Status::OK();
702 }
ReturnCreatedShape(const std::vector<DimensionHandle> & dims,ShapeHandle * out)703 Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims,
704 ShapeHandle* out) {
705 *out = MakeShape(dims);
706 return Status::OK();
707 }
708
709 // Adds additional context to the given status.
710 Status AttachContext(const Status& status);
711
712 // Relaxes an existing value <d_old> with a new value <d_new> and returns the
713 // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible
714 // values, returns an error.
715 //
716 // Note that <*out> may be set to <d_old> or <d_new>.
717 void Relax(DimensionHandle d_old, DimensionHandle d_new,
718 DimensionHandle* out);
719 // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the
720 // relaxed shape in <*out>. See 'RelaxInput' function for full details and
721 // examples.
722 void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out);
723
724 // Used to implement MergeInputHandleShapesAndTypes and
725 // MergeOutputHandleShapesAndTypes.
726 bool MergeHandleShapesAndTypes(
727 const std::vector<ShapeAndType>& shapes_and_types,
728 std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
729 // Used to implement RelaxInputHandleShapesAndMergeTypes and
730 // RelaxOutputHandleShapesAndMergeTypes.
731 bool RelaxHandleShapesAndMergeTypes(
732 const std::vector<ShapeAndType>& shapes_and_types,
733 std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
734
735 // Forget all the previous merged shapes and dims.
ForgetMerges()736 void ForgetMerges() {
737 merged_shapes_.clear();
738 merged_dims_.clear();
739 }
740
741 // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor.
742 Status InternalMakeShapeFromTensor(
743 bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
744 ShapeHandle tensor_shape, ShapeHandle* out);
745
746 ShapeManager shape_manager_;
747
748 // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
749 // `shape_manager_`.
750 std::vector<ShapeHandle> inputs_;
751 std::vector<const Tensor*> input_tensors_;
752 std::vector<bool> requested_input_tensor_;
753 std::vector<ShapeHandle> outputs_;
754 // Can have fewer elements than inputs_.
755 std::vector<ShapeHandle> input_tensors_as_shapes_;
756 std::vector<bool> requested_input_tensor_as_partial_shape_;
757
758 // input_handle_shapes_and_types_[i] is the list of shape/type pairs available
759 // through the resource handle passed along input i of the node.
760 //
761 // Values may be NULL.
762 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
763 input_handle_shapes_and_types_;
764
765 // output_handle_shapes_and_types_[i] is the list of shape/type pairs
766 // available through the resource handle passed along output i of the node.
767 //
768 // Values may be NULL.
769 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
770 output_handle_shapes_and_types_;
771
772 const int graph_def_version_;
773 const NodeDef* node_def_;
774 NameRangeMap input_name_map_;
775 NameRangeMap output_name_map_;
776
777 // An error set during construction. TODO(cwhipkey): remove when test
778 // constructor is removed.
779 Status construction_status_;
780
781 // Pair of shape or dim handles that are equivalent, ie that represent the
782 // same underlying shape of dimension. Note that for each pair at least one of
783 // the handles must contain an unknown shape, since we don't keep track of
784 // known shapes or dims here.
785 std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_;
786 std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_;
787
788 TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
789 };
790
791 // -----------------------------------------------------------------------------
792 // Template and inline method implementations, please ignore
793
Dimension()794 inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
Dimension(int64 value)795 inline Dimension::Dimension(int64 value) : value_(value) {
796 DCHECK(value >= 0 || value == InferenceContext::kUnknownDim)
797 << "Dimension must be non-negative or equal to "
798 "InferenceContext::kUnknownDim but got "
799 << value;
800 }
801
Shape()802 inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
Shape(const std::vector<DimensionHandle> & dims)803 inline Shape::Shape(const std::vector<DimensionHandle>& dims)
804 : rank_(dims.size()), dims_(dims) {}
805
DimensionOrConstant(DimensionHandle dim)806 inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim)
807 : dim(dim) {
808 DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension.";
809 }
810
DimensionOrConstant(int64 val)811 inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) {
812 DCHECK(val >= 0 || val == InferenceContext::kUnknownDim)
813 << "Dimension must be non-negative or equal to "
814 "InferenceContext::kUnknownDim but got "
815 << val;
816 }
817
818 template <class T>
GetAttr(StringPiece attr_name,T * value)819 Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
820 return GetNodeAttr(*node_def_, attr_name, value);
821 }
822
823 } // namespace shape_inference
824 } // namespace tensorflow
825
826 #endif // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
827