• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
18 
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/statusor.h"
23 
24 namespace xla {
25 
26 // Some lightweight utilities intended to make HLO instruction creation more
27 // ergonomic.  We don't have a complete set of helpers yet -- I expect we'll
28 // expand this interface as needed on an ad-hoc basis.
29 
30 // Creates a unary HLO instruction and adds it to the computation containing
31 // `operand`.
32 StatusOr<HloInstruction*> MakeUnaryHlo(HloOpcode opcode,
33                                        HloInstruction* operand);
34 
35 // Creates a binary HLO instruction and adds it to the computation containing
36 // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
37 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
38                                         HloInstruction* rhs);
39 
40 // Creates a compare HLO instruction and adds it to the computation containing
41 // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
42 StatusOr<HloInstruction*> MakeCompareHlo(Comparison::Direction direction,
43                                          HloInstruction* lhs,
44                                          HloInstruction* rhs);
45 
46 // Creates a pad HLO instruction and adds it to the computation containing
47 // `operand` and `padding_value` (`operand` and `padding_value` must be in the
48 // same computation).
49 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
50                                      HloInstruction* padding_value,
51                                      const PaddingConfig& padding_config);
52 
53 // Creates a slice HLO instruction and adds it to the computation containing
54 // `operand`.
55 StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
56                                        absl::Span<const int64> start_indices,
57                                        absl::Span<const int64> limit_indices,
58                                        absl::Span<const int64> strides);
59 
60 // Creates a convolution HLO instruction and adds it to the computation
61 // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
62 // If the result shape has integral element type, an optional
63 // preferred_element_type can be specified to override the element type.
64 StatusOr<HloInstruction*> MakeConvolveHlo(
65     HloInstruction* lhs, HloInstruction* rhs, int64_t feature_group_count,
66     int64_t batch_group_count, const Window& window,
67     const ConvolutionDimensionNumbers& dimension_numbers,
68     const PrecisionConfig& precision_config,
69     absl::optional<PrimitiveType> preferred_element_type);
70 
71 // Creates a transpose HLO instruction and adds it to the computation containing
72 // `operand`.
73 StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
74                                            absl::Span<const int64> dimensions);
75 
76 // Creates a reshape HLO instruction and adds it to the computation containing
77 // `operand`.
78 StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
79                                          HloInstruction* operand);
80 
81 StatusOr<HloInstruction*> MakeReshapeHlo(
82     absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand);
83 
84 // Creates a dynamic-slice HLO instruction and adds it to the computation
85 // containing `operand` and `start_indices` (`operand` and `start_indices` must
86 // be in the same computation).
87 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
88     HloInstruction* operand, absl::Span<HloInstruction* const> start_indices,
89     absl::Span<const int64> slice_sizes);
90 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
91     HloInstruction* operand, HloInstruction* start_indices,
92     absl::Span<const int64> slice_sizes);
93 
94 // Creates a dynamic-update-slice HLO instruction and adds it to the computation
95 // containing `operand`, `update` and `start_indices` (`operand`, `update` and
96 // `start_indices` must be in the same computation).
97 StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
98     HloInstruction* operand, HloInstruction* update,
99     HloInstruction* start_indices);
100 
101 // Creates a broadcast HLO instruction and adds it to the computation containing
102 // `operand`.
103 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
104                                  absl::Span<const int64> broadcast_dimensions,
105                                  absl::Span<const int64> result_shape_bounds);
106 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
107                                  absl::Span<const int64> broadcast_dimensions,
108                                  const Shape& shape);
109 
110 // Creates a GetTupleElement HLO instruction and adds it to the computation
111 // containing `operand`.
112 StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
113                                                  int64_t index);
114 
115 // Creates a Concatenate HLO instruction and adds it to the computation
116 // containing `operands` (`operands` must be non-empty and every element must be
117 // contained in the same computation).
118 StatusOr<HloInstruction*> MakeConcatHlo(
119     absl::Span<HloInstruction* const> operands, int64_t dimension);
120 
121 // Creates a Convert HLO instruction that converts the given instruction to have
122 // the given primitive type.
123 HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type);
124 
125 // Creates a BitcastConvert HLO instruction.
126 HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo,
127                                         PrimitiveType type);
128 
129 // Creates an Iota HLO instruction.
130 HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
131                             int64_t iota_dimension);
132 
133 // Creates a Dot HLO instruction and adds it to the computation containing `lhs`
134 // and `rhs` (both must be in the same computation). If the result shape has
135 // integral element type, an optional preferred_element_type can be specified to
136 // override the element type.
137 StatusOr<HloInstruction*> MakeDotHlo(
138     HloInstruction* lhs, HloInstruction* rhs,
139     const DotDimensionNumbers& dim_numbers,
140     const PrecisionConfig& precision_config,
141     absl::optional<PrimitiveType> preferred_element_type);
142 
143 // Creates a Map HLO instruction and adds it to the computation containing the
144 // operands. All operands must be in the same computation.
145 StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
146                                      HloComputation* map_computation);
147 
148 // Creates a Reduce HLO instruction and adds it to the computation containing
149 // the operand. This will create the sub-computation needed for the reduction in
150 // the given module. binary_opcode should represent a binary operation.
151 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
152                                         HloInstruction* init_value,
153                                         absl::Span<const int64> dimensions,
154                                         HloOpcode binary_opcode);
155 
156 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
157                                         HloInstruction* init_value,
158                                         HloOpcode binary_opcode,
159                                         HloModule* module);
160 
161 // Creates a Reverse HLO instruction and adds it to the computation containing
162 // `operand`.
163 StatusOr<HloInstruction*> MakeReverseHlo(HloInstruction* operand,
164                                          absl::Span<const int64> dimensions);
165 
166 // Creates a Select HLO instruction and adds it to the computation containing
167 // the predicate. The on_true and on_false instructions must also be contained
168 // in the same computation. If on_true and on_false are tuples, create a tuple
169 // select instead. `pred` is broadcasted up from a scalar if necessary.
170 StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
171                                         HloInstruction* on_true,
172                                         HloInstruction* on_false,
173                                         HloInstruction* derived_from = nullptr);
174 
175 // Creates a Sort HLO instruction and adds it to the computation containing the
176 // operands. All operands must be in the same computation. Also creates a
177 // default compare sub-computation which sorts the first operand into ascending
178 // order. 'is_stable' specifies whether the sorting should be stable.
179 StatusOr<HloInstruction*> MakeSortHlo(
180     const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
181     int64_t dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
182     HloModule* module);
183 
184 // Creates an R1 Constant HLO instruction of the given PrimitiveType with the
185 // given values and adds it to the given computation.
186 template <typename NativeT>
MakeR1ConstantHlo(HloComputation * computation,PrimitiveType type,absl::Span<const NativeT> values)187 StatusOr<HloInstruction*> MakeR1ConstantHlo(HloComputation* computation,
188                                             PrimitiveType type,
189                                             absl::Span<const NativeT> values) {
190   Literal literal = LiteralUtil::CreateR1<NativeT>(values);
191   if (literal.shape().element_type() != type) {
192     TF_ASSIGN_OR_RETURN(literal, literal.Convert(type));
193   }
194   return computation->AddInstruction(
195       HloInstruction::CreateConstant(std::move(literal)));
196 }
197 
198 // Creates an R0 Constant HLO instruction of the PrimitiveType corresponding to
199 // `NativeT` with the given value and adds it to the given computation.
200 template <class NativeT>
MakeR0ConstantHlo(HloComputation * computation,NativeT value)201 HloInstruction* MakeR0ConstantHlo(HloComputation* computation, NativeT value) {
202   return computation->AddInstruction(
203       HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)));
204 }
205 
206 // Makes a scalar that is elementwise compatible with the shape of the base
207 // instruction.
208 template <class NativeT>
MakeScalarLike(HloInstruction * base,NativeT value)209 HloInstruction* MakeScalarLike(HloInstruction* base, NativeT value) {
210   auto scalar = base->parent()->AddInstruction(
211       HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)
212                                          .Convert(base->shape().element_type())
213                                          .ValueOrDie()));
214   if (base->shape().rank() == 0) {
215     *scalar->mutable_shape() = base->shape();
216     return scalar;
217   }
218   return base->parent()->AddInstruction(
219       HloInstruction::CreateBroadcast(base->shape(), scalar, {}));
220 }
221 
222 // Creates a fusion instruction and fuses `fused` into the created fusion
223 // instruction.
224 StatusOr<HloInstruction*> MakeFusionInstruction(
225     HloInstruction* fused, HloInstruction::FusionKind kind);
226 
227 // -----------------------------------------------------------------------------
228 // Some other miscellaneous helpers to generate common HLO patterns.  All of
229 // these add all the instructions they generate into the computation containing
230 // their operand(s).
231 
232 // Collapses (via reshape) the first N (logical) dimensions of `operand` into a
233 // single leading dimension.  `operand` must have rank > `n` and `n` must not be
234 // 0.
235 //
236 // For instance if `operand` has shape f32[7,8,9] and n is 2 then the output is
237 // the `operand` reshaped to [56,9].
238 StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand,
239                                              int64_t n);
240 
241 // Prepends `n` degenerate dimensions (dimensions with bound = 1) to `operand`
242 // using a reshape.
243 //
244 // For instance if operand has shape f32[3,4,5] then this returns the operand
245 // reshaped to f32[1,3,4,5].  If the operand is a f32 scalar (i.e. has shape
246 // f32[]) then this returns the operand reshaped to f32[1].
247 StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
248                                                 int64_t n);
249 
250 // Expands (via reshape) the first (logical) dimension of `operand` into a
251 // sequence of `expanded_dims` dimensions.  `operand` must at least be of rank 1
252 // and the number of elements in its first dimension must be equal to the
253 // product of `expanded_dims`.
254 //
255 // For instance if `operand` has shape f32[200,9,7] and expanded_dims is
256 // {2,5,20} the result is `operand` reshaped to [2,5,20,9,7].
257 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
258     HloInstruction* operand, absl::Span<const int64> expanded_dims);
259 
260 // Elides (via reshape) a set of degenerate dimensions (dimensions containing
261 // exactly one element), `dims_to_elide` from `operand`.  Every dimension in
262 // `dims_to_elide` must be a degenerate dimension.  `dims_to_elide` must be
263 // sorted and not contain duplicates.
264 //
265 // For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide
266 // is {1,5} then the result is `operand` reshaped to [19,20,1,7,9].
267 StatusOr<HloInstruction*> ElideDegenerateDims(
268     HloInstruction* operand, absl::Span<const int64> dims_to_elide);
269 
270 // Inserts (via reshape) a set of degenerate dimensions (dimensions containing
271 // exactly one element), `dims_to_insert` into `operand`. The dimensions in
272 // `dims_to_insert` refer to the dimensions in the result, and hence should be
273 // less than the rank of the result. Also, `dims_to_insert` must be sorted.
274 //
275 // For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is
276 // {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34].
277 StatusOr<HloInstruction*> InsertDegenerateDims(
278     HloInstruction* operand, absl::Span<const int64> dims_to_insert);
279 
280 // Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the
281 // front and `zeros_to_append` zeros in the back.
282 StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
283                                              int64_t zeros_to_prepend,
284                                              int64_t zeros_to_append);
285 
286 // Broadcasts a zero value of type `element_type` into a tensor with element
287 // type `element_type` and dimension bounds `broadcast_dimensions`.  The
288 // broadcast instruction is emitted into `computation`.
289 HloInstruction* BroadcastZeros(HloComputation* computation,
290                                PrimitiveType element_type,
291                                absl::Span<const int64> broadcast_dimensions);
292 
293 // Same as above, but fill the tensor with ones.
294 HloInstruction* BroadcastOnes(HloComputation* computation,
295                               PrimitiveType element_type,
296                               absl::Span<const int64> broadcast_dimensions);
297 
298 // Creates a HLO computation that takes arguments of type `domain` and produces
299 // a value of type `range`.
300 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
301     absl::Span<const Shape* const> domain, const Shape& range,
302     absl::string_view name);
303 
304 }  // namespace xla
305 
306 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
307