1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
16
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/Support/raw_ostream.h"
19 #include "mlir/IR/Builders.h" // from @llvm-project
20 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
21 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
22 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
23 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
24 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
25 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
26 #include "tensorflow/compiler/xla/comparison_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/shape_inference.h"
29 #include "tensorflow/compiler/xla/util.h"
30
31 namespace xla {
32
GetMlirOpName(HloOpcode opcode)33 static std::string GetMlirOpName(HloOpcode opcode) {
34 std::string op_name = HloOpcodeString(opcode);
35 absl::c_replace(op_name, '-', '_');
36 return mlir::mhlo::MhloDialect::getDialectNamespace().str() + "." + op_name;
37 }
38
ToString(mlir::Type ty)39 static std::string ToString(mlir::Type ty) {
40 std::string str;
41 llvm::raw_string_ostream sstream(str);
42 ty.print(sstream);
43 sstream.flush();
44 return str;
45 }
46
47 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(absl::Span<const int64> values,mlir::Builder * builder)48 static mlir::DenseIntElementsAttr GetI64ElementsAttr(
49 absl::Span<const int64> values, mlir::Builder* builder) {
50 auto ty = mlir::RankedTensorType::get({static_cast<int64_t>(values.size())},
51 builder->getIntegerType(64));
52 return mlir::DenseIntElementsAttr::get(
53 ty, llvm::makeArrayRef(values.data(), values.size()));
54 }
55
ConvertPadding(absl::Span<const std::pair<int64_t,int64_t>> padding,mlir::Builder * builder)56 static mlir::DenseIntElementsAttr ConvertPadding(
57 absl::Span<const std::pair<int64_t, int64_t>> padding,
58 mlir::Builder* builder) {
59 llvm::SmallVector<int64_t, 8> elements;
60 elements.reserve(padding.size() * 2);
61 for (const auto& vals : padding) {
62 elements.push_back(vals.first);
63 elements.push_back(vals.second);
64 }
65 auto ty = mlir::RankedTensorType::get(
66 {static_cast<int64_t>(padding.size()), 2}, builder->getIntegerType(64));
67 return mlir::DenseIntElementsAttr::get(ty, elements);
68 }
69
70 MlirHloBuilder::~MlirHloBuilder() = default;
71
MakeXlaOp(mlir::Value val)72 StatusOr<XlaOp> MlirHloBuilder::MakeXlaOp(mlir::Value val) {
73 mlir::Type ty = val.getType();
74 auto shape = std::make_unique<Shape>(TypeToShape(ty));
75 if (shape->element_type() == PrimitiveType::PRIMITIVE_TYPE_INVALID) {
76 return InvalidArgument("unsupported type: %s", ToString(ty).c_str());
77 }
78
79 int64_t handle = reinterpret_cast<int64_t>(val.getAsOpaquePointer());
80 handle_to_shape_[handle] = std::move(shape);
81 return XlaOp(handle, this);
82 }
83
ConstantLiteral(const LiteralSlice & literal)84 XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) {
85 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
86 TF_ASSIGN_OR_RETURN(mlir::DenseElementsAttr attr,
87 CreateDenseElementsAttrFromLiteral(literal, builder_));
88 auto op = builder_.create<mlir::mhlo::ConstOp>(loc_, attr);
89 return MakeXlaOp(op);
90 });
91 }
92
ConvGeneralDilatedInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const Window & window,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)93 StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
94 const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
95 absl::Span<const int64> window_strides,
96 absl::Span<const std::pair<int64, int64>> padding,
97 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
98 const ConvolutionDimensionNumbers& dimension_numbers,
99 int64 feature_group_count, int64 batch_group_count,
100 const PrecisionConfig* precision_config) {
101 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
102 shape, builder_));
103 mlir::ArrayAttr config_attr;
104 if (precision_config)
105 config_attr = ConvertPrecisionConfig(precision_config, &builder_);
106 auto op = builder_.create<mlir::mhlo::ConvOp>(
107 loc_, ty, GetValue(lhs), GetValue(rhs),
108 GetI64ElementsAttr(window_strides, &builder_),
109 ConvertPadding(padding, &builder_),
110 GetI64ElementsAttr(lhs_dilation, &builder_),
111 GetI64ElementsAttr(rhs_dilation, &builder_),
112 /*window_reversal=*/nullptr,
113 ConvertConvDimensionNumbers(dimension_numbers, &builder_),
114 builder_.getI64IntegerAttr(feature_group_count),
115 builder_.getI64IntegerAttr(batch_group_count), config_attr);
116 return MakeXlaOp(op);
117 }
118
FftInternal(const Shape & shape,XlaOp operand,FftType fft_type,absl::Span<const int64> fft_length)119 StatusOr<XlaOp> MlirHloBuilder::FftInternal(
120 const Shape& shape, XlaOp operand, FftType fft_type,
121 absl::Span<const int64> fft_length) {
122 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
123 shape, builder_));
124 auto op = builder_.create<mlir::mhlo::FftOp>(
125 loc_, ty, GetValue(operand),
126 builder_.getStringAttr(FftType_Name(fft_type)),
127 GetI64ElementsAttr(fft_length, &builder_));
128 return MakeXlaOp(op);
129 }
130
CustomCallInternal(const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal)131 StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
132 const string& call_target_name, absl::Span<const XlaOp> operands,
133 const Shape& shape, const string& opaque,
134 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
135 bool has_side_effect,
136 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
137 output_operand_aliasing,
138 const Literal* literal) {
139 if (operand_shapes_with_layout.has_value())
140 return Unimplemented(
141 "CustomCall doesn't support operands shapes with layout");
142 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
143 shape, builder_));
144 TF_RET_CHECK(output_operand_aliasing.empty())
145 << "MLIR CustomCallOp does not support output_operand_aliasing yet";
146 TF_RET_CHECK(literal == nullptr)
147 << "MLIR CustomCallOp does not support literal yet";
148 auto op = builder_.create<mlir::mhlo::CustomCallOp>(
149 loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
150 /*has_side_effect=*/builder_.getBoolAttr(has_side_effect),
151 builder_.getStringAttr(opaque));
152 return MakeXlaOp(op.getResult(0));
153 }
154
ReduceInternal(const Shape & shape,absl::Span<const XlaOp> all_operands,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)155 StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
156 const Shape& shape, absl::Span<const XlaOp> all_operands,
157 const XlaComputation& computation,
158 absl::Span<const int64> dimensions_to_reduce) {
159 // Reduce takes two set of variadic operands inputs and init_values.
160 // all_operands contains both of these so split operands into two parts.
161 int64_t num_args = all_operands.size() / 2;
162 auto op = builder_.create<mlir::mhlo::ReduceOp>(
163 loc_, GetValues(all_operands.first(num_args)),
164 GetValues(all_operands.subspan(num_args)),
165 GetI64ElementsAttr(dimensions_to_reduce, &builder_));
166 TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
167 if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0));
168 auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
169 return MakeXlaOp(tuple);
170 }
171
ReduceWindowInternal(const Shape & shape,XlaOp operand,XlaOp init_value,const XlaComputation & computation,Window window)172 StatusOr<XlaOp> MlirHloBuilder::ReduceWindowInternal(
173 const Shape& shape, XlaOp operand, XlaOp init_value,
174 const XlaComputation& computation, Window window) {
175 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
176 shape, builder_));
177 llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
178 llvm::SmallVector<int64, 8> padding;
179 for (const auto& dim : window.dimensions()) {
180 sizes.push_back(dim.size());
181 strides.push_back(dim.stride());
182 base_dilations.push_back(dim.base_dilation());
183 win_dilations.push_back(dim.window_dilation());
184 padding.push_back(dim.padding_low());
185 padding.push_back(dim.padding_high());
186 }
187 auto padding_ty =
188 mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
189 builder_.getIntegerType(64));
190 auto op = builder_.create<mlir::mhlo::ReduceWindowOp>(
191 loc_, ty, GetValue(operand), GetValue(init_value),
192 GetI64ElementsAttr(sizes, &builder_),
193 GetI64ElementsAttr(strides, &builder_),
194 GetI64ElementsAttr(base_dilations, &builder_),
195 GetI64ElementsAttr(win_dilations, &builder_),
196 mlir::DenseIntElementsAttr::get(padding_ty, padding));
197 TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
198 return MakeXlaOp(op);
199 }
200
Iota(const Shape & shape,int64 iota_dimension)201 XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) {
202 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
203 TF_ASSIGN_OR_RETURN(
204 mlir::Type ty,
205 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
206 auto op = builder_.create<mlir::mhlo::IotaOp>(
207 loc_, ty,
208 builder_.getIntegerAttr(builder_.getI64Type(), iota_dimension));
209 return MakeXlaOp(op);
210 });
211 }
212
BitcastConvertTypeInternal(const Shape & shape,XlaOp operand)213 StatusOr<XlaOp> MlirHloBuilder::BitcastConvertTypeInternal(const Shape& shape,
214 XlaOp operand) {
215 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
216 shape, builder_));
217 auto op = builder_.create<mlir::mhlo::BitcastConvertOp>(loc_, ty,
218 GetValue(operand));
219 return MakeXlaOp(op);
220 }
221
TransposeInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> permutation)222 StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
223 const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
224 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
225 shape, builder_));
226 auto op = builder_.create<mlir::mhlo::TransposeOp>(
227 loc_, ty, GetValue(operand), GetI64ElementsAttr(permutation, &builder_));
228 return MakeXlaOp(op);
229 }
230
RevInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> dimensions)231 StatusOr<XlaOp> MlirHloBuilder::RevInternal(
232 const Shape& shape, XlaOp operand, absl::Span<const int64> dimensions) {
233 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
234 shape, builder_));
235 auto op = builder_.create<mlir::mhlo::ReverseOp>(
236 loc_, ty, GetValue(operand), GetI64ElementsAttr(dimensions, &builder_));
237 return MakeXlaOp(op);
238 }
239
SortInternal(const Shape & shape,absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64 dimension,bool is_stable)240 StatusOr<XlaOp> MlirHloBuilder::SortInternal(const Shape& shape,
241 absl::Span<const XlaOp> operands,
242 const XlaComputation& comparator,
243 int64 dimension, bool is_stable) {
244 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
245 shape, builder_));
246 llvm::SmallVector<mlir::Type, 4> sort_types = {ty};
247 if (auto tuple_ty = ty.dyn_cast<mlir::TupleType>()) {
248 sort_types = llvm::to_vector<6>(tuple_ty.getTypes());
249 }
250
251 auto op = builder_.create<mlir::mhlo::SortOp>(
252 loc_, sort_types, GetValues(operands),
253 builder_.getI64IntegerAttr(dimension), builder_.getBoolAttr(is_stable));
254 TF_RETURN_IF_ERROR(ImportComputation(comparator.proto(), &op.comparator()));
255
256 if (ty.isa<mlir::TupleType>()) {
257 auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
258 return MakeXlaOp(tuple);
259 }
260
261 return MakeXlaOp(op.getResult(0));
262 }
263
WhileInternal(const Shape & shape,const XlaComputation & condition,const XlaComputation & body,XlaOp init)264 StatusOr<XlaOp> MlirHloBuilder::WhileInternal(const Shape& shape,
265 const XlaComputation& condition,
266 const XlaComputation& body,
267 XlaOp init) {
268 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
269 shape, builder_));
270 auto op = builder_.create<mlir::mhlo::WhileOp>(loc_, ty, GetValue(init));
271 TF_RETURN_IF_ERROR(ImportComputation(condition.proto(), &op.cond()));
272 TF_RETURN_IF_ERROR(ImportComputation(body.proto(), &op.body()));
273 return MakeXlaOp(op);
274 }
275
ReducePrecisionInternal(const Shape & shape,XlaOp operand,const int exponent_bits,const int mantissa_bits)276 StatusOr<XlaOp> MlirHloBuilder::ReducePrecisionInternal(
277 const Shape& shape, XlaOp operand, const int exponent_bits,
278 const int mantissa_bits) {
279 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
280 shape, builder_));
281 auto op = builder_.create<mlir::mhlo::ReducePrecisionOp>(
282 loc_, ty, GetValue(operand), builder_.getI32IntegerAttr(exponent_bits),
283 builder_.getI32IntegerAttr(mantissa_bits));
284 return MakeXlaOp(op);
285 }
286
GatherInternal(const Shape & shape,XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)287 StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
288 const Shape& shape, XlaOp input, XlaOp start_indices,
289 const GatherDimensionNumbers& dimension_numbers,
290 absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
291 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
292 shape, builder_));
293 auto op = builder_.create<mlir::mhlo::GatherOp>(
294 loc_, ty, GetValue(input), GetValue(start_indices),
295 ConvertGatherDimensionNumbers(dimension_numbers, &builder_),
296 GetI64ElementsAttr(slice_sizes, &builder_));
297 return MakeXlaOp(op);
298 }
299
ScatterInternal(const Shape & shape,XlaOp input,XlaOp scatter_indices,XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)300 StatusOr<XlaOp> MlirHloBuilder::ScatterInternal(
301 const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
302 const XlaComputation& update_computation,
303 const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
304 bool unique_indices) {
305 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
306 shape, builder_));
307 auto op = builder_.create<mlir::mhlo::ScatterOp>(
308 loc_, ty, GetValue(input), GetValue(scatter_indices), GetValue(updates),
309 ConvertScatterDimensionNumbers(dimension_numbers, &builder_),
310 builder_.getBoolAttr(indices_are_sorted),
311 builder_.getBoolAttr(unique_indices));
312
313 TF_RETURN_IF_ERROR(
314 ImportComputation(update_computation.proto(), &op.update_computation()));
315 return MakeXlaOp(op);
316 }
317
SetDimensionSizeInternal(const Shape & shape,XlaOp operand,XlaOp val,int64 dimension)318 StatusOr<XlaOp> MlirHloBuilder::SetDimensionSizeInternal(const Shape& shape,
319 XlaOp operand,
320 XlaOp val,
321 int64 dimension) {
322 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
323 shape, builder_));
324 auto op = builder_.create<mlir::mhlo::SetDimensionSizeOp>(
325 loc_, ty, GetValue(operand), GetValue(val),
326 builder_.getI64IntegerAttr(dimension));
327 return MakeXlaOp(op);
328 }
329
RngOpInternal(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)330 StatusOr<XlaOp> MlirHloBuilder::RngOpInternal(
331 RandomDistribution distribution, absl::Span<const XlaOp> parameters,
332 const Shape& shape) {
333 // TODO(hinsu): Introduce RngOp in the HLO dialect in MLIR and then RngUniform
334 // and RngNormal can be mapped to the new op.
335 std::string op_name;
336 if (distribution == xla::RandomDistribution::RNG_UNIFORM) {
337 op_name = "mhlo.rng_uniform";
338 } else {
339 TF_RET_CHECK(distribution == xla::RandomDistribution::RNG_NORMAL)
340 << "Unexpected distribution: " << distribution;
341 op_name = "mhlo.rng_normal";
342 }
343
344 if (shape.is_dynamic())
345 return Unimplemented("RngOp with dynamic dims not supported");
346 llvm::SmallVector<XlaOp, 3> operands;
347 operands.append(parameters.begin(), parameters.end());
348 operands.push_back(
349 ConstantLiteral(LiteralUtil::CreateR1<int64>(shape.dimensions())));
350 return CreateOp(op_name, shape, operands);
351 }
352
RngBitGeneratorInternal(const Shape & full_result_shape,RandomAlgorithm algorithm,XlaOp initial_state)353 StatusOr<XlaOp> MlirHloBuilder::RngBitGeneratorInternal(
354 const Shape& full_result_shape, RandomAlgorithm algorithm,
355 XlaOp initial_state) {
356 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
357 full_result_shape, builder_));
358 auto op = builder_.create<mlir::mhlo::RngBitGeneratorOp>(
359 loc_, ty, builder_.getI32IntegerAttr(algorithm), GetValue(initial_state));
360 return MakeXlaOp(op);
361 }
362
ReshapeInternal(const Shape & shape,XlaOp operand,int64 inferred_dimension)363 StatusOr<XlaOp> MlirHloBuilder::ReshapeInternal(const Shape& shape,
364 XlaOp operand,
365 int64 inferred_dimension) {
366 TF_RETURN_IF_ERROR(first_error());
367
368 if (inferred_dimension != -1)
369 return Unimplemented("inferred_dimension not yet supported for Reshape op");
370 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
371 shape, builder_));
372 mlir::Value value = GetValue(operand);
373 auto op = builder_.create<mlir::mhlo::ReshapeOp>(loc_, ty, value);
374 return MakeXlaOp(op.getResult());
375 }
376
DotGeneralInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_number,const PrecisionConfig * precision_config)377 StatusOr<XlaOp> MlirHloBuilder::DotGeneralInternal(
378 const Shape& shape, XlaOp lhs, XlaOp rhs,
379 const DotDimensionNumbers& dimension_number,
380 const PrecisionConfig* precision_config) {
381 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
382 shape, builder_));
383 auto op = builder_.create<mlir::mhlo::DotGeneralOp>(
384 loc_, ty, GetValue(lhs), GetValue(rhs),
385 ConvertDotDimensionNumbers(dimension_number, &builder_),
386 ConvertPrecisionConfig(precision_config, &builder_));
387 return MakeXlaOp(op.getResult());
388 }
389
InDimBroadcast(const Shape & shape,XlaOp operand,absl::Span<const int64> broadcast_dimensions)390 StatusOr<XlaOp> MlirHloBuilder::InDimBroadcast(
391 const Shape& shape, XlaOp operand,
392 absl::Span<const int64> broadcast_dimensions) {
393 TF_RETURN_IF_ERROR(first_error());
394 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
395 shape, builder_));
396 mlir::Value value = GetValue(operand);
397 auto op = builder_.create<mlir::mhlo::BroadcastInDimOp>(
398 loc_, ty, value, GetI64ElementsAttr(broadcast_dimensions, &builder_));
399 return MakeXlaOp(op.getResult());
400 }
401
AddInstruction(HloInstructionProto && instr,HloOpcode opcode,absl::Span<const XlaOp> operands)402 StatusOr<XlaOp> MlirHloBuilder::AddInstruction(
403 HloInstructionProto&& instr, HloOpcode opcode,
404 absl::Span<const XlaOp> operands) {
405 return Unimplemented("MlirHloBuilder does not support op %s",
406 HloOpcodeString(opcode));
407 }
408
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction,Comparison::Type type)409 StatusOr<XlaOp> MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs,
410 XlaOp rhs,
411 ComparisonDirection direction,
412 Comparison::Type type) {
413 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
414 shape, builder_));
415 auto op = builder_.create<mlir::mhlo::CompareOp>(
416 loc_, ty, GetValue(lhs), GetValue(rhs),
417 builder_.getStringAttr(ComparisonDirectionToString(direction)),
418 builder_.getStringAttr(ComparisonTypeToString(type)));
419 return MakeXlaOp(op.getResult());
420 }
421
BinaryOpNoBroadcast(HloOpcode binop,const Shape & shape,XlaOp lhs,XlaOp rhs)422 XlaOp MlirHloBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
423 XlaOp lhs, XlaOp rhs) {
424 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
425 return CreateOp(GetMlirOpName(binop), shape, {lhs, rhs});
426 });
427 }
428
AddOpWithShape(HloOpcode opcode,const Shape & shape,absl::Span<const XlaOp> operands)429 StatusOr<XlaOp> MlirHloBuilder::AddOpWithShape(
430 HloOpcode opcode, const Shape& shape, absl::Span<const XlaOp> operands) {
431 return CreateOp(GetMlirOpName(opcode), shape,
432 llvm::makeArrayRef<XlaOp>(operands.data(), operands.size()));
433 }
434
CreateToken()435 XlaOp MlirHloBuilder::CreateToken() {
436 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
437 return MakeXlaOp(builder_.create<mlir::mhlo::CreateTokenOp>(
438 loc_, mlir::mhlo::TokenType::get(builder_.getContext())));
439 });
440 }
441
TriangularSolveInternal(const Shape & shape,XlaOp a,XlaOp b,TriangularSolveOptions options)442 StatusOr<XlaOp> MlirHloBuilder::TriangularSolveInternal(
443 const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) {
444 TF_ASSIGN_OR_RETURN(
445 mlir::Type result_ty,
446 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
447 auto op = builder_.create<mlir::mhlo::TriangularSolveOp>(
448 loc_, result_ty, GetValue(a), GetValue(b),
449 builder_.getBoolAttr(options.left_side()),
450 builder_.getBoolAttr(options.lower()),
451 builder_.getBoolAttr(options.unit_diagonal()),
452 builder_.getStringAttr(
453 TriangularSolveOptions::Transpose_Name(options.transpose_a())));
454 return MakeXlaOp(op);
455 }
456
CholeskyInternal(const Shape & shape,XlaOp a,bool lower)457 StatusOr<XlaOp> MlirHloBuilder::CholeskyInternal(const Shape& shape, XlaOp a,
458 bool lower) {
459 TF_ASSIGN_OR_RETURN(
460 mlir::Type result_ty,
461 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
462 auto op = builder_.create<mlir::mhlo::CholeskyOp>(
463 loc_, result_ty, GetValue(a), builder_.getBoolAttr(lower));
464 return MakeXlaOp(op);
465 }
466
InfeedWithTokenInternal(const Shape & infeed_instruction_shape,XlaOp token,const string & config)467 StatusOr<XlaOp> MlirHloBuilder::InfeedWithTokenInternal(
468 const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
469 TF_ASSIGN_OR_RETURN(mlir::Type result_type,
470 ConvertShapeToType<mlir::RankedTensorType>(
471 infeed_instruction_shape, builder_));
472 mlir::ArrayAttr layout;
473 return MakeXlaOp(
474 builder_.create<mlir::mhlo::InfeedOp>(loc_, result_type, GetValue(token),
475 /*infeed_config=*/config,
476 /*layout=*/layout));
477 }
478
OutfeedWithTokenInternal(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)479 StatusOr<XlaOp> MlirHloBuilder::OutfeedWithTokenInternal(
480 XlaOp operand, XlaOp token, const Shape& shape_with_layout,
481 const string& outfeed_config) {
482 auto token_type = mlir::mhlo::TokenType::get(builder_.getContext());
483 return MakeXlaOp(builder_.create<mlir::mhlo::OutfeedOp>(
484 loc_, token_type, GetValue(operand), GetValue(token), outfeed_config));
485 }
486
ConcatInDimInternal(const Shape & shape,absl::Span<const XlaOp> operands,int64 dimension)487 StatusOr<XlaOp> MlirHloBuilder::ConcatInDimInternal(
488 const Shape& shape, absl::Span<const XlaOp> operands, int64 dimension) {
489 TF_ASSIGN_OR_RETURN(
490 mlir::Type result_type,
491 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
492 auto mlir_operands = GetValues(operands);
493 return MakeXlaOp(builder_.create<mlir::mhlo::ConcatenateOp>(
494 loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension)));
495 }
496
GetTupleElementInternal(const Shape & shape,XlaOp tuple_data,int64 index)497 StatusOr<XlaOp> MlirHloBuilder::GetTupleElementInternal(const Shape& shape,
498 XlaOp tuple_data,
499 int64 index) {
500 TF_ASSIGN_OR_RETURN(
501 mlir::Type result_type,
502 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
503 return MakeXlaOp(builder_.create<mlir::mhlo::GetTupleElementOp>(
504 loc_, result_type, GetValue(tuple_data),
505 builder_.getI32IntegerAttr(index)));
506 }
507
SliceInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)508 StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
509 const Shape& shape, XlaOp operand, absl::Span<const int64> start_indices,
510 absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
511 return MakeXlaOp(builder_.create<mlir::mhlo::SliceOp>(
512 loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_),
513 GetI64ElementsAttr(limit_indices, &builder_),
514 GetI64ElementsAttr(strides, &builder_)));
515 }
516
DynamicSliceInternal(const Shape & shape,XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)517 StatusOr<XlaOp> MlirHloBuilder::DynamicSliceInternal(
518 const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
519 absl::Span<const int64> slice_sizes) {
520 TF_ASSIGN_OR_RETURN(
521 mlir::Type result_ty,
522 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
523 return MakeXlaOp(builder_.create<mlir::mhlo::DynamicSliceOp>(
524 loc_, result_ty, GetValue(operand), GetValues(start_indices),
525 GetI64ElementsAttr(slice_sizes, &builder_)));
526 }
527
DynamicUpdateSliceInternal(const Shape & shape,XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)528 StatusOr<XlaOp> MlirHloBuilder::DynamicUpdateSliceInternal(
529 const Shape& shape, XlaOp operand, XlaOp update,
530 absl::Span<const XlaOp> start_indices) {
531 TF_ASSIGN_OR_RETURN(
532 mlir::Type result_ty,
533 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
534 return MakeXlaOp(builder_.create<mlir::mhlo::DynamicUpdateSliceOp>(
535 loc_, result_ty, GetValue(operand), GetValue(update),
536 GetValues(start_indices)));
537 }
538
PadInternal(const Shape & shape,XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)539 StatusOr<XlaOp> MlirHloBuilder::PadInternal(
540 const Shape& shape, XlaOp operand, XlaOp padding_value,
541 const PaddingConfig& padding_config) {
542 TF_ASSIGN_OR_RETURN(
543 mlir::Type result_type,
544 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
545 std::vector<int64> low;
546 std::vector<int64> high;
547 std::vector<int64> internal;
548 for (auto& dimension : padding_config.dimensions()) {
549 low.push_back(dimension.edge_padding_low());
550 high.push_back(dimension.edge_padding_high());
551 internal.push_back(dimension.interior_padding());
552 }
553 return MakeXlaOp(builder_.create<mlir::mhlo::PadOp>(
554 loc_, result_type, GetValue(operand), GetValue(padding_value),
555 GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_),
556 GetI64ElementsAttr(internal, &builder_)));
557 }
558
TupleInternal(const Shape & shape,absl::Span<const XlaOp> elements)559 StatusOr<XlaOp> MlirHloBuilder::TupleInternal(
560 const Shape& shape, absl::Span<const XlaOp> elements) {
561 mlir::SmallVector<mlir::Value, 4> operands;
562 for (auto& element : elements) {
563 operands.push_back(GetValue(element));
564 }
565 return MakeXlaOp(builder_.create<mlir::mhlo::TupleOp>(loc_, operands));
566 }
567
CreateOp(const std::string & op_name,const Shape & shape,llvm::ArrayRef<XlaOp> operands,llvm::ArrayRef<mlir::NamedAttribute> attributes)568 StatusOr<XlaOp> MlirHloBuilder::CreateOp(
569 const std::string& op_name, const Shape& shape,
570 llvm::ArrayRef<XlaOp> operands,
571 llvm::ArrayRef<mlir::NamedAttribute> attributes) {
572 llvm::SmallVector<mlir::Value, 4> operand_values;
573 operand_values.reserve(operands.size());
574 for (XlaOp xla_op : operands) {
575 operand_values.push_back(GetValue(xla_op));
576 }
577 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
578 shape, builder_));
579 mlir::OperationState state(loc_, op_name, operand_values, {ty}, attributes);
580 mlir::Operation* op = builder_.createOperation(state);
581 return MakeXlaOp(op->getResult(0));
582 }
583
ImportComputation(const HloModuleProto & computation,mlir::Region * region)584 Status MlirHloBuilder::ImportComputation(const HloModuleProto& computation,
585 mlir::Region* region) {
586 TF_ASSIGN_OR_RETURN(auto module_config,
587 xla::HloModule::CreateModuleConfigFromProto(
588 computation, xla::DebugOptions()));
589 TF_ASSIGN_OR_RETURN(auto hlo_module, xla::HloModule::CreateFromProto(
590 computation, module_config));
591
592 return HloFunctionImporter::ImportAsRegion(*hlo_module->entry_computation(),
593 region, &builder_);
594 }
595
GetShapePtr(XlaOp op) const596 StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const {
597 TF_RETURN_IF_ERROR(first_error());
598 TF_RETURN_IF_ERROR(CheckOpBuilder(op));
599 auto it = handle_to_shape_.find(op.handle());
600 if (it == handle_to_shape_.end()) {
601 return InvalidArgument("No XlaOp with handle %d", op.handle());
602 }
603 return it->second.get();
604 }
605
606 } // namespace xla
607