• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
7     http://www.apache.org/licenses/LICENSE-2.0
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/statusor.h"
24 namespace xla {
26 // Some lightweight utilities intended to make HLO instruction creation more
27 // ergonomic.  We don't have a complete set of helpers yet -- I expect we'll
28 // expand this interface as needed on an ad-hoc basis.
30 // Creates a binary HLO instruction and adds it to the computation containing
31 // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
32 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
33                                         HloInstruction* rhs);
35 // Creates a compare HLO instruction and adds it to the computation containing
36 // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
37 StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
38                                          HloInstruction* lhs,
39                                          HloInstruction* rhs);
41 // Creates a pad HLO instruction and adds it to the computation containing
42 // `operand` and `padding_value` (`operand` and `padding_value` must be in the
43 // same computation).
44 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
45                                      HloInstruction* padding_value,
46                                      const PaddingConfig& padding_config);
48 // Creates a slice HLO instruction and adds it to the computation containing
49 // `operand`.
50 StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
51                                        absl::Span<const int64> start_indices,
52                                        absl::Span<const int64> limit_indices,
53                                        absl::Span<const int64> strides);
55 // Creates a convolution HLO instruction and adds it to the computation
56 // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
57 StatusOr<HloInstruction*> MakeConvolveHlo(
58     HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
59     const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
60     const PrecisionConfig& precision_config);
62 // Creates a transpose HLO instruction and adds it to the computation containing
63 // `operand`.
64 StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
65                                            absl::Span<const int64> dimensions);
67 // Creates a reshape HLO instruction and adds it to the computation containing
68 // `operand`.
69 StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
70                                          HloInstruction* operand);
72 StatusOr<HloInstruction*> MakeReshapeHlo(
73     absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand);
75 // Creates a dynamic-slice HLO instruction and adds it to the computation
76 // containing `operand` and `start_indices` (`operand` and `start_indices` must
77 // be in the same computation).
78 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
79     HloInstruction* operand, HloInstruction* start_indices,
80     absl::Span<const int64> slice_sizes);
82 // Creates a dynamic-update-slice HLO instruction and adds it to the computation
83 // containing `operand`, `update` and `start_indices` (`operand`, `update` and
84 // `start_indices` must be in the same computation).
85 StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
86     HloInstruction* operand, HloInstruction* update,
87     HloInstruction* start_indices);
89 // Creates a broadcast HLO instruction and adds it to the computation containing
90 // `operand`.
91 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
92                                  absl::Span<const int64> broadcast_dimensions,
93                                  absl::Span<const int64> result_shape_bounds);
95 // Creates a GetTupleElement HLO instruction and adds it to the computation
96 // containing `operand`.
97 StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
98                                                  int64 index);
100 // Creates a Concatenate HLO instruction and adds it to the computation
101 // containing `operands` (`operands` must be non-empty and every element must be
102 // contained in the same computation).
103 StatusOr<HloInstruction*> MakeConcatHlo(
104     absl::Span<HloInstruction* const> operands, int64 dimension);
106 // Creates a Dot HLO instruction and adds it to the computation containing `lhs`
107 // and `rhs` (both must be in the same computation).
108 StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
109                                      const DotDimensionNumbers& dim_numbers,
110                                      const PrecisionConfig& precision_config);
112 // Creates a Map HLO instruction and adds it to the computation containing the
113 // operands. All operands must be in the same computation.
114 StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
115                                      HloComputation* map_computation);
117 // Creates a Reduce HLO instruction and adds it to the computation containing
118 // the operand. This will create the sub-computation needed for the reduction in
119 // the given module. binary_opcode should represent a binary operation.
120 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
121                                         HloInstruction* init_value,
122                                         HloOpcode binary_opcode,
123                                         HloModule* module);
125 // Creates a Select HLO instruction and adds it to the computation containing
126 // the predicate. The on_true and on_false instructions must also be contained
127 // in the same computation.
128 StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
129                                         HloInstruction* on_true,
130                                         HloInstruction* on_false);
132 // Creates a Sort HLO instruction and adds it to the computation containing the
133 // operands. All operands must be in the same computation. Also creates a
134 // default compare sub-computation which sorts the first operand into ascending
135 // order. 'is_stable' specifies whether the sorting should be stable.
136 StatusOr<HloInstruction*> MakeSortHlo(
137     const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
138     int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
139     HloModule* module);
141 // Creates an R1 Constant HLO instruction of the given PrimitiveType with the
142 // given values and adds it to the given computation.
143 template <typename NativeT>
MakeR1ConstantHlo(HloComputation * computation,PrimitiveType type,absl::Span<const NativeT> values)144 StatusOr<HloInstruction*> MakeR1ConstantHlo(HloComputation* computation,
145                                             PrimitiveType type,
146                                             absl::Span<const NativeT> values) {
147   Literal literal = LiteralUtil::CreateR1<NativeT>(values);
148   if (literal.shape().element_type() != type) {
149     TF_ASSIGN_OR_RETURN(literal, literal.Convert(type));
150   }
151   return computation->AddInstruction(
152       HloInstruction::CreateConstant(std::move(literal)));
153 }
155 // -----------------------------------------------------------------------------
156 // Some other miscellaneous helpers to generate common HLO patterns.  All of
157 // these add all the instructions they generate into the computation containing
158 // their operand(s).
160 // Collapses (via reshape) the first N (logical) dimensions of `operand` into a
161 // single leading dimension.  `operand` must have rank > `n` and `n` must not be
162 // 0.
163 //
164 // For instance if `operand` has shape f32[7,8,9] and n is 2 then the output is
165 // the `operand` reshaped to [56,9].
166 StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n);
168 // Prepends `n` degenerate dimensions (dimensions with bound = 1) to `operand`
169 // using a reshape.
170 //
171 // For instance if operand has shape f32[3,4,5] then this returns the operand
172 // reshaped to f32[1,3,4,5].  If the operand is a f32 scalar (i.e. has shape
173 // f32[]) then this returns the operand reshaped to f32[1].
174 StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
175                                                 int64 n);
177 // Expands (via reshape) the first (logical) dimension of `operand` into a
178 // sequence of `expanded_dims` dimensions.  `operand` must at least be of rank 1
179 // and the number of elements in its first dimension must be equal to the
180 // product of `expanded_dims`.
181 //
182 // For instance if `operand` has shape f32[200,9,7] and expanded_dims is
183 // {2,5,20} the result is `operand` reshaped to [2,5,20,9,7].
184 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
185     HloInstruction* operand, absl::Span<const int64> expanded_dims);
187 // Elides (via reshape) a set of degenerate dimensions (dimensions containing
188 // exactly one element), `dims_to_elide` from `operand`.  Every dimension in
189 // `dims_to_elide` must be a degenerate dimension.  `dims_to_elide` must be
190 // sorted and not contain duplicates.
191 //
192 // For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide
193 // is {1,5} then the result is `operand` reshaped to [19,20,1,7,9].
194 StatusOr<HloInstruction*> ElideDegenerateDims(
195     HloInstruction* operand, absl::Span<const int64> dims_to_elide);
197 // Inserts (via reshape) a set of degenerate dimensions (dimensions containing
198 // exactly one element), `dims_to_insert` into `operand`. The dimensions in
199 // `dims_to_insert` refer to the dimensions in the result, and hence should be
200 // less than the rank of the result. Also, `dims_to_insert` must be sorted.
201 //
202 // For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is
203 // {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34].
204 StatusOr<HloInstruction*> InsertDegenerateDims(
205     HloInstruction* operand, absl::Span<const int64> dims_to_insert);
207 // Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the
208 // front and `zeros_to_append` zeros in the back.
209 StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
210                                              int64 zeros_to_prepend,
211                                              int64 zeros_to_append);
213 // Broadcasts a zero value of type `element_type` into a tensor with element
214 // type `element_type` and dimension bounds `broadcast_dimensions`.  The
215 // broadcast instruction is emitted into `computation`.
216 HloInstruction* BroadcastZeros(HloComputation* computation,
217                                PrimitiveType element_type,
218                                absl::Span<const int64> broadcast_dimensions);
220 // Creates a HLO computation that takes arguments of type `domain` and produces
221 // a value of type `range`.
222 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
223     absl::Span<const Shape* const> domain, const Shape& range,
224     absl::string_view name);
226 }  // namespace xla