1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
18
19 #include <memory>
20 #include <optional>
21
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26
27 namespace xla {
28
29 // Some lightweight utilities intended to make HLO instruction creation more
30 // ergonomic. We don't have a complete set of helpers yet -- I expect we'll
31 // expand this interface as needed on an ad-hoc basis.
32
33 // Creates a unary HLO instruction and adds it to the computation containing
34 // `operand`.
35 StatusOr<HloInstruction*> MakeUnaryHlo(HloOpcode opcode,
36 HloInstruction* operand,
37 const OpMetadata* metadata = nullptr);
38
39 // Creates a binary HLO instruction and adds it to the computation containing
40 // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
41 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
42 HloInstruction* rhs,
43 const OpMetadata* metadata = nullptr);
44
45 // Creates a kCopy HLO.
46 HloInstruction* MakeCopyHlo(HloInstruction* from, const Shape& to);
47
48 // Creates a compare HLO instruction and adds it to the computation containing
49 // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
50 StatusOr<HloInstruction*> MakeCompareHlo(Comparison::Direction direction,
51 HloInstruction* lhs,
52 HloInstruction* rhs,
53 const OpMetadata* metadata = nullptr);
54
55 // Creates a pad HLO instruction and adds it to the computation containing
56 // `operand` and `padding_value` (`operand` and `padding_value` must be in the
57 // same computation).
58 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
59 HloInstruction* padding_value,
60 const PaddingConfig& padding_config,
61 const OpMetadata* metadata = nullptr);
62
63 // Creates a slice HLO instruction and adds it to the computation containing
64 // `operand`.
65 StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
66 absl::Span<const int64_t> start_indices,
67 absl::Span<const int64_t> limit_indices,
68 absl::Span<const int64_t> strides,
69 const OpMetadata* metadata = nullptr);
70
71 // Creates a convolution HLO instruction and adds it to the computation
72 // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
73 // If the result shape has integral element type, an optional
74 // preferred_element_type can be specified to override the element type.
75 StatusOr<HloInstruction*> MakeConvolveHlo(
76 HloInstruction* lhs, HloInstruction* rhs, int64_t feature_group_count,
77 int64_t batch_group_count, const Window& window,
78 const ConvolutionDimensionNumbers& dimension_numbers,
79 const PrecisionConfig& precision_config,
80 std::optional<PrimitiveType> preferred_element_type,
81 const OpMetadata* metadata = nullptr);
82
83 // Creates a transpose HLO instruction and adds it to the computation containing
84 // `operand`.
85 StatusOr<HloInstruction*> MakeTransposeHlo(
86 HloInstruction* operand, absl::Span<const int64_t> dimensions);
87
88 // Creates a reshape HLO instruction and adds it to the computation containing
89 // `operand`.
90 StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
91 HloInstruction* operand);
92
93 StatusOr<HloInstruction*> MakeReshapeHlo(
94 absl::Span<const int64_t> result_shape_dim_bounds, HloInstruction* operand);
95
96 // Creates a dynamic-slice HLO instruction and adds it to the computation
97 // containing `operand` and `start_indices` (`operand` and `start_indices` must
98 // be in the same computation).
99 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
100 HloInstruction* operand, absl::Span<HloInstruction* const> start_indices,
101 absl::Span<const int64_t> slice_sizes,
102 const OpMetadata* metadata = nullptr);
103 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
104 HloInstruction* operand, HloInstruction* start_indices,
105 absl::Span<const int64_t> slice_sizes,
106 const OpMetadata* metadata = nullptr);
107
108 // Creates a dynamic-update-slice HLO instruction and adds it to the computation
109 // containing `operand`, `update` and `start_indices` (`operand`, `update` and
110 // `start_indices` must be in the same computation).
111 StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
112 HloInstruction* operand, HloInstruction* update,
113 HloInstruction* start_indices, const OpMetadata* metadata = nullptr);
114
115 // Creates a broadcast HLO instruction and adds it to the computation containing
116 // `operand`.
117 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
118 absl::Span<const int64_t> broadcast_dimensions,
119 absl::Span<const int64_t> result_shape_bounds,
120 const OpMetadata* metadata = nullptr);
121 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
122 absl::Span<const int64_t> broadcast_dimensions,
123 const Shape& shape,
124 const OpMetadata* metadata = nullptr);
125
126 // Creates a GetTupleElement HLO instruction and adds it to the computation
127 // containing `operand`.
128 StatusOr<HloInstruction*> MakeGetTupleElementHlo(
129 HloInstruction* operand, int64_t index,
130 const OpMetadata* metadata = nullptr);
131
132 // Creates a Concatenate HLO instruction and adds it to the computation
133 // containing `operands` (`operands` must be non-empty and every element must be
134 // contained in the same computation).
135 StatusOr<HloInstruction*> MakeConcatHlo(
136 absl::Span<HloInstruction* const> operands, int64_t dimension,
137 const OpMetadata* metadata = nullptr);
138
139 // Creates a Convert HLO instruction that converts the given instruction to have
140 // the given primitive type.
141 HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type,
142 const OpMetadata* metadata = nullptr);
143
144 // Creates a Bitcast HLO instruction to the given shape+layout.
145 HloInstruction* MakeBitcastHlo(HloInstruction* hlo, const Shape& shape,
146 const OpMetadata* metadata = nullptr);
147
148 // Creates a BitcastConvert HLO instruction.
149 HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo, PrimitiveType type,
150 const OpMetadata* metadata = nullptr);
151
152 // Creates an Iota HLO instruction.
153 HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
154 int64_t iota_dimension);
155
156 // Creates a Dot HLO instruction and adds it to the computation containing `lhs`
157 // and `rhs` (both must be in the same computation). If the result shape has
158 // integral element type, an optional preferred_element_type can be specified to
159 // override the element type.
160 StatusOr<HloInstruction*> MakeDotHlo(
161 HloInstruction* lhs, HloInstruction* rhs,
162 const DotDimensionNumbers& dim_numbers,
163 const PrecisionConfig& precision_config,
164 std::optional<PrimitiveType> preferred_element_type,
165 const OpMetadata* metadata = nullptr);
166
167 // Creates a Map HLO instruction and adds it to the computation containing the
168 // operands. All operands must be in the same computation.
169 StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
170 HloComputation* map_computation,
171 const OpMetadata* metadata = nullptr);
172
173 // Creates a reduce-precision op, where operand is the data to reduce in
174 // precision, and exponent_bits and mantissa_bits describe the precision to
175 // reduce it to.
176 HloInstruction* MakeReducePrecisionHlo(HloInstruction* operand,
177 int exponent_bits, int mantissa_bits,
178 const OpMetadata* metadata = nullptr);
179
180 // Creates a Reduce HLO instruction and adds it to the computation containing
181 // the operand. This will create the sub-computation needed for the reduction in
182 // the given module. binary_opcode should represent a binary operation.
183 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
184 HloInstruction* init_value,
185 absl::Span<const int64_t> dimensions,
186 HloOpcode binary_opcode,
187 const OpMetadata* metadata = nullptr);
188
189 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
190 HloInstruction* init_value,
191 absl::Span<const int64_t> dimensions,
192 HloComputation* reduce_computation,
193 const OpMetadata* metadata = nullptr);
194
195 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
196 HloInstruction* init_value,
197 HloOpcode binary_opcode,
198 HloModule* module,
199 const OpMetadata* metadata = nullptr);
200
201 // Generic helper function to create a reduction.
202 //
203 // Precondition: size of operands is equal to the size of init values and equal
204 // to the size of the computation output shape.
205 //
206 // Creates a non-variadic reduction if the size is singular, and a variadic one
207 // otherwise.
208 StatusOr<HloInstruction*> MakeReduceHlo(
209 absl::Span<HloInstruction* const> operands,
210 absl::Span<HloInstruction* const> init_values,
211 absl::Span<const int64_t> dimensions, HloComputation* reduce_computation,
212 const OpMetadata* metadata = nullptr);
213
214 // Creates a Reverse HLO instruction and adds it to the computation containing
215 // `operand`.
216 StatusOr<HloInstruction*> MakeReverseHlo(HloInstruction* operand,
217 absl::Span<const int64_t> dimensions,
218 const OpMetadata* metadata = nullptr);
219
220 // Creates a Select HLO instruction and adds it to the computation containing
221 // the predicate. The on_true and on_false instructions must also be contained
222 // in the same computation. If on_true and on_false are tuples, create a tuple
223 // select instead. `pred` is broadcasted up from a scalar if necessary.
224 StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
225 HloInstruction* on_true,
226 HloInstruction* on_false,
227 HloInstruction* derived_from = nullptr);
228
229 // Forwards the first operand if operands.size() == 1, or creates a tuple
230 // instruction with all the operands. Crashes if `operands` is empty.
231 HloInstruction* MaybeMakeTuple(absl::Span<HloInstruction* const> operands);
232
233 // Creates a Sort HLO instruction and adds it to the computation containing the
234 // operands. All operands must be in the same computation. Also creates a
235 // default compare sub-computation which sorts the first operand into ascending
236 // order. 'is_stable' specifies whether the sorting should be stable.
237 StatusOr<HloInstruction*> MakeSortHlo(
238 const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
239 int64_t dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
240 HloModule* module, const OpMetadata* metadata = nullptr);
241
242 // Creates an R1 Constant HLO instruction of the given PrimitiveType with the
243 // given values and adds it to the given computation.
244 template <typename NativeT>
MakeR1ConstantHlo(HloComputation * computation,PrimitiveType type,absl::Span<const NativeT> values)245 StatusOr<HloInstruction*> MakeR1ConstantHlo(HloComputation* computation,
246 PrimitiveType type,
247 absl::Span<const NativeT> values) {
248 Literal literal = LiteralUtil::CreateR1<NativeT>(values);
249 if (literal.shape().element_type() != type) {
250 TF_ASSIGN_OR_RETURN(literal, literal.Convert(type));
251 }
252 return computation->AddInstruction(
253 HloInstruction::CreateConstant(std::move(literal)));
254 }
255
256 // Creates an R0 Constant HLO instruction of the PrimitiveType corresponding to
257 // `NativeT` with the given value and adds it to the given computation.
258 template <class NativeT>
MakeR0ConstantHlo(HloComputation * computation,NativeT value)259 HloInstruction* MakeR0ConstantHlo(HloComputation* computation, NativeT value) {
260 return computation->AddInstruction(
261 HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)));
262 }
263
264 // Makes a scalar that is elementwise compatible with the shape of the base
265 // instruction.
266 template <class NativeT>
MakeScalarLike(HloInstruction * base,NativeT value)267 HloInstruction* MakeScalarLike(HloInstruction* base, NativeT value) {
268 auto scalar = base->AddInstruction(
269 HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)
270 .Convert(base->shape().element_type())
271 .ValueOrDie()));
272 if (base->shape().rank() == 0) {
273 *scalar->mutable_shape() = base->shape();
274 return scalar;
275 }
276 return base->AddInstruction(
277 HloInstruction::CreateBroadcast(base->shape(), scalar, {}));
278 }
279
280 // Creates a fusion instruction and fuses `fused` into the created fusion
281 // instruction.
282 StatusOr<HloInstruction*> MakeFusionInstruction(
283 HloInstruction* fused, HloInstruction::FusionKind kind);
284
285 // -----------------------------------------------------------------------------
286 // Some other miscellaneous helpers to generate common HLO patterns. All of
287 // these add all the instructions they generate into the computation containing
288 // their operand(s).
289
290 // Collapses (via reshape) the first N (logical) dimensions of `operand` into a
291 // single leading dimension. `operand` must have rank > `n` and `n` must not be
292 // 0.
293 //
294 // For instance if `operand` has shape f32[7,8,9] and n is 2 then the output is
295 // the `operand` reshaped to [56,9].
296 StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand,
297 int64_t n);
298
299 // Prepends `n` degenerate dimensions (dimensions with bound = 1) to `operand`
300 // using a reshape.
301 //
302 // For instance if operand has shape f32[3,4,5] then this returns the operand
303 // reshaped to f32[1,3,4,5]. If the operand is a f32 scalar (i.e. has shape
304 // f32[]) then this returns the operand reshaped to f32[1].
305 StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
306 int64_t n);
307
308 // Expands (via reshape) the first (logical) dimension of `operand` into a
309 // sequence of `expanded_dims` dimensions. `operand` must at least be of rank 1
310 // and the number of elements in its first dimension must be equal to the
311 // product of `expanded_dims`.
312 //
313 // For instance if `operand` has shape f32[200,9,7] and expanded_dims is
314 // {2,5,20} the result is `operand` reshaped to [2,5,20,9,7].
315 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
316 HloInstruction* operand, absl::Span<const int64_t> expanded_dims);
317
318 // Elides (via reshape) a set of degenerate dimensions (dimensions containing
319 // exactly one element), `dims_to_elide` from `operand`. Every dimension in
320 // `dims_to_elide` must be a degenerate dimension. `dims_to_elide` must be
321 // sorted and not contain duplicates.
322 //
323 // For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide
324 // is {1,5} then the result is `operand` reshaped to [19,20,1,7,9].
325 StatusOr<HloInstruction*> ElideDegenerateDims(
326 HloInstruction* operand, absl::Span<const int64_t> dims_to_elide);
327
328 // Inserts (via reshape) a set of degenerate dimensions (dimensions containing
329 // exactly one element), `dims_to_insert` into `operand`. The dimensions in
330 // `dims_to_insert` refer to the dimensions in the result, and hence should be
331 // less than the rank of the result. Also, `dims_to_insert` must be sorted.
332 //
333 // For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is
334 // {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34].
335 StatusOr<HloInstruction*> InsertDegenerateDims(
336 HloInstruction* operand, absl::Span<const int64_t> dims_to_insert);
337
338 // Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the
339 // front and `zeros_to_append` zeros in the back.
340 StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
341 int64_t zeros_to_prepend,
342 int64_t zeros_to_append);
343
344 // Broadcasts a zero value of type `element_type` into a tensor with element
345 // type `element_type` and dimension bounds `broadcast_dimensions`. The
346 // broadcast instruction is emitted into `computation`.
347 HloInstruction* BroadcastZeros(HloComputation* computation,
348 PrimitiveType element_type,
349 absl::Span<const int64_t> broadcast_dimensions);
350
351 // Same as above, but fill the tensor with ones.
352 HloInstruction* BroadcastOnes(HloComputation* computation,
353 PrimitiveType element_type,
354 absl::Span<const int64_t> broadcast_dimensions);
355
356 // Creates a HLO computation that takes arguments of type `domain` and produces
357 // a value of type `range`.
358 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
359 absl::Span<const Shape* const> domain, const Shape& range,
360 absl::string_view name);
361
362 } // namespace xla
363
364 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
365