• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
18 
19 #include <memory>
20 #include <optional>
21 
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 
27 namespace xla {
28 
29 // Some lightweight utilities intended to make HLO instruction creation more
30 // ergonomic.  We don't have a complete set of helpers yet -- I expect we'll
31 // expand this interface as needed on an ad-hoc basis.
32 
33 // Creates a unary HLO instruction and adds it to the computation containing
34 // `operand`.
35 StatusOr<HloInstruction*> MakeUnaryHlo(HloOpcode opcode,
36                                        HloInstruction* operand,
37                                        const OpMetadata* metadata = nullptr);
38 
39 // Creates a binary HLO instruction and adds it to the computation containing
40 // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
41 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
42                                         HloInstruction* rhs,
43                                         const OpMetadata* metadata = nullptr);
44 
45 // Creates a kCopy HLO.
46 HloInstruction* MakeCopyHlo(HloInstruction* from, const Shape& to);
47 
48 // Creates a compare HLO instruction and adds it to the computation containing
49 // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
50 StatusOr<HloInstruction*> MakeCompareHlo(Comparison::Direction direction,
51                                          HloInstruction* lhs,
52                                          HloInstruction* rhs,
53                                          const OpMetadata* metadata = nullptr);
54 
55 // Creates a pad HLO instruction and adds it to the computation containing
56 // `operand` and `padding_value` (`operand` and `padding_value` must be in the
57 // same computation).
58 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
59                                      HloInstruction* padding_value,
60                                      const PaddingConfig& padding_config,
61                                      const OpMetadata* metadata = nullptr);
62 
63 // Creates a slice HLO instruction and adds it to the computation containing
64 // `operand`.
65 StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
66                                        absl::Span<const int64_t> start_indices,
67                                        absl::Span<const int64_t> limit_indices,
68                                        absl::Span<const int64_t> strides,
69                                        const OpMetadata* metadata = nullptr);
70 
71 // Creates a convolution HLO instruction and adds it to the computation
72 // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
73 // If the result shape has integral element type, an optional
74 // preferred_element_type can be specified to override the element type.
75 StatusOr<HloInstruction*> MakeConvolveHlo(
76     HloInstruction* lhs, HloInstruction* rhs, int64_t feature_group_count,
77     int64_t batch_group_count, const Window& window,
78     const ConvolutionDimensionNumbers& dimension_numbers,
79     const PrecisionConfig& precision_config,
80     std::optional<PrimitiveType> preferred_element_type,
81     const OpMetadata* metadata = nullptr);
82 
83 // Creates a transpose HLO instruction and adds it to the computation containing
84 // `operand`.
85 StatusOr<HloInstruction*> MakeTransposeHlo(
86     HloInstruction* operand, absl::Span<const int64_t> dimensions);
87 
88 // Creates a reshape HLO instruction and adds it to the computation containing
89 // `operand`.
90 StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
91                                          HloInstruction* operand);
92 
93 StatusOr<HloInstruction*> MakeReshapeHlo(
94     absl::Span<const int64_t> result_shape_dim_bounds, HloInstruction* operand);
95 
96 // Creates a dynamic-slice HLO instruction and adds it to the computation
97 // containing `operand` and `start_indices` (`operand` and `start_indices` must
98 // be in the same computation).
99 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
100     HloInstruction* operand, absl::Span<HloInstruction* const> start_indices,
101     absl::Span<const int64_t> slice_sizes,
102     const OpMetadata* metadata = nullptr);
103 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
104     HloInstruction* operand, HloInstruction* start_indices,
105     absl::Span<const int64_t> slice_sizes,
106     const OpMetadata* metadata = nullptr);
107 
108 // Creates a dynamic-update-slice HLO instruction and adds it to the computation
109 // containing `operand`, `update` and `start_indices` (`operand`, `update` and
110 // `start_indices` must be in the same computation).
111 StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
112     HloInstruction* operand, HloInstruction* update,
113     HloInstruction* start_indices, const OpMetadata* metadata = nullptr);
114 
115 // Creates a broadcast HLO instruction and adds it to the computation containing
116 // `operand`.
117 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
118                                  absl::Span<const int64_t> broadcast_dimensions,
119                                  absl::Span<const int64_t> result_shape_bounds,
120                                  const OpMetadata* metadata = nullptr);
121 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
122                                  absl::Span<const int64_t> broadcast_dimensions,
123                                  const Shape& shape,
124                                  const OpMetadata* metadata = nullptr);
125 
126 // Creates a GetTupleElement HLO instruction and adds it to the computation
127 // containing `operand`.
128 StatusOr<HloInstruction*> MakeGetTupleElementHlo(
129     HloInstruction* operand, int64_t index,
130     const OpMetadata* metadata = nullptr);
131 
132 // Creates a Concatenate HLO instruction and adds it to the computation
133 // containing `operands` (`operands` must be non-empty and every element must be
134 // contained in the same computation).
135 StatusOr<HloInstruction*> MakeConcatHlo(
136     absl::Span<HloInstruction* const> operands, int64_t dimension,
137     const OpMetadata* metadata = nullptr);
138 
139 // Creates a Convert HLO instruction that converts the given instruction to have
140 // the given primitive type.
141 HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type,
142                                  const OpMetadata* metadata = nullptr);
143 
144 // Creates a Bitcast HLO instruction to the given shape+layout.
145 HloInstruction* MakeBitcastHlo(HloInstruction* hlo, const Shape& shape,
146                                const OpMetadata* metadata = nullptr);
147 
148 // Creates a BitcastConvert HLO instruction.
149 HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo, PrimitiveType type,
150                                         const OpMetadata* metadata = nullptr);
151 
152 // Creates an Iota HLO instruction.
153 HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
154                             int64_t iota_dimension);
155 
156 // Creates a Dot HLO instruction and adds it to the computation containing `lhs`
157 // and `rhs` (both must be in the same computation). If the result shape has
158 // integral element type, an optional preferred_element_type can be specified to
159 // override the element type.
160 StatusOr<HloInstruction*> MakeDotHlo(
161     HloInstruction* lhs, HloInstruction* rhs,
162     const DotDimensionNumbers& dim_numbers,
163     const PrecisionConfig& precision_config,
164     std::optional<PrimitiveType> preferred_element_type,
165     const OpMetadata* metadata = nullptr);
166 
167 // Creates a Map HLO instruction and adds it to the computation containing the
168 // operands. All operands must be in the same computation.
169 StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
170                                      HloComputation* map_computation,
171                                      const OpMetadata* metadata = nullptr);
172 
173 // Creates a reduce-precision op, where operand is the data to reduce in
174 // precision, and exponent_bits and mantissa_bits describe the precision to
175 // reduce it to.
176 HloInstruction* MakeReducePrecisionHlo(HloInstruction* operand,
177                                        int exponent_bits, int mantissa_bits,
178                                        const OpMetadata* metadata = nullptr);
179 
180 // Creates a Reduce HLO instruction and adds it to the computation containing
181 // the operand. This will create the sub-computation needed for the reduction in
182 // the given module. binary_opcode should represent a binary operation.
183 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
184                                         HloInstruction* init_value,
185                                         absl::Span<const int64_t> dimensions,
186                                         HloOpcode binary_opcode,
187                                         const OpMetadata* metadata = nullptr);
188 
189 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
190                                         HloInstruction* init_value,
191                                         absl::Span<const int64_t> dimensions,
192                                         HloComputation* reduce_computation,
193                                         const OpMetadata* metadata = nullptr);
194 
195 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
196                                         HloInstruction* init_value,
197                                         HloOpcode binary_opcode,
198                                         HloModule* module,
199                                         const OpMetadata* metadata = nullptr);
200 
201 // Generic helper function to create a reduction.
202 //
203 // Precondition: size of operands is equal to the size of init values and equal
204 // to the size of the computation output shape.
205 //
206 // Creates a non-variadic reduction if the size is singular, and a variadic one
207 // otherwise.
208 StatusOr<HloInstruction*> MakeReduceHlo(
209     absl::Span<HloInstruction* const> operands,
210     absl::Span<HloInstruction* const> init_values,
211     absl::Span<const int64_t> dimensions, HloComputation* reduce_computation,
212     const OpMetadata* metadata = nullptr);
213 
214 // Creates a Reverse HLO instruction and adds it to the computation containing
215 // `operand`.
216 StatusOr<HloInstruction*> MakeReverseHlo(HloInstruction* operand,
217                                          absl::Span<const int64_t> dimensions,
218                                          const OpMetadata* metadata = nullptr);
219 
220 // Creates a Select HLO instruction and adds it to the computation containing
221 // the predicate. The on_true and on_false instructions must also be contained
222 // in the same computation. If on_true and on_false are tuples, create a tuple
223 // select instead. `pred` is broadcasted up from a scalar if necessary.
224 StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
225                                         HloInstruction* on_true,
226                                         HloInstruction* on_false,
227                                         HloInstruction* derived_from = nullptr);
228 
229 // Forwards the first operand if operands.size() == 1, or creates a tuple
230 // instruction with all the operands. Crashes if `operands` is empty.
231 HloInstruction* MaybeMakeTuple(absl::Span<HloInstruction* const> operands);
232 
233 // Creates a Sort HLO instruction and adds it to the computation containing the
234 // operands. All operands must be in the same computation. Also creates a
235 // default compare sub-computation which sorts the first operand into ascending
236 // order. 'is_stable' specifies whether the sorting should be stable.
237 StatusOr<HloInstruction*> MakeSortHlo(
238     const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
239     int64_t dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
240     HloModule* module, const OpMetadata* metadata = nullptr);
241 
242 // Creates an R1 Constant HLO instruction of the given PrimitiveType with the
243 // given values and adds it to the given computation.
244 template <typename NativeT>
MakeR1ConstantHlo(HloComputation * computation,PrimitiveType type,absl::Span<const NativeT> values)245 StatusOr<HloInstruction*> MakeR1ConstantHlo(HloComputation* computation,
246                                             PrimitiveType type,
247                                             absl::Span<const NativeT> values) {
248   Literal literal = LiteralUtil::CreateR1<NativeT>(values);
249   if (literal.shape().element_type() != type) {
250     TF_ASSIGN_OR_RETURN(literal, literal.Convert(type));
251   }
252   return computation->AddInstruction(
253       HloInstruction::CreateConstant(std::move(literal)));
254 }
255 
256 // Creates an R0 Constant HLO instruction of the PrimitiveType corresponding to
257 // `NativeT` with the given value and adds it to the given computation.
258 template <class NativeT>
MakeR0ConstantHlo(HloComputation * computation,NativeT value)259 HloInstruction* MakeR0ConstantHlo(HloComputation* computation, NativeT value) {
260   return computation->AddInstruction(
261       HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)));
262 }
263 
264 // Makes a scalar that is elementwise compatible with the shape of the base
265 // instruction.
266 template <class NativeT>
MakeScalarLike(HloInstruction * base,NativeT value)267 HloInstruction* MakeScalarLike(HloInstruction* base, NativeT value) {
268   auto scalar = base->AddInstruction(
269       HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)
270                                          .Convert(base->shape().element_type())
271                                          .ValueOrDie()));
272   if (base->shape().rank() == 0) {
273     *scalar->mutable_shape() = base->shape();
274     return scalar;
275   }
276   return base->AddInstruction(
277       HloInstruction::CreateBroadcast(base->shape(), scalar, {}));
278 }
279 
280 // Creates a fusion instruction and fuses `fused` into the created fusion
281 // instruction.
282 StatusOr<HloInstruction*> MakeFusionInstruction(
283     HloInstruction* fused, HloInstruction::FusionKind kind);
284 
285 // -----------------------------------------------------------------------------
286 // Some other miscellaneous helpers to generate common HLO patterns.  All of
287 // these add all the instructions they generate into the computation containing
288 // their operand(s).
289 
290 // Collapses (via reshape) the first N (logical) dimensions of `operand` into a
291 // single leading dimension.  `operand` must have rank > `n` and `n` must not be
292 // 0.
293 //
294 // For instance if `operand` has shape f32[7,8,9] and n is 2 then the output is
295 // the `operand` reshaped to [56,9].
296 StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand,
297                                              int64_t n);
298 
299 // Prepends `n` degenerate dimensions (dimensions with bound = 1) to `operand`
300 // using a reshape.
301 //
302 // For instance if operand has shape f32[3,4,5] then this returns the operand
303 // reshaped to f32[1,3,4,5].  If the operand is a f32 scalar (i.e. has shape
304 // f32[]) then this returns the operand reshaped to f32[1].
305 StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
306                                                 int64_t n);
307 
308 // Expands (via reshape) the first (logical) dimension of `operand` into a
309 // sequence of `expanded_dims` dimensions.  `operand` must at least be of rank 1
310 // and the number of elements in its first dimension must be equal to the
311 // product of `expanded_dims`.
312 //
313 // For instance if `operand` has shape f32[200,9,7] and expanded_dims is
314 // {2,5,20} the result is `operand` reshaped to [2,5,20,9,7].
315 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
316     HloInstruction* operand, absl::Span<const int64_t> expanded_dims);
317 
318 // Elides (via reshape) a set of degenerate dimensions (dimensions containing
319 // exactly one element), `dims_to_elide` from `operand`.  Every dimension in
320 // `dims_to_elide` must be a degenerate dimension.  `dims_to_elide` must be
321 // sorted and not contain duplicates.
322 //
323 // For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide
324 // is {1,5} then the result is `operand` reshaped to [19,20,1,7,9].
325 StatusOr<HloInstruction*> ElideDegenerateDims(
326     HloInstruction* operand, absl::Span<const int64_t> dims_to_elide);
327 
328 // Inserts (via reshape) a set of degenerate dimensions (dimensions containing
329 // exactly one element), `dims_to_insert` into `operand`. The dimensions in
330 // `dims_to_insert` refer to the dimensions in the result, and hence should be
331 // less than the rank of the result. Also, `dims_to_insert` must be sorted.
332 //
333 // For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is
334 // {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34].
335 StatusOr<HloInstruction*> InsertDegenerateDims(
336     HloInstruction* operand, absl::Span<const int64_t> dims_to_insert);
337 
338 // Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the
339 // front and `zeros_to_append` zeros in the back.
340 StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
341                                              int64_t zeros_to_prepend,
342                                              int64_t zeros_to_append);
343 
344 // Broadcasts a zero value of type `element_type` into a tensor with element
345 // type `element_type` and dimension bounds `broadcast_dimensions`.  The
346 // broadcast instruction is emitted into `computation`.
347 HloInstruction* BroadcastZeros(HloComputation* computation,
348                                PrimitiveType element_type,
349                                absl::Span<const int64_t> broadcast_dimensions);
350 
351 // Same as above, but fill the tensor with ones.
352 HloInstruction* BroadcastOnes(HloComputation* computation,
353                               PrimitiveType element_type,
354                               absl::Span<const int64_t> broadcast_dimensions);
355 
356 // Creates a HLO computation that takes arguments of type `domain` and produces
357 // a value of type `range`.
358 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
359     absl::Span<const Shape* const> domain, const Shape& range,
360     absl::string_view name);
361 
362 }  // namespace xla
363 
364 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
365