1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
17
18 #include <vector>
19
20 #include "absl/memory/memory.h"
21 #include "tensorflow/core/framework/full_type.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/gtl/inlined_vector.h"
27 #include "tensorflow/core/platform/macros.h"
28
29 namespace tensorflow {
30
31 namespace grappler {
32 class GraphProperties;
33 class SymbolicShapeManager;
34 } // namespace grappler
35
36 namespace shape_inference {
37
38 struct DimensionOrConstant;
39 class InferenceContext;
40
41 // Dimension values are accessed through InferenceContext.
42 class Dimension {
43 private:
44 Dimension();
45 Dimension(int64_t value);
~Dimension()46 ~Dimension() {}
47
48 const int64 value_;
49
50 friend class InferenceContext;
51 friend class ShapeManager;
52 TF_DISALLOW_COPY_AND_ASSIGN(Dimension);
53 };
54
55 class DimensionHandle {
56 public:
DimensionHandle()57 DimensionHandle() {}
SameHandle(DimensionHandle d)58 bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; }
Handle()59 std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
60
61 private:
DimensionHandle(const Dimension * dim)62 DimensionHandle(const Dimension* dim) { ptr_ = dim; }
63
64 const Dimension* operator->() const { return ptr_; }
IsSet()65 bool IsSet() const { return ptr_ != nullptr; }
66
67 const Dimension* ptr_ = nullptr;
68
69 friend struct DimensionOrConstant;
70 friend class InferenceContext;
71 friend class ShapeInferenceTest;
72 friend class ShapeInferenceTestutil;
73 friend class ::tensorflow::grappler::GraphProperties;
74 friend class ::tensorflow::grappler::SymbolicShapeManager;
75
76 // Intentionally copyable.
77 };
78
79 // Shape rank and dimensions are accessed through InferenceContext.
80 class Shape {
81 private:
82 Shape();
83 Shape(const std::vector<DimensionHandle>& dims);
~Shape()84 ~Shape() {}
85
86 const int32 rank_;
87 const std::vector<DimensionHandle> dims_;
88
89 friend class InferenceContext;
90 friend class ::tensorflow::grappler::SymbolicShapeManager;
91
92 TF_DISALLOW_COPY_AND_ASSIGN(Shape);
93 };
94
95 class ShapeHandle {
96 public:
ShapeHandle()97 ShapeHandle() {}
SameHandle(ShapeHandle s)98 bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; }
Handle()99 std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
100
101 private:
ShapeHandle(const Shape * shape)102 ShapeHandle(const Shape* shape) { ptr_ = shape; }
103 const Shape* operator->() const { return ptr_; }
IsSet()104 bool IsSet() const { return ptr_ != nullptr; }
105
106 const Shape* ptr_ = nullptr;
107
108 friend class InferenceContext;
109 friend class ShapeInferenceTest;
110 friend class ShapeInferenceTestutil;
111 friend class ::tensorflow::grappler::SymbolicShapeManager;
112
113 // Intentionally copyable.
114 };
115
116 // Struct used to allow functions to take DimensionHandle or a dimension value.
117 // Not meant to be constructed directly.
118 struct DimensionOrConstant {
119 public:
120 // Intentionally not explicit.
121 DimensionOrConstant(DimensionHandle dim);
122
123 // val must be non-negative or InferenceContext::kUnknownDim.
124 DimensionOrConstant(int64_t val);
125
126 // dim takes precedence. If dim != nullptr, val is ignored.
127 DimensionHandle dim;
128 int64 val;
129
130 private:
131 DimensionOrConstant();
132 };
133
134 struct ShapeAndType {
ShapeAndTypeShapeAndType135 ShapeAndType() {}
ShapeAndTypeShapeAndType136 ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
137 // TODO(mdan): Remove dtype from constructor, and use type_ instead.
138 // dtype is kept here for backward compatibiity. Its information should
139 // be redundant to that in type;
ShapeAndTypeShapeAndType140 ShapeAndType(ShapeHandle s, DataType t, FullTypeDef type_)
141 : shape(s), dtype(t), type(type_) {}
142
143 ShapeHandle shape;
144 DataType dtype = DT_INVALID;
145 FullTypeDef type;
146 };
147
148 // Shape inference functions registered on ops in REGISTER_OP implement
149 // their shape functions in terms of this InferenceContext. An InferenceContext
150 // is created by the framework and passed to a shape inference function. The
151 // shape inference function calls functions on the context, and should call
152 // set_output() to set the shape on all outputs.
153 //
154 // To infer shapes for user-defined functions see ShapeRefiner.
155 //
156 // All Shape* and Dimension* returned by functions of InferenceContext are owned
157 // by the InferenceContext.
158 class InferenceContext {
159 public:
160 static constexpr int64_t kUnknownDim = -1;
161 static constexpr int32_t kUnknownRank = -1;
162
163 // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
164 //
165 // Elements of <input_tensors_as_shapes> are used for when a shape function
166 // makes a call to MakeShapeFromShapeTensor; in particular, when the
167 // input_tensors[i] is nullptr but the shape represented by it is partially
168 // known from analysis of the graph.
169 // <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
170 // Values of <input_tensors_as_shapes> do not need to outlive the context.
171 InferenceContext(int graph_def_version, const AttrSlice& attrs,
172 const OpDef& op_def,
173 const std::vector<ShapeHandle>& input_shapes,
174 const std::vector<const Tensor*>& input_tensors,
175 const std::vector<ShapeHandle>& input_tensors_as_shapes,
176 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
177 input_handle_shapes_and_types);
178
179 // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
180 //
181 // Elements of <input_tensors_as_shapes> are used for when a shape
182 // function makes a call to MakeShapeFromShapeTensor; in particular, when
183 // the input_tensors[i] is nullptr but the shape represented by it is
184 // partially known from analysis of the graph. <input_tensors_as_shapes>
185 // can have fewer elements than <input_shapes>. Values of
186 // <input_tensors_as_shapes> do not need to outlive the context.
187 InferenceContext(
188 int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
189 const std::vector<PartialTensorShape>& input_shapes,
190 const std::vector<const Tensor*>& input_tensors,
191 const std::vector<PartialTensorShape>& input_tensors_as_shapes,
192 const std::vector<std::unique_ptr<
193 std::vector<std::pair<PartialTensorShape, DataType>>>>&
194 input_handle_shapes_and_types);
195
196 ~InferenceContext();
197
198 // Runs the shape inference function 'fn' with 'this' as the
199 // argument, returns the status of the inference.
200 //
201 // On error, additional context is provided in the error message.
202 Status Run(
203 const std::function<Status(shape_inference::InferenceContext* c)>& fn);
204
205 // Merge the stored shape of the input in position idx with <shape> according
206 // to the following rules:
207 //
208 // - If the ShapeHandles are the same or <shape> is unknown, there will be no
209 // change. Otherwise if the stored shape is unknown, the new shape will be
210 // <shape>.
211 // - If both shapes are known, then they must have the same rank.
212 // - For any one dimension, if the values for that dimension in both shapes
213 // are known, then the values must match.
214 // - If one shape has equal or more information than the other shape in every
215 // dimension, the new shape will become the shape with more information.
216 // - Example: merging [2,?] and [?,2] results in [2,2]
217 // - Example: [2,2] cannot be merged with [1,2]
218 //
219 // This requires idx to be in the [0, num_inputs) range. If the merge is
220 // successful, return true. Return false otherwise.
MergeInput(int idx,ShapeHandle shape)221 bool MergeInput(int idx, ShapeHandle shape) {
222 ShapeHandle new_shape;
223 if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
224 inputs_[idx] = new_shape;
225 return true;
226 }
227
228 // Relax the stored shape of the input in position idx with <shape> according
229 // to the following rules:
230 //
231 // - If the ShapeHandles are the same then the stored shape will be returned.
232 // - If either of the ShapeHandles are unknown, then a new UnknownShape will
233 // be returned. A new shape must be returned because we cannot claim that
234 // the resulting shape is necessarily the same as either of the input
235 // shapes.
236 // - If the shapes both have known ranks but their ranks are different, a new
237 // UnknownShape will be returned.
238 // - For any one dimension, if the value for that dimension in either of the
239 // shapes is unknown, a new shape will be returned with a new UnknownDim in
240 // that dimension.
241 // - For any one dimension, if the values for that dimension in both shapes
242 // are known but do not match, a new shape will be returned with a new
243 // UnknownDim in that dimension.
244 // - If both shapes have the same known rank and match in every dimension,
245 // the stored shape will be returned.
246 // - Example: relaxing [2,?] and [?,2] results in [?,?]
247 // - Example: relaxing [2,2] and [3,2] results in [?,2]
248 // - Example: relaxing [2,2] with [1,2,3] results in ?
249 //
250 // This requires idx to be in the [0, num_inputs) range. If the relax is
251 // successful and the new shape differs from the old one, store the new
252 // shape and return true. Return false otherwise.
RelaxInput(int idx,ShapeHandle shape)253 bool RelaxInput(int idx, ShapeHandle shape) {
254 ShapeHandle new_shape;
255 Relax(inputs_[idx], shape, &new_shape);
256 if (inputs_[idx].SameHandle(new_shape)) {
257 return false;
258 }
259 inputs_[idx] = new_shape;
260 return true;
261 }
262
SetInput(int idx,ShapeHandle shape)263 void SetInput(int idx, ShapeHandle shape) { inputs_[idx] = shape; }
264
input(int64_t idx)265 ShapeHandle input(int64_t idx) const { return inputs_[idx]; }
266 Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
num_inputs()267 int num_inputs() const { return inputs_.size(); }
268
269 // Returns the input tensor at index <idx>, or nullptr if the input tensor is
270 // not available at the time of shape inference.
input_tensor(int idx)271 const Tensor* input_tensor(int idx) {
272 // Mark that this idx was requested.
273 request_input_tensor(idx);
274 return input_tensors_[idx];
275 }
276
277 // Notifies the shape refiner that the value of the tensor at index <idx>
278 // is needed. The shape refiner tries to statically compute this tensor,
279 // and if successful re-runs the shape function with this tensor available
280 // in the call to 'input_tensor(idx)'.
request_input_tensor(int idx)281 void request_input_tensor(int idx) { requested_input_tensor_[idx] = true; }
282
283 // Returns true iff input_tensor(idx) was called by the shape function.
requested_input_tensor(int idx)284 bool requested_input_tensor(int idx) const {
285 return requested_input_tensor_[idx];
286 }
287
288 // Notifies the shape refiner that the value of the tensor at index <idx>
289 // as a partial shape is needed. The shape refiner tries to statically compute
290 // this, and if successful re-runs the shape function with the
291 // computed PartialTensorShape available in the call to
292 // 'MakeShapeFromShapeTensor(idx, handle)' or
293 // 'MakeShapeFromShapeTensorTreatScalarAsUnknownShape(idx, handle)'.
request_input_tensor_as_partial_shape(int idx)294 void request_input_tensor_as_partial_shape(int idx) {
295 requested_input_tensor_as_partial_shape_[idx] = true;
296 }
297
298 // Returns true if MakeShapeFromInputTensor was called but the constant
299 // input_tensor was not present.
requested_input_tensor_as_partial_shape(int idx)300 bool requested_input_tensor_as_partial_shape(int idx) const {
301 return requested_input_tensor_as_partial_shape_[idx];
302 }
303
set_input_tensors(const std::vector<const Tensor * > & input_tensors)304 void set_input_tensors(const std::vector<const Tensor*>& input_tensors) {
305 input_tensors_ = input_tensors;
306 }
307
set_input_tensors_as_shapes(const std::vector<ShapeHandle> & input_tensors_as_shapes)308 void set_input_tensors_as_shapes(
309 const std::vector<ShapeHandle>& input_tensors_as_shapes) {
310 input_tensors_as_shapes_ = input_tensors_as_shapes;
311 }
312
input_tensors_as_shapes()313 const std::vector<ShapeHandle>& input_tensors_as_shapes() const {
314 return input_tensors_as_shapes_;
315 }
316
output(int64_t idx)317 ShapeHandle output(int64_t idx) const { return outputs_.at(idx); }
set_output(int idx,ShapeHandle shape)318 void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
319 Status set_output(StringPiece output_name,
320 const std::vector<ShapeHandle>& shapes);
321
num_outputs()322 int num_outputs() const { return outputs_.size(); }
output(int idx)323 ShapeHandle output(int idx) const { return outputs_.at(idx); }
324 Status output(StringPiece output_name,
325 std::vector<ShapeHandle>* output) const;
326
327 // Returns the value for attribute named `attr_name`.
GetAttr(StringPiece attr_name,const AttrValue ** attr_value)328 Status GetAttr(StringPiece attr_name, const AttrValue** attr_value) const {
329 return attrs_.Find(attr_name, attr_value);
330 }
GetAttr(StringPiece attr_name)331 const AttrValue* GetAttr(StringPiece attr_name) const {
332 return attrs_.Find(attr_name);
333 }
334
ret_types()335 const FullTypeDef& ret_types() const { return ret_types_; }
336
337 // idx can be negative for an offset from end of dimensions.
338 // idx must be in the range [-1 * s.rank, s.rank).
Dim(ShapeHandle s,int64_t idx)339 DimensionHandle Dim(ShapeHandle s, int64_t idx) {
340 if (!s.Handle() || s->rank_ == kUnknownRank) {
341 return UnknownDim();
342 }
343 return DimKnownRank(s, idx);
344 }
345 // As above, but asserts that the rank of the shape is known.
DimKnownRank(ShapeHandle s,int64_t idx)346 static DimensionHandle DimKnownRank(ShapeHandle s, int64_t idx) {
347 CHECK_NE(s->rank_, kUnknownRank);
348 if (idx < 0) {
349 return s->dims_[s->dims_.size() + idx];
350 }
351 return s->dims_[idx];
352 }
353
Rank(ShapeHandle s)354 static int32 Rank(ShapeHandle s) {
355 return s.IsSet() ? s->rank_ : kUnknownRank;
356 }
RankKnown(ShapeHandle s)357 static bool RankKnown(ShapeHandle s) {
358 return (s.IsSet() && (Rank(s) != kUnknownRank));
359 }
Value(DimensionOrConstant d)360 static inline int64 Value(DimensionOrConstant d) {
361 return d.dim.IsSet() ? d.dim->value_ : d.val;
362 }
ValueKnown(DimensionOrConstant d)363 static inline bool ValueKnown(DimensionOrConstant d) {
364 return Value(d) != kUnknownDim;
365 }
366
367 // Fills the output proto with the shape defined by the handle.
368 // "proto" is expected to be empty prior to the call.
369 void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto);
370
371 // Returns true if the rank and all dimensions of the Shape are known.
372 bool FullyDefined(ShapeHandle s);
373
374 // Returns the total number of elements, or an unknown dimension for an
375 // incomplete shape.
376 DimensionHandle NumElements(ShapeHandle s);
377
378 std::string DebugString(ShapeHandle s);
379 std::string DebugString(DimensionHandle d);
380 std::string DebugString(const ShapeAndType& shape_and_type);
381 std::string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types);
382
383 // Describes the whole context, for debugging purposes.
384 std::string DebugString() const;
385
386 // If <shape> has rank <rank>, or its rank is unknown, return OK and return
387 // the shape with asserted rank in <*out>. Otherwise return an error.
388 //
389 // Note that <*out> may be set to <shape>.
390 Status WithRank(ShapeHandle shape, int64_t rank,
391 ShapeHandle* out) TF_MUST_USE_RESULT;
392 Status WithRankAtLeast(ShapeHandle shape, int64_t rank,
393 ShapeHandle* out) TF_MUST_USE_RESULT;
394 Status WithRankAtMost(ShapeHandle shape, int64_t rank,
395 ShapeHandle* out) TF_MUST_USE_RESULT;
396
397 // If <dim> has value <value>, or its value is unknown, returns OK and returns
398 // the dimension with asserted value in <*out>. Otherwise returns an error.
399 //
400 // Note that <*out> may be set to <dim>.
401 Status WithValue(DimensionHandle dim, int64_t value,
402 DimensionHandle* out) TF_MUST_USE_RESULT;
403
404 // Merges <s0> and <s1> and returns the merged shape in <*out>. See
405 // 'MergeInput' function for full details and examples.
406 Status Merge(ShapeHandle s0, ShapeHandle s1,
407 ShapeHandle* out) TF_MUST_USE_RESULT;
408
409 // Asserts that <s>'s rank >= <prefix>'s rank, and the first
410 // <prefix.rank> dimensions of <s> are compatible with the dimensions of
411 // <prefix>.
412 // Returns the merged results in <*s_out> and <*prefix_out>.
413 Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out,
414 ShapeHandle* prefix_out) TF_MUST_USE_RESULT;
415
416 // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0>
417 // and <d1> have incompatible values, returns an error.
418 //
419 // Note that <*out> may be set to <d0> or <d1>.
420 Status Merge(DimensionHandle d0, DimensionHandle d1,
421 DimensionHandle* out) TF_MUST_USE_RESULT;
422
423 // Returns in <*out> a sub-shape of <s> with dimensions [start:].
424 // <start> can be negative to index from the end of the shape. If <start> >
425 // rank of <s>, then an empty subshape is returned.
426 Status Subshape(ShapeHandle s, int64_t start,
427 ShapeHandle* out) TF_MUST_USE_RESULT;
428
429 // Returns in <*out> a sub-shape of <s>, with dimensions [start:end].
430 // <start> and <end> can be negative, to index from the end of the shape.
431 // <start> and <end> are set to the rank of <s> if > rank of <s>.
432 Status Subshape(ShapeHandle s, int64_t start, int64_t end,
433 ShapeHandle* out) TF_MUST_USE_RESULT;
434
435 // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride].
436 // <start> and <end> can be negative, to index from the end of the shape.
437 // <start> and <end> are set to the rank of <s> if > rank of <s>.
438 // <stride> can be negative, to reverse the <s>.
439 Status Subshape(ShapeHandle s, int64_t start, int64_t end, int64_t stride,
440 ShapeHandle* out) TF_MUST_USE_RESULT;
441
442 // Returns in <*out> the result of appending the dimensions of <s2> to those
443 // of <s1>.
444 Status Concatenate(ShapeHandle s1, ShapeHandle s2,
445 ShapeHandle* out) TF_MUST_USE_RESULT;
446
447 // Returns in <out> the shape from replacing <s.dim[dim_index]> with
448 // <new_dim>.
449 Status ReplaceDim(ShapeHandle s, int64_t dim_index, DimensionHandle new_dim,
450 ShapeHandle* out) TF_MUST_USE_RESULT;
451
452 // Returns a new shape with the given dims. The returned value is owned by
453 // this context.
454 ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
455 ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims);
456
457 // Returns a new unknown shape.
458 ShapeHandle UnknownShape();
459
460 // Returns a shape with specified rank but unknown dims.
461 ShapeHandle UnknownShapeOfRank(int64_t rank);
462
463 // Returns a new shape of zero dimensions.
464 ShapeHandle Scalar();
465
466 // Returns a new shape of one dimension.
467 ShapeHandle Vector(DimensionOrConstant dim);
468
469 // Returns a new shape of two dimensions.
470 ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2);
471
472 // Returns in <out> a new shape whose dimension sizes come from input tensor
473 // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor. If
474 // the input tensor is NULL, then an unknown shape is returned.
475 Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
476
477 // Like the function above, but treats scalar values as unknown
478 // shapes. **NOTE** If the scalar is statically known, its value
479 // must be -1 or an error is returned.
480 Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,
481 ShapeHandle* out);
482
483 // Returns in <out> a new shape corresponding to <proto>.
484 Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
485 ShapeHandle* out);
486
487 // Returns in <out> a new shape corresponding to <partial_shape>.
488 Status MakeShapeFromPartialTensorShape(
489 const PartialTensorShape& partial_shape, ShapeHandle* out);
490
491 // Returns in <out> a new shape corresponding to <shape>.
492 Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out);
493
494 // Returns a new dimension of the given size. The returned value is owned by
495 // this context.
MakeDim(DimensionOrConstant d)496 inline DimensionHandle MakeDim(DimensionOrConstant d) {
497 return shape_manager_.MakeDim(d);
498 }
499
UnknownDim()500 inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
501
502 // Returns in <val> a scalar value from an input tensor <t>. The input tensor
503 // must be a 0-dimensional int32 or int64 tensor. Caller must ensure that the
504 // input tensor is not NULL.
505 Status GetScalarFromTensor(const Tensor* t, int64* val);
506
507 // Returns in <val> a scalar value from a 1D input tensor <t> with int32 or
508 // int64 elements. Caller must ensure that the input tensor is not NULL.
509 Status GetScalarFromTensor(const Tensor* t, int64_t idx, int64* val);
510
511 // Returns a new dimension whose value is given by a scalar input tensor.
512 // The input tensor must be in host memory, since it is dereferenced to get
513 // the value.
514 Status MakeDimForScalarInput(int idx, DimensionHandle* out);
515
516 // Returns a new dimension whose value is given by a scalar input tensor.
517 // This allows for a negative input dimension given the rank of a separate
518 // tensor. This rank can be negative if unknown.
519 // The input tensor must be in host memory, since it is dereferenced to get
520 // the value.
521 Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
522 DimensionHandle* out);
523
524 // Look up the attr being evaluated with name attr_name and set *value to its
525 // value. If no attr with attr_name is found in def(), or the attr does not
526 // have a matching type, a non-ok status will be returned.
527 template <class T>
528 Status GetAttr(StringPiece attr_name, T* value) const;
529
530 // Returns in <out> the result of dividing <dividend> by <divisor>.
531 // Returns an error if <divisor> is not positive or if <evenly_divisible>
532 // and <divisor> does not evenly divide <dividend>.
533 Status Divide(DimensionHandle dividend, DimensionOrConstant divisor,
534 bool evenly_divisible, DimensionHandle* out);
535
536 // Returns in <out> the sum of <first> and <second>.
537 Status Add(DimensionHandle first, DimensionOrConstant second,
538 DimensionHandle* out);
539
540 // Returns in <out> the dimension that is <first> minus <second>.
541 Status Subtract(DimensionHandle first, DimensionOrConstant second,
542 DimensionHandle* out);
543
544 // Returns in <out> the product of <first> and <second>.
545 Status Multiply(DimensionHandle first, DimensionOrConstant second,
546 DimensionHandle* out);
547
548 // Returns in <out> the minimum of <first> and <second>. If either <first> or
549 // <second> is zero the results is zero. Otherwise, if either <first> or
550 // <second> is unknown the results is unknown.
551 Status Min(DimensionHandle first, DimensionOrConstant second,
552 DimensionHandle* out);
553
554 // Returns in <out> the maximum of <first> and <second>. If either <first> or
555 // <second> is unknown the results is unknown.
556 Status Max(DimensionHandle first, DimensionOrConstant second,
557 DimensionHandle* out);
558
construction_status()559 Status construction_status() const { return construction_status_; }
560
561 // Methods to propagate shape and dtype on edges of handles. Handles are the
562 // dtype DT_RESOURCE which can be used to access state stored in a
563 // ResourceManager. When ops (such as variables) consume these handles to
564 // produce tensors they might need to know side-information about the shapes
565 // and dtypes of tensors which can be accessed via the handle. These methods
566 // propagate that information. Output handle dtypes and shapes are ignored if
567 // the output tensor is not of type DT_RESOURCE.
568
569 // Merge the stored shapes and types corresponding to the input handle in
570 // position idx with the specified shapes and types. This requires idx to be
571 // in the [0, num_inputs) range.
572 //
573 // If the merge is successful and any of the new shapes differs from the old
574 // one, or any of the old dtypes was DT_INVALID, store the new shapes and
575 // return true. Return false otherwise.
576 //
577 // See 'MergeInput' function for full details and examples.
578 bool MergeInputHandleShapesAndTypes(
579 int idx,
580 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
581
582 // As MergeInputHandleShapesAndTypes, but for an output.
583 bool MergeOutputHandleShapesAndTypes(
584 int idx,
585 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
586
587 // Relaxes the stored shapes and types corresponding to the input handle in
588 // position idx with the specified shapes and types. This requires idx to be
589 // in the [0, num_inputs) range.
590 //
591 // If the relax is successful (sizes are the same, old dtypes match new ones
592 // or are DT_INVALID), then store the relaxed shapes and return true.
593 // Return false otherwise.
594 //
595 // See 'RelaxInput' function for full details and examples.
596 bool RelaxInputHandleShapesAndMergeTypes(
597 int idx,
598 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
599
600 // As RelaxInputHandleShapesAndTypes, but for an output.
601 bool RelaxOutputHandleShapesAndMergeTypes(
602 int idx,
603 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
604
set_input_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)605 void set_input_handle_shapes_and_types(
606 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
607 input_handle_shapes_and_types_[idx] =
608 absl::make_unique<std::vector<ShapeAndType>>(shapes_and_types);
609 }
610
611 // Returns the output handle shapes and types, for the resource tensor output
612 // at index <idx>. Returns NULL if the shape and types were never set.
output_handle_shapes_and_types(int idx)613 const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
614 return output_handle_shapes_and_types_[idx].get();
615 }
616
617 // Returns the inputs handle shapes and types, for the resource tensor output
618 // at index <idx>. Returns NULL if the shape and types were not available.
input_handle_shapes_and_types(int idx)619 const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
620 return input_handle_shapes_and_types_[idx].get();
621 }
622
set_output_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)623 void set_output_handle_shapes_and_types(
624 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
625 output_handle_shapes_and_types_[idx].reset(
626 new std::vector<ShapeAndType>(shapes_and_types));
627 }
628
629 // Note that shape functions should usually call MakeShapeFromShapeTensor,
630 // as it does more analysis to provide partial shapes.
631 //
632 // Returns in <out> a new shape whose dimension sizes come from tensor <t>.
633 // The tensor must be a 1-dimensional int32 or int64 tensor. If <t> is NULL,
634 // then an unknown shape is returned.
635 Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
636 ShapeHandle* out);
637
graph_def_version()638 int graph_def_version() const { return graph_def_version_; }
639
MergedShapes()640 const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const {
641 return merged_shapes_;
642 }
MergedDims()643 const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims()
644 const {
645 return merged_dims_;
646 }
647
648 // Adds new outputs; useful when mutating the graph.
649 Status ExpandOutputs(int new_output_size);
650
651 private:
652 // Creates and stores shapes for use in InferenceContext.
653 class ShapeManager {
654 public:
655 ShapeManager();
656 ~ShapeManager();
657
658 // Returns a new shape with the given dims. The returned value is owned by
659 // this class.
660 ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
661
662 // Returns a new unknown shape.
663 ShapeHandle UnknownShape();
664
665 // Returns a new dimension of the given size. The returned value
666 // is owned by this class.
MakeDim(DimensionOrConstant d)667 inline DimensionHandle MakeDim(DimensionOrConstant d) {
668 if (d.dim.IsSet()) {
669 return d.dim;
670 } else {
671 all_dims_.push_back(new Dimension(d.val));
672 return all_dims_.back();
673 }
674 }
675
676 private:
677 std::vector<Shape*> all_shapes_; // values are owned.
678 std::vector<Dimension*> all_dims_; // values are owned.
679 };
680
681 friend class ::tensorflow::grappler::GraphProperties;
682
683 friend class ShapeInferenceTest; // For testing Relax functions.
684 friend class ShapeInferenceTestutil; // For testing shapes.
685
686 // Shared initialization across the two constructors. Remove
687 // once we get rid of one of them.
688 void PreInputInit(const OpDef& op_def,
689 const std::vector<const Tensor*>& input_tensors,
690 const std::vector<ShapeHandle>& input_tensors_as_shapes);
691 void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
692 input_handle_data);
693
ReturnUnknownShape(ShapeHandle * out)694 Status ReturnUnknownShape(ShapeHandle* out) {
695 *out = UnknownShape();
696 return Status::OK();
697 }
ReturnCreatedShape(const std::vector<DimensionHandle> & dims,ShapeHandle * out)698 Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims,
699 ShapeHandle* out) {
700 *out = MakeShape(dims);
701 return Status::OK();
702 }
703
704 // Adds additional context to the given status.
705 Status AttachContext(const Status& status);
706
707 // Relaxes an existing value <d_old> with a new value <d_new> and returns the
708 // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible
709 // values, returns an error.
710 //
711 // Note that <*out> may be set to <d_old> or <d_new>.
712 void Relax(DimensionHandle d_old, DimensionHandle d_new,
713 DimensionHandle* out);
714 // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the
715 // relaxed shape in <*out>. See 'RelaxInput' function for full details and
716 // examples.
717 void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out);
718
719 // Used to implement MergeInputHandleShapesAndTypes and
720 // MergeOutputHandleShapesAndTypes.
721 bool MergeHandleShapesAndTypes(
722 const std::vector<ShapeAndType>& shapes_and_types,
723 std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
724 // Used to implement RelaxInputHandleShapesAndMergeTypes and
725 // RelaxOutputHandleShapesAndMergeTypes.
726 bool RelaxHandleShapesAndMergeTypes(
727 const std::vector<ShapeAndType>& shapes_and_types,
728 std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
729
730 // Forget all the previous merged shapes and dims.
ForgetMerges()731 void ForgetMerges() {
732 merged_shapes_.clear();
733 merged_dims_.clear();
734 }
735
736 // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor.
737 Status InternalMakeShapeFromTensor(
738 bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
739 ShapeHandle tensor_shape, ShapeHandle* out);
740
741 ShapeManager shape_manager_;
742
743 // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
744 // `shape_manager_`.
745 std::vector<ShapeHandle> inputs_;
746 std::vector<const Tensor*> input_tensors_;
747 std::vector<bool> requested_input_tensor_;
748 std::vector<ShapeHandle> outputs_;
749 // Can have fewer elements than inputs_.
750 std::vector<ShapeHandle> input_tensors_as_shapes_;
751 std::vector<bool> requested_input_tensor_as_partial_shape_;
752
753 // input_handle_shapes_and_types_[i] is the list of shape/type pairs available
754 // through the resource handle passed along input i of the node.
755 //
756 // Values may be NULL.
757 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
758 input_handle_shapes_and_types_;
759
760 // output_handle_shapes_and_types_[i] is the list of shape/type pairs
761 // available through the resource handle passed along output i of the node.
762 //
763 // Values may be NULL.
764 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
765 output_handle_shapes_and_types_;
766
767 // Return types for the node this context is associated with. This information
768 // is to eventually consolidate all the dtype and shape info, allowing for
769 // output_handle_shapes_and_types_ to be removed.
770 FullTypeDef ret_types_;
771
772 const int graph_def_version_;
773 AttrSlice attrs_;
774 NameRangeMap input_name_map_;
775 NameRangeMap output_name_map_;
776
777 // An error set during construction. TODO(cwhipkey): remove when test
778 // constructor is removed.
779 Status construction_status_;
780
781 // Pair of shape or dim handles that are equivalent, ie that represent the
782 // same underlying shape of dimension. Note that for each pair at least one of
783 // the handles must contain an unknown shape, since we don't keep track of
784 // known shapes or dims here.
785 std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_;
786 std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_;
787
788 TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
789 };
790
791 // -----------------------------------------------------------------------------
792 // Template and inline method implementations, please ignore
793
Dimension()794 inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
Dimension(int64_t value)795 inline Dimension::Dimension(int64_t value) : value_(value) {
796 DCHECK(value >= 0 || value == InferenceContext::kUnknownDim)
797 << "Dimension must be non-negative or equal to "
798 "InferenceContext::kUnknownDim but got "
799 << value;
800 }
801
Shape()802 inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
Shape(const std::vector<DimensionHandle> & dims)803 inline Shape::Shape(const std::vector<DimensionHandle>& dims)
804 : rank_(dims.size()), dims_(dims) {}
805
DimensionOrConstant(DimensionHandle dim)806 inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim)
807 : dim(dim) {
808 DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension.";
809 }
810
DimensionOrConstant(int64_t val)811 inline DimensionOrConstant::DimensionOrConstant(int64_t val) : val(val) {
812 DCHECK(val >= 0 || val == InferenceContext::kUnknownDim)
813 << "Dimension must be non-negative or equal to "
814 "InferenceContext::kUnknownDim but got "
815 << val;
816 }
817
818 template <class T>
GetAttr(StringPiece attr_name,T * value)819 Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
820 return GetNodeAttr(attrs_, attr_name, value);
821 }
822
823 } // namespace shape_inference
824 } // namespace tensorflow
825
826 #endif // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
827