• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Shape inference is used by the XLA service as the user builds up
17 // computation requests.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
20 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
21 
22 #include <vector>
23 
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/lib/gtl/array_slice.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace xla {
34 
35 // For a given operation and input shapes, infers what the resulting shape is
36 // for the operation. With this functionality, the user does not need to specify
37 // the expected result type for computations that are built up via the API --
38 // the shape that results from an operation is inferred. Some methods have
39 // overloads for inferring shape at the HLO level.
40 //
41 // TODO(b/73352135): Shape inference does not issue very good error messages, in
42 // part because HloInstruction::ToString() is not available since shape
43 // inference runs before the HloInstruction object is created. We need a
44 // solution for this.
45 class ShapeInference {
46  public:
47   // Infers the shape produced by applying the given unary operation to the
48   // given input shape.
49   static StatusOr<Shape> InferUnaryOpShape(UnaryOperation operation,
50                                            const Shape& arg);
51   static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
52                                            const HloInstruction* operand);
53 
54   // Infers the shape produced by applying the given binary operation to the
55   // given input shapes.
56   static StatusOr<Shape> InferBinaryOpShape(
57       BinaryOperation operation, const Shape& lhs, const Shape& rhs,
58       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
59   static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode,
60                                             const HloInstruction* lhs,
61                                             const HloInstruction* rhs);
62 
63   // Infers the shape produced by applying the given ternary operation to the
64   // given input shapes.
65   static StatusOr<Shape> InferTernaryOpShape(TernaryOperation operation,
66                                              const Shape& lhs, const Shape& rhs,
67                                              const Shape& ehs);
68   static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode,
69                                              const HloInstruction* lhs,
70                                              const HloInstruction* rhs,
71                                              const HloInstruction* ehs);
72 
73   // Infers the shape produced by applying the given variadic operation to the
74   // given input operand shapes.
75   static StatusOr<Shape> InferVariadicOpShape(
76       VariadicOperation operation,
77       tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
78   static StatusOr<Shape> InferVariadicOpShape(
79       HloOpcode opcode,
80       tensorflow::gtl::ArraySlice<const HloInstruction*> operands);
81 
82   // Infers the shape produced by applying the given mapping computation shape
83   // to the given operand shapes.
84   static StatusOr<Shape> InferMapShape(
85       tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
86       const ProgramShape& to_apply,
87       tensorflow::gtl::ArraySlice<int64> dimensions);
88 
89   // Infers the shape produced by InferBatchNormTraining with the given
90   // operands.
91   static StatusOr<Shape> InferBatchNormTrainingShape(const Shape& operand_shape,
92                                                      const Shape& scale_shape,
93                                                      const Shape& offset_shape,
94                                                      int64 feature_index);
95 
96   // Infers the shape produced by InferBatchNormInference with the given
97   // operands.
98   static StatusOr<Shape> InferBatchNormInferenceShape(
99       const Shape& operand_shape, const Shape& scale_shape,
100       const Shape& offset_shape, const Shape& mean_shape,
101       const Shape& variance_shape, int64 feature_index);
102 
103   // Infers the shape produced by InferBatchNormGrad with the given operands.
104   static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape,
105                                                  const Shape& scale_shape,
106                                                  const Shape& mean_shape,
107                                                  const Shape& var_shape,
108                                                  const Shape& output_grad_shape,
109                                                  int64 feature_index);
110 
111   // Infers the shape produced by applying the given convolutional
112   // filter (rhs) to lhs in the way specified by the fields on window.
113   static StatusOr<Shape> InferConvolveShape(
114       const Shape& lhs, const Shape& rhs, const Window& window,
115       const ConvolutionDimensionNumbers& dimension_numbers);
116 
117   // Infers the shape produced by the given FFT type on the given operand.
118   static StatusOr<Shape> InferFftShape(
119       const Shape& in, FftType fft_type,
120       tensorflow::gtl::ArraySlice<int64> fft_length);
121 
122   // Infers the shape produced a cross replica sum with the given operand
123   // shapes.
124   static StatusOr<Shape> InferCrossReplicaSumShape(
125       tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
126 
127   // Infers the shape produced by applying the given reduction computation
128   // shape to the given input operand shape.
129   //
130   // If pass_index is true, the reduce function is invoked with the element
131   // index as the leading parameter, and the program shape should match
132   // accordingly (or an error will result).
133   static StatusOr<Shape> InferReduceShape(
134       const Shape& arg, const Shape& init_value,
135       tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
136       const ProgramShape& to_apply);
137 
138   // Infers the shape produced by applying the given computation to the operand
139   // shape with the given window and stride dimensions.
140   static StatusOr<Shape> InferReduceWindowShape(
141       const Shape& operand_shape, const Shape& init_value, const Window& window,
142       const ProgramShape& to_apply_shape);
143 
144   // Infers the shape produced by scattering the given source shape to the
145   // selected indices of each window on the operand shape.
146   static StatusOr<Shape> InferSelectAndScatterShape(
147       const Shape& operand_shape, const ProgramShape& select_shape,
148       const Window& window, const Shape& source_shape,
149       const Shape& init_value_shape, const ProgramShape& scatter_shape);
150 
151   // Infers the shape produced by a reverse operation that reverses the order
152   // of the elements in the given dimensions.
153   static StatusOr<Shape> InferReverseShape(
154       const Shape& operand_shape,
155       tensorflow::gtl::ArraySlice<int64> dimensions);
156 
157   // Infers the shape produced by a slice operation spanning from the starts to
158   // the limits in the original shape's dimensions.
159   //
160   // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16]
161   static StatusOr<Shape> InferSliceShape(
162       const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
163       tensorflow::gtl::ArraySlice<int64> limits,
164       tensorflow::gtl::ArraySlice<int64> strides);
165 
166   // Infers the shape produced by a dynamic slice operation of size specified
167   // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'.
168   static StatusOr<Shape> InferDynamicSliceShape(
169       const Shape& operand_shape, const Shape& start_indices_shape,
170       tensorflow::gtl::ArraySlice<int64> slice_sizes);
171 
172   // Infers the shape produced by a dynamic update slice operation based
173   // on the shape of operand and update.
174   static StatusOr<Shape> InferDynamicUpdateSliceShape(
175       const Shape& operand_shape, const Shape& update_shape,
176       const Shape& start_indices_shape);
177 
178   // Infers the shape produced by doing a compile-time-constant indexing into
179   // the given input shape. This is essential for operations on tuples, because
180   // it is impossible to infer the type that comes out of the tuple indexing if
181   // it is not a compile time constant.
182   static StatusOr<Shape> InferGetTupleElementShape(const Shape& arg,
183                                                    int64 index);
184 
185   // Infers the shape produced from a while node. condition and body are the
186   // shapes of computations for the condition and the body of a while node, and
187   // init is the shape of data initially passed in to the body as an argument.
188   // The shapes must match; condition: T -> PRED, body: T -> T, init: T
189   static StatusOr<Shape> InferWhileShape(const ProgramShape& condition,
190                                          const ProgramShape& body,
191                                          const Shape& init);
192 
193   // Infers the shape produced by a conditional operation.
194   static StatusOr<Shape> InferConditionalShape(
195       const Shape& predicate, const Shape& true_operand,
196       const Shape& false_operand, const ProgramShape& true_computation,
197       const ProgramShape& false_computation);
198 
199   // Infers the shape produced by a broadcast operation.
200   static StatusOr<Shape> InferBroadcastShape(
201       const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
202 
203   // Infers the shape produced by a reshape operation from the element type of
204   // its operand and the new dimension sizes specified.
205   static StatusOr<Shape> InferReshapeShape(
206       const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
207       tensorflow::gtl::ArraySlice<int64> new_sizes);
208 
209   // Infers the shape produced by a transpose operation from the element type of
210   // its operand and its dimensions field.
211   static StatusOr<Shape> InferTransposeShape(
212       const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
213 
214   // Helper that infers the shape produced by performing a concatenate operation
215   // with the given operand shapes.
216   static StatusOr<Shape> InferConcatOpShape(
217       tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, int64 dimension);
218 
219   // Helper that validates the given operand shape can be converted to the
220   // target output_shape via a convert instruction -- the requirement is that
221   // the shape is identical except for the element type.
222   static StatusOr<Shape> InferConvertShape(const Shape& operand_shape,
223                                            PrimitiveType new_element_type);
224 
225   // Helper that validates the given operand shape can be bitcast converted to
226   // the target output_shape via a bitcast convert instruction -- the
227   // requirement is that the shape is identical except for the element type and
228   // the element types have identical bit-widths.
229   static StatusOr<Shape> InferBitcastConvertShape(
230       const Shape& operand_shape, PrimitiveType new_element_type);
231 
232   // Helper that validates the input data type for a reduce-precision operation,
233   // and returns the result shape.
234   static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape,
235                                                    const int exponent_bits,
236                                                    const int mantissa_bits);
237 
238   // Helper that infers the shape produced by a pad operation based on the
239   // padding configuration.
240   static StatusOr<Shape> InferPadShape(const Shape& operand_shape,
241                                        const Shape& padding_value_shape,
242                                        const PaddingConfig& padding_config);
243 
244   // Helper that validates the given arg_shapes are compatible with the shape of
245   // the to_apply parameters, and returns the to_apply result shape.
246   static StatusOr<Shape> InferCallShape(
247       tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
248       const ProgramShape& to_apply);
249 
250   // Helper that infers the shape produced by performing a dot operation with
251   // the given LHS and RHS shapes.
252   static StatusOr<Shape> InferDotOpShape(
253       const Shape& lhs, const Shape& rhs,
254       const DotDimensionNumbers& dimension_numbers);
255 
256   // Helper that infers the shape of the tensor produced by a gather operation
257   // with the given input shape, gather indices shape and gather dimension
258   // numbers.
259   static StatusOr<Shape> InferGatherShape(
260       const Shape& input_shape, const Shape& gather_indices_shape,
261       const GatherDimensionNumbers& gather_dim_numbers,
262       tensorflow::gtl::ArraySlice<int64> window_bounds);
263 
264  private:
265   // Helper that infers the shape produced by performing an element-wise binary
266   // operation with the given LHS and RHS shapes.
267   // Note: By "element-wise" we mean operations that look at a single element in
268   // the LHS and a single element in the RHS to produce a single output element,
269   // even in the presence of broadcasting of one of the operands over the other.
270   static StatusOr<Shape> InferElementwiseBinaryOpShape(
271       BinaryOperation operation, const Shape& lhs, const Shape& rhs,
272       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
273 
274   // Helper for inferring the shape of Clamp ops.
275   static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand,
276                                          const Shape& max);
277 
278   // Helper for inferring the shape of Select ops.
279   static StatusOr<Shape> InferSelectShape(const Shape& pred,
280                                           const Shape& on_true,
281                                           const Shape& on_false);
282 
283   // Helper for inferring shapes of binary operations which use degenerate
284   // dimension broadcasting (a dimension of size 1 in one operand is broadcast
285   // up to match the size of the dimension in the other operand).
286   static StatusOr<Shape> InferDegenerateDimensionBroadcastShape(
287       BinaryOperation operation, const Shape& lhs, const Shape& rhs);
288 
289   // Helper for inferring shapes of binary operations using "InDim"
290   // broadcasting. This is the broadcasting used in the *InDim binary operations
291   // (for example ComputationBuilder::AddInDim). smaller_shape must be a
292   // lower-rank shape than larger_shape. Returns the shape that the
293   // smaller_shape is broadcast to.
294   static StatusOr<Shape> InferInDimBroadcastShape(
295       BinaryOperation operation, const Shape& smaller_shape,
296       const Shape& larger_shape,
297       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
298 
299   TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference);
300 };
301 
302 }  // namespace xla
303 
304 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
305