1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
16
17 #include <string>
18 #include <utility>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
29 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
30 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
31 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
32 #include "tensorflow/compiler/xla/comparison_util.h"
33 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
34 #include "tensorflow/compiler/xla/service/hlo_module.h"
35 #include "tensorflow/compiler/xla/service/shape_inference.h"
36 #include "tensorflow/compiler/xla/util.h"
37
38 namespace xla {
39
GetMlirOpName(HloOpcode opcode)40 static std::string GetMlirOpName(HloOpcode opcode) {
41 std::string op_name = HloOpcodeString(opcode);
42 absl::c_replace(op_name, '-', '_');
43 return mlir::mhlo::MhloDialect::getDialectNamespace().str() + "." + op_name;
44 }
45
ToString(mlir::Type ty)46 static std::string ToString(mlir::Type ty) {
47 std::string str;
48 llvm::raw_string_ostream sstream(str);
49 ty.print(sstream);
50 sstream.flush();
51 return str;
52 }
53
54 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(absl::Span<const int64_t> values,mlir::Builder * builder)55 static mlir::DenseIntElementsAttr GetI64ElementsAttr(
56 absl::Span<const int64_t> values, mlir::Builder* builder) {
57 auto ty = mlir::RankedTensorType::get({static_cast<int64_t>(values.size())},
58 builder->getIntegerType(64));
59 return mlir::DenseIntElementsAttr::get(
60 ty, llvm::makeArrayRef(values.data(), values.size()));
61 }
62
ConvertPadding(absl::Span<const std::pair<int64_t,int64_t>> padding,mlir::Builder * builder)63 static mlir::DenseIntElementsAttr ConvertPadding(
64 absl::Span<const std::pair<int64_t, int64_t>> padding,
65 mlir::Builder* builder) {
66 llvm::SmallVector<int64_t, 8> elements;
67 elements.reserve(padding.size() * 2);
68 for (const auto& vals : padding) {
69 elements.push_back(vals.first);
70 elements.push_back(vals.second);
71 }
72 auto ty = mlir::RankedTensorType::get(
73 {static_cast<int64_t>(padding.size()), 2}, builder->getIntegerType(64));
74 return mlir::DenseIntElementsAttr::get(ty, elements);
75 }
76
77 MlirHloBuilder::~MlirHloBuilder() = default;
78
MakeXlaOp(mlir::Value val)79 StatusOr<XlaOp> MlirHloBuilder::MakeXlaOp(mlir::Value val) {
80 mlir::Type ty = val.getType();
81 auto shape = std::make_unique<Shape>(TypeToShape(ty));
82 if (shape->element_type() == PrimitiveType::PRIMITIVE_TYPE_INVALID) {
83 return InvalidArgument("unsupported type: %s", ToString(ty).c_str());
84 }
85
86 int64_t handle = reinterpret_cast<int64_t>(val.getAsOpaquePointer());
87 handle_to_shape_[handle] = std::move(shape);
88 return XlaOp(handle, this);
89 }
90
ConstantLiteral(const LiteralSlice & literal)91 XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) {
92 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
93 TF_ASSIGN_OR_RETURN(mlir::DenseElementsAttr attr,
94 CreateDenseElementsAttrFromLiteral(literal, builder_));
95 auto op = builder_.create<mlir::mhlo::ConstantOp>(loc_, attr);
96 return MakeXlaOp(op);
97 });
98 }
99
ConvGeneralDilatedInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const Window & window,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config)100 StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
101 const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
102 absl::Span<const int64_t> window_strides,
103 absl::Span<const std::pair<int64_t, int64_t>> padding,
104 absl::Span<const int64_t> lhs_dilation,
105 absl::Span<const int64_t> rhs_dilation,
106 const ConvolutionDimensionNumbers& dimension_numbers,
107 int64_t feature_group_count, int64_t batch_group_count,
108 const PrecisionConfig* precision_config) {
109 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
110 shape, builder_));
111 mlir::ArrayAttr config_attr;
112 if (precision_config)
113 config_attr = ConvertPrecisionConfig(precision_config, &builder_);
114 auto op = builder_.create<mlir::mhlo::ConvolutionOp>(
115 loc_, ty, GetValue(lhs), GetValue(rhs),
116 GetI64ElementsAttr(window_strides, &builder_),
117 ConvertPadding(padding, &builder_),
118 GetI64ElementsAttr(lhs_dilation, &builder_),
119 GetI64ElementsAttr(rhs_dilation, &builder_),
120 /*window_reversal=*/nullptr,
121 ConvertConvDimensionNumbers(dimension_numbers, &builder_),
122 builder_.getI64IntegerAttr(feature_group_count),
123 builder_.getI64IntegerAttr(batch_group_count), config_attr);
124 return MakeXlaOp(op);
125 }
126
FftInternal(const Shape & shape,XlaOp operand,FftType fft_type,absl::Span<const int64_t> fft_length)127 StatusOr<XlaOp> MlirHloBuilder::FftInternal(
128 const Shape& shape, XlaOp operand, FftType fft_type,
129 absl::Span<const int64_t> fft_length) {
130 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
131 shape, builder_));
132 auto fft_type_attr = mlir::mhlo::symbolizeFftType(FftType_Name(fft_type));
133 auto op = builder_.create<mlir::mhlo::FftOp>(
134 loc_, ty, GetValue(operand),
135 mlir::mhlo::FftTypeAttr::get(builder_.getContext(),
136 fft_type_attr.getValue()),
137 GetI64ElementsAttr(fft_length, &builder_));
138 return MakeXlaOp(op);
139 }
140
141 // TODO(b/235207091) Add actual support for the called computation.
CustomCallInternal(const std::string & call_target_name,absl::Span<const XlaOp> operands,const XlaComputation * computation,const Shape & shape,const std::string & opaque,std::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> output_operand_aliasing,const Literal * literal,std::optional<Window> window,std::optional<ConvolutionDimensionNumbers> dnums,CustomCallSchedule schedule,CustomCallApiVersion api_version)142 StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
143 const std::string& call_target_name, absl::Span<const XlaOp> operands,
144 const XlaComputation* computation, const Shape& shape,
145 const std::string& opaque,
146 std::optional<absl::Span<const Shape>> operand_shapes_with_layout,
147 bool has_side_effect,
148 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
149 output_operand_aliasing,
150 const Literal* literal, std::optional<Window> window,
151 std::optional<ConvolutionDimensionNumbers> dnums,
152 CustomCallSchedule schedule, CustomCallApiVersion api_version) {
153 TF_RET_CHECK(output_operand_aliasing.empty())
154 << "MLIR CustomCallOp does not support output_operand_aliasing yet";
155 TF_RET_CHECK(literal == nullptr)
156 << "MLIR CustomCallOp does not support literal yet";
157 TF_RET_CHECK(!window.has_value())
158 << "MLIR CustomCallOp does not support ConvolutionDimensionNumbers yet";
159 TF_RET_CHECK(!dnums.has_value())
160 << "MLIR CustomCallOp does not support ConvolutionDimensionNumbers yet";
161 TF_RET_CHECK(schedule == CustomCallSchedule::SCHEDULE_NONE)
162 << "MLIR CustomCallOp does not support custom-call-schedule yet";
163 TF_RET_CHECK(computation == nullptr || computation->IsNull() ||
164 build_functions_ == true)
165 << "MLIR CustomCallOp with computation isn't supported when not allowed "
166 "to create functions";
167
168 llvm::SmallVector<mlir::NamedAttribute> attributes;
169 if (operand_shapes_with_layout.has_value()) {
170 TF_ASSIGN_OR_RETURN(mlir::ArrayAttr operand_layouts,
171 ExtractLayoutsFromShapes(
172 operand_shapes_with_layout.value(), &builder_));
173 attributes.push_back(
174 builder_.getNamedAttr("operand_layouts", operand_layouts));
175
176 mlir::ArrayAttr result_layouts;
177 if (shape.IsTuple()) {
178 TF_ASSIGN_OR_RETURN(result_layouts,
179 ExtractLayoutsFromTuple(shape, &builder_));
180 } else {
181 TF_ASSIGN_OR_RETURN(result_layouts,
182 ExtractLayoutsFromShapes({shape}, &builder_));
183 }
184 attributes.push_back(
185 builder_.getNamedAttr("result_layouts", result_layouts));
186 }
187 TF_ASSIGN_OR_RETURN(auto mlir_api_version,
188 ConvertCustomCallApiVersion(api_version));
189 attributes.push_back(builder_.getNamedAttr(
190 "api_version", mlir::mhlo::CustomCallApiVersionAttr::get(
191 builder_.getContext(), mlir_api_version)));
192 attributes.push_back(builder_.getNamedAttr(
193 "call_target_name", builder_.getStringAttr(call_target_name)));
194 attributes.push_back(builder_.getNamedAttr(
195 "has_side_effect", builder_.getBoolAttr(has_side_effect)));
196 attributes.push_back(
197 builder_.getNamedAttr("backend_config", builder_.getStringAttr(opaque)));
198
199 if (computation && !computation->IsNull()) {
200 llvm::SmallVector<mlir::Attribute> computation_names;
201 for (const auto& computation_proto : computation->proto().computations()) {
202 computation_names.push_back(mlir::SymbolRefAttr::get(
203 builder_.getContext(), computation_proto.name()));
204 }
205 attributes.push_back(builder_.getNamedAttr(
206 "called_computations", builder_.getArrayAttr(computation_names)));
207
208 // Create new function(s) to represent the called computations. As a result,
209 // this legalization may only be called during a module pass rather than the
210 // typical parallelized func pass which is not permitted to create
211 // functions.
212 TF_RETURN_IF_ERROR(ImportComputation(
213 computation->proto(),
214 builder_.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>()));
215 }
216
217 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
218 shape, builder_));
219 auto op = builder_.create<mlir::mhlo::CustomCallOp>(
220 loc_, ty, GetValues(operands), attributes);
221 return MakeXlaOp(op.getResult(0));
222 }
223
ReduceInternal(const Shape & shape,absl::Span<const XlaOp> all_operands,const XlaComputation & computation,absl::Span<const int64_t> dimensions_to_reduce)224 StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
225 const Shape& shape, absl::Span<const XlaOp> all_operands,
226 const XlaComputation& computation,
227 absl::Span<const int64_t> dimensions_to_reduce) {
228 // Reduce takes two set of variadic operands inputs and init_values.
229 // all_operands contains both of these so split operands into two parts.
230 int64_t num_args = all_operands.size() / 2;
231 auto op = builder_.create<mlir::mhlo::ReduceOp>(
232 loc_, GetValues(all_operands.first(num_args)),
233 GetValues(all_operands.subspan(num_args)),
234 GetI64ElementsAttr(dimensions_to_reduce, &builder_));
235 TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body(),
236 /*flatten_region_arg_tuple*/ true));
237 if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0));
238 auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
239 return MakeXlaOp(tuple);
240 }
241
ReduceWindowInternal(const Shape & shape,XlaOp operand,XlaOp init_value,const XlaComputation & computation,Window window)242 StatusOr<XlaOp> MlirHloBuilder::ReduceWindowInternal(
243 const Shape& shape, XlaOp operand, XlaOp init_value,
244 const XlaComputation& computation, Window window) {
245 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
246 shape, builder_));
247 llvm::SmallVector<int64_t, 4> sizes, strides, base_dilations, win_dilations;
248 llvm::SmallVector<int64_t, 8> padding;
249 for (const auto& dim : window.dimensions()) {
250 sizes.push_back(dim.size());
251 strides.push_back(dim.stride());
252 base_dilations.push_back(dim.base_dilation());
253 win_dilations.push_back(dim.window_dilation());
254 padding.push_back(dim.padding_low());
255 padding.push_back(dim.padding_high());
256 }
257 auto padding_ty =
258 mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
259 builder_.getIntegerType(64));
260 auto op = builder_.create<mlir::mhlo::ReduceWindowOp>(
261 loc_, ty, GetValue(operand), GetValue(init_value),
262 GetI64ElementsAttr(sizes, &builder_),
263 GetI64ElementsAttr(strides, &builder_),
264 GetI64ElementsAttr(base_dilations, &builder_),
265 GetI64ElementsAttr(win_dilations, &builder_),
266 mlir::DenseIntElementsAttr::get(padding_ty, padding));
267 TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body(),
268 /*flatten_region_arg_tuple*/ true));
269 return MakeXlaOp(op.getResult(0));
270 }
271
Iota(const Shape & shape,int64_t iota_dimension)272 XlaOp MlirHloBuilder::Iota(const Shape& shape, int64_t iota_dimension) {
273 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
274 TF_ASSIGN_OR_RETURN(
275 mlir::Type ty,
276 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
277 auto op = builder_.create<mlir::mhlo::IotaOp>(
278 loc_, ty,
279 builder_.getIntegerAttr(builder_.getI64Type(), iota_dimension));
280 return MakeXlaOp(op);
281 });
282 }
283
BitcastConvertTypeInternal(const Shape & shape,XlaOp operand)284 StatusOr<XlaOp> MlirHloBuilder::BitcastConvertTypeInternal(const Shape& shape,
285 XlaOp operand) {
286 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
287 shape, builder_));
288 auto op = builder_.create<mlir::mhlo::BitcastConvertOp>(loc_, ty,
289 GetValue(operand));
290 return MakeXlaOp(op);
291 }
292
TransposeInternal(const Shape & shape,XlaOp operand,absl::Span<const int64_t> permutation)293 StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
294 const Shape& shape, XlaOp operand, absl::Span<const int64_t> permutation) {
295 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
296 shape, builder_));
297 auto op = builder_.create<mlir::mhlo::TransposeOp>(
298 loc_, ty, GetValue(operand), GetI64ElementsAttr(permutation, &builder_));
299 return MakeXlaOp(op);
300 }
301
RevInternal(const Shape & shape,XlaOp operand,absl::Span<const int64_t> dimensions)302 StatusOr<XlaOp> MlirHloBuilder::RevInternal(
303 const Shape& shape, XlaOp operand, absl::Span<const int64_t> dimensions) {
304 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
305 shape, builder_));
306 auto op = builder_.create<mlir::mhlo::ReverseOp>(
307 loc_, ty, GetValue(operand), GetI64ElementsAttr(dimensions, &builder_));
308 return MakeXlaOp(op);
309 }
310
SortInternal(const Shape & shape,absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64_t dimension,bool is_stable)311 StatusOr<XlaOp> MlirHloBuilder::SortInternal(const Shape& shape,
312 absl::Span<const XlaOp> operands,
313 const XlaComputation& comparator,
314 int64_t dimension,
315 bool is_stable) {
316 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
317 shape, builder_));
318 llvm::SmallVector<mlir::Type, 4> sort_types = {ty};
319 if (auto tuple_ty = ty.dyn_cast<mlir::TupleType>()) {
320 sort_types = llvm::to_vector<6>(tuple_ty.getTypes());
321 }
322
323 auto op = builder_.create<mlir::mhlo::SortOp>(
324 loc_, sort_types, GetValues(operands),
325 builder_.getI64IntegerAttr(dimension), builder_.getBoolAttr(is_stable));
326 TF_RETURN_IF_ERROR(ImportComputation(comparator.proto(), &op.comparator()));
327
328 if (ty.isa<mlir::TupleType>()) {
329 auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
330 return MakeXlaOp(tuple);
331 }
332
333 return MakeXlaOp(op.getResult(0));
334 }
335
WhileInternal(const Shape & shape,const XlaComputation & condition,const XlaComputation & body,XlaOp init)336 StatusOr<XlaOp> MlirHloBuilder::WhileInternal(const Shape& shape,
337 const XlaComputation& condition,
338 const XlaComputation& body,
339 XlaOp init) {
340 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
341 shape, builder_));
342
343 llvm::SmallVector<mlir::Value> flattened_operands;
344 llvm::SmallVector<mlir::Type> flattened_operand_types;
345
346 HloFunctionImporter::FlattenTupleType(ty, flattened_operand_types);
347 HloFunctionImporter::FlattenTupleValue(&builder_, loc_, GetValue(init),
348 flattened_operands);
349
350 auto op = builder_.create<mlir::mhlo::WhileOp>(loc_, flattened_operand_types,
351 flattened_operands);
352
353 TF_RETURN_IF_ERROR(ImportComputation(condition.proto(), &op.cond(),
354 /*flatten_region_arg_tuple*/ true));
355 TF_RETURN_IF_ERROR(ImportComputation(body.proto(), &op.body(),
356 /*flatten_region_arg_tuple*/ true));
357
358 if (ty.isa<mlir::TupleType>()) {
359 llvm::SmallVector<mlir::Value> flattened_results = op->getResults();
360 llvm::MutableArrayRef<mlir::Value> flattened_results_ref(flattened_results);
361 auto result = HloFunctionImporter::CreateTupleValue(
362 &builder_, loc_, flattened_results_ref, ty);
363 auto defining_tuple_op = result.getDefiningOp<mlir::mhlo::TupleOp>();
364 return MakeXlaOp(defining_tuple_op);
365 }
366
367 return MakeXlaOp(op.getResult(0));
368 }
369
ReducePrecisionInternal(const Shape & shape,XlaOp operand,const int exponent_bits,const int mantissa_bits)370 StatusOr<XlaOp> MlirHloBuilder::ReducePrecisionInternal(
371 const Shape& shape, XlaOp operand, const int exponent_bits,
372 const int mantissa_bits) {
373 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
374 shape, builder_));
375 auto op = builder_.create<mlir::mhlo::ReducePrecisionOp>(
376 loc_, ty, GetValue(operand), builder_.getI32IntegerAttr(exponent_bits),
377 builder_.getI32IntegerAttr(mantissa_bits));
378 return MakeXlaOp(op);
379 }
380
GatherInternal(const Shape & shape,XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64_t> slice_sizes,bool indices_are_sorted)381 StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
382 const Shape& shape, XlaOp input, XlaOp start_indices,
383 const GatherDimensionNumbers& dimension_numbers,
384 absl::Span<const int64_t> slice_sizes, bool indices_are_sorted) {
385 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
386 shape, builder_));
387 auto op = builder_.create<mlir::mhlo::GatherOp>(
388 loc_, ty, GetValue(input), GetValue(start_indices),
389 ConvertGatherDimensionNumbers(dimension_numbers, &builder_),
390 GetI64ElementsAttr(slice_sizes, &builder_));
391 return MakeXlaOp(op);
392 }
393
ScatterInternal(const Shape & shape,absl::Span<const XlaOp> inputs,XlaOp scatter_indices,absl::Span<const XlaOp> updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)394 StatusOr<XlaOp> MlirHloBuilder::ScatterInternal(
395 const Shape& shape, absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
396 absl::Span<const XlaOp> updates, const XlaComputation& update_computation,
397 const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
398 bool unique_indices) {
399 // TODO(b/230137437): Allow variadic scatter after adding mhlo support.
400 if (inputs.size() != 1) {
401 return Unimplemented("Variadic scatter not implemented in mhlo yet.");
402 }
403 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
404 shape, builder_));
405 auto op = builder_.create<mlir::mhlo::ScatterOp>(
406 loc_, ty, GetValue(inputs[0]), GetValue(scatter_indices),
407 GetValue(updates[0]),
408 ConvertScatterDimensionNumbers(dimension_numbers, &builder_),
409 builder_.getBoolAttr(indices_are_sorted),
410 builder_.getBoolAttr(unique_indices));
411
412 TF_RETURN_IF_ERROR(
413 ImportComputation(update_computation.proto(), &op.update_computation()));
414 return MakeXlaOp(op.getResult(0));
415 }
416
SetDimensionSizeInternal(const Shape & shape,XlaOp operand,XlaOp val,int64_t dimension)417 StatusOr<XlaOp> MlirHloBuilder::SetDimensionSizeInternal(const Shape& shape,
418 XlaOp operand,
419 XlaOp val,
420 int64_t dimension) {
421 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
422 shape, builder_));
423 auto op = builder_.create<mlir::mhlo::SetDimensionSizeOp>(
424 loc_, ty, GetValue(operand), GetValue(val),
425 builder_.getI64IntegerAttr(dimension));
426 return MakeXlaOp(op);
427 }
428
RngOpInternal(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)429 StatusOr<XlaOp> MlirHloBuilder::RngOpInternal(
430 RandomDistribution distribution, absl::Span<const XlaOp> parameters,
431 const Shape& shape) {
432 mlir::mhlo::RngDistributionAttr attr;
433 if (distribution == xla::RandomDistribution::RNG_UNIFORM) {
434 attr = mlir::mhlo::RngDistributionAttr::get(
435 builder_.getContext(), mlir::mhlo::RngDistribution::UNIFORM);
436 } else {
437 TF_RET_CHECK(distribution == xla::RandomDistribution::RNG_NORMAL)
438 << "Unexpected distribution: " << distribution;
439 attr = mlir::mhlo::RngDistributionAttr::get(
440 builder_.getContext(), mlir::mhlo::RngDistribution::NORMAL);
441 }
442 llvm::SmallVector<mlir::NamedAttribute, 1> attributes = {
443 builder_.getNamedAttr("rng_distribution", attr)};
444
445 if (shape.is_dynamic())
446 return Unimplemented("RngOp with dynamic dims not supported");
447 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
448 shape, builder_));
449
450 auto op = builder_.create<mlir::mhlo::RngOp>(
451 loc_, ty, GetValue(parameters[0]), GetValue(parameters[1]),
452 GetValue(
453 ConstantLiteral(LiteralUtil::CreateR1<int64_t>(shape.dimensions()))),
454 attr);
455 return MakeXlaOp(op.getResult());
456 }
457
RngBitGeneratorInternal(const Shape & full_result_shape,RandomAlgorithm algorithm,XlaOp initial_state)458 StatusOr<XlaOp> MlirHloBuilder::RngBitGeneratorInternal(
459 const Shape& full_result_shape, RandomAlgorithm algorithm,
460 XlaOp initial_state) {
461 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
462 full_result_shape, builder_));
463
464 llvm::SmallVector<mlir::Type> flattened_ret_types;
465 HloFunctionImporter::FlattenTupleType(ty, flattened_ret_types);
466
467 auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get(
468 builder_.getContext(), *mlir::mhlo::symbolizeRngAlgorithm(algorithm));
469 auto op = builder_.create<mlir::mhlo::RngBitGeneratorOp>(
470 loc_, flattened_ret_types, algorithm_attr, GetValue(initial_state));
471
472 if (ty.isa<mlir::TupleType>()) {
473 llvm::SmallVector<mlir::Value> flattened_results = op->getResults();
474 llvm::MutableArrayRef<mlir::Value> flattened_results_ref(flattened_results);
475 auto result = HloFunctionImporter::CreateTupleValue(
476 &builder_, loc_, flattened_results_ref, ty);
477 auto defining_tuple_op = result.getDefiningOp<mlir::mhlo::TupleOp>();
478 return MakeXlaOp(defining_tuple_op);
479 }
480
481 return MakeXlaOp(op.getResult(0));
482 }
483
ReshapeInternal(const Shape & shape,XlaOp operand,int64_t inferred_dimension)484 StatusOr<XlaOp> MlirHloBuilder::ReshapeInternal(const Shape& shape,
485 XlaOp operand,
486 int64_t inferred_dimension) {
487 TF_RETURN_IF_ERROR(first_error());
488
489 if (inferred_dimension != -1)
490 return Unimplemented("inferred_dimension not yet supported for Reshape op");
491 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
492 shape, builder_));
493 mlir::Value value = GetValue(operand);
494 auto op = builder_.create<mlir::mhlo::ReshapeOp>(loc_, ty, value);
495 return MakeXlaOp(op.getResult());
496 }
497
DotGeneralInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_number,const PrecisionConfig * precision_config)498 StatusOr<XlaOp> MlirHloBuilder::DotGeneralInternal(
499 const Shape& shape, XlaOp lhs, XlaOp rhs,
500 const DotDimensionNumbers& dimension_number,
501 const PrecisionConfig* precision_config) {
502 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
503 shape, builder_));
504 auto op = builder_.create<mlir::mhlo::DotGeneralOp>(
505 loc_, ty, GetValue(lhs), GetValue(rhs),
506 ConvertDotDimensionNumbers(dimension_number, &builder_),
507 ConvertPrecisionConfig(precision_config, &builder_));
508 return MakeXlaOp(op.getResult());
509 }
510
InDimBroadcast(const Shape & shape,XlaOp operand,absl::Span<const int64_t> broadcast_dimensions)511 StatusOr<XlaOp> MlirHloBuilder::InDimBroadcast(
512 const Shape& shape, XlaOp operand,
513 absl::Span<const int64_t> broadcast_dimensions) {
514 TF_RETURN_IF_ERROR(first_error());
515 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
516 shape, builder_));
517 mlir::Value value = GetValue(operand);
518 auto op = builder_.create<mlir::mhlo::BroadcastInDimOp>(
519 loc_, ty, value, GetI64ElementsAttr(broadcast_dimensions, &builder_));
520 return MakeXlaOp(op.getResult());
521 }
522
AddInstruction(HloInstructionProto && instr,HloOpcode opcode,absl::Span<const XlaOp> operands)523 StatusOr<XlaOp> MlirHloBuilder::AddInstruction(
524 HloInstructionProto&& instr, HloOpcode opcode,
525 absl::Span<const XlaOp> operands) {
526 return Unimplemented("MlirHloBuilder does not support op %s",
527 HloOpcodeString(opcode));
528 }
529
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction,Comparison::Type type)530 StatusOr<XlaOp> MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs,
531 XlaOp rhs,
532 ComparisonDirection direction,
533 Comparison::Type type) {
534 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
535 shape, builder_));
536 auto op = builder_.create<mlir::mhlo::CompareOp>(
537 loc_, ty, GetValue(lhs), GetValue(rhs),
538 mlir::mhlo::ComparisonDirectionAttr::get(
539 builder_.getContext(), mlir::mhlo::symbolizeComparisonDirection(
540 ComparisonDirectionToString(direction))
541 .getValue()),
542 mlir::mhlo::ComparisonTypeAttr::get(
543 builder_.getContext(),
544 mlir::mhlo::symbolizeComparisonType(ComparisonTypeToString(type))
545 .getValue()));
546 return MakeXlaOp(op.getResult());
547 }
548
BinaryOpNoBroadcast(HloOpcode binop,const Shape & shape,XlaOp lhs,XlaOp rhs)549 XlaOp MlirHloBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
550 XlaOp lhs, XlaOp rhs) {
551 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
552 return CreateOp(GetMlirOpName(binop), shape, {lhs, rhs});
553 });
554 }
555
AddOpWithShape(HloOpcode opcode,const Shape & shape,absl::Span<const XlaOp> operands)556 StatusOr<XlaOp> MlirHloBuilder::AddOpWithShape(
557 HloOpcode opcode, const Shape& shape, absl::Span<const XlaOp> operands) {
558 return CreateOp(GetMlirOpName(opcode), shape,
559 llvm::makeArrayRef<XlaOp>(operands.data(), operands.size()));
560 }
561
CreateToken()562 XlaOp MlirHloBuilder::CreateToken() {
563 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
564 return MakeXlaOp(builder_.create<mlir::mhlo::CreateTokenOp>(
565 loc_, mlir::mhlo::TokenType::get(builder_.getContext())));
566 });
567 }
568
TriangularSolveInternal(const Shape & shape,XlaOp a,XlaOp b,TriangularSolveOptions options)569 StatusOr<XlaOp> MlirHloBuilder::TriangularSolveInternal(
570 const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) {
571 TF_ASSIGN_OR_RETURN(
572 mlir::Type result_ty,
573 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
574 auto op = builder_.create<mlir::mhlo::TriangularSolveOp>(
575 loc_, result_ty, GetValue(a), GetValue(b),
576 builder_.getBoolAttr(options.left_side()),
577 builder_.getBoolAttr(options.lower()),
578 builder_.getBoolAttr(options.unit_diagonal()),
579 mlir::mhlo::TransposeAttr::get(
580 builder_.getContext(),
581 ::mlir::mhlo::symbolizeTranspose(
582 TriangularSolveOptions::Transpose_Name(options.transpose_a()))
583 .getValue()));
584 return MakeXlaOp(op);
585 }
586
CholeskyInternal(const Shape & shape,XlaOp a,bool lower)587 StatusOr<XlaOp> MlirHloBuilder::CholeskyInternal(const Shape& shape, XlaOp a,
588 bool lower) {
589 TF_ASSIGN_OR_RETURN(
590 mlir::Type result_ty,
591 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
592 auto op = builder_.create<mlir::mhlo::CholeskyOp>(
593 loc_, result_ty, GetValue(a), builder_.getBoolAttr(lower));
594 return MakeXlaOp(op);
595 }
596
InfeedWithTokenInternal(const Shape & infeed_instruction_shape,XlaOp token,const std::string & config)597 StatusOr<XlaOp> MlirHloBuilder::InfeedWithTokenInternal(
598 const Shape& infeed_instruction_shape, XlaOp token,
599 const std::string& config) {
600 TF_ASSIGN_OR_RETURN(mlir::Type result_type,
601 ConvertShapeToType<mlir::RankedTensorType>(
602 infeed_instruction_shape, builder_));
603 llvm::SmallVector<mlir::Type> flattened_ret_types;
604 HloFunctionImporter::FlattenTupleType(result_type, flattened_ret_types);
605
606 mlir::ArrayAttr layout;
607 auto op = builder_.create<mlir::mhlo::InfeedOp>(loc_, flattened_ret_types,
608 GetValue(token),
609 /*infeed_config=*/config,
610 /*layout=*/layout);
611
612 llvm::SmallVector<mlir::Value> flattened_results = op->getResults();
613 llvm::MutableArrayRef<mlir::Value> flattened_results_ref(flattened_results);
614 auto result = HloFunctionImporter::CreateTupleValue(
615 &builder_, loc_, flattened_results_ref, result_type);
616 auto defining_tuple_op = result.getDefiningOp<mlir::mhlo::TupleOp>();
617 return MakeXlaOp(defining_tuple_op);
618 }
619
OutfeedWithTokenInternal(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const std::string & outfeed_config)620 StatusOr<XlaOp> MlirHloBuilder::OutfeedWithTokenInternal(
621 XlaOp operand, XlaOp token, const Shape& shape_with_layout,
622 const std::string& outfeed_config) {
623 auto token_type = mlir::mhlo::TokenType::get(builder_.getContext());
624 llvm::SmallVector<mlir::Value> flattened_operands;
625 HloFunctionImporter::FlattenTupleValue(&builder_, loc_, GetValue(operand),
626 flattened_operands);
627 return MakeXlaOp(builder_.create<mlir::mhlo::OutfeedOp>(
628 loc_, token_type, flattened_operands, GetValue(token), outfeed_config));
629 }
630
ConcatInDimInternal(const Shape & shape,absl::Span<const XlaOp> operands,int64_t dimension)631 StatusOr<XlaOp> MlirHloBuilder::ConcatInDimInternal(
632 const Shape& shape, absl::Span<const XlaOp> operands, int64_t dimension) {
633 TF_ASSIGN_OR_RETURN(
634 mlir::Type result_type,
635 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
636 auto mlir_operands = GetValues(operands);
637 return MakeXlaOp(builder_.create<mlir::mhlo::ConcatenateOp>(
638 loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension)));
639 }
640
GetTupleElementInternal(const Shape & shape,XlaOp tuple_data,int64_t index)641 StatusOr<XlaOp> MlirHloBuilder::GetTupleElementInternal(const Shape& shape,
642 XlaOp tuple_data,
643 int64_t index) {
644 TF_ASSIGN_OR_RETURN(
645 mlir::Type result_type,
646 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
647 return MakeXlaOp(builder_.create<mlir::mhlo::GetTupleElementOp>(
648 loc_, result_type, GetValue(tuple_data),
649 builder_.getI32IntegerAttr(index)));
650 }
651
SliceInternal(const Shape & shape,XlaOp operand,absl::Span<const int64_t> start_indices,absl::Span<const int64_t> limit_indices,absl::Span<const int64_t> strides)652 StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
653 const Shape& shape, XlaOp operand, absl::Span<const int64_t> start_indices,
654 absl::Span<const int64_t> limit_indices,
655 absl::Span<const int64_t> strides) {
656 return MakeXlaOp(builder_.create<mlir::mhlo::SliceOp>(
657 loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_),
658 GetI64ElementsAttr(limit_indices, &builder_),
659 GetI64ElementsAttr(strides, &builder_)));
660 }
661
DynamicSliceInternal(const Shape & shape,XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64_t> slice_sizes)662 StatusOr<XlaOp> MlirHloBuilder::DynamicSliceInternal(
663 const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
664 absl::Span<const int64_t> slice_sizes) {
665 TF_ASSIGN_OR_RETURN(
666 mlir::Type result_ty,
667 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
668 return MakeXlaOp(builder_.create<mlir::mhlo::DynamicSliceOp>(
669 loc_, result_ty, GetValue(operand), GetValues(start_indices),
670 GetI64ElementsAttr(slice_sizes, &builder_)));
671 }
672
DynamicUpdateSliceInternal(const Shape & shape,XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)673 StatusOr<XlaOp> MlirHloBuilder::DynamicUpdateSliceInternal(
674 const Shape& shape, XlaOp operand, XlaOp update,
675 absl::Span<const XlaOp> start_indices) {
676 TF_ASSIGN_OR_RETURN(
677 mlir::Type result_ty,
678 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
679 return MakeXlaOp(builder_.create<mlir::mhlo::DynamicUpdateSliceOp>(
680 loc_, result_ty, GetValue(operand), GetValue(update),
681 GetValues(start_indices)));
682 }
683
PadInternal(const Shape & shape,XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)684 StatusOr<XlaOp> MlirHloBuilder::PadInternal(
685 const Shape& shape, XlaOp operand, XlaOp padding_value,
686 const PaddingConfig& padding_config) {
687 TF_ASSIGN_OR_RETURN(
688 mlir::Type result_type,
689 ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
690 llvm::SmallVector<int64_t> low, high, internal;
691 for (auto& dimension : padding_config.dimensions()) {
692 low.push_back(dimension.edge_padding_low());
693 high.push_back(dimension.edge_padding_high());
694 internal.push_back(dimension.interior_padding());
695 }
696 return MakeXlaOp(builder_.create<mlir::mhlo::PadOp>(
697 loc_, result_type, GetValue(operand), GetValue(padding_value),
698 GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_),
699 GetI64ElementsAttr(internal, &builder_)));
700 }
701
TupleInternal(const Shape & shape,absl::Span<const XlaOp> elements)702 StatusOr<XlaOp> MlirHloBuilder::TupleInternal(
703 const Shape& shape, absl::Span<const XlaOp> elements) {
704 mlir::SmallVector<mlir::Value, 4> operands;
705 for (auto& element : elements) {
706 operands.push_back(GetValue(element));
707 }
708 return MakeXlaOp(builder_.create<mlir::mhlo::TupleOp>(loc_, operands));
709 }
710
CreateOp(const std::string & op_name,const Shape & shape,llvm::ArrayRef<XlaOp> operands,llvm::ArrayRef<mlir::NamedAttribute> attributes)711 StatusOr<XlaOp> MlirHloBuilder::CreateOp(
712 const std::string& op_name, const Shape& shape,
713 llvm::ArrayRef<XlaOp> operands,
714 llvm::ArrayRef<mlir::NamedAttribute> attributes) {
715 llvm::SmallVector<mlir::Value, 4> operand_values;
716 operand_values.reserve(operands.size());
717 for (XlaOp xla_op : operands) {
718 operand_values.push_back(GetValue(xla_op));
719 }
720 TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
721 shape, builder_));
722 mlir::OperationState state(loc_, op_name, operand_values, {ty}, attributes);
723 mlir::Operation* op = builder_.create(state);
724 return MakeXlaOp(op->getResult(0));
725 }
726
CreateHloModuleFromProto(const HloModuleProto & proto)727 StatusOr<std::unique_ptr<xla::HloModule>> CreateHloModuleFromProto(
728 const HloModuleProto& proto) {
729 TF_ASSIGN_OR_RETURN(
730 auto module_config,
731 xla::HloModule::CreateModuleConfigFromProto(proto, xla::DebugOptions()));
732 TF_ASSIGN_OR_RETURN(auto hlo_module,
733 xla::HloModule::CreateFromProto(proto, module_config));
734 return hlo_module;
735 }
736
ImportComputation(const HloModuleProto & computation,mlir::Region * region,bool flatten_region_arg_tuple)737 Status MlirHloBuilder::ImportComputation(const HloModuleProto& computation,
738 mlir::Region* region,
739 bool flatten_region_arg_tuple) {
740 TF_ASSIGN_OR_RETURN(auto hlo_module, CreateHloModuleFromProto(computation));
741
742 return HloFunctionImporter::ImportAsRegion(*hlo_module->entry_computation(),
743 region, &builder_,
744 flatten_region_arg_tuple);
745 }
746
ImportComputation(const HloModuleProto & computation,mlir::ModuleOp module)747 Status MlirHloBuilder::ImportComputation(const HloModuleProto& computation,
748 mlir::ModuleOp module) {
749 TF_ASSIGN_OR_RETURN(auto hlo_module, CreateHloModuleFromProto(computation));
750
751 return HloFunctionImporter::ImportAsFunc(*hlo_module->entry_computation(),
752 module, {}, &builder_,
753 /*is_main=*/false);
754 }
755
GetShapePtr(XlaOp op) const756 StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const {
757 TF_RETURN_IF_ERROR(first_error());
758 TF_RETURN_IF_ERROR(CheckOpBuilder(op));
759 auto it = handle_to_shape_.find(op.handle());
760 if (it == handle_to_shape_.end()) {
761 return InvalidArgument("No XlaOp with handle %d", op.handle());
762 }
763 return it->second.get();
764 }
765
766 } // namespace xla
767