• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Shape inference is used by the XLA service as the user builds up
17 // computation requests.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
20 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
21 
22 #include <vector>
23 
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace xla {
34 
35 // For a given operation and input shapes, infers what the resulting shape is
36 // for the operation. With this functionality, the user does not need to specify
37 // the expected result type for computations that are built up via the API --
38 // the shape that results from an operation is inferred. Some methods have
39 // overloads for inferring shape at the HLO level.
40 //
41 // TODO(b/73352135): Shape inference does not issue very good error messages, in
42 // part because HloInstruction::ToString() is not available since shape
43 // inference runs before the HloInstruction object is created. We need a
44 // solution for this.
45 class ShapeInference {
46  public:
47   // Infers the shape produced by applying the given unary operation to the
48   // given input shape.
49   static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
50                                            const Shape& shape);
51   static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
52                                            const HloInstruction* operand);
53 
54   // Infers the shape produced by applying the given binary operation to the
55   // given input shapes.
56   static StatusOr<Shape> InferBinaryOpShape(
57       HloOpcode opcode, const Shape& lhs, const Shape& rhs,
58       absl::Span<const int64> broadcast_dimensions);
59   static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode,
60                                             const HloInstruction* lhs,
61                                             const HloInstruction* rhs);
62 
63   // Infers the shape produced by applying the given ternary operation to the
64   // given input shapes.
65   static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode, const Shape& lhs,
66                                              const Shape& rhs,
67                                              const Shape& ehs);
68   static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode,
69                                              const HloInstruction* lhs,
70                                              const HloInstruction* rhs,
71                                              const HloInstruction* ehs);
72 
73   // Infers the shape produced by applying the given variadic operation to the
74   // given input operand shapes.
75   static StatusOr<Shape> InferVariadicOpShape(
76       HloOpcode opcode, absl::Span<const Shape* const> operand_shapes);
77   static StatusOr<Shape> InferVariadicOpShape(
78       HloOpcode opcode, absl::Span<const HloInstruction* const> operands);
79 
80   // Infers the shape produced by applying the given mapping computation shape
81   // to the given operand shapes.
82   static StatusOr<Shape> InferMapShape(
83       absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
84       absl::Span<const int64> dimensions);
85 
86   // Infers the shape produced by InferBatchNormTraining with the given
87   // operands.
88   static StatusOr<Shape> InferBatchNormTrainingShape(const Shape& operand_shape,
89                                                      const Shape& scale_shape,
90                                                      const Shape& offset_shape,
91                                                      int64_t feature_index);
92 
93   // Infers the shape produced by InferBatchNormInference with the given
94   // operands.
95   static StatusOr<Shape> InferBatchNormInferenceShape(
96       const Shape& operand_shape, const Shape& scale_shape,
97       const Shape& offset_shape, const Shape& mean_shape,
98       const Shape& variance_shape, int64_t feature_index);
99 
100   // Infers the shape produced by InferBatchNormGrad with the given operands.
101   static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape,
102                                                  const Shape& scale_shape,
103                                                  const Shape& mean_shape,
104                                                  const Shape& var_shape,
105                                                  const Shape& output_grad_shape,
106                                                  int64_t feature_index);
107 
108   // Infers the shape produced by applying the given convolutional filter (rhs)
109   // to lhs in the way specified by the fields on window. An optional
110   // preferred_element_type can be specified to upcast the element type.
111   static StatusOr<Shape> InferConvolveShape(
112       const Shape& lhs, const Shape& rhs, int64_t feature_group_count,
113       int64_t batch_group_count, const Window& window,
114       const ConvolutionDimensionNumbers& dimension_numbers,
115       absl::optional<PrimitiveType> preferred_element_type);
116 
117   // Infers the shape produced by the given FFT type on the given operand.
118   static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type,
119                                        absl::Span<const int64> fft_length);
120 
121   // Infers the shape produced by the given triangular solve operation.
122   static StatusOr<Shape> InferTriangularSolveShape(
123       const Shape& a, const Shape& b, const TriangularSolveOptions& options);
124 
125   // Infers the shape produced by the given triangular solve operation.
126   static StatusOr<Shape> InferCholeskyShape(const Shape& a);
127 
128   // Infers the shape produced by an all-gather with the given operand shape,
129   // concat dimension, and shard count.
130   static StatusOr<Shape> InferAllGatherShape(
131       absl::Span<const Shape* const> operand_shapes,
132       int64_t all_gather_dimension, int64_t shard_count);
133 
134   // Infers the shape produced by an all-gather-start with the given operand
135   // shape, concat dimension, and shard count.
136   static StatusOr<Shape> InferAllGatherStartShape(
137       absl::Span<const Shape* const> operand_shapes,
138       int64_t all_gather_dimension, int64_t shard_count);
139 
140   // Infers the shape produced by an all-gather-done given a certain
141   // all-gather-start shape.
142   static StatusOr<Shape> InferAllGatherDoneShape(
143       const Shape& all_gather_start_shape);
144 
145   // Infers the shape produced by a cross replica sum with the given operand
146   // shapes.
147   static StatusOr<Shape> InferAllReduceShape(
148       absl::Span<const Shape* const> operand_shapes);
149 
150   // Infers the shape produced by a reduce-scatter with the given operand
151   // shape, scatter dimension, and shard count.
152   static StatusOr<Shape> InferReduceScatterShape(
153       absl::Span<const Shape* const> operand_shapes, int64_t scatter_dimension,
154       int64_t shard_count);
155 
156   // Infers the shape produced by a cross replica sum start.
157   static StatusOr<Shape> InferAllReduceStartShape(
158       absl::Span<const Shape* const> operand_shapes);
159 
160   // Infers the shape produced by a cross replica sum done.
161   static StatusOr<Shape> InferAllReduceDoneShape(const Shape& operand_shape);
162 
163   // Infers final shape of an Alltoall operation that is created by the xla
164   // builder.
165   static StatusOr<Shape> InferAllToAllShape(const Shape& shape,
166                                             int64_t split_dimension,
167                                             int64_t concat_dimension,
168                                             int64_t split_count);
169 
170   // Infers the shape of an HLO all-to-all instruction.
171   static StatusOr<Shape> InferAllToAllTupleShape(
172       absl::Span<const Shape* const> operand_shapes);
173 
174   // Infers the shape of a collective permute operation.
175   static StatusOr<Shape> InferCollectivePermuteShape(
176       absl::Span<const Shape* const> operand_shapes);
177 
178   // Infers the shape of a collective permute start operation.
179   static StatusOr<Shape> InferCollectivePermuteStartShape(
180       absl::Span<const Shape* const> operand_shapes);
181 
182   // Infers the shape of a collective permute operation.
183   static StatusOr<Shape> InferCollectivePermuteDoneShape(
184       const Shape& operand_shape);
185 
186   // Infers the shape produced by applying the given reduction computation
187   // shape to the given input operand shape.
188   //
189   // If pass_index is true, the reduce function is invoked with the element
190   // index as the leading parameter, and the program shape should match
191   // accordingly (or an error will result).
192   static StatusOr<Shape> InferReduceShape(
193       absl::Span<const Shape* const> arg_shapes,
194       absl::Span<const int64> dimensions_to_reduce,
195       const ProgramShape& to_apply);
196 
197   // Infers the shape produced by applying the given computation to the operand
198   // shape with the given window and stride dimensions.
199   static StatusOr<Shape> InferReduceWindowShape(
200       const Shape& operand_shape, const Shape& init_value, const Window& window,
201       const ProgramShape& to_apply_shape);
202   static StatusOr<Shape> InferReduceWindowShape(const Shape& operand_shape,
203                                                 const Shape& init_value,
204                                                 const Window& window);
205   static StatusOr<Shape> InferReduceWindowShape(
206       absl::Span<const Shape* const> operands,
207       absl::Span<const Shape* const> init_values, const Window& window,
208       const ProgramShape& to_apply_shape);
209 
210   static StatusOr<Shape> InferReduceWindowShape(
211       absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
212       const Window& window);
213 
214   // Infers the shape produced by scattering the given source shape to the
215   // selected indices of each window on the operand shape.
216   static StatusOr<Shape> InferSelectAndScatterShape(
217       const Shape& operand_shape, const ProgramShape& select_shape,
218       const Window& window, const Shape& source_shape,
219       const Shape& init_value_shape, const ProgramShape& scatter_shape);
220 
221   // Infers the shape produced by a reverse operation that reverses the order
222   // of the elements in the given dimensions.
223   static StatusOr<Shape> InferReverseShape(const Shape& operand_shape,
224                                            absl::Span<const int64> dimensions);
225 
226   // Infers the shape produced by a slice operation spanning from the starts to
227   // the limits in the original shape's dimensions.
228   //
229   // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16]
230   static StatusOr<Shape> InferSliceShape(const Shape& arg,
231                                          absl::Span<const int64> starts,
232                                          absl::Span<const int64> limits,
233                                          absl::Span<const int64> strides);
234 
235   // Infers the shape produced by a dynamic slice operation of size specified
236   // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'.
237   static StatusOr<Shape> InferDynamicSliceShape(
238       const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
239       absl::Span<const int64> slice_sizes, bool allow_scalar_indices = true);
240 
241   // Infers the shape produced by a dynamic update slice operation based
242   // on the shape of operand and update.
243   static StatusOr<Shape> InferDynamicUpdateSliceShape(
244       const Shape& operand_shape, const Shape& update_shape,
245       absl::Span<const Shape> start_index_shapes,
246       bool allow_scalar_indices = true);
247 
248   // Infers the shape produced by doing a compile-time-constant indexing into
249   // the given input shape. This is essential for operations on tuples, because
250   // it is impossible to infer the type that comes out of the tuple indexing if
251   // it is not a compile time constant.
252   static StatusOr<Shape> InferGetTupleElementShape(const Shape& arg,
253                                                    int64_t index);
254 
255   // Infers the shape produced from a while node. condition and body are the
256   // shapes of computations for the condition and the body of a while node, and
257   // init is the shape of data initially passed in to the body as an argument.
258   // The shapes must match; condition: T -> PRED, body: T -> T, init: T
259   static StatusOr<Shape> InferWhileShape(const ProgramShape& condition,
260                                          const ProgramShape& body,
261                                          const Shape& init);
262 
263   // Infers the shape produced by a predicated or indexed conditional operation.
264   static StatusOr<Shape> InferConditionalShape(
265       const Shape& branch_index,
266       absl::Span<const ProgramShape> branch_computations,
267       absl::Span<const Shape> branch_operands);
268 
269   // Infers the shape produced by a broadcast operation.
270   static StatusOr<Shape> InferBroadcastShape(
271       const Shape& operand, absl::Span<const int64> broadcast_sizes);
272 
273   // Checks whether the given parameters can form a broadcast. Returns the same
274   // output_shape if it's legal.
275   static StatusOr<Shape> InferBroadcastShape(
276       const Shape& operand_shape, const Shape& output_shape,
277       absl::Span<const int64> broadcast_dimensions);
278 
279   // Infers the shape produced by a reshape operation from the element type of
280   // its operand and the new dimension sizes specified.
281   static StatusOr<Shape> InferReshapeShape(const Shape& operand,
282                                            absl::Span<const int64> dimensions,
283                                            absl::Span<const int64> new_sizes,
284                                            int64_t inferred_dimension);
285 
286   // Infers the shape produced by a dynamic reshape operation from the element
287   // type of its operand and the new dimension sizes specified. The result shape
288   // will have dynamic dimensions as specific in `dim_is_dynamic` and bound
289   // `new_size_bounds`.
290   static StatusOr<Shape> InferDynamicReshapeShape(
291       const Shape& operand, absl::Span<const Shape* const> dim_size_shapes,
292       absl::Span<const int64> new_size_bounds,
293       const std::vector<bool>& dims_are_dynamic);
294 
295   // Infers the shape produced by a transpose operation from the element type of
296   // its operand and its dimensions field.
297   static StatusOr<Shape> InferTransposeShape(
298       const Shape& operand, absl::Span<const int64> dimensions);
299 
300   // Helper that infers the shape produced by performing a concatenate operation
301   // with the given operand shapes.
302   static StatusOr<Shape> InferConcatOpShape(
303       absl::Span<const Shape* const> arg_shapes, int64_t dimension);
304 
305   // Helper that validates the given operand shape can be converted to the
306   // target output_shape via a convert instruction -- the requirement is that
307   // the shape is identical except for the element type.
308   static StatusOr<Shape> InferConvertShape(const Shape& operand_shape,
309                                            PrimitiveType new_element_type);
310 
311   // Helper that validates the given operand shape can be bitcast converted to
312   // the target output_shape via a bitcast convert instruction -- the
313   // requirement is that the shape is identical except for the element type and
314   // the element types have identical bit-widths.
315   static StatusOr<Shape> InferBitcastConvertShape(
316       const Shape& operand_shape, PrimitiveType new_element_type);
317 
318   // Helper that validates the input data type for a reduce-precision operation,
319   // and returns the result shape.
320   static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape,
321                                                    const int exponent_bits,
322                                                    const int mantissa_bits);
323 
324   // Helper that infers the shape produced by a pad operation based on the
325   // padding configuration.
326   static StatusOr<Shape> InferPadShape(const Shape& operand_shape,
327                                        const Shape& padding_value_shape,
328                                        const PaddingConfig& padding_config);
329 
330   // Helper that validates the given arg_shapes are compatible with the shape of
331   // the to_apply parameters, and returns the to_apply result shape.
332   static StatusOr<Shape> InferCallShape(
333       absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply);
334 
335   // Helper that infers the shape produced by performing a dot operation with
336   // the given LHS and RHS shapes. An optional preferred_element_type can be
337   // specified to upcast the element type.
338   static StatusOr<Shape> InferDotOpShape(
339       const Shape& lhs, const Shape& rhs,
340       const DotDimensionNumbers& dimension_numbers,
341       absl::optional<PrimitiveType> preferred_element_type);
342 
343   // Helper that infers the shape of the tensor produced by a gather operation
344   // with the given input shape, gather indices shape and gather dimension
345   // numbers.
346   static StatusOr<Shape> InferGatherShape(
347       const Shape& input_shape, const Shape& start_indices_shape,
348       const GatherDimensionNumbers& gather_dim_numbers,
349       absl::Span<const int64> slice_sizes);
350 
351   // Helper that validates the given input shape, scatter indices shape, updates
352   // shape, and scatter dimension numbers that constitute a scatter operation,
353   // and returns the result shape of the scatter operation.
354   static StatusOr<Shape> InferScatterShape(
355       const Shape& operand_shape, const Shape& scatter_indices_shape,
356       const Shape& updates_shape, const ProgramShape& to_apply_shape,
357       const ScatterDimensionNumbers& scatter_dim_numbers);
358 
359   // Helper that validates the given input shape to GetDimensionSize.
360   static StatusOr<Shape> InferGetDimensionSizeShape(const Shape& shape,
361                                                     int64_t dimension);
362 
363   // Helper that validates the given input shape to SetDimensionSize.
364   static StatusOr<Shape> InferSetDimensionSizeShape(const Shape& operand_shape,
365                                                     const Shape& val_shape,
366                                                     int64_t dimension);
367 
368   // Helper function for creating a Window proto from user-supplied data.
369   // Returns error if the user-supplied data was invalid.
370   static StatusOr<Window> InferWindowFromDimensions(
371       absl::Span<const int64> window_dimensions,
372       absl::Span<const int64> window_strides,
373       absl::Span<const std::pair<int64, int64>> padding,
374       absl::Span<const int64> lhs_dilation,
375       absl::Span<const int64> rhs_dilation);
376 
377  private:
378   // Helper that infers the shape produced by performing an element-wise binary
379   // operation with the given LHS and RHS shapes.
380   // Note: By "element-wise" we mean operations that look at a single element in
381   // the LHS and a single element in the RHS to produce a single output element,
382   // even in the presence of broadcasting of one of the operands over the other.
383   static StatusOr<Shape> InferElementwiseBinaryOpShape(
384       HloOpcode operation, const Shape& lhs, const Shape& rhs,
385       absl::Span<const int64> broadcast_dimensions);
386 
387   // Helper for inferring the shape of Clamp ops.
388   static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand,
389                                          const Shape& max);
390 
391   // Helper for inferring the shape of Select ops.
392   static StatusOr<Shape> InferSelectShape(const Shape& pred,
393                                           const Shape& on_true,
394                                           const Shape& on_false);
395   // Helper for inferring the shape of TupleSelect ops.
396   static StatusOr<Shape> InferTupleSelectShape(const Shape& pred,
397                                                const Shape& on_true,
398                                                const Shape& on_false);
399 
400   // Helper for inferring shapes of binary operations which use degenerate
401   // dimension broadcasting (a dimension of size 1 in one operand is broadcast
402   // up to match the size of the dimension in the other operand).
403   static StatusOr<Shape> InferDegenerateDimensionBroadcastShape(
404       HloOpcode operation, const Shape& lhs, const Shape& rhs);
405 
406   // Helper for inferring shapes of binary operations using "InDim"
407   // broadcasting. This is the broadcasting used in the *InDim binary operations
408   // (for example ComputationBuilder::AddInDim). smaller_shape must be a
409   // lower-rank shape than larger_shape. Returns the shape that the
410   // smaller_shape is broadcast to.
411   static StatusOr<Shape> InferInDimBroadcastShape(
412       const Shape& smaller_shape, const Shape& larger_shape,
413       absl::Span<const int64> broadcast_dimensions);
414 
415   TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference);
416 };
417 
418 }  // namespace xla
419 
420 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
421