1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
16
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/Support/raw_ostream.h"
19 #include "mlir/IR/Builders.h" // from @llvm-project
20 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
21 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
22 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
23 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
24 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
25 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
26 #include "tensorflow/compiler/xla/comparison_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/shape_inference.h"
29 #include "tensorflow/compiler/xla/util.h"
30
31 namespace xla {
32
GetMlirOpName(HloOpcode opcode)33 static std::string GetMlirOpName(HloOpcode opcode) {
34 std::string op_name = HloOpcodeString(opcode);
35 absl::c_replace(op_name, '-', '_');
36 return mlir::mhlo::MhloDialect::getDialectNamespace().str() + "." + op_name;
37 }
38
ToString(mlir::Type ty)39 static std::string ToString(mlir::Type ty) {
40 std::string str;
41 llvm::raw_string_ostream sstream(str);
42 ty.print(sstream);
43 sstream.flush();
44 return str;
45 }
46
47 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(absl::Span<const int64> values,mlir::Builder * builder)48 static mlir::DenseIntElementsAttr GetI64ElementsAttr(
49 absl::Span<const int64> values, mlir::Builder* builder) {
50 auto ty = mlir::RankedTensorType::get({static_cast<int64_t>(values.size())},
51 builder->getIntegerType(64));
52 return mlir::DenseIntElementsAttr::get(
53 ty, llvm::makeArrayRef(values.data(), values.size()));
54 }
55
ConvertPadding(absl::Span<const std::pair<int64_t,int64_t>> padding,mlir::Builder * builder)56 static mlir::DenseIntElementsAttr ConvertPadding(
57 absl::Span<const std::pair<int64_t, int64_t>> padding,
58 mlir::Builder* builder) {
59 llvm::SmallVector<int64_t, 8> elements;
60 elements.reserve(padding.size() * 2);
61 for (const auto& vals : padding) {
62 elements.push_back(vals.first);
63 elements.push_back(vals.second);
64 }
65 auto ty = mlir::RankedTensorType::get(
66 {static_cast<int64_t>(padding.size()), 2}, builder->getIntegerType(64));
67 return mlir::DenseIntElementsAttr::get(ty, elements);
68 }
69
70 MlirHloBuilder::~MlirHloBuilder() = default;
71
MakeXlaOp(mlir::Value val)72 StatusOr<XlaOp> MlirHloBuilder::MakeXlaOp(mlir::Value val) {
73 mlir::Type ty = val.getType();
74 auto shape = std::make_unique<Shape>(TypeToShape(ty));
75 if (shape->element_type() == PrimitiveType::PRIMITIVE_TYPE_INVALID) {
76 return InvalidArgument("unsupported type: %s", ToString(ty).c_str());
77 }
78
79 int64_t handle = reinterpret_cast<int64_t>(val.getAsOpaquePointer());
80 handle_to_shape_[handle] = std::move(shape);
81 return XlaOp(handle, this);
82 }
83
ConstantLiteral(const LiteralSlice & literal)84 XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) {
85 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
86 TF_ASSIGN_OR_RETURN(mlir::DenseElementsAttr attr,
87 CreateDenseElementsAttrFromLiteral(literal, builder_));
88 auto op = builder_.create<mlir::mhlo::ConstOp>(loc_, attr);
89 return MakeXlaOp(op);
90 });
91 }
92
ConvGeneralDilatedInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const Window & window,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config)93 StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
94 const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
95 absl::Span<const int64> window_strides,
96 absl::Span<const std::pair<int64, int64>> padding,
97 absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
98 const ConvolutionDimensionNumbers& dimension_numbers,
99 int64_t feature_group_count, int64_t batch_group_count,
100 const PrecisionConfig* precision_config) {
101 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
102 shape, builder_));
103 mlir::ArrayAttr config_attr;
104 if (precision_config)
105 config_attr = ConvertPrecisionConfig(precision_config, &builder_);
106 auto op = builder_.create<mlir::mhlo::ConvOp>(
107 loc_, ty, GetValue(lhs), GetValue(rhs),
108 GetI64ElementsAttr(window_strides, &builder_),
109 ConvertPadding(padding, &builder_),
110 GetI64ElementsAttr(lhs_dilation, &builder_),
111 GetI64ElementsAttr(rhs_dilation, &builder_),
112 /*window_reversal=*/nullptr,
113 ConvertConvDimensionNumbers(dimension_numbers, &builder_),
114 builder_.getI64IntegerAttr(feature_group_count),
115 builder_.getI64IntegerAttr(batch_group_count), config_attr);
116 return MakeXlaOp(op);
117 }
118
FftInternal(const Shape & shape,XlaOp operand,FftType fft_type,absl::Span<const int64> fft_length)119 StatusOr<XlaOp> MlirHloBuilder::FftInternal(
120 const Shape& shape, XlaOp operand, FftType fft_type,
121 absl::Span<const int64> fft_length) {
122 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
123 shape, builder_));
124 auto op = builder_.create<mlir::mhlo::FftOp>(
125 loc_, ty, GetValue(operand),
126 builder_.getStringAttr(FftType_Name(fft_type)),
127 GetI64ElementsAttr(fft_length, &builder_));
128 return MakeXlaOp(op);
129 }
130
CustomCallInternal(const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal,absl::optional<Window> window,absl::optional<ConvolutionDimensionNumbers> dnums,CustomCallSchedule schedule,CustomCallApiVersion api_version)131 StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
132 const string& call_target_name, absl::Span<const XlaOp> operands,
133 const Shape& shape, const string& opaque,
134 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
135 bool has_side_effect,
136 absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
137 output_operand_aliasing,
138 const Literal* literal, absl::optional<Window> window,
139 absl::optional<ConvolutionDimensionNumbers> dnums,
140 CustomCallSchedule schedule, CustomCallApiVersion api_version) {
141 if (operand_shapes_with_layout.has_value())
142 return Unimplemented(
143 "CustomCall doesn't support operands shapes with layout");
144 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
145 shape, builder_));
146 TF_ASSIGN_OR_RETURN(auto mlir_api_version,
147 ConvertCustomCallApiVersion(api_version));
148 TF_RET_CHECK(output_operand_aliasing.empty())
149 << "MLIR CustomCallOp does not support output_operand_aliasing yet";
150 TF_RET_CHECK(literal == nullptr)
151 << "MLIR CustomCallOp does not support literal yet";
152 TF_RET_CHECK(!window.has_value())
153 << "MLIR CustomCallOp does not support ConvolutionDimensionNumbers yet";
154 TF_RET_CHECK(!dnums.has_value())
155 << "MLIR CustomCallOp does not support ConvolutionDimensionNumbers yet";
156 TF_RET_CHECK(schedule == CustomCallSchedule::SCHEDULE_NONE)
157 << "MLIR CustomCallOp does not support custom-call-schedule yet";
158 auto op = builder_.create<mlir::mhlo::CustomCallOp>(
159 loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
160 /*has_side_effect=*/builder_.getBoolAttr(has_side_effect),
161 builder_.getStringAttr(opaque),
162 /*api_version=*/
163 mlir::mhlo::CustomCallApiVersionAttr::get(builder_.getContext(),
164 mlir_api_version));
165 return MakeXlaOp(op.getResult(0));
166 }
167
ReduceInternal(const Shape & shape,absl::Span<const XlaOp> all_operands,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)168 StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
169 const Shape& shape, absl::Span<const XlaOp> all_operands,
170 const XlaComputation& computation,
171 absl::Span<const int64> dimensions_to_reduce) {
172 // Reduce takes two set of variadic operands inputs and init_values.
173 // all_operands contains both of these so split operands into two parts.
174 int64_t num_args = all_operands.size() / 2;
175 auto op = builder_.create<mlir::mhlo::ReduceOp>(
176 loc_, GetValues(all_operands.first(num_args)),
177 GetValues(all_operands.subspan(num_args)),
178 GetI64ElementsAttr(dimensions_to_reduce, &builder_));
179 TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
180 if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0));
181 auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
182 return MakeXlaOp(tuple);
183 }
184
ReduceWindowInternal(const Shape & shape,XlaOp operand,XlaOp init_value,const XlaComputation & computation,Window window)185 StatusOr<XlaOp> MlirHloBuilder::ReduceWindowInternal(
186 const Shape& shape, XlaOp operand, XlaOp init_value,
187 const XlaComputation& computation, Window window) {
188 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
189 shape, builder_));
190 llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
191 llvm::SmallVector<int64, 8> padding;
192 for (const auto& dim : window.dimensions()) {
193 sizes.push_back(dim.size());
194 strides.push_back(dim.stride());
195 base_dilations.push_back(dim.base_dilation());
196 win_dilations.push_back(dim.window_dilation());
197 padding.push_back(dim.padding_low());
198 padding.push_back(dim.padding_high());
199 }
200 auto padding_ty =
201 mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
202 builder_.getIntegerType(64));
203 auto op = builder_.create<mlir::mhlo::ReduceWindowOp>(
204 loc_, ty, GetValue(operand), GetValue(init_value),
205 GetI64ElementsAttr(sizes, &builder_),
206 GetI64ElementsAttr(strides, &builder_),
207 GetI64ElementsAttr(base_dilations, &builder_),
208 GetI64ElementsAttr(win_dilations, &builder_),
209 mlir::DenseIntElementsAttr::get(padding_ty, padding));
210 TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
211 return MakeXlaOp(op.getResult(0));
212 }
213
Iota(const Shape & shape,int64_t iota_dimension)214 XlaOp MlirHloBuilder::Iota(const Shape& shape, int64_t iota_dimension) {
215 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
216 TF_ASSIGN_OR_RETURN(
217 mlir::Type ty,
218 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
219 auto op = builder_.create<mlir::mhlo::IotaOp>(
220 loc_, ty,
221 builder_.getIntegerAttr(builder_.getI64Type(), iota_dimension));
222 return MakeXlaOp(op);
223 });
224 }
225
BitcastConvertTypeInternal(const Shape & shape,XlaOp operand)226 StatusOr<XlaOp> MlirHloBuilder::BitcastConvertTypeInternal(const Shape& shape,
227 XlaOp operand) {
228 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
229 shape, builder_));
230 auto op = builder_.create<mlir::mhlo::BitcastConvertOp>(loc_, ty,
231 GetValue(operand));
232 return MakeXlaOp(op);
233 }
234
TransposeInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> permutation)235 StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
236 const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
237 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
238 shape, builder_));
239 auto op = builder_.create<mlir::mhlo::TransposeOp>(
240 loc_, ty, GetValue(operand), GetI64ElementsAttr(permutation, &builder_));
241 return MakeXlaOp(op);
242 }
243
RevInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> dimensions)244 StatusOr<XlaOp> MlirHloBuilder::RevInternal(
245 const Shape& shape, XlaOp operand, absl::Span<const int64> dimensions) {
246 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
247 shape, builder_));
248 auto op = builder_.create<mlir::mhlo::ReverseOp>(
249 loc_, ty, GetValue(operand), GetI64ElementsAttr(dimensions, &builder_));
250 return MakeXlaOp(op);
251 }
252
SortInternal(const Shape & shape,absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64_t dimension,bool is_stable)253 StatusOr<XlaOp> MlirHloBuilder::SortInternal(const Shape& shape,
254 absl::Span<const XlaOp> operands,
255 const XlaComputation& comparator,
256 int64_t dimension,
257 bool is_stable) {
258 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
259 shape, builder_));
260 llvm::SmallVector<mlir::Type, 4> sort_types = {ty};
261 if (auto tuple_ty = ty.dyn_cast<mlir::TupleType>()) {
262 sort_types = llvm::to_vector<6>(tuple_ty.getTypes());
263 }
264
265 auto op = builder_.create<mlir::mhlo::SortOp>(
266 loc_, sort_types, GetValues(operands),
267 builder_.getI64IntegerAttr(dimension), builder_.getBoolAttr(is_stable));
268 TF_RETURN_IF_ERROR(ImportComputation(comparator.proto(), &op.comparator()));
269
270 if (ty.isa<mlir::TupleType>()) {
271 auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
272 return MakeXlaOp(tuple);
273 }
274
275 return MakeXlaOp(op.getResult(0));
276 }
277
WhileInternal(const Shape & shape,const XlaComputation & condition,const XlaComputation & body,XlaOp init)278 StatusOr<XlaOp> MlirHloBuilder::WhileInternal(const Shape& shape,
279 const XlaComputation& condition,
280 const XlaComputation& body,
281 XlaOp init) {
282 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
283 shape, builder_));
284 auto op = builder_.create<mlir::mhlo::WhileOp>(loc_, ty, GetValue(init));
285 TF_RETURN_IF_ERROR(ImportComputation(condition.proto(), &op.cond()));
286 TF_RETURN_IF_ERROR(ImportComputation(body.proto(), &op.body()));
287 // TODO(jpienaar): Support multi-operand while op.
288 if (op.getNumResults() != 1)
289 return Unimplemented("Only single result MHLO WhileOp's can be import.");
290 return MakeXlaOp(op.getResult(0));
291 }
292
ReducePrecisionInternal(const Shape & shape,XlaOp operand,const int exponent_bits,const int mantissa_bits)293 StatusOr<XlaOp> MlirHloBuilder::ReducePrecisionInternal(
294 const Shape& shape, XlaOp operand, const int exponent_bits,
295 const int mantissa_bits) {
296 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
297 shape, builder_));
298 auto op = builder_.create<mlir::mhlo::ReducePrecisionOp>(
299 loc_, ty, GetValue(operand), builder_.getI32IntegerAttr(exponent_bits),
300 builder_.getI32IntegerAttr(mantissa_bits));
301 return MakeXlaOp(op);
302 }
303
GatherInternal(const Shape & shape,XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)304 StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
305 const Shape& shape, XlaOp input, XlaOp start_indices,
306 const GatherDimensionNumbers& dimension_numbers,
307 absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
308 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
309 shape, builder_));
310 auto op = builder_.create<mlir::mhlo::GatherOp>(
311 loc_, ty, GetValue(input), GetValue(start_indices),
312 ConvertGatherDimensionNumbers(dimension_numbers, &builder_),
313 GetI64ElementsAttr(slice_sizes, &builder_));
314 return MakeXlaOp(op);
315 }
316
ScatterInternal(const Shape & shape,XlaOp input,XlaOp scatter_indices,XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)317 StatusOr<XlaOp> MlirHloBuilder::ScatterInternal(
318 const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
319 const XlaComputation& update_computation,
320 const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
321 bool unique_indices) {
322 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
323 shape, builder_));
324 auto op = builder_.create<mlir::mhlo::ScatterOp>(
325 loc_, ty, GetValue(input), GetValue(scatter_indices), GetValue(updates),
326 ConvertScatterDimensionNumbers(dimension_numbers, &builder_),
327 builder_.getBoolAttr(indices_are_sorted),
328 builder_.getBoolAttr(unique_indices));
329
330 TF_RETURN_IF_ERROR(
331 ImportComputation(update_computation.proto(), &op.update_computation()));
332 return MakeXlaOp(op);
333 }
334
SetDimensionSizeInternal(const Shape & shape,XlaOp operand,XlaOp val,int64_t dimension)335 StatusOr<XlaOp> MlirHloBuilder::SetDimensionSizeInternal(const Shape& shape,
336 XlaOp operand,
337 XlaOp val,
338 int64_t dimension) {
339 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
340 shape, builder_));
341 auto op = builder_.create<mlir::mhlo::SetDimensionSizeOp>(
342 loc_, ty, GetValue(operand), GetValue(val),
343 builder_.getI64IntegerAttr(dimension));
344 return MakeXlaOp(op);
345 }
346
RngOpInternal(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)347 StatusOr<XlaOp> MlirHloBuilder::RngOpInternal(
348 RandomDistribution distribution, absl::Span<const XlaOp> parameters,
349 const Shape& shape) {
350 // TODO(hinsu): Introduce RngOp in the HLO dialect in MLIR and then RngUniform
351 // and RngNormal can be mapped to the new op.
352 std::string op_name;
353 if (distribution == xla::RandomDistribution::RNG_UNIFORM) {
354 op_name = "mhlo.rng_uniform";
355 } else {
356 TF_RET_CHECK(distribution == xla::RandomDistribution::RNG_NORMAL)
357 << "Unexpected distribution: " << distribution;
358 op_name = "mhlo.rng_normal";
359 }
360
361 if (shape.is_dynamic())
362 return Unimplemented("RngOp with dynamic dims not supported");
363 llvm::SmallVector<XlaOp, 3> operands;
364 operands.append(parameters.begin(), parameters.end());
365 operands.push_back(
366 ConstantLiteral(LiteralUtil::CreateR1<int64>(shape.dimensions())));
367 return CreateOp(op_name, shape, operands);
368 }
369
RngBitGeneratorInternal(const Shape & full_result_shape,RandomAlgorithm algorithm,XlaOp initial_state)370 StatusOr<XlaOp> MlirHloBuilder::RngBitGeneratorInternal(
371 const Shape& full_result_shape, RandomAlgorithm algorithm,
372 XlaOp initial_state) {
373 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
374 full_result_shape, builder_));
375 auto op = builder_.create<mlir::mhlo::RngBitGeneratorOp>(
376 loc_, ty, builder_.getI32IntegerAttr(algorithm), GetValue(initial_state));
377 return MakeXlaOp(op);
378 }
379
ReshapeInternal(const Shape & shape,XlaOp operand,int64_t inferred_dimension)380 StatusOr<XlaOp> MlirHloBuilder::ReshapeInternal(const Shape& shape,
381 XlaOp operand,
382 int64_t inferred_dimension) {
383 TF_RETURN_IF_ERROR(first_error());
384
385 if (inferred_dimension != -1)
386 return Unimplemented("inferred_dimension not yet supported for Reshape op");
387 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
388 shape, builder_));
389 mlir::Value value = GetValue(operand);
390 auto op = builder_.create<mlir::mhlo::ReshapeOp>(loc_, ty, value);
391 return MakeXlaOp(op.getResult());
392 }
393
DotGeneralInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_number,const PrecisionConfig * precision_config)394 StatusOr<XlaOp> MlirHloBuilder::DotGeneralInternal(
395 const Shape& shape, XlaOp lhs, XlaOp rhs,
396 const DotDimensionNumbers& dimension_number,
397 const PrecisionConfig* precision_config) {
398 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
399 shape, builder_));
400 auto op = builder_.create<mlir::mhlo::DotGeneralOp>(
401 loc_, ty, GetValue(lhs), GetValue(rhs),
402 ConvertDotDimensionNumbers(dimension_number, &builder_),
403 ConvertPrecisionConfig(precision_config, &builder_));
404 return MakeXlaOp(op.getResult());
405 }
406
InDimBroadcast(const Shape & shape,XlaOp operand,absl::Span<const int64> broadcast_dimensions)407 StatusOr<XlaOp> MlirHloBuilder::InDimBroadcast(
408 const Shape& shape, XlaOp operand,
409 absl::Span<const int64> broadcast_dimensions) {
410 TF_RETURN_IF_ERROR(first_error());
411 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
412 shape, builder_));
413 mlir::Value value = GetValue(operand);
414 auto op = builder_.create<mlir::mhlo::BroadcastInDimOp>(
415 loc_, ty, value, GetI64ElementsAttr(broadcast_dimensions, &builder_));
416 return MakeXlaOp(op.getResult());
417 }
418
AddInstruction(HloInstructionProto && instr,HloOpcode opcode,absl::Span<const XlaOp> operands)419 StatusOr<XlaOp> MlirHloBuilder::AddInstruction(
420 HloInstructionProto&& instr, HloOpcode opcode,
421 absl::Span<const XlaOp> operands) {
422 return Unimplemented("MlirHloBuilder does not support op %s",
423 HloOpcodeString(opcode));
424 }
425
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction,Comparison::Type type)426 StatusOr<XlaOp> MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs,
427 XlaOp rhs,
428 ComparisonDirection direction,
429 Comparison::Type type) {
430 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
431 shape, builder_));
432 auto op = builder_.create<mlir::mhlo::CompareOp>(
433 loc_, ty, GetValue(lhs), GetValue(rhs),
434 builder_.getStringAttr(ComparisonDirectionToString(direction)),
435 builder_.getStringAttr(ComparisonTypeToString(type)));
436 return MakeXlaOp(op.getResult());
437 }
438
BinaryOpNoBroadcast(HloOpcode binop,const Shape & shape,XlaOp lhs,XlaOp rhs)439 XlaOp MlirHloBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
440 XlaOp lhs, XlaOp rhs) {
441 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
442 return CreateOp(GetMlirOpName(binop), shape, {lhs, rhs});
443 });
444 }
445
AddOpWithShape(HloOpcode opcode,const Shape & shape,absl::Span<const XlaOp> operands)446 StatusOr<XlaOp> MlirHloBuilder::AddOpWithShape(
447 HloOpcode opcode, const Shape& shape, absl::Span<const XlaOp> operands) {
448 return CreateOp(GetMlirOpName(opcode), shape,
449 llvm::makeArrayRef<XlaOp>(operands.data(), operands.size()));
450 }
451
CreateToken()452 XlaOp MlirHloBuilder::CreateToken() {
453 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
454 return MakeXlaOp(builder_.create<mlir::mhlo::CreateTokenOp>(
455 loc_, mlir::mhlo::TokenType::get(builder_.getContext())));
456 });
457 }
458
TriangularSolveInternal(const Shape & shape,XlaOp a,XlaOp b,TriangularSolveOptions options)459 StatusOr<XlaOp> MlirHloBuilder::TriangularSolveInternal(
460 const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) {
461 TF_ASSIGN_OR_RETURN(
462 mlir::Type result_ty,
463 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
464 auto op = builder_.create<mlir::mhlo::TriangularSolveOp>(
465 loc_, result_ty, GetValue(a), GetValue(b),
466 builder_.getBoolAttr(options.left_side()),
467 builder_.getBoolAttr(options.lower()),
468 builder_.getBoolAttr(options.unit_diagonal()),
469 builder_.getStringAttr(
470 TriangularSolveOptions::Transpose_Name(options.transpose_a())));
471 return MakeXlaOp(op);
472 }
473
CholeskyInternal(const Shape & shape,XlaOp a,bool lower)474 StatusOr<XlaOp> MlirHloBuilder::CholeskyInternal(const Shape& shape, XlaOp a,
475 bool lower) {
476 TF_ASSIGN_OR_RETURN(
477 mlir::Type result_ty,
478 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
479 auto op = builder_.create<mlir::mhlo::CholeskyOp>(
480 loc_, result_ty, GetValue(a), builder_.getBoolAttr(lower));
481 return MakeXlaOp(op);
482 }
483
InfeedWithTokenInternal(const Shape & infeed_instruction_shape,XlaOp token,const string & config)484 StatusOr<XlaOp> MlirHloBuilder::InfeedWithTokenInternal(
485 const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
486 TF_ASSIGN_OR_RETURN(mlir::Type result_type,
487 ConvertShapeToType<mlir::RankedTensorType>(
488 infeed_instruction_shape, builder_));
489 mlir::ArrayAttr layout;
490 return MakeXlaOp(
491 builder_.create<mlir::mhlo::InfeedOp>(loc_, result_type, GetValue(token),
492 /*infeed_config=*/config,
493 /*layout=*/layout));
494 }
495
OutfeedWithTokenInternal(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)496 StatusOr<XlaOp> MlirHloBuilder::OutfeedWithTokenInternal(
497 XlaOp operand, XlaOp token, const Shape& shape_with_layout,
498 const string& outfeed_config) {
499 auto token_type = mlir::mhlo::TokenType::get(builder_.getContext());
500 return MakeXlaOp(builder_.create<mlir::mhlo::OutfeedOp>(
501 loc_, token_type, GetValue(operand), GetValue(token), outfeed_config));
502 }
503
ConcatInDimInternal(const Shape & shape,absl::Span<const XlaOp> operands,int64_t dimension)504 StatusOr<XlaOp> MlirHloBuilder::ConcatInDimInternal(
505 const Shape& shape, absl::Span<const XlaOp> operands, int64_t dimension) {
506 TF_ASSIGN_OR_RETURN(
507 mlir::Type result_type,
508 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
509 auto mlir_operands = GetValues(operands);
510 return MakeXlaOp(builder_.create<mlir::mhlo::ConcatenateOp>(
511 loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension)));
512 }
513
GetTupleElementInternal(const Shape & shape,XlaOp tuple_data,int64_t index)514 StatusOr<XlaOp> MlirHloBuilder::GetTupleElementInternal(const Shape& shape,
515 XlaOp tuple_data,
516 int64_t index) {
517 TF_ASSIGN_OR_RETURN(
518 mlir::Type result_type,
519 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
520 return MakeXlaOp(builder_.create<mlir::mhlo::GetTupleElementOp>(
521 loc_, result_type, GetValue(tuple_data),
522 builder_.getI32IntegerAttr(index)));
523 }
524
SliceInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)525 StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
526 const Shape& shape, XlaOp operand, absl::Span<const int64> start_indices,
527 absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
528 return MakeXlaOp(builder_.create<mlir::mhlo::SliceOp>(
529 loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_),
530 GetI64ElementsAttr(limit_indices, &builder_),
531 GetI64ElementsAttr(strides, &builder_)));
532 }
533
DynamicSliceInternal(const Shape & shape,XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)534 StatusOr<XlaOp> MlirHloBuilder::DynamicSliceInternal(
535 const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
536 absl::Span<const int64> slice_sizes) {
537 TF_ASSIGN_OR_RETURN(
538 mlir::Type result_ty,
539 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
540 return MakeXlaOp(builder_.create<mlir::mhlo::DynamicSliceOp>(
541 loc_, result_ty, GetValue(operand), GetValues(start_indices),
542 GetI64ElementsAttr(slice_sizes, &builder_)));
543 }
544
DynamicUpdateSliceInternal(const Shape & shape,XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)545 StatusOr<XlaOp> MlirHloBuilder::DynamicUpdateSliceInternal(
546 const Shape& shape, XlaOp operand, XlaOp update,
547 absl::Span<const XlaOp> start_indices) {
548 TF_ASSIGN_OR_RETURN(
549 mlir::Type result_ty,
550 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
551 return MakeXlaOp(builder_.create<mlir::mhlo::DynamicUpdateSliceOp>(
552 loc_, result_ty, GetValue(operand), GetValue(update),
553 GetValues(start_indices)));
554 }
555
PadInternal(const Shape & shape,XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)556 StatusOr<XlaOp> MlirHloBuilder::PadInternal(
557 const Shape& shape, XlaOp operand, XlaOp padding_value,
558 const PaddingConfig& padding_config) {
559 TF_ASSIGN_OR_RETURN(
560 mlir::Type result_type,
561 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
562 std::vector<int64> low;
563 std::vector<int64> high;
564 std::vector<int64> internal;
565 for (auto& dimension : padding_config.dimensions()) {
566 low.push_back(dimension.edge_padding_low());
567 high.push_back(dimension.edge_padding_high());
568 internal.push_back(dimension.interior_padding());
569 }
570 return MakeXlaOp(builder_.create<mlir::mhlo::PadOp>(
571 loc_, result_type, GetValue(operand), GetValue(padding_value),
572 GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_),
573 GetI64ElementsAttr(internal, &builder_)));
574 }
575
TupleInternal(const Shape & shape,absl::Span<const XlaOp> elements)576 StatusOr<XlaOp> MlirHloBuilder::TupleInternal(
577 const Shape& shape, absl::Span<const XlaOp> elements) {
578 mlir::SmallVector<mlir::Value, 4> operands;
579 for (auto& element : elements) {
580 operands.push_back(GetValue(element));
581 }
582 return MakeXlaOp(builder_.create<mlir::mhlo::TupleOp>(loc_, operands));
583 }
584
CreateOp(const std::string & op_name,const Shape & shape,llvm::ArrayRef<XlaOp> operands,llvm::ArrayRef<mlir::NamedAttribute> attributes)585 StatusOr<XlaOp> MlirHloBuilder::CreateOp(
586 const std::string& op_name, const Shape& shape,
587 llvm::ArrayRef<XlaOp> operands,
588 llvm::ArrayRef<mlir::NamedAttribute> attributes) {
589 llvm::SmallVector<mlir::Value, 4> operand_values;
590 operand_values.reserve(operands.size());
591 for (XlaOp xla_op : operands) {
592 operand_values.push_back(GetValue(xla_op));
593 }
594 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
595 shape, builder_));
596 mlir::OperationState state(loc_, op_name, operand_values, {ty}, attributes);
597 mlir::Operation* op = builder_.createOperation(state);
598 return MakeXlaOp(op->getResult(0));
599 }
600
ImportComputation(const HloModuleProto & computation,mlir::Region * region)601 Status MlirHloBuilder::ImportComputation(const HloModuleProto& computation,
602 mlir::Region* region) {
603 TF_ASSIGN_OR_RETURN(auto module_config,
604 xla::HloModule::CreateModuleConfigFromProto(
605 computation, xla::DebugOptions()));
606 TF_ASSIGN_OR_RETURN(auto hlo_module, xla::HloModule::CreateFromProto(
607 computation, module_config));
608
609 return HloFunctionImporter::ImportAsRegion(*hlo_module->entry_computation(),
610 region, &builder_);
611 }
612
GetShapePtr(XlaOp op) const613 StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const {
614 TF_RETURN_IF_ERROR(first_error());
615 TF_RETURN_IF_ERROR(CheckOpBuilder(op));
616 auto it = handle_to_shape_.find(op.handle());
617 if (it == handle_to_shape_.end()) {
618 return InvalidArgument("No XlaOp with handle %d", op.handle());
619 }
620 return it->second.get();
621 }
622
623 } // namespace xla
624