1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
18
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/statusor.h"
23
24 namespace xla {
25
26 // Some lightweight utilities intended to make HLO instruction creation more
27 // ergonomic. We don't have a complete set of helpers yet -- I expect we'll
28 // expand this interface as needed on an ad-hoc basis.
29
30 // Creates a unary HLO instruction and adds it to the computation containing
31 // `operand`.
32 StatusOr<HloInstruction*> MakeUnaryHlo(HloOpcode opcode,
33 HloInstruction* operand);
34
35 // Creates a binary HLO instruction and adds it to the computation containing
36 // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
37 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
38 HloInstruction* rhs);
39
40 // Creates a compare HLO instruction and adds it to the computation containing
41 // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
42 StatusOr<HloInstruction*> MakeCompareHlo(Comparison::Direction direction,
43 HloInstruction* lhs,
44 HloInstruction* rhs);
45
46 // Creates a pad HLO instruction and adds it to the computation containing
47 // `operand` and `padding_value` (`operand` and `padding_value` must be in the
48 // same computation).
49 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
50 HloInstruction* padding_value,
51 const PaddingConfig& padding_config);
52
53 // Creates a slice HLO instruction and adds it to the computation containing
54 // `operand`.
55 StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
56 absl::Span<const int64> start_indices,
57 absl::Span<const int64> limit_indices,
58 absl::Span<const int64> strides);
59
60 // Creates a convolution HLO instruction and adds it to the computation
61 // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
62 // If the result shape has integral element type, an optional
63 // preferred_element_type can be specified to override the element type.
64 StatusOr<HloInstruction*> MakeConvolveHlo(
65 HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
66 int64 batch_group_count, const Window& window,
67 const ConvolutionDimensionNumbers& dimension_numbers,
68 const PrecisionConfig& precision_config,
69 absl::optional<PrimitiveType> preferred_element_type);
70
71 // Creates a transpose HLO instruction and adds it to the computation containing
72 // `operand`.
73 StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
74 absl::Span<const int64> dimensions);
75
76 // Creates a reshape HLO instruction and adds it to the computation containing
77 // `operand`.
78 StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
79 HloInstruction* operand);
80
81 StatusOr<HloInstruction*> MakeReshapeHlo(
82 absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand);
83
84 // Creates a dynamic-slice HLO instruction and adds it to the computation
85 // containing `operand` and `start_indices` (`operand` and `start_indices` must
86 // be in the same computation).
87 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
88 HloInstruction* operand, absl::Span<HloInstruction* const> start_indices,
89 absl::Span<const int64> slice_sizes);
90 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
91 HloInstruction* operand, HloInstruction* start_indices,
92 absl::Span<const int64> slice_sizes);
93
94 // Creates a dynamic-update-slice HLO instruction and adds it to the computation
95 // containing `operand`, `update` and `start_indices` (`operand`, `update` and
96 // `start_indices` must be in the same computation).
97 StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
98 HloInstruction* operand, HloInstruction* update,
99 HloInstruction* start_indices);
100
101 // Creates a broadcast HLO instruction and adds it to the computation containing
102 // `operand`.
103 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
104 absl::Span<const int64> broadcast_dimensions,
105 absl::Span<const int64> result_shape_bounds);
106 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
107 absl::Span<const int64> broadcast_dimensions,
108 const Shape& shape);
109
110 // Creates a GetTupleElement HLO instruction and adds it to the computation
111 // containing `operand`.
112 StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
113 int64 index);
114
115 // Creates a Concatenate HLO instruction and adds it to the computation
116 // containing `operands` (`operands` must be non-empty and every element must be
117 // contained in the same computation).
118 StatusOr<HloInstruction*> MakeConcatHlo(
119 absl::Span<HloInstruction* const> operands, int64 dimension);
120
121 // Creates a Convert HLO instruction that converts the given instruction to have
122 // the given primitive type.
123 HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type);
124
125 // Creates a BitcastConvert HLO instruction.
126 HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo,
127 PrimitiveType type);
128
129 // Creates an Iota HLO instruction.
130 HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
131 int64 iota_dimension);
132
133 // Creates a Dot HLO instruction and adds it to the computation containing `lhs`
134 // and `rhs` (both must be in the same computation). If the result shape has
135 // integral element type, an optional preferred_element_type can be specified to
136 // override the element type.
137 StatusOr<HloInstruction*> MakeDotHlo(
138 HloInstruction* lhs, HloInstruction* rhs,
139 const DotDimensionNumbers& dim_numbers,
140 const PrecisionConfig& precision_config,
141 absl::optional<PrimitiveType> preferred_element_type);
142
143 // Creates a Map HLO instruction and adds it to the computation containing the
144 // operands. All operands must be in the same computation.
145 StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
146 HloComputation* map_computation);
147
148 // Creates a Reduce HLO instruction and adds it to the computation containing
149 // the operand. This will create the sub-computation needed for the reduction in
150 // the given module. binary_opcode should represent a binary operation.
151 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
152 HloInstruction* init_value,
153 absl::Span<const int64> dimensions,
154 HloOpcode binary_opcode);
155
156 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
157 HloInstruction* init_value,
158 HloOpcode binary_opcode,
159 HloModule* module);
160
161 // Creates a Reverse HLO instruction and adds it to the computation containing
162 // `operand`.
163 StatusOr<HloInstruction*> MakeReverseHlo(HloInstruction* operand,
164 absl::Span<const int64> dimensions);
165
166 // Creates a Select HLO instruction and adds it to the computation containing
167 // the predicate. The on_true and on_false instructions must also be contained
168 // in the same computation. If on_true and on_false are tuples, create a tuple
169 // select instead. `pred` is broadcasted up from a scalar if necessary.
170 StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
171 HloInstruction* on_true,
172 HloInstruction* on_false,
173 HloInstruction* derived_from = nullptr);
174
175 // Creates a Sort HLO instruction and adds it to the computation containing the
176 // operands. All operands must be in the same computation. Also creates a
177 // default compare sub-computation which sorts the first operand into ascending
178 // order. 'is_stable' specifies whether the sorting should be stable.
179 StatusOr<HloInstruction*> MakeSortHlo(
180 const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
181 int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
182 HloModule* module);
183
184 // Creates an R1 Constant HLO instruction of the given PrimitiveType with the
185 // given values and adds it to the given computation.
186 template <typename NativeT>
MakeR1ConstantHlo(HloComputation * computation,PrimitiveType type,absl::Span<const NativeT> values)187 StatusOr<HloInstruction*> MakeR1ConstantHlo(HloComputation* computation,
188 PrimitiveType type,
189 absl::Span<const NativeT> values) {
190 Literal literal = LiteralUtil::CreateR1<NativeT>(values);
191 if (literal.shape().element_type() != type) {
192 TF_ASSIGN_OR_RETURN(literal, literal.Convert(type));
193 }
194 return computation->AddInstruction(
195 HloInstruction::CreateConstant(std::move(literal)));
196 }
197
198 // Creates an R0 Constant HLO instruction of the PrimitiveType corresponding to
199 // `NativeT` with the given value and adds it to the given computation.
200 template <class NativeT>
MakeR0ConstantHlo(HloComputation * computation,NativeT value)201 HloInstruction* MakeR0ConstantHlo(HloComputation* computation, NativeT value) {
202 return computation->AddInstruction(
203 HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)));
204 }
205
206 // Makes a scalar that is elementwise compatible with the shape of the base
207 // instruction.
208 template <class NativeT>
MakeScalarLike(HloInstruction * base,NativeT value)209 HloInstruction* MakeScalarLike(HloInstruction* base, NativeT value) {
210 auto scalar = base->parent()->AddInstruction(
211 HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)
212 .Convert(base->shape().element_type())
213 .ValueOrDie()));
214 if (base->shape().rank() == 0) {
215 *scalar->mutable_shape() = base->shape();
216 return scalar;
217 }
218 return base->parent()->AddInstruction(
219 HloInstruction::CreateBroadcast(base->shape(), scalar, {}));
220 }
221
222 // -----------------------------------------------------------------------------
223 // Some other miscellaneous helpers to generate common HLO patterns. All of
224 // these add all the instructions they generate into the computation containing
225 // their operand(s).
226
227 // Collapses (via reshape) the first N (logical) dimensions of `operand` into a
228 // single leading dimension. `operand` must have rank > `n` and `n` must not be
229 // 0.
230 //
231 // For instance if `operand` has shape f32[7,8,9] and n is 2 then the output is
232 // the `operand` reshaped to [56,9].
233 StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n);
234
235 // Prepends `n` degenerate dimensions (dimensions with bound = 1) to `operand`
236 // using a reshape.
237 //
238 // For instance if operand has shape f32[3,4,5] then this returns the operand
239 // reshaped to f32[1,3,4,5]. If the operand is a f32 scalar (i.e. has shape
240 // f32[]) then this returns the operand reshaped to f32[1].
241 StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
242 int64 n);
243
244 // Expands (via reshape) the first (logical) dimension of `operand` into a
245 // sequence of `expanded_dims` dimensions. `operand` must at least be of rank 1
246 // and the number of elements in its first dimension must be equal to the
247 // product of `expanded_dims`.
248 //
249 // For instance if `operand` has shape f32[200,9,7] and expanded_dims is
250 // {2,5,20} the result is `operand` reshaped to [2,5,20,9,7].
251 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
252 HloInstruction* operand, absl::Span<const int64> expanded_dims);
253
254 // Elides (via reshape) a set of degenerate dimensions (dimensions containing
255 // exactly one element), `dims_to_elide` from `operand`. Every dimension in
256 // `dims_to_elide` must be a degenerate dimension. `dims_to_elide` must be
257 // sorted and not contain duplicates.
258 //
259 // For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide
260 // is {1,5} then the result is `operand` reshaped to [19,20,1,7,9].
261 StatusOr<HloInstruction*> ElideDegenerateDims(
262 HloInstruction* operand, absl::Span<const int64> dims_to_elide);
263
264 // Inserts (via reshape) a set of degenerate dimensions (dimensions containing
265 // exactly one element), `dims_to_insert` into `operand`. The dimensions in
266 // `dims_to_insert` refer to the dimensions in the result, and hence should be
267 // less than the rank of the result. Also, `dims_to_insert` must be sorted.
268 //
269 // For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is
270 // {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34].
271 StatusOr<HloInstruction*> InsertDegenerateDims(
272 HloInstruction* operand, absl::Span<const int64> dims_to_insert);
273
274 // Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the
275 // front and `zeros_to_append` zeros in the back.
276 StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
277 int64 zeros_to_prepend,
278 int64 zeros_to_append);
279
280 // Broadcasts a zero value of type `element_type` into a tensor with element
281 // type `element_type` and dimension bounds `broadcast_dimensions`. The
282 // broadcast instruction is emitted into `computation`.
283 HloInstruction* BroadcastZeros(HloComputation* computation,
284 PrimitiveType element_type,
285 absl::Span<const int64> broadcast_dimensions);
286
287 // Same as above, but fill the tensor with ones.
288 HloInstruction* BroadcastOnes(HloComputation* computation,
289 PrimitiveType element_type,
290 absl::Span<const int64> broadcast_dimensions);
291
292 // Creates a HLO computation that takes arguments of type `domain` and produces
293 // a value of type `range`.
294 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
295 absl::Span<const Shape* const> domain, const Shape& range,
296 absl::string_view name);
297
298 } // namespace xla
299
300 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
301