1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
17
18 #include <vector>
19
20 #include "absl/memory/memory.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/gtl/inlined_vector.h"
26 #include "tensorflow/core/platform/macros.h"
27
28 namespace tensorflow {
29
30 namespace grappler {
31 class GraphProperties;
32 class SymbolicShapeManager;
33 } // namespace grappler
34
35 namespace shape_inference {
36
37 struct DimensionOrConstant;
38 class InferenceContext;
39
40 // Dimension values are accessed through InferenceContext.
41 class Dimension {
42 private:
43 Dimension();
44 Dimension(int64 value);
~Dimension()45 ~Dimension() {}
46
47 const int64 value_;
48
49 friend class InferenceContext;
50 friend class ShapeManager;
51 TF_DISALLOW_COPY_AND_ASSIGN(Dimension);
52 };
53
54 class DimensionHandle {
55 public:
DimensionHandle()56 DimensionHandle() {}
SameHandle(DimensionHandle d)57 bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; }
Handle()58 std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
59
60 private:
DimensionHandle(const Dimension * dim)61 DimensionHandle(const Dimension* dim) { ptr_ = dim; }
62
63 const Dimension* operator->() const { return ptr_; }
IsSet()64 bool IsSet() const { return ptr_ != nullptr; }
65
66 const Dimension* ptr_ = nullptr;
67
68 friend struct DimensionOrConstant;
69 friend class InferenceContext;
70 friend class ShapeInferenceTest;
71 friend class ShapeInferenceTestutil;
72 friend class ::tensorflow::grappler::GraphProperties;
73 friend class ::tensorflow::grappler::SymbolicShapeManager;
74
75 // Intentionally copyable.
76 };
77
78 // Shape rank and dimensions are accessed through InferenceContext.
79 class Shape {
80 private:
81 Shape();
82 Shape(const std::vector<DimensionHandle>& dims);
~Shape()83 ~Shape() {}
84
85 const int32 rank_;
86 const std::vector<DimensionHandle> dims_;
87
88 friend class InferenceContext;
89 friend class ::tensorflow::grappler::SymbolicShapeManager;
90
91 TF_DISALLOW_COPY_AND_ASSIGN(Shape);
92 };
93
94 class ShapeHandle {
95 public:
ShapeHandle()96 ShapeHandle() {}
SameHandle(ShapeHandle s)97 bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; }
Handle()98 std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
99
100 private:
ShapeHandle(const Shape * shape)101 ShapeHandle(const Shape* shape) { ptr_ = shape; }
102 const Shape* operator->() const { return ptr_; }
IsSet()103 bool IsSet() const { return ptr_ != nullptr; }
104
105 const Shape* ptr_ = nullptr;
106
107 friend class InferenceContext;
108 friend class ShapeInferenceTest;
109 friend class ShapeInferenceTestutil;
110 friend class ::tensorflow::grappler::SymbolicShapeManager;
111
112 // Intentionally copyable.
113 };
114
115 // Struct used to allow functions to take DimensionHandle or a dimension value.
116 // Not meant to be constructed directly.
117 struct DimensionOrConstant {
118 public:
119 // Intentionally not explicit.
120 DimensionOrConstant(DimensionHandle dim);
121
122 // val must be non-negative or InferenceContext::kUnknownDim.
123 DimensionOrConstant(int64 val);
124
125 // dim takes precedence. If dim != nullptr, val is ignored.
126 DimensionHandle dim;
127 int64 val;
128
129 private:
130 DimensionOrConstant();
131 };
132
133 struct ShapeAndType {
ShapeAndTypeShapeAndType134 ShapeAndType() {}
ShapeAndTypeShapeAndType135 ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
136
137 ShapeHandle shape;
138 DataType dtype = DT_INVALID;
139 };
140
141 // Shape inference functions registered on ops in REGISTER_OP implement
142 // their shape functions in terms of this InferenceContext. An InferenceContext
143 // is created by the framework and passed to a shape inference function. The
144 // shape inference function calls functions on the context, and should call
145 // set_output() to set the shape on all outputs.
146 //
147 // To infer shapes for user-defined functions see ShapeRefiner.
148 //
149 // All Shape* and Dimension* returned by functions of InferenceContext are owned
150 // by the InferenceContext.
151 class InferenceContext {
152 public:
153 static constexpr int64 kUnknownDim = -1;
154 static constexpr int32 kUnknownRank = -1;
155
156 // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
157 //
158 // Elements of <input_tensors_as_shapes> are used for when a shape function
159 // makes a call to MakeShapeFromShapeTensor; in particular, when the
160 // input_tensors[i] is nullptr but the shape represented by it is partially
161 // known from analysis of the graph.
162 // <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
163 // Values of <input_tensors_as_shapes> do not need to outlive the context.
164 InferenceContext(int graph_def_version, const AttrSlice& attrs,
165 const OpDef& op_def,
166 const std::vector<ShapeHandle>& input_shapes,
167 const std::vector<const Tensor*>& input_tensors,
168 const std::vector<ShapeHandle>& input_tensors_as_shapes,
169 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
170 input_handle_shapes_and_types);
171
172 // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
173 //
174 // Elements of <input_tensors_as_shapes> are used for when a shape
175 // function makes a call to MakeShapeFromShapeTensor; in particular, when
176 // the input_tensors[i] is nullptr but the shape represented by it is
177 // partially known from analysis of the graph. <input_tensors_as_shapes>
178 // can have fewer elements than <input_shapes>. Values of
179 // <input_tensors_as_shapes> do not need to outlive the context.
180 InferenceContext(
181 int graph_def_version, const AttrSlice& attrs, const OpDef& op_def,
182 const std::vector<PartialTensorShape>& input_shapes,
183 const std::vector<const Tensor*>& input_tensors,
184 const std::vector<PartialTensorShape>& input_tensors_as_shapes,
185 const std::vector<std::unique_ptr<
186 std::vector<std::pair<PartialTensorShape, DataType>>>>&
187 input_handle_shapes_and_types);
188
189 ~InferenceContext();
190
191 // Runs the shape inference function 'fn' with 'this' as the
192 // argument, returns the status of the inference.
193 //
194 // On error, additional context is provided in the error message.
195 Status Run(
196 const std::function<Status(shape_inference::InferenceContext* c)>& fn);
197
198 // Merge the stored shape of the input in position idx with <shape> according
199 // to the following rules:
200 //
201 // - If the ShapeHandles are the same or <shape> is unknown, there will be no
202 // change. Otherwise if the stored shape is unknown, the new shape will be
203 // <shape>.
204 // - If both shapes are known, then they must have the same rank.
205 // - For any one dimension, if the values for that dimension in both shapes
206 // are known, then the values must match.
207 // - If one shape has equal or more information than the other shape in every
208 // dimension, the new shape will become the shape with more information.
209 // - Example: merging [2,?] and [?,2] results in [2,2]
210 // - Example: [2,2] cannot be merged with [1,2]
211 //
212 // This requires idx to be in the [0, num_inputs) range. If the merge is
213 // successful, return true. Return false otherwise.
MergeInput(int idx,ShapeHandle shape)214 bool MergeInput(int idx, ShapeHandle shape) {
215 ShapeHandle new_shape;
216 if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
217 inputs_[idx] = new_shape;
218 return true;
219 }
220
221 // Relax the stored shape of the input in position idx with <shape> according
222 // to the following rules:
223 //
224 // - If the ShapeHandles are the same then the stored shape will be returned.
225 // - If either of the ShapeHandles are unknown, then a new UnknownShape will
226 // be returned. A new shape must be returned because we cannot claim that
227 // the resulting shape is necessarily the same as either of the input
228 // shapes.
229 // - If the shapes both have known ranks but their ranks are different, a new
230 // UnknownShape will be returned.
231 // - For any one dimension, if the value for that dimension in either of the
232 // shapes is unknown, a new shape will be returned with a new UnknownDim in
233 // that dimension.
234 // - For any one dimension, if the values for that dimension in both shapes
235 // are known but do not match, a new shape will be returned with a new
236 // UnknownDim in that dimension.
237 // - If both shapes have the same known rank and match in every dimension,
238 // the stored shape will be returned.
239 // - Example: relaxing [2,?] and [?,2] results in [?,?]
240 // - Example: relaxing [2,2] and [3,2] results in [?,2]
241 // - Example: relaxing [2,2] with [1,2,3] results in ?
242 //
243 // This requires idx to be in the [0, num_inputs) range. If the relax is
244 // successful and the new shape differs from the old one, store the new
245 // shape and return true. Return false otherwise.
RelaxInput(int idx,ShapeHandle shape)246 bool RelaxInput(int idx, ShapeHandle shape) {
247 ShapeHandle new_shape;
248 Relax(inputs_[idx], shape, &new_shape);
249 if (inputs_[idx].SameHandle(new_shape)) {
250 return false;
251 }
252 inputs_[idx] = new_shape;
253 return true;
254 }
255
SetInput(int idx,ShapeHandle shape)256 void SetInput(int idx, ShapeHandle shape) { inputs_[idx] = shape; }
257
input(int64 idx)258 ShapeHandle input(int64 idx) const { return inputs_[idx]; }
259 Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
num_inputs()260 int num_inputs() const { return inputs_.size(); }
261
262 // Returns the input tensor at index <idx>, or nullptr if the input tensor is
263 // not available at the time of shape inference.
input_tensor(int idx)264 const Tensor* input_tensor(int idx) {
265 // Mark that this idx was requested.
266 requested_input_tensor_[idx] = true;
267 return input_tensors_[idx];
268 }
269
270 // Returns true iff input_tensor(idx) was called by the shape function.
requested_input_tensor(int idx)271 bool requested_input_tensor(int idx) const {
272 return requested_input_tensor_[idx];
273 }
274
275 // Returns true if MakeShapeFromInputTensor was called but the constant
276 // input_tensor was not present.
requested_input_tensor_as_partial_shape(int idx)277 bool requested_input_tensor_as_partial_shape(int idx) const {
278 return requested_input_tensor_as_partial_shape_[idx];
279 }
280
set_input_tensors(const std::vector<const Tensor * > & input_tensors)281 void set_input_tensors(const std::vector<const Tensor*>& input_tensors) {
282 input_tensors_ = input_tensors;
283 }
284
set_input_tensors_as_shapes(const std::vector<ShapeHandle> & input_tensors_as_shapes)285 void set_input_tensors_as_shapes(
286 const std::vector<ShapeHandle>& input_tensors_as_shapes) {
287 input_tensors_as_shapes_ = input_tensors_as_shapes;
288 }
289
input_tensors_as_shapes()290 const std::vector<ShapeHandle>& input_tensors_as_shapes() const {
291 return input_tensors_as_shapes_;
292 }
293
output(int64 idx)294 ShapeHandle output(int64 idx) const { return outputs_.at(idx); }
set_output(int idx,ShapeHandle shape)295 void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
296 Status set_output(StringPiece output_name,
297 const std::vector<ShapeHandle>& shapes);
298
num_outputs()299 int num_outputs() const { return outputs_.size(); }
output(int idx)300 ShapeHandle output(int idx) const { return outputs_.at(idx); }
301 Status output(StringPiece output_name,
302 std::vector<ShapeHandle>* output) const;
303
attrs()304 const AttrSlice& attrs() const { return attrs_; }
305
306 // idx can be negative for an offset from end of dimensions.
307 // idx must be in the range [-1 * s.rank, s.rank).
Dim(ShapeHandle s,int64 idx)308 DimensionHandle Dim(ShapeHandle s, int64 idx) {
309 if (!s.Handle() || s->rank_ == kUnknownRank) {
310 return UnknownDim();
311 }
312 return DimKnownRank(s, idx);
313 }
314 // As above, but asserts that the rank of the shape is known.
DimKnownRank(ShapeHandle s,int64 idx)315 static DimensionHandle DimKnownRank(ShapeHandle s, int64 idx) {
316 CHECK_NE(s->rank_, kUnknownRank);
317 if (idx < 0) {
318 return s->dims_[s->dims_.size() + idx];
319 }
320 return s->dims_[idx];
321 }
322
Rank(ShapeHandle s)323 static int32 Rank(ShapeHandle s) {
324 return s.IsSet() ? s->rank_ : kUnknownRank;
325 }
RankKnown(ShapeHandle s)326 static bool RankKnown(ShapeHandle s) {
327 return (s.IsSet() && (Rank(s) != kUnknownRank));
328 }
Value(DimensionOrConstant d)329 static inline int64 Value(DimensionOrConstant d) {
330 return d.dim.IsSet() ? d.dim->value_ : d.val;
331 }
ValueKnown(DimensionOrConstant d)332 static inline bool ValueKnown(DimensionOrConstant d) {
333 return Value(d) != kUnknownDim;
334 }
335
336 // Fills the output proto with the shape defined by the handle.
337 // "proto" is expected to be empty prior to the call.
338 void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto);
339
340 // Returns true if the rank and all dimensions of the Shape are known.
341 bool FullyDefined(ShapeHandle s);
342
343 // Returns the total number of elements, or an unknown dimension for an
344 // incomplete shape.
345 DimensionHandle NumElements(ShapeHandle s);
346
347 string DebugString(ShapeHandle s);
348 string DebugString(DimensionHandle d);
349 string DebugString(const ShapeAndType& shape_and_type);
350 string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types);
351
352 // Describes the whole context, for debugging purposes.
353 string DebugString() const;
354
355 // If <shape> has rank <rank>, or its rank is unknown, return OK and return
356 // the shape with asserted rank in <*out>. Otherwise return an error.
357 //
358 // Note that <*out> may be set to <shape>.
359 Status WithRank(ShapeHandle shape, int64 rank,
360 ShapeHandle* out) TF_MUST_USE_RESULT;
361 Status WithRankAtLeast(ShapeHandle shape, int64 rank,
362 ShapeHandle* out) TF_MUST_USE_RESULT;
363 Status WithRankAtMost(ShapeHandle shape, int64 rank,
364 ShapeHandle* out) TF_MUST_USE_RESULT;
365
366 // If <dim> has value <value>, or its value is unknown, returns OK and returns
367 // the dimension with asserted value in <*out>. Otherwise returns an error.
368 //
369 // Note that <*out> may be set to <dim>.
370 Status WithValue(DimensionHandle dim, int64 value,
371 DimensionHandle* out) TF_MUST_USE_RESULT;
372
373 // Merges <s0> and <s1> and returns the merged shape in <*out>. See
374 // 'MergeInput' function for full details and examples.
375 Status Merge(ShapeHandle s0, ShapeHandle s1,
376 ShapeHandle* out) TF_MUST_USE_RESULT;
377
378 // Asserts that <s>'s rank >= <prefix>'s rank, and the first
379 // <prefix.rank> dimensions of <s> are compatible with the dimensions of
380 // <prefix>.
381 // Returns the merged results in <*s_out> and <*prefix_out>.
382 Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out,
383 ShapeHandle* prefix_out) TF_MUST_USE_RESULT;
384
385 // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0>
386 // and <d1> have incompatible values, returns an error.
387 //
388 // Note that <*out> may be set to <d0> or <d1>.
389 Status Merge(DimensionHandle d0, DimensionHandle d1,
390 DimensionHandle* out) TF_MUST_USE_RESULT;
391
392 // Returns in <*out> a sub-shape of <s> with dimensions [start:].
393 // <start> can be negative to index from the end of the shape. If <start> >
394 // rank of <s>, then an empty subshape is returned.
395 Status Subshape(ShapeHandle s, int64 start,
396 ShapeHandle* out) TF_MUST_USE_RESULT;
397
398 // Returns in <*out> a sub-shape of <s>, with dimensions [start:end].
399 // <start> and <end> can be negative, to index from the end of the shape.
400 // <start> and <end> are set to the rank of <s> if > rank of <s>.
401 Status Subshape(ShapeHandle s, int64 start, int64 end,
402 ShapeHandle* out) TF_MUST_USE_RESULT;
403
404 // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride].
405 // <start> and <end> can be negative, to index from the end of the shape.
406 // <start> and <end> are set to the rank of <s> if > rank of <s>.
407 // <stride> can be negative, to reverse the <s>.
408 Status Subshape(ShapeHandle s, int64 start, int64 end, int64 stride,
409 ShapeHandle* out) TF_MUST_USE_RESULT;
410
411 // Returns in <*out> the result of appending the dimensions of <s2> to those
412 // of <s1>.
413 Status Concatenate(ShapeHandle s1, ShapeHandle s2,
414 ShapeHandle* out) TF_MUST_USE_RESULT;
415
416 // Returns in <out> the shape from replacing <s.dim[dim_index]> with
417 // <new_dim>.
418 Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim,
419 ShapeHandle* out) TF_MUST_USE_RESULT;
420
421 // Returns a new shape with the given dims. The returned value is owned by
422 // this context.
423 ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
424 ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims);
425
426 // Returns a new unknown shape.
427 ShapeHandle UnknownShape();
428
429 // Returns a shape with specified rank but unknown dims.
430 ShapeHandle UnknownShapeOfRank(int64 rank);
431
432 // Returns a new shape of zero dimensions.
433 ShapeHandle Scalar();
434
435 // Returns a new shape of one dimension.
436 ShapeHandle Vector(DimensionOrConstant dim);
437
438 // Returns a new shape of two dimensions.
439 ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2);
440
441 // Returns in <out> a new shape whose dimension sizes come from input tensor
442 // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor. If
443 // the input tensor is NULL, then an unknown shape is returned.
444 Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
445
446 // Like the function above, but treats scalar values as unknown
447 // shapes. **NOTE** If the scalar is statically known, its value
448 // must be -1 or an error is returned.
449 Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,
450 ShapeHandle* out);
451
452 // Returns in <out> a new shape corresponding to <proto>.
453 Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
454 ShapeHandle* out);
455
456 // Returns in <out> a new shape corresponding to <partial_shape>.
457 Status MakeShapeFromPartialTensorShape(
458 const PartialTensorShape& partial_shape, ShapeHandle* out);
459
460 // Returns in <out> a new shape corresponding to <shape>.
461 Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out);
462
463 // Returns a new dimension of the given size. The returned value is owned by
464 // this context.
MakeDim(DimensionOrConstant d)465 inline DimensionHandle MakeDim(DimensionOrConstant d) {
466 return shape_manager_.MakeDim(d);
467 }
468
UnknownDim()469 inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
470
471 // Returns in <val> a scalar value from an input tensor <t>. The input tensor
472 // must be a 1-dimensional int32 or int64 tensor. Caller must ensure that the
473 // input tensor is not NULL.
474 Status GetScalarFromTensor(const Tensor* t, int64* val);
475
476 // Returns a new dimension whose value is given by a scalar input tensor.
477 // The input tensor must be in host memory, since it is dereferenced to get
478 // the value.
479 Status MakeDimForScalarInput(int idx, DimensionHandle* out);
480
481 // Returns a new dimension whose value is given by a scalar input tensor.
482 // This allows for a negative input dimension given the rank of a separate
483 // tensor. This rank can be negative if unknown.
484 // The input tensor must be in host memory, since it is dereferenced to get
485 // the value.
486 Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
487 DimensionHandle* out);
488
489 // Look up the attr being evaluated with name attr_name and set *value to its
490 // value. If no attr with attr_name is found in def(), or the attr does not
491 // have a matching type, a non-ok status will be returned.
492 template <class T>
493 Status GetAttr(StringPiece attr_name, T* value) const;
494
495 // Returns in <out> the result of dividing <dividend> by <divisor>.
496 // Returns an error if <divisor> is not positive or if <evenly_divisible>
497 // and <divisor> does not evenly divide <dividend>.
498 Status Divide(DimensionHandle dividend, DimensionOrConstant divisor,
499 bool evenly_divisible, DimensionHandle* out);
500
501 // Returns in <out> the sum of <first> and <second>.
502 Status Add(DimensionHandle first, DimensionOrConstant second,
503 DimensionHandle* out);
504
505 // Returns in <out> the dimension that is <first> minus <second>.
506 Status Subtract(DimensionHandle first, DimensionOrConstant second,
507 DimensionHandle* out);
508
509 // Returns in <out> the product of <first> and <second>.
510 Status Multiply(DimensionHandle first, DimensionOrConstant second,
511 DimensionHandle* out);
512
513 // Returns in <out> the minimum of <first> and <second>. If either <first> or
514 // <second> is zero the results is zero. Otherwise, if either <first> or
515 // <second> is unknown the results is unknown.
516 Status Min(DimensionHandle first, DimensionOrConstant second,
517 DimensionHandle* out);
518
519 // Returns in <out> the maximum of <first> and <second>. If either <first> or
520 // <second> is unknown the results is unknown.
521 Status Max(DimensionHandle first, DimensionOrConstant second,
522 DimensionHandle* out);
523
construction_status()524 Status construction_status() const { return construction_status_; }
525
526 // Methods to propagate shape and dtype on edges of handles. Handles are the
527 // dtype DT_RESOURCE which can be used to access state stored in a
528 // ResourceManager. When ops (such as variables) consume these handles to
529 // produce tensors they might need to know side-information about the shapes
530 // and dtypes of tensors which can be accessed via the handle. These methods
531 // propagate that information. Output handle dtypes and shapes are ignored if
532 // the output tensor is not of type DT_RESOURCE.
533
534 // Merge the stored shapes and types corresponding to the input handle in
535 // position idx with the specified shapes and types. This requires idx to be
536 // in the [0, num_inputs) range.
537 //
538 // If the merge is successful and any of the new shapes differs from the old
539 // one, or any of the old dtypes was DT_INVALID, store the new shapes and
540 // return true. Return false otherwise.
541 //
542 // See 'MergeInput' function for full details and examples.
543 bool MergeInputHandleShapesAndTypes(
544 int idx,
545 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
546
547 // As MergeInputHandleShapesAndTypes, but for an output.
548 bool MergeOutputHandleShapesAndTypes(
549 int idx,
550 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
551
552 // Relaxes the stored shapes and types corresponding to the input handle in
553 // position idx with the specified shapes and types. This requires idx to be
554 // in the [0, num_inputs) range.
555 //
556 // If the relax is successful (sizes are the same, old dtypes match new ones
557 // or are DT_INVALID), then store the relaxed shapes and return true.
558 // Return false otherwise.
559 //
560 // See 'RelaxInput' function for full details and examples.
561 bool RelaxInputHandleShapesAndMergeTypes(
562 int idx,
563 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
564
565 // As RelaxInputHandleShapesAndTypes, but for an output.
566 bool RelaxOutputHandleShapesAndMergeTypes(
567 int idx,
568 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
569
set_input_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)570 void set_input_handle_shapes_and_types(
571 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
572 input_handle_shapes_and_types_[idx] =
573 absl::make_unique<std::vector<ShapeAndType>>(shapes_and_types);
574 }
575
576 // Returns the output handle shapes and types, for the resource tensor output
577 // at index <idx>. Returns NULL if the shape and types were never set.
output_handle_shapes_and_types(int idx)578 const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
579 return output_handle_shapes_and_types_[idx].get();
580 }
581
582 // Returns the inputs handle shapes and types, for the resource tensor output
583 // at index <idx>. Returns NULL if the shape and types were not available.
input_handle_shapes_and_types(int idx)584 const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
585 return input_handle_shapes_and_types_[idx].get();
586 }
587
set_output_handle_shapes_and_types(int idx,const std::vector<ShapeAndType> & shapes_and_types)588 void set_output_handle_shapes_and_types(
589 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
590 output_handle_shapes_and_types_[idx].reset(
591 new std::vector<ShapeAndType>(shapes_and_types));
592 }
593
594 // Note that shape functions should usually call MakeShapeFromShapeTensor,
595 // as it does more analysis to provide partial shapes.
596 //
597 // Returns in <out> a new shape whose dimension sizes come from tensor <t>.
598 // The tensor must be a 1-dimensional int32 or int64 tensor. If <t> is NULL,
599 // then an unknown shape is returned.
600 Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
601 ShapeHandle* out);
602
graph_def_version()603 int graph_def_version() const { return graph_def_version_; }
604
MergedShapes()605 const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const {
606 return merged_shapes_;
607 }
MergedDims()608 const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims()
609 const {
610 return merged_dims_;
611 }
612
613 // Adds new outputs; useful when mutating the graph.
614 Status ExpandOutputs(int new_output_size);
615
616 private:
617 // Creates and stores shapes for use in InferenceContext.
618 class ShapeManager {
619 public:
620 ShapeManager();
621 ~ShapeManager();
622
623 // Returns a new shape with the given dims. The returned value is owned by
624 // this class.
625 ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
626
627 // Returns a new unknown shape.
628 ShapeHandle UnknownShape();
629
630 // Returns a new dimension of the given size. The returned value
631 // is owned by this class.
MakeDim(DimensionOrConstant d)632 inline DimensionHandle MakeDim(DimensionOrConstant d) {
633 if (d.dim.IsSet()) {
634 return d.dim;
635 } else {
636 all_dims_.push_back(new Dimension(d.val));
637 return all_dims_.back();
638 }
639 }
640
641 private:
642 std::vector<Shape*> all_shapes_; // values are owned.
643 std::vector<Dimension*> all_dims_; // values are owned.
644 };
645
646 friend class ::tensorflow::grappler::GraphProperties;
647
648 friend class ShapeInferenceTest; // For testing Relax functions.
649 friend class ShapeInferenceTestutil; // For testing shapes.
650
651 // Shared initialization across the two constructors. Remove
652 // once we get rid of one of them.
653 void PreInputInit(const OpDef& op_def,
654 const std::vector<const Tensor*>& input_tensors,
655 const std::vector<ShapeHandle>& input_tensors_as_shapes);
656 void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
657 input_handle_data);
658
ReturnUnknownShape(ShapeHandle * out)659 Status ReturnUnknownShape(ShapeHandle* out) {
660 *out = UnknownShape();
661 return Status::OK();
662 }
ReturnCreatedShape(const std::vector<DimensionHandle> & dims,ShapeHandle * out)663 Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims,
664 ShapeHandle* out) {
665 *out = MakeShape(dims);
666 return Status::OK();
667 }
668
669 // Adds additional context to the given status.
670 Status AttachContext(const Status& status);
671
672 // Relaxes an existing value <d_old> with a new value <d_new> and returns the
673 // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible
674 // values, returns an error.
675 //
676 // Note that <*out> may be set to <d_old> or <d_new>.
677 void Relax(DimensionHandle d_old, DimensionHandle d_new,
678 DimensionHandle* out);
679 // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the
680 // relaxed shape in <*out>. See 'RelaxInput' function for full details and
681 // examples.
682 void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out);
683
684 // Used to implement MergeInputHandleShapesAndTypes and
685 // MergeOutputHandleShapesAndTypes.
686 bool MergeHandleShapesAndTypes(
687 const std::vector<ShapeAndType>& shapes_and_types,
688 std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
689 // Used to implement RelaxInputHandleShapesAndMergeTypes and
690 // RelaxOutputHandleShapesAndMergeTypes.
691 bool RelaxHandleShapesAndMergeTypes(
692 const std::vector<ShapeAndType>& shapes_and_types,
693 std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
694
695 // Forget all the previous merged shapes and dims.
ForgetMerges()696 void ForgetMerges() {
697 merged_shapes_.clear();
698 merged_dims_.clear();
699 }
700
701 // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor.
702 Status InternalMakeShapeFromTensor(
703 bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
704 ShapeHandle tensor_shape, ShapeHandle* out);
705
706 ShapeManager shape_manager_;
707
708 // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
709 // `shape_manager_`.
710 std::vector<ShapeHandle> inputs_;
711 std::vector<const Tensor*> input_tensors_;
712 std::vector<bool> requested_input_tensor_;
713 std::vector<ShapeHandle> outputs_;
714 // Can have fewer elements than inputs_.
715 std::vector<ShapeHandle> input_tensors_as_shapes_;
716 std::vector<bool> requested_input_tensor_as_partial_shape_;
717
718 // input_handle_shapes_and_types_[i] is the list of shape/type pairs available
719 // through the resource handle passed along input i of the node.
720 //
721 // Values may be NULL.
722 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
723 input_handle_shapes_and_types_;
724
725 // output_handle_shapes_and_types_[i] is the list of shape/type pairs
726 // available through the resource handle passed along output i of the node.
727 //
728 // Values may be NULL.
729 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
730 output_handle_shapes_and_types_;
731
732 const int graph_def_version_;
733 AttrSlice attrs_;
734 NameRangeMap input_name_map_;
735 NameRangeMap output_name_map_;
736
737 // An error set during construction. TODO(cwhipkey): remove when test
738 // constructor is removed.
739 Status construction_status_;
740
741 // Pair of shape or dim handles that are equivalent, ie that represent the
742 // same underlying shape of dimension. Note that for each pair at least one of
743 // the handles must contain an unknown shape, since we don't keep track of
744 // known shapes or dims here.
745 std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_;
746 std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_;
747
748 TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
749 };
750
751 // -----------------------------------------------------------------------------
752 // Template and inline method implementations, please ignore
753
Dimension()754 inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
Dimension(int64 value)755 inline Dimension::Dimension(int64 value) : value_(value) {
756 DCHECK(value >= 0 || value == InferenceContext::kUnknownDim)
757 << "Dimension must be non-negative or equal to "
758 "InferenceContext::kUnknownDim but got "
759 << value;
760 }
761
Shape()762 inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
Shape(const std::vector<DimensionHandle> & dims)763 inline Shape::Shape(const std::vector<DimensionHandle>& dims)
764 : rank_(dims.size()), dims_(dims) {}
765
DimensionOrConstant(DimensionHandle dim)766 inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim)
767 : dim(dim) {
768 DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension.";
769 }
770
DimensionOrConstant(int64 val)771 inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) {
772 DCHECK(val >= 0 || val == InferenceContext::kUnknownDim)
773 << "Dimension must be non-negative or equal to "
774 "InferenceContext::kUnknownDim but got "
775 << val;
776 }
777
778 template <class T>
GetAttr(StringPiece attr_name,T * value)779 Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
780 return GetNodeAttr(attrs_, attr_name, value);
781 }
782
783 } // namespace shape_inference
784 } // namespace tensorflow
785
786 #endif // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
787