• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
16 
17 #include <string>
18 #include <utility>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
29 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
30 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
31 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
32 #include "tensorflow/compiler/xla/comparison_util.h"
33 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
34 #include "tensorflow/compiler/xla/service/hlo_module.h"
35 #include "tensorflow/compiler/xla/service/shape_inference.h"
36 #include "tensorflow/compiler/xla/util.h"
37 
38 namespace xla {
39 
GetMlirOpName(HloOpcode opcode)40 static std::string GetMlirOpName(HloOpcode opcode) {
41   std::string op_name = HloOpcodeString(opcode);
42   absl::c_replace(op_name, '-', '_');
43   return mlir::mhlo::MhloDialect::getDialectNamespace().str() + "." + op_name;
44 }
45 
ToString(mlir::Type ty)46 static std::string ToString(mlir::Type ty) {
47   std::string str;
48   llvm::raw_string_ostream sstream(str);
49   ty.print(sstream);
50   sstream.flush();
51   return str;
52 }
53 
54 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(absl::Span<const int64_t> values,mlir::Builder * builder)55 static mlir::DenseIntElementsAttr GetI64ElementsAttr(
56     absl::Span<const int64_t> values, mlir::Builder* builder) {
57   auto ty = mlir::RankedTensorType::get({static_cast<int64_t>(values.size())},
58                                         builder->getIntegerType(64));
59   return mlir::DenseIntElementsAttr::get(
60       ty, llvm::makeArrayRef(values.data(), values.size()));
61 }
62 
ConvertPadding(absl::Span<const std::pair<int64_t,int64_t>> padding,mlir::Builder * builder)63 static mlir::DenseIntElementsAttr ConvertPadding(
64     absl::Span<const std::pair<int64_t, int64_t>> padding,
65     mlir::Builder* builder) {
66   llvm::SmallVector<int64_t, 8> elements;
67   elements.reserve(padding.size() * 2);
68   for (const auto& vals : padding) {
69     elements.push_back(vals.first);
70     elements.push_back(vals.second);
71   }
72   auto ty = mlir::RankedTensorType::get(
73       {static_cast<int64_t>(padding.size()), 2}, builder->getIntegerType(64));
74   return mlir::DenseIntElementsAttr::get(ty, elements);
75 }
76 
77 MlirHloBuilder::~MlirHloBuilder() = default;
78 
MakeXlaOp(mlir::Value val)79 StatusOr<XlaOp> MlirHloBuilder::MakeXlaOp(mlir::Value val) {
80   mlir::Type ty = val.getType();
81   auto shape = std::make_unique<Shape>(TypeToShape(ty));
82   if (shape->element_type() == PrimitiveType::PRIMITIVE_TYPE_INVALID) {
83     return InvalidArgument("unsupported type: %s", ToString(ty).c_str());
84   }
85 
86   int64_t handle = reinterpret_cast<int64_t>(val.getAsOpaquePointer());
87   handle_to_shape_[handle] = std::move(shape);
88   return XlaOp(handle, this);
89 }
90 
ConstantLiteral(const LiteralSlice & literal)91 XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) {
92   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
93     TF_ASSIGN_OR_RETURN(mlir::DenseElementsAttr attr,
94                         CreateDenseElementsAttrFromLiteral(literal, builder_));
95     auto op = builder_.create<mlir::mhlo::ConstantOp>(loc_, attr);
96     return MakeXlaOp(op);
97   });
98 }
99 
ConvGeneralDilatedInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const Window & window,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config)100 StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
101     const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
102     absl::Span<const int64_t> window_strides,
103     absl::Span<const std::pair<int64_t, int64_t>> padding,
104     absl::Span<const int64_t> lhs_dilation,
105     absl::Span<const int64_t> rhs_dilation,
106     const ConvolutionDimensionNumbers& dimension_numbers,
107     int64_t feature_group_count, int64_t batch_group_count,
108     const PrecisionConfig* precision_config) {
109   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
110                                          shape, builder_));
111   mlir::ArrayAttr config_attr;
112   if (precision_config)
113     config_attr = ConvertPrecisionConfig(precision_config, &builder_);
114   auto op = builder_.create<mlir::mhlo::ConvolutionOp>(
115       loc_, ty, GetValue(lhs), GetValue(rhs),
116       GetI64ElementsAttr(window_strides, &builder_),
117       ConvertPadding(padding, &builder_),
118       GetI64ElementsAttr(lhs_dilation, &builder_),
119       GetI64ElementsAttr(rhs_dilation, &builder_),
120       /*window_reversal=*/nullptr,
121       ConvertConvDimensionNumbers(dimension_numbers, &builder_),
122       builder_.getI64IntegerAttr(feature_group_count),
123       builder_.getI64IntegerAttr(batch_group_count), config_attr);
124   return MakeXlaOp(op);
125 }
126 
FftInternal(const Shape & shape,XlaOp operand,FftType fft_type,absl::Span<const int64_t> fft_length)127 StatusOr<XlaOp> MlirHloBuilder::FftInternal(
128     const Shape& shape, XlaOp operand, FftType fft_type,
129     absl::Span<const int64_t> fft_length) {
130   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
131                                          shape, builder_));
132   auto fft_type_attr = mlir::mhlo::symbolizeFftType(FftType_Name(fft_type));
133   auto op = builder_.create<mlir::mhlo::FftOp>(
134       loc_, ty, GetValue(operand),
135       mlir::mhlo::FftTypeAttr::get(builder_.getContext(),
136                                    fft_type_attr.getValue()),
137       GetI64ElementsAttr(fft_length, &builder_));
138   return MakeXlaOp(op);
139 }
140 
141 // TODO(b/235207091) Add actual support for the called computation.
CustomCallInternal(const std::string & call_target_name,absl::Span<const XlaOp> operands,const XlaComputation * computation,const Shape & shape,const std::string & opaque,std::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> output_operand_aliasing,const Literal * literal,std::optional<Window> window,std::optional<ConvolutionDimensionNumbers> dnums,CustomCallSchedule schedule,CustomCallApiVersion api_version)142 StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
143     const std::string& call_target_name, absl::Span<const XlaOp> operands,
144     const XlaComputation* computation, const Shape& shape,
145     const std::string& opaque,
146     std::optional<absl::Span<const Shape>> operand_shapes_with_layout,
147     bool has_side_effect,
148     absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
149         output_operand_aliasing,
150     const Literal* literal, std::optional<Window> window,
151     std::optional<ConvolutionDimensionNumbers> dnums,
152     CustomCallSchedule schedule, CustomCallApiVersion api_version) {
153   TF_RET_CHECK(output_operand_aliasing.empty())
154       << "MLIR CustomCallOp does not support output_operand_aliasing yet";
155   TF_RET_CHECK(literal == nullptr)
156       << "MLIR CustomCallOp does not support literal yet";
157   TF_RET_CHECK(!window.has_value())
158       << "MLIR CustomCallOp does not support ConvolutionDimensionNumbers yet";
159   TF_RET_CHECK(!dnums.has_value())
160       << "MLIR CustomCallOp does not support ConvolutionDimensionNumbers yet";
161   TF_RET_CHECK(schedule == CustomCallSchedule::SCHEDULE_NONE)
162       << "MLIR CustomCallOp does not support custom-call-schedule yet";
163   TF_RET_CHECK(computation == nullptr || computation->IsNull() ||
164                build_functions_ == true)
165       << "MLIR CustomCallOp with computation isn't supported when not allowed "
166          "to create functions";
167 
168   llvm::SmallVector<mlir::NamedAttribute> attributes;
169   if (operand_shapes_with_layout.has_value()) {
170     TF_ASSIGN_OR_RETURN(mlir::ArrayAttr operand_layouts,
171                         ExtractLayoutsFromShapes(
172                             operand_shapes_with_layout.value(), &builder_));
173     attributes.push_back(
174         builder_.getNamedAttr("operand_layouts", operand_layouts));
175 
176     mlir::ArrayAttr result_layouts;
177     if (shape.IsTuple()) {
178       TF_ASSIGN_OR_RETURN(result_layouts,
179                           ExtractLayoutsFromTuple(shape, &builder_));
180     } else {
181       TF_ASSIGN_OR_RETURN(result_layouts,
182                           ExtractLayoutsFromShapes({shape}, &builder_));
183     }
184     attributes.push_back(
185         builder_.getNamedAttr("result_layouts", result_layouts));
186   }
187   TF_ASSIGN_OR_RETURN(auto mlir_api_version,
188                       ConvertCustomCallApiVersion(api_version));
189   attributes.push_back(builder_.getNamedAttr(
190       "api_version", mlir::mhlo::CustomCallApiVersionAttr::get(
191                          builder_.getContext(), mlir_api_version)));
192   attributes.push_back(builder_.getNamedAttr(
193       "call_target_name", builder_.getStringAttr(call_target_name)));
194   attributes.push_back(builder_.getNamedAttr(
195       "has_side_effect", builder_.getBoolAttr(has_side_effect)));
196   attributes.push_back(
197       builder_.getNamedAttr("backend_config", builder_.getStringAttr(opaque)));
198 
199   if (computation && !computation->IsNull()) {
200     llvm::SmallVector<mlir::Attribute> computation_names;
201     for (const auto& computation_proto : computation->proto().computations()) {
202       computation_names.push_back(mlir::SymbolRefAttr::get(
203           builder_.getContext(), computation_proto.name()));
204     }
205     attributes.push_back(builder_.getNamedAttr(
206         "called_computations", builder_.getArrayAttr(computation_names)));
207 
208     // Create new function(s) to represent the called computations. As a result,
209     // this legalization may only be called during a module pass rather than the
210     // typical parallelized func pass which is not permitted to create
211     // functions.
212     TF_RETURN_IF_ERROR(ImportComputation(
213         computation->proto(),
214         builder_.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>()));
215   }
216 
217   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
218                                          shape, builder_));
219   auto op = builder_.create<mlir::mhlo::CustomCallOp>(
220       loc_, ty, GetValues(operands), attributes);
221   return MakeXlaOp(op.getResult(0));
222 }
223 
ReduceInternal(const Shape & shape,absl::Span<const XlaOp> all_operands,const XlaComputation & computation,absl::Span<const int64_t> dimensions_to_reduce)224 StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
225     const Shape& shape, absl::Span<const XlaOp> all_operands,
226     const XlaComputation& computation,
227     absl::Span<const int64_t> dimensions_to_reduce) {
228   // Reduce takes two set of variadic operands inputs and init_values.
229   // all_operands contains both of these so split operands into two parts.
230   int64_t num_args = all_operands.size() / 2;
231   auto op = builder_.create<mlir::mhlo::ReduceOp>(
232       loc_, GetValues(all_operands.first(num_args)),
233       GetValues(all_operands.subspan(num_args)),
234       GetI64ElementsAttr(dimensions_to_reduce, &builder_));
235   TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body(),
236                                        /*flatten_region_arg_tuple*/ true));
237   if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0));
238   auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
239   return MakeXlaOp(tuple);
240 }
241 
ReduceWindowInternal(const Shape & shape,XlaOp operand,XlaOp init_value,const XlaComputation & computation,Window window)242 StatusOr<XlaOp> MlirHloBuilder::ReduceWindowInternal(
243     const Shape& shape, XlaOp operand, XlaOp init_value,
244     const XlaComputation& computation, Window window) {
245   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
246                                          shape, builder_));
247   llvm::SmallVector<int64_t, 4> sizes, strides, base_dilations, win_dilations;
248   llvm::SmallVector<int64_t, 8> padding;
249   for (const auto& dim : window.dimensions()) {
250     sizes.push_back(dim.size());
251     strides.push_back(dim.stride());
252     base_dilations.push_back(dim.base_dilation());
253     win_dilations.push_back(dim.window_dilation());
254     padding.push_back(dim.padding_low());
255     padding.push_back(dim.padding_high());
256   }
257   auto padding_ty =
258       mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
259                                   builder_.getIntegerType(64));
260   auto op = builder_.create<mlir::mhlo::ReduceWindowOp>(
261       loc_, ty, GetValue(operand), GetValue(init_value),
262       GetI64ElementsAttr(sizes, &builder_),
263       GetI64ElementsAttr(strides, &builder_),
264       GetI64ElementsAttr(base_dilations, &builder_),
265       GetI64ElementsAttr(win_dilations, &builder_),
266       mlir::DenseIntElementsAttr::get(padding_ty, padding));
267   TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body(),
268                                        /*flatten_region_arg_tuple*/ true));
269   return MakeXlaOp(op.getResult(0));
270 }
271 
Iota(const Shape & shape,int64_t iota_dimension)272 XlaOp MlirHloBuilder::Iota(const Shape& shape, int64_t iota_dimension) {
273   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
274     TF_ASSIGN_OR_RETURN(
275         mlir::Type ty,
276         ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
277     auto op = builder_.create<mlir::mhlo::IotaOp>(
278         loc_, ty,
279         builder_.getIntegerAttr(builder_.getI64Type(), iota_dimension));
280     return MakeXlaOp(op);
281   });
282 }
283 
BitcastConvertTypeInternal(const Shape & shape,XlaOp operand)284 StatusOr<XlaOp> MlirHloBuilder::BitcastConvertTypeInternal(const Shape& shape,
285                                                            XlaOp operand) {
286   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
287                                          shape, builder_));
288   auto op = builder_.create<mlir::mhlo::BitcastConvertOp>(loc_, ty,
289                                                           GetValue(operand));
290   return MakeXlaOp(op);
291 }
292 
TransposeInternal(const Shape & shape,XlaOp operand,absl::Span<const int64_t> permutation)293 StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
294     const Shape& shape, XlaOp operand, absl::Span<const int64_t> permutation) {
295   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
296                                          shape, builder_));
297   auto op = builder_.create<mlir::mhlo::TransposeOp>(
298       loc_, ty, GetValue(operand), GetI64ElementsAttr(permutation, &builder_));
299   return MakeXlaOp(op);
300 }
301 
RevInternal(const Shape & shape,XlaOp operand,absl::Span<const int64_t> dimensions)302 StatusOr<XlaOp> MlirHloBuilder::RevInternal(
303     const Shape& shape, XlaOp operand, absl::Span<const int64_t> dimensions) {
304   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
305                                          shape, builder_));
306   auto op = builder_.create<mlir::mhlo::ReverseOp>(
307       loc_, ty, GetValue(operand), GetI64ElementsAttr(dimensions, &builder_));
308   return MakeXlaOp(op);
309 }
310 
SortInternal(const Shape & shape,absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64_t dimension,bool is_stable)311 StatusOr<XlaOp> MlirHloBuilder::SortInternal(const Shape& shape,
312                                              absl::Span<const XlaOp> operands,
313                                              const XlaComputation& comparator,
314                                              int64_t dimension,
315                                              bool is_stable) {
316   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
317                                          shape, builder_));
318   llvm::SmallVector<mlir::Type, 4> sort_types = {ty};
319   if (auto tuple_ty = ty.dyn_cast<mlir::TupleType>()) {
320     sort_types = llvm::to_vector<6>(tuple_ty.getTypes());
321   }
322 
323   auto op = builder_.create<mlir::mhlo::SortOp>(
324       loc_, sort_types, GetValues(operands),
325       builder_.getI64IntegerAttr(dimension), builder_.getBoolAttr(is_stable));
326   TF_RETURN_IF_ERROR(ImportComputation(comparator.proto(), &op.comparator()));
327 
328   if (ty.isa<mlir::TupleType>()) {
329     auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
330     return MakeXlaOp(tuple);
331   }
332 
333   return MakeXlaOp(op.getResult(0));
334 }
335 
WhileInternal(const Shape & shape,const XlaComputation & condition,const XlaComputation & body,XlaOp init)336 StatusOr<XlaOp> MlirHloBuilder::WhileInternal(const Shape& shape,
337                                               const XlaComputation& condition,
338                                               const XlaComputation& body,
339                                               XlaOp init) {
340   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
341                                          shape, builder_));
342 
343   llvm::SmallVector<mlir::Value> flattened_operands;
344   llvm::SmallVector<mlir::Type> flattened_operand_types;
345 
346   HloFunctionImporter::FlattenTupleType(ty, flattened_operand_types);
347   HloFunctionImporter::FlattenTupleValue(&builder_, loc_, GetValue(init),
348                                          flattened_operands);
349 
350   auto op = builder_.create<mlir::mhlo::WhileOp>(loc_, flattened_operand_types,
351                                                  flattened_operands);
352 
353   TF_RETURN_IF_ERROR(ImportComputation(condition.proto(), &op.cond(),
354                                        /*flatten_region_arg_tuple*/ true));
355   TF_RETURN_IF_ERROR(ImportComputation(body.proto(), &op.body(),
356                                        /*flatten_region_arg_tuple*/ true));
357 
358   if (ty.isa<mlir::TupleType>()) {
359     llvm::SmallVector<mlir::Value> flattened_results = op->getResults();
360     llvm::MutableArrayRef<mlir::Value> flattened_results_ref(flattened_results);
361     auto result = HloFunctionImporter::CreateTupleValue(
362         &builder_, loc_, flattened_results_ref, ty);
363     auto defining_tuple_op = result.getDefiningOp<mlir::mhlo::TupleOp>();
364     return MakeXlaOp(defining_tuple_op);
365   }
366 
367   return MakeXlaOp(op.getResult(0));
368 }
369 
ReducePrecisionInternal(const Shape & shape,XlaOp operand,const int exponent_bits,const int mantissa_bits)370 StatusOr<XlaOp> MlirHloBuilder::ReducePrecisionInternal(
371     const Shape& shape, XlaOp operand, const int exponent_bits,
372     const int mantissa_bits) {
373   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
374                                          shape, builder_));
375   auto op = builder_.create<mlir::mhlo::ReducePrecisionOp>(
376       loc_, ty, GetValue(operand), builder_.getI32IntegerAttr(exponent_bits),
377       builder_.getI32IntegerAttr(mantissa_bits));
378   return MakeXlaOp(op);
379 }
380 
GatherInternal(const Shape & shape,XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64_t> slice_sizes,bool indices_are_sorted)381 StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
382     const Shape& shape, XlaOp input, XlaOp start_indices,
383     const GatherDimensionNumbers& dimension_numbers,
384     absl::Span<const int64_t> slice_sizes, bool indices_are_sorted) {
385   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
386                                          shape, builder_));
387   auto op = builder_.create<mlir::mhlo::GatherOp>(
388       loc_, ty, GetValue(input), GetValue(start_indices),
389       ConvertGatherDimensionNumbers(dimension_numbers, &builder_),
390       GetI64ElementsAttr(slice_sizes, &builder_));
391   return MakeXlaOp(op);
392 }
393 
ScatterInternal(const Shape & shape,absl::Span<const XlaOp> inputs,XlaOp scatter_indices,absl::Span<const XlaOp> updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)394 StatusOr<XlaOp> MlirHloBuilder::ScatterInternal(
395     const Shape& shape, absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
396     absl::Span<const XlaOp> updates, const XlaComputation& update_computation,
397     const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
398     bool unique_indices) {
399   // TODO(b/230137437): Allow variadic scatter after adding mhlo support.
400   if (inputs.size() != 1) {
401     return Unimplemented("Variadic scatter not implemented in mhlo yet.");
402   }
403   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
404                                          shape, builder_));
405   auto op = builder_.create<mlir::mhlo::ScatterOp>(
406       loc_, ty, GetValue(inputs[0]), GetValue(scatter_indices),
407       GetValue(updates[0]),
408       ConvertScatterDimensionNumbers(dimension_numbers, &builder_),
409       builder_.getBoolAttr(indices_are_sorted),
410       builder_.getBoolAttr(unique_indices));
411 
412   TF_RETURN_IF_ERROR(
413       ImportComputation(update_computation.proto(), &op.update_computation()));
414   return MakeXlaOp(op.getResult(0));
415 }
416 
SetDimensionSizeInternal(const Shape & shape,XlaOp operand,XlaOp val,int64_t dimension)417 StatusOr<XlaOp> MlirHloBuilder::SetDimensionSizeInternal(const Shape& shape,
418                                                          XlaOp operand,
419                                                          XlaOp val,
420                                                          int64_t dimension) {
421   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
422                                          shape, builder_));
423   auto op = builder_.create<mlir::mhlo::SetDimensionSizeOp>(
424       loc_, ty, GetValue(operand), GetValue(val),
425       builder_.getI64IntegerAttr(dimension));
426   return MakeXlaOp(op);
427 }
428 
RngOpInternal(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)429 StatusOr<XlaOp> MlirHloBuilder::RngOpInternal(
430     RandomDistribution distribution, absl::Span<const XlaOp> parameters,
431     const Shape& shape) {
432   mlir::mhlo::RngDistributionAttr attr;
433   if (distribution == xla::RandomDistribution::RNG_UNIFORM) {
434     attr = mlir::mhlo::RngDistributionAttr::get(
435         builder_.getContext(), mlir::mhlo::RngDistribution::UNIFORM);
436   } else {
437     TF_RET_CHECK(distribution == xla::RandomDistribution::RNG_NORMAL)
438         << "Unexpected distribution: " << distribution;
439     attr = mlir::mhlo::RngDistributionAttr::get(
440         builder_.getContext(), mlir::mhlo::RngDistribution::NORMAL);
441   }
442   llvm::SmallVector<mlir::NamedAttribute, 1> attributes = {
443       builder_.getNamedAttr("rng_distribution", attr)};
444 
445   if (shape.is_dynamic())
446     return Unimplemented("RngOp with dynamic dims not supported");
447   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
448                                          shape, builder_));
449 
450   auto op = builder_.create<mlir::mhlo::RngOp>(
451       loc_, ty, GetValue(parameters[0]), GetValue(parameters[1]),
452       GetValue(
453           ConstantLiteral(LiteralUtil::CreateR1<int64_t>(shape.dimensions()))),
454       attr);
455   return MakeXlaOp(op.getResult());
456 }
457 
RngBitGeneratorInternal(const Shape & full_result_shape,RandomAlgorithm algorithm,XlaOp initial_state)458 StatusOr<XlaOp> MlirHloBuilder::RngBitGeneratorInternal(
459     const Shape& full_result_shape, RandomAlgorithm algorithm,
460     XlaOp initial_state) {
461   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
462                                          full_result_shape, builder_));
463 
464   llvm::SmallVector<mlir::Type> flattened_ret_types;
465   HloFunctionImporter::FlattenTupleType(ty, flattened_ret_types);
466 
467   auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get(
468       builder_.getContext(), *mlir::mhlo::symbolizeRngAlgorithm(algorithm));
469   auto op = builder_.create<mlir::mhlo::RngBitGeneratorOp>(
470       loc_, flattened_ret_types, algorithm_attr, GetValue(initial_state));
471 
472   if (ty.isa<mlir::TupleType>()) {
473     llvm::SmallVector<mlir::Value> flattened_results = op->getResults();
474     llvm::MutableArrayRef<mlir::Value> flattened_results_ref(flattened_results);
475     auto result = HloFunctionImporter::CreateTupleValue(
476         &builder_, loc_, flattened_results_ref, ty);
477     auto defining_tuple_op = result.getDefiningOp<mlir::mhlo::TupleOp>();
478     return MakeXlaOp(defining_tuple_op);
479   }
480 
481   return MakeXlaOp(op.getResult(0));
482 }
483 
ReshapeInternal(const Shape & shape,XlaOp operand,int64_t inferred_dimension)484 StatusOr<XlaOp> MlirHloBuilder::ReshapeInternal(const Shape& shape,
485                                                 XlaOp operand,
486                                                 int64_t inferred_dimension) {
487   TF_RETURN_IF_ERROR(first_error());
488 
489   if (inferred_dimension != -1)
490     return Unimplemented("inferred_dimension not yet supported for Reshape op");
491   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
492                                          shape, builder_));
493   mlir::Value value = GetValue(operand);
494   auto op = builder_.create<mlir::mhlo::ReshapeOp>(loc_, ty, value);
495   return MakeXlaOp(op.getResult());
496 }
497 
DotGeneralInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_number,const PrecisionConfig * precision_config)498 StatusOr<XlaOp> MlirHloBuilder::DotGeneralInternal(
499     const Shape& shape, XlaOp lhs, XlaOp rhs,
500     const DotDimensionNumbers& dimension_number,
501     const PrecisionConfig* precision_config) {
502   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
503                                          shape, builder_));
504   auto op = builder_.create<mlir::mhlo::DotGeneralOp>(
505       loc_, ty, GetValue(lhs), GetValue(rhs),
506       ConvertDotDimensionNumbers(dimension_number, &builder_),
507       ConvertPrecisionConfig(precision_config, &builder_));
508   return MakeXlaOp(op.getResult());
509 }
510 
InDimBroadcast(const Shape & shape,XlaOp operand,absl::Span<const int64_t> broadcast_dimensions)511 StatusOr<XlaOp> MlirHloBuilder::InDimBroadcast(
512     const Shape& shape, XlaOp operand,
513     absl::Span<const int64_t> broadcast_dimensions) {
514   TF_RETURN_IF_ERROR(first_error());
515   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
516                                          shape, builder_));
517   mlir::Value value = GetValue(operand);
518   auto op = builder_.create<mlir::mhlo::BroadcastInDimOp>(
519       loc_, ty, value, GetI64ElementsAttr(broadcast_dimensions, &builder_));
520   return MakeXlaOp(op.getResult());
521 }
522 
AddInstruction(HloInstructionProto && instr,HloOpcode opcode,absl::Span<const XlaOp> operands)523 StatusOr<XlaOp> MlirHloBuilder::AddInstruction(
524     HloInstructionProto&& instr, HloOpcode opcode,
525     absl::Span<const XlaOp> operands) {
526   return Unimplemented("MlirHloBuilder does not support op %s",
527                        HloOpcodeString(opcode));
528 }
529 
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction,Comparison::Type type)530 StatusOr<XlaOp> MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs,
531                                         XlaOp rhs,
532                                         ComparisonDirection direction,
533                                         Comparison::Type type) {
534   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
535                                          shape, builder_));
536   auto op = builder_.create<mlir::mhlo::CompareOp>(
537       loc_, ty, GetValue(lhs), GetValue(rhs),
538       mlir::mhlo::ComparisonDirectionAttr::get(
539           builder_.getContext(), mlir::mhlo::symbolizeComparisonDirection(
540                                      ComparisonDirectionToString(direction))
541                                      .getValue()),
542       mlir::mhlo::ComparisonTypeAttr::get(
543           builder_.getContext(),
544           mlir::mhlo::symbolizeComparisonType(ComparisonTypeToString(type))
545               .getValue()));
546   return MakeXlaOp(op.getResult());
547 }
548 
BinaryOpNoBroadcast(HloOpcode binop,const Shape & shape,XlaOp lhs,XlaOp rhs)549 XlaOp MlirHloBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
550                                           XlaOp lhs, XlaOp rhs) {
551   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
552     return CreateOp(GetMlirOpName(binop), shape, {lhs, rhs});
553   });
554 }
555 
AddOpWithShape(HloOpcode opcode,const Shape & shape,absl::Span<const XlaOp> operands)556 StatusOr<XlaOp> MlirHloBuilder::AddOpWithShape(
557     HloOpcode opcode, const Shape& shape, absl::Span<const XlaOp> operands) {
558   return CreateOp(GetMlirOpName(opcode), shape,
559                   llvm::makeArrayRef<XlaOp>(operands.data(), operands.size()));
560 }
561 
CreateToken()562 XlaOp MlirHloBuilder::CreateToken() {
563   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
564     return MakeXlaOp(builder_.create<mlir::mhlo::CreateTokenOp>(
565         loc_, mlir::mhlo::TokenType::get(builder_.getContext())));
566   });
567 }
568 
TriangularSolveInternal(const Shape & shape,XlaOp a,XlaOp b,TriangularSolveOptions options)569 StatusOr<XlaOp> MlirHloBuilder::TriangularSolveInternal(
570     const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) {
571   TF_ASSIGN_OR_RETURN(
572       mlir::Type result_ty,
573       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
574   auto op = builder_.create<mlir::mhlo::TriangularSolveOp>(
575       loc_, result_ty, GetValue(a), GetValue(b),
576       builder_.getBoolAttr(options.left_side()),
577       builder_.getBoolAttr(options.lower()),
578       builder_.getBoolAttr(options.unit_diagonal()),
579       mlir::mhlo::TransposeAttr::get(
580           builder_.getContext(),
581           ::mlir::mhlo::symbolizeTranspose(
582               TriangularSolveOptions::Transpose_Name(options.transpose_a()))
583               .getValue()));
584   return MakeXlaOp(op);
585 }
586 
CholeskyInternal(const Shape & shape,XlaOp a,bool lower)587 StatusOr<XlaOp> MlirHloBuilder::CholeskyInternal(const Shape& shape, XlaOp a,
588                                                  bool lower) {
589   TF_ASSIGN_OR_RETURN(
590       mlir::Type result_ty,
591       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
592   auto op = builder_.create<mlir::mhlo::CholeskyOp>(
593       loc_, result_ty, GetValue(a), builder_.getBoolAttr(lower));
594   return MakeXlaOp(op);
595 }
596 
InfeedWithTokenInternal(const Shape & infeed_instruction_shape,XlaOp token,const std::string & config)597 StatusOr<XlaOp> MlirHloBuilder::InfeedWithTokenInternal(
598     const Shape& infeed_instruction_shape, XlaOp token,
599     const std::string& config) {
600   TF_ASSIGN_OR_RETURN(mlir::Type result_type,
601                       ConvertShapeToType<mlir::RankedTensorType>(
602                           infeed_instruction_shape, builder_));
603   llvm::SmallVector<mlir::Type> flattened_ret_types;
604   HloFunctionImporter::FlattenTupleType(result_type, flattened_ret_types);
605 
606   mlir::ArrayAttr layout;
607   auto op = builder_.create<mlir::mhlo::InfeedOp>(loc_, flattened_ret_types,
608                                                   GetValue(token),
609                                                   /*infeed_config=*/config,
610                                                   /*layout=*/layout);
611 
612   llvm::SmallVector<mlir::Value> flattened_results = op->getResults();
613   llvm::MutableArrayRef<mlir::Value> flattened_results_ref(flattened_results);
614   auto result = HloFunctionImporter::CreateTupleValue(
615       &builder_, loc_, flattened_results_ref, result_type);
616   auto defining_tuple_op = result.getDefiningOp<mlir::mhlo::TupleOp>();
617   return MakeXlaOp(defining_tuple_op);
618 }
619 
OutfeedWithTokenInternal(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const std::string & outfeed_config)620 StatusOr<XlaOp> MlirHloBuilder::OutfeedWithTokenInternal(
621     XlaOp operand, XlaOp token, const Shape& shape_with_layout,
622     const std::string& outfeed_config) {
623   auto token_type = mlir::mhlo::TokenType::get(builder_.getContext());
624   llvm::SmallVector<mlir::Value> flattened_operands;
625   HloFunctionImporter::FlattenTupleValue(&builder_, loc_, GetValue(operand),
626                                          flattened_operands);
627   return MakeXlaOp(builder_.create<mlir::mhlo::OutfeedOp>(
628       loc_, token_type, flattened_operands, GetValue(token), outfeed_config));
629 }
630 
ConcatInDimInternal(const Shape & shape,absl::Span<const XlaOp> operands,int64_t dimension)631 StatusOr<XlaOp> MlirHloBuilder::ConcatInDimInternal(
632     const Shape& shape, absl::Span<const XlaOp> operands, int64_t dimension) {
633   TF_ASSIGN_OR_RETURN(
634       mlir::Type result_type,
635       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
636   auto mlir_operands = GetValues(operands);
637   return MakeXlaOp(builder_.create<mlir::mhlo::ConcatenateOp>(
638       loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension)));
639 }
640 
GetTupleElementInternal(const Shape & shape,XlaOp tuple_data,int64_t index)641 StatusOr<XlaOp> MlirHloBuilder::GetTupleElementInternal(const Shape& shape,
642                                                         XlaOp tuple_data,
643                                                         int64_t index) {
644   TF_ASSIGN_OR_RETURN(
645       mlir::Type result_type,
646       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
647   return MakeXlaOp(builder_.create<mlir::mhlo::GetTupleElementOp>(
648       loc_, result_type, GetValue(tuple_data),
649       builder_.getI32IntegerAttr(index)));
650 }
651 
SliceInternal(const Shape & shape,XlaOp operand,absl::Span<const int64_t> start_indices,absl::Span<const int64_t> limit_indices,absl::Span<const int64_t> strides)652 StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
653     const Shape& shape, XlaOp operand, absl::Span<const int64_t> start_indices,
654     absl::Span<const int64_t> limit_indices,
655     absl::Span<const int64_t> strides) {
656   return MakeXlaOp(builder_.create<mlir::mhlo::SliceOp>(
657       loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_),
658       GetI64ElementsAttr(limit_indices, &builder_),
659       GetI64ElementsAttr(strides, &builder_)));
660 }
661 
DynamicSliceInternal(const Shape & shape,XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64_t> slice_sizes)662 StatusOr<XlaOp> MlirHloBuilder::DynamicSliceInternal(
663     const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
664     absl::Span<const int64_t> slice_sizes) {
665   TF_ASSIGN_OR_RETURN(
666       mlir::Type result_ty,
667       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
668   return MakeXlaOp(builder_.create<mlir::mhlo::DynamicSliceOp>(
669       loc_, result_ty, GetValue(operand), GetValues(start_indices),
670       GetI64ElementsAttr(slice_sizes, &builder_)));
671 }
672 
DynamicUpdateSliceInternal(const Shape & shape,XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)673 StatusOr<XlaOp> MlirHloBuilder::DynamicUpdateSliceInternal(
674     const Shape& shape, XlaOp operand, XlaOp update,
675     absl::Span<const XlaOp> start_indices) {
676   TF_ASSIGN_OR_RETURN(
677       mlir::Type result_ty,
678       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
679   return MakeXlaOp(builder_.create<mlir::mhlo::DynamicUpdateSliceOp>(
680       loc_, result_ty, GetValue(operand), GetValue(update),
681       GetValues(start_indices)));
682 }
683 
PadInternal(const Shape & shape,XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)684 StatusOr<XlaOp> MlirHloBuilder::PadInternal(
685     const Shape& shape, XlaOp operand, XlaOp padding_value,
686     const PaddingConfig& padding_config) {
687   TF_ASSIGN_OR_RETURN(
688       mlir::Type result_type,
689       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
690   llvm::SmallVector<int64_t> low, high, internal;
691   for (auto& dimension : padding_config.dimensions()) {
692     low.push_back(dimension.edge_padding_low());
693     high.push_back(dimension.edge_padding_high());
694     internal.push_back(dimension.interior_padding());
695   }
696   return MakeXlaOp(builder_.create<mlir::mhlo::PadOp>(
697       loc_, result_type, GetValue(operand), GetValue(padding_value),
698       GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_),
699       GetI64ElementsAttr(internal, &builder_)));
700 }
701 
TupleInternal(const Shape & shape,absl::Span<const XlaOp> elements)702 StatusOr<XlaOp> MlirHloBuilder::TupleInternal(
703     const Shape& shape, absl::Span<const XlaOp> elements) {
704   mlir::SmallVector<mlir::Value, 4> operands;
705   for (auto& element : elements) {
706     operands.push_back(GetValue(element));
707   }
708   return MakeXlaOp(builder_.create<mlir::mhlo::TupleOp>(loc_, operands));
709 }
710 
CreateOp(const std::string & op_name,const Shape & shape,llvm::ArrayRef<XlaOp> operands,llvm::ArrayRef<mlir::NamedAttribute> attributes)711 StatusOr<XlaOp> MlirHloBuilder::CreateOp(
712     const std::string& op_name, const Shape& shape,
713     llvm::ArrayRef<XlaOp> operands,
714     llvm::ArrayRef<mlir::NamedAttribute> attributes) {
715   llvm::SmallVector<mlir::Value, 4> operand_values;
716   operand_values.reserve(operands.size());
717   for (XlaOp xla_op : operands) {
718     operand_values.push_back(GetValue(xla_op));
719   }
720   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
721                                          shape, builder_));
722   mlir::OperationState state(loc_, op_name, operand_values, {ty}, attributes);
723   mlir::Operation* op = builder_.create(state);
724   return MakeXlaOp(op->getResult(0));
725 }
726 
CreateHloModuleFromProto(const HloModuleProto & proto)727 StatusOr<std::unique_ptr<xla::HloModule>> CreateHloModuleFromProto(
728     const HloModuleProto& proto) {
729   TF_ASSIGN_OR_RETURN(
730       auto module_config,
731       xla::HloModule::CreateModuleConfigFromProto(proto, xla::DebugOptions()));
732   TF_ASSIGN_OR_RETURN(auto hlo_module,
733                       xla::HloModule::CreateFromProto(proto, module_config));
734   return hlo_module;
735 }
736 
ImportComputation(const HloModuleProto & computation,mlir::Region * region,bool flatten_region_arg_tuple)737 Status MlirHloBuilder::ImportComputation(const HloModuleProto& computation,
738                                          mlir::Region* region,
739                                          bool flatten_region_arg_tuple) {
740   TF_ASSIGN_OR_RETURN(auto hlo_module, CreateHloModuleFromProto(computation));
741 
742   return HloFunctionImporter::ImportAsRegion(*hlo_module->entry_computation(),
743                                              region, &builder_,
744                                              flatten_region_arg_tuple);
745 }
746 
ImportComputation(const HloModuleProto & computation,mlir::ModuleOp module)747 Status MlirHloBuilder::ImportComputation(const HloModuleProto& computation,
748                                          mlir::ModuleOp module) {
749   TF_ASSIGN_OR_RETURN(auto hlo_module, CreateHloModuleFromProto(computation));
750 
751   return HloFunctionImporter::ImportAsFunc(*hlo_module->entry_computation(),
752                                            module, {}, &builder_,
753                                            /*is_main=*/false);
754 }
755 
GetShapePtr(XlaOp op) const756 StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const {
757   TF_RETURN_IF_ERROR(first_error());
758   TF_RETURN_IF_ERROR(CheckOpBuilder(op));
759   auto it = handle_to_shape_.find(op.handle());
760   if (it == handle_to_shape_.end()) {
761     return InvalidArgument("No XlaOp with handle %d", op.handle());
762   }
763   return it->second.get();
764 }
765 
766 }  // namespace xla
767