• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Shape inference is used by the XLA service as the user builds up
17 // computation requests.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
20 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
21 
22 #include <vector>
23 
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace xla {
34 
35 // For a given operation and input shapes, infers what the resulting shape is
36 // for the operation. With this functionality, the user does not need to specify
37 // the expected result type for computations that are built up via the API --
38 // the shape that results from an operation is inferred. Some methods have
39 // overloads for inferring shape at the HLO level.
40 //
41 // TODO(b/73352135): Shape inference does not issue very good error messages, in
42 // part because HloInstruction::ToString() is not available since shape
43 // inference runs before the HloInstruction object is created. We need a
44 // solution for this.
45 class ShapeInference {
46  public:
47   // Infers the shape produced by applying the given unary operation to the
48   // given input shape.
49   static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
50                                            const Shape& shape);
51   static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
52                                            const HloInstruction* operand);
53 
54   // Infers the shape produced by applying the given binary operation to the
55   // given input shapes.
56   static StatusOr<Shape> InferBinaryOpShape(
57       HloOpcode opcode, const Shape& lhs, const Shape& rhs,
58       absl::Span<const int64> broadcast_dimensions);
59   static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode,
60                                             const HloInstruction* lhs,
61                                             const HloInstruction* rhs);
62 
63   // Infers the shape produced by applying the given ternary operation to the
64   // given input shapes.
65   static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode, const Shape& lhs,
66                                              const Shape& rhs,
67                                              const Shape& ehs);
68   static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode,
69                                              const HloInstruction* lhs,
70                                              const HloInstruction* rhs,
71                                              const HloInstruction* ehs);
72 
73   // Infers the shape produced by applying the given variadic operation to the
74   // given input operand shapes.
75   static StatusOr<Shape> InferVariadicOpShape(
76       HloOpcode opcode, absl::Span<const Shape* const> operand_shapes);
77   static StatusOr<Shape> InferVariadicOpShape(
78       HloOpcode opcode, absl::Span<const HloInstruction* const> operands);
79 
80   // Infers the shape produced by applying the given mapping computation shape
81   // to the given operand shapes.
82   static StatusOr<Shape> InferMapShape(
83       absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
84       absl::Span<const int64> dimensions);
85 
86   // Infers the shape produced by InferBatchNormTraining with the given
87   // operands.
88   static StatusOr<Shape> InferBatchNormTrainingShape(const Shape& operand_shape,
89                                                      const Shape& scale_shape,
90                                                      const Shape& offset_shape,
91                                                      int64 feature_index);
92 
93   // Infers the shape produced by InferBatchNormInference with the given
94   // operands.
95   static StatusOr<Shape> InferBatchNormInferenceShape(
96       const Shape& operand_shape, const Shape& scale_shape,
97       const Shape& offset_shape, const Shape& mean_shape,
98       const Shape& variance_shape, int64 feature_index);
99 
100   // Infers the shape produced by InferBatchNormGrad with the given operands.
101   static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape,
102                                                  const Shape& scale_shape,
103                                                  const Shape& mean_shape,
104                                                  const Shape& var_shape,
105                                                  const Shape& output_grad_shape,
106                                                  int64 feature_index);
107 
108   // Infers the shape produced by applying the given convolutional filter (rhs)
109   // to lhs in the way specified by the fields on window. An optional
110   // preferred_element_type can be specified to upcast the element type.
111   static StatusOr<Shape> InferConvolveShape(
112       const Shape& lhs, const Shape& rhs, int64 feature_group_count,
113       int64 batch_group_count, const Window& window,
114       const ConvolutionDimensionNumbers& dimension_numbers,
115       absl::optional<PrimitiveType> preferred_element_type);
116 
117   // Infers the shape produced by the given FFT type on the given operand.
118   static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type,
119                                        absl::Span<const int64> fft_length);
120 
121   // Infers the shape produced by the given triangular solve operation.
122   static StatusOr<Shape> InferTriangularSolveShape(
123       const Shape& a, const Shape& b, const TriangularSolveOptions& options);
124 
125   // Infers the shape produced by the given triangular solve operation.
126   static StatusOr<Shape> InferCholeskyShape(const Shape& a);
127 
128   // Infers the shape produced by an all-gather with the given operand shape,
129   // concat dimension, and shard count.
130   static StatusOr<Shape> InferAllGatherShape(const Shape& operand_shape,
131                                              int64 all_gather_dimension,
132                                              int64 shard_count);
133 
134   // Infers the shape produced by a cross replica sum with the given operand
135   // shapes.
136   static StatusOr<Shape> InferAllReduceShape(
137       absl::Span<const Shape* const> operand_shapes);
138 
139   // Infers final shape of an Alltoall operation that is created by the xla
140   // builder.
141   static StatusOr<Shape> InferAllToAllShape(const Shape& shape,
142                                             int64 split_dimension,
143                                             int64 concat_dimension,
144                                             int64 split_count);
145 
146   // Infers the shape of an HLO all-to-all instruction.
147   static StatusOr<Shape> InferAllToAllTupleShape(
148       absl::Span<const Shape* const> operand_shapes);
149 
150   // Infers the shape of a collective permute operation.
151   static StatusOr<Shape> InferCollectivePermuteShape(const Shape& shape);
152 
153   // Infers the shape produced by applying the given reduction computation
154   // shape to the given input operand shape.
155   //
156   // If pass_index is true, the reduce function is invoked with the element
157   // index as the leading parameter, and the program shape should match
158   // accordingly (or an error will result).
159   static StatusOr<Shape> InferReduceShape(
160       absl::Span<const Shape* const> arg_shapes,
161       absl::Span<const int64> dimensions_to_reduce,
162       const ProgramShape& to_apply);
163 
164   // Infers the shape produced by applying the given computation to the operand
165   // shape with the given window and stride dimensions.
166   static StatusOr<Shape> InferReduceWindowShape(
167       const Shape& operand_shape, const Shape& init_value, const Window& window,
168       const ProgramShape& to_apply_shape);
169   static StatusOr<Shape> InferReduceWindowShape(const Shape& operand_shape,
170                                                 const Shape& init_value,
171                                                 const Window& window);
172   static StatusOr<Shape> InferReduceWindowShape(
173       absl::Span<const Shape* const> operands,
174       absl::Span<const Shape* const> init_values, const Window& window,
175       const ProgramShape& to_apply_shape);
176 
177   static StatusOr<Shape> InferReduceWindowShape(
178       absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
179       const Window& window);
180 
181   // Infers the shape produced by scattering the given source shape to the
182   // selected indices of each window on the operand shape.
183   static StatusOr<Shape> InferSelectAndScatterShape(
184       const Shape& operand_shape, const ProgramShape& select_shape,
185       const Window& window, const Shape& source_shape,
186       const Shape& init_value_shape, const ProgramShape& scatter_shape);
187 
188   // Infers the shape produced by a reverse operation that reverses the order
189   // of the elements in the given dimensions.
190   static StatusOr<Shape> InferReverseShape(const Shape& operand_shape,
191                                            absl::Span<const int64> dimensions);
192 
193   // Infers the shape produced by a slice operation spanning from the starts to
194   // the limits in the original shape's dimensions.
195   //
196   // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16]
197   static StatusOr<Shape> InferSliceShape(const Shape& arg,
198                                          absl::Span<const int64> starts,
199                                          absl::Span<const int64> limits,
200                                          absl::Span<const int64> strides);
201 
202   // Infers the shape produced by a dynamic slice operation of size specified
203   // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'.
204   static StatusOr<Shape> InferDynamicSliceShape(
205       const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
206       absl::Span<const int64> slice_sizes, bool allow_scalar_indices = true);
207 
208   // Infers the shape produced by a dynamic update slice operation based
209   // on the shape of operand and update.
210   static StatusOr<Shape> InferDynamicUpdateSliceShape(
211       const Shape& operand_shape, const Shape& update_shape,
212       absl::Span<const Shape> start_index_shapes,
213       bool allow_scalar_indices = true);
214 
215   // Infers the shape produced by doing a compile-time-constant indexing into
216   // the given input shape. This is essential for operations on tuples, because
217   // it is impossible to infer the type that comes out of the tuple indexing if
218   // it is not a compile time constant.
219   static StatusOr<Shape> InferGetTupleElementShape(const Shape& arg,
220                                                    int64 index);
221 
222   // Infers the shape produced from a while node. condition and body are the
223   // shapes of computations for the condition and the body of a while node, and
224   // init is the shape of data initially passed in to the body as an argument.
225   // The shapes must match; condition: T -> PRED, body: T -> T, init: T
226   static StatusOr<Shape> InferWhileShape(const ProgramShape& condition,
227                                          const ProgramShape& body,
228                                          const Shape& init);
229 
230   // Infers the shape produced by a predicated or indexed conditional operation.
231   static StatusOr<Shape> InferConditionalShape(
232       const Shape& branch_index,
233       absl::Span<const ProgramShape> branch_computations,
234       absl::Span<const Shape> branch_operands);
235 
236   // Infers the shape produced by a broadcast operation.
237   static StatusOr<Shape> InferBroadcastShape(
238       const Shape& operand, absl::Span<const int64> broadcast_sizes);
239 
240   // Checks whether the given parameters can form a broadcast. Returns the same
241   // output_shape if it's legal.
242   static StatusOr<Shape> InferBroadcastShape(
243       const Shape& operand_shape, const Shape& output_shape,
244       absl::Span<const int64> broadcast_dimensions);
245 
246   // Infers the shape produced by a reshape operation from the element type of
247   // its operand and the new dimension sizes specified.
248   static StatusOr<Shape> InferReshapeShape(const Shape& operand,
249                                            absl::Span<const int64> dimensions,
250                                            absl::Span<const int64> new_sizes,
251                                            int64 inferred_dimension);
252 
253   // Infers the shape produced by a dynamic reshape operation from the element
254   // type of its operand and the new dimension sizes specified. The result shape
255   // will have dynamic dimensions as specific in `dim_is_dynamic` and bound
256   // `new_size_bounds`.
257   static StatusOr<Shape> InferDynamicReshapeShape(
258       const Shape& operand, absl::Span<const Shape* const> dim_size_shapes,
259       absl::Span<const int64> new_size_bounds,
260       const std::vector<bool>& dims_are_dynamic);
261 
262   // Infers the shape produced by a transpose operation from the element type of
263   // its operand and its dimensions field.
264   static StatusOr<Shape> InferTransposeShape(
265       const Shape& operand, absl::Span<const int64> dimensions);
266 
267   // Helper that infers the shape produced by performing a concatenate operation
268   // with the given operand shapes.
269   static StatusOr<Shape> InferConcatOpShape(
270       absl::Span<const Shape* const> arg_shapes, int64 dimension);
271 
272   // Helper that validates the given operand shape can be converted to the
273   // target output_shape via a convert instruction -- the requirement is that
274   // the shape is identical except for the element type.
275   static StatusOr<Shape> InferConvertShape(const Shape& operand_shape,
276                                            PrimitiveType new_element_type);
277 
278   // Helper that validates the given operand shape can be bitcast converted to
279   // the target output_shape via a bitcast convert instruction -- the
280   // requirement is that the shape is identical except for the element type and
281   // the element types have identical bit-widths.
282   static StatusOr<Shape> InferBitcastConvertShape(
283       const Shape& operand_shape, PrimitiveType new_element_type);
284 
285   // Helper that validates the input data type for a reduce-precision operation,
286   // and returns the result shape.
287   static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape,
288                                                    const int exponent_bits,
289                                                    const int mantissa_bits);
290 
291   // Helper that infers the shape produced by a pad operation based on the
292   // padding configuration.
293   static StatusOr<Shape> InferPadShape(const Shape& operand_shape,
294                                        const Shape& padding_value_shape,
295                                        const PaddingConfig& padding_config);
296 
297   // Helper that validates the given arg_shapes are compatible with the shape of
298   // the to_apply parameters, and returns the to_apply result shape.
299   static StatusOr<Shape> InferCallShape(
300       absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply);
301 
302   // Helper that infers the shape produced by performing a dot operation with
303   // the given LHS and RHS shapes. An optional preferred_element_type can be
304   // specified to upcast the element type.
305   static StatusOr<Shape> InferDotOpShape(
306       const Shape& lhs, const Shape& rhs,
307       const DotDimensionNumbers& dimension_numbers,
308       absl::optional<PrimitiveType> preferred_element_type);
309 
310   // Helper that infers the shape of the tensor produced by a gather operation
311   // with the given input shape, gather indices shape and gather dimension
312   // numbers.
313   static StatusOr<Shape> InferGatherShape(
314       const Shape& input_shape, const Shape& start_indices_shape,
315       const GatherDimensionNumbers& gather_dim_numbers,
316       absl::Span<const int64> slice_sizes);
317 
318   // Helper that validates the given input shape, scatter indices shape, updates
319   // shape, and scatter dimension numbers that constitute a scatter operation,
320   // and returns the result shape of the scatter operation.
321   static StatusOr<Shape> InferScatterShape(
322       const Shape& operand_shape, const Shape& scatter_indices_shape,
323       const Shape& updates_shape, const ProgramShape& to_apply_shape,
324       const ScatterDimensionNumbers& scatter_dim_numbers);
325 
326   // Helper that validates the given input shape to GetDimensionSize.
327   static StatusOr<Shape> InferGetDimensionSizeShape(const Shape& shape,
328                                                     int64 dimension);
329 
330   // Helper that validates the given input shape to SetDimensionSize.
331   static StatusOr<Shape> InferSetDimensionSizeShape(const Shape& operand_shape,
332                                                     const Shape& val_shape,
333                                                     int64 dimension);
334 
335   // Helper function for creating a Window proto from user-supplied data.
336   // Returns error if the user-supplied data was invalid.
337   static StatusOr<Window> InferWindowFromDimensions(
338       absl::Span<const int64> window_dimensions,
339       absl::Span<const int64> window_strides,
340       absl::Span<const std::pair<int64, int64>> padding,
341       absl::Span<const int64> lhs_dilation,
342       absl::Span<const int64> rhs_dilation);
343 
344  private:
345   // Helper that infers the shape produced by performing an element-wise binary
346   // operation with the given LHS and RHS shapes.
347   // Note: By "element-wise" we mean operations that look at a single element in
348   // the LHS and a single element in the RHS to produce a single output element,
349   // even in the presence of broadcasting of one of the operands over the other.
350   static StatusOr<Shape> InferElementwiseBinaryOpShape(
351       HloOpcode operation, const Shape& lhs, const Shape& rhs,
352       absl::Span<const int64> broadcast_dimensions);
353 
354   // Helper for inferring the shape of Clamp ops.
355   static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand,
356                                          const Shape& max);
357 
358   // Helper for inferring the shape of Select ops.
359   static StatusOr<Shape> InferSelectShape(const Shape& pred,
360                                           const Shape& on_true,
361                                           const Shape& on_false);
362   // Helper for inferring the shape of TupleSelect ops.
363   static StatusOr<Shape> InferTupleSelectShape(const Shape& pred,
364                                                const Shape& on_true,
365                                                const Shape& on_false);
366 
367   // Helper for inferring shapes of binary operations which use degenerate
368   // dimension broadcasting (a dimension of size 1 in one operand is broadcast
369   // up to match the size of the dimension in the other operand).
370   static StatusOr<Shape> InferDegenerateDimensionBroadcastShape(
371       HloOpcode operation, const Shape& lhs, const Shape& rhs);
372 
373   // Helper for inferring shapes of binary operations using "InDim"
374   // broadcasting. This is the broadcasting used in the *InDim binary operations
375   // (for example ComputationBuilder::AddInDim). smaller_shape must be a
376   // lower-rank shape than larger_shape. Returns the shape that the
377   // smaller_shape is broadcast to.
378   static StatusOr<Shape> InferInDimBroadcastShape(
379       const Shape& smaller_shape, const Shape& larger_shape,
380       absl::Span<const int64> broadcast_dimensions);
381 
382   TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference);
383 };
384 
385 }  // namespace xla
386 
387 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
388