• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
16 
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/Support/raw_ostream.h"
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
21 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
22 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
23 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
24 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
25 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
26 #include "tensorflow/compiler/xla/comparison_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/shape_inference.h"
29 #include "tensorflow/compiler/xla/util.h"
30 
31 namespace xla {
32 
GetMlirOpName(HloOpcode opcode)33 static std::string GetMlirOpName(HloOpcode opcode) {
34   std::string op_name = HloOpcodeString(opcode);
35   absl::c_replace(op_name, '-', '_');
36   return mlir::mhlo::MhloDialect::getDialectNamespace().str() + "." + op_name;
37 }
38 
ToString(mlir::Type ty)39 static std::string ToString(mlir::Type ty) {
40   std::string str;
41   llvm::raw_string_ostream sstream(str);
42   ty.print(sstream);
43   sstream.flush();
44   return str;
45 }
46 
47 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(absl::Span<const int64> values,mlir::Builder * builder)48 static mlir::DenseIntElementsAttr GetI64ElementsAttr(
49     absl::Span<const int64> values, mlir::Builder* builder) {
50   auto ty = mlir::RankedTensorType::get({static_cast<int64_t>(values.size())},
51                                         builder->getIntegerType(64));
52   return mlir::DenseIntElementsAttr::get(
53       ty, llvm::makeArrayRef(values.data(), values.size()));
54 }
55 
ConvertPadding(absl::Span<const std::pair<int64_t,int64_t>> padding,mlir::Builder * builder)56 static mlir::DenseIntElementsAttr ConvertPadding(
57     absl::Span<const std::pair<int64_t, int64_t>> padding,
58     mlir::Builder* builder) {
59   llvm::SmallVector<int64_t, 8> elements;
60   elements.reserve(padding.size() * 2);
61   for (const auto& vals : padding) {
62     elements.push_back(vals.first);
63     elements.push_back(vals.second);
64   }
65   auto ty = mlir::RankedTensorType::get(
66       {static_cast<int64_t>(padding.size()), 2}, builder->getIntegerType(64));
67   return mlir::DenseIntElementsAttr::get(ty, elements);
68 }
69 
70 MlirHloBuilder::~MlirHloBuilder() = default;
71 
MakeXlaOp(mlir::Value val)72 StatusOr<XlaOp> MlirHloBuilder::MakeXlaOp(mlir::Value val) {
73   mlir::Type ty = val.getType();
74   auto shape = std::make_unique<Shape>(TypeToShape(ty));
75   if (shape->element_type() == PrimitiveType::PRIMITIVE_TYPE_INVALID) {
76     return InvalidArgument("unsupported type: %s", ToString(ty).c_str());
77   }
78 
79   int64_t handle = reinterpret_cast<int64_t>(val.getAsOpaquePointer());
80   handle_to_shape_[handle] = std::move(shape);
81   return XlaOp(handle, this);
82 }
83 
ConstantLiteral(const LiteralSlice & literal)84 XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) {
85   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
86     TF_ASSIGN_OR_RETURN(mlir::DenseElementsAttr attr,
87                         CreateDenseElementsAttrFromLiteral(literal, builder_));
88     auto op = builder_.create<mlir::mhlo::ConstOp>(loc_, attr);
89     return MakeXlaOp(op);
90   });
91 }
92 
ConvGeneralDilatedInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const Window & window,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64_t feature_group_count,int64_t batch_group_count,const PrecisionConfig * precision_config)93 StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
94     const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
95     absl::Span<const int64> window_strides,
96     absl::Span<const std::pair<int64, int64>> padding,
97     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
98     const ConvolutionDimensionNumbers& dimension_numbers,
99     int64_t feature_group_count, int64_t batch_group_count,
100     const PrecisionConfig* precision_config) {
101   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
102                                          shape, builder_));
103   mlir::ArrayAttr config_attr;
104   if (precision_config)
105     config_attr = ConvertPrecisionConfig(precision_config, &builder_);
106   auto op = builder_.create<mlir::mhlo::ConvOp>(
107       loc_, ty, GetValue(lhs), GetValue(rhs),
108       GetI64ElementsAttr(window_strides, &builder_),
109       ConvertPadding(padding, &builder_),
110       GetI64ElementsAttr(lhs_dilation, &builder_),
111       GetI64ElementsAttr(rhs_dilation, &builder_),
112       /*window_reversal=*/nullptr,
113       ConvertConvDimensionNumbers(dimension_numbers, &builder_),
114       builder_.getI64IntegerAttr(feature_group_count),
115       builder_.getI64IntegerAttr(batch_group_count), config_attr);
116   return MakeXlaOp(op);
117 }
118 
FftInternal(const Shape & shape,XlaOp operand,FftType fft_type,absl::Span<const int64> fft_length)119 StatusOr<XlaOp> MlirHloBuilder::FftInternal(
120     const Shape& shape, XlaOp operand, FftType fft_type,
121     absl::Span<const int64> fft_length) {
122   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
123                                          shape, builder_));
124   auto op = builder_.create<mlir::mhlo::FftOp>(
125       loc_, ty, GetValue(operand),
126       builder_.getStringAttr(FftType_Name(fft_type)),
127       GetI64ElementsAttr(fft_length, &builder_));
128   return MakeXlaOp(op);
129 }
130 
CustomCallInternal(const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal,absl::optional<Window> window,absl::optional<ConvolutionDimensionNumbers> dnums,CustomCallSchedule schedule,CustomCallApiVersion api_version)131 StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
132     const string& call_target_name, absl::Span<const XlaOp> operands,
133     const Shape& shape, const string& opaque,
134     absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
135     bool has_side_effect,
136     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
137         output_operand_aliasing,
138     const Literal* literal, absl::optional<Window> window,
139     absl::optional<ConvolutionDimensionNumbers> dnums,
140     CustomCallSchedule schedule, CustomCallApiVersion api_version) {
141   if (operand_shapes_with_layout.has_value())
142     return Unimplemented(
143         "CustomCall doesn't support operands shapes with layout");
144   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
145                                          shape, builder_));
146   TF_ASSIGN_OR_RETURN(auto mlir_api_version,
147                       ConvertCustomCallApiVersion(api_version));
148   TF_RET_CHECK(output_operand_aliasing.empty())
149       << "MLIR CustomCallOp does not support output_operand_aliasing yet";
150   TF_RET_CHECK(literal == nullptr)
151       << "MLIR CustomCallOp does not support literal yet";
152   TF_RET_CHECK(!window.has_value())
153       << "MLIR CustomCallOp does not support ConvolutionDimensionNumbers yet";
154   TF_RET_CHECK(!dnums.has_value())
155       << "MLIR CustomCallOp does not support ConvolutionDimensionNumbers yet";
156   TF_RET_CHECK(schedule == CustomCallSchedule::SCHEDULE_NONE)
157       << "MLIR CustomCallOp does not support custom-call-schedule yet";
158   auto op = builder_.create<mlir::mhlo::CustomCallOp>(
159       loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
160       /*has_side_effect=*/builder_.getBoolAttr(has_side_effect),
161       builder_.getStringAttr(opaque),
162       /*api_version=*/
163       mlir::mhlo::CustomCallApiVersionAttr::get(builder_.getContext(),
164                                                 mlir_api_version));
165   return MakeXlaOp(op.getResult(0));
166 }
167 
ReduceInternal(const Shape & shape,absl::Span<const XlaOp> all_operands,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)168 StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
169     const Shape& shape, absl::Span<const XlaOp> all_operands,
170     const XlaComputation& computation,
171     absl::Span<const int64> dimensions_to_reduce) {
172   // Reduce takes two set of variadic operands inputs and init_values.
173   // all_operands contains both of these so split operands into two parts.
174   int64_t num_args = all_operands.size() / 2;
175   auto op = builder_.create<mlir::mhlo::ReduceOp>(
176       loc_, GetValues(all_operands.first(num_args)),
177       GetValues(all_operands.subspan(num_args)),
178       GetI64ElementsAttr(dimensions_to_reduce, &builder_));
179   TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
180   if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0));
181   auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
182   return MakeXlaOp(tuple);
183 }
184 
ReduceWindowInternal(const Shape & shape,XlaOp operand,XlaOp init_value,const XlaComputation & computation,Window window)185 StatusOr<XlaOp> MlirHloBuilder::ReduceWindowInternal(
186     const Shape& shape, XlaOp operand, XlaOp init_value,
187     const XlaComputation& computation, Window window) {
188   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
189                                          shape, builder_));
190   llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
191   llvm::SmallVector<int64, 8> padding;
192   for (const auto& dim : window.dimensions()) {
193     sizes.push_back(dim.size());
194     strides.push_back(dim.stride());
195     base_dilations.push_back(dim.base_dilation());
196     win_dilations.push_back(dim.window_dilation());
197     padding.push_back(dim.padding_low());
198     padding.push_back(dim.padding_high());
199   }
200   auto padding_ty =
201       mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
202                                   builder_.getIntegerType(64));
203   auto op = builder_.create<mlir::mhlo::ReduceWindowOp>(
204       loc_, ty, GetValue(operand), GetValue(init_value),
205       GetI64ElementsAttr(sizes, &builder_),
206       GetI64ElementsAttr(strides, &builder_),
207       GetI64ElementsAttr(base_dilations, &builder_),
208       GetI64ElementsAttr(win_dilations, &builder_),
209       mlir::DenseIntElementsAttr::get(padding_ty, padding));
210   TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
211   return MakeXlaOp(op.getResult(0));
212 }
213 
Iota(const Shape & shape,int64_t iota_dimension)214 XlaOp MlirHloBuilder::Iota(const Shape& shape, int64_t iota_dimension) {
215   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
216     TF_ASSIGN_OR_RETURN(
217         mlir::Type ty,
218         ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
219     auto op = builder_.create<mlir::mhlo::IotaOp>(
220         loc_, ty,
221         builder_.getIntegerAttr(builder_.getI64Type(), iota_dimension));
222     return MakeXlaOp(op);
223   });
224 }
225 
BitcastConvertTypeInternal(const Shape & shape,XlaOp operand)226 StatusOr<XlaOp> MlirHloBuilder::BitcastConvertTypeInternal(const Shape& shape,
227                                                            XlaOp operand) {
228   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
229                                          shape, builder_));
230   auto op = builder_.create<mlir::mhlo::BitcastConvertOp>(loc_, ty,
231                                                           GetValue(operand));
232   return MakeXlaOp(op);
233 }
234 
TransposeInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> permutation)235 StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
236     const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
237   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
238                                          shape, builder_));
239   auto op = builder_.create<mlir::mhlo::TransposeOp>(
240       loc_, ty, GetValue(operand), GetI64ElementsAttr(permutation, &builder_));
241   return MakeXlaOp(op);
242 }
243 
RevInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> dimensions)244 StatusOr<XlaOp> MlirHloBuilder::RevInternal(
245     const Shape& shape, XlaOp operand, absl::Span<const int64> dimensions) {
246   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
247                                          shape, builder_));
248   auto op = builder_.create<mlir::mhlo::ReverseOp>(
249       loc_, ty, GetValue(operand), GetI64ElementsAttr(dimensions, &builder_));
250   return MakeXlaOp(op);
251 }
252 
SortInternal(const Shape & shape,absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64_t dimension,bool is_stable)253 StatusOr<XlaOp> MlirHloBuilder::SortInternal(const Shape& shape,
254                                              absl::Span<const XlaOp> operands,
255                                              const XlaComputation& comparator,
256                                              int64_t dimension,
257                                              bool is_stable) {
258   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
259                                          shape, builder_));
260   llvm::SmallVector<mlir::Type, 4> sort_types = {ty};
261   if (auto tuple_ty = ty.dyn_cast<mlir::TupleType>()) {
262     sort_types = llvm::to_vector<6>(tuple_ty.getTypes());
263   }
264 
265   auto op = builder_.create<mlir::mhlo::SortOp>(
266       loc_, sort_types, GetValues(operands),
267       builder_.getI64IntegerAttr(dimension), builder_.getBoolAttr(is_stable));
268   TF_RETURN_IF_ERROR(ImportComputation(comparator.proto(), &op.comparator()));
269 
270   if (ty.isa<mlir::TupleType>()) {
271     auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
272     return MakeXlaOp(tuple);
273   }
274 
275   return MakeXlaOp(op.getResult(0));
276 }
277 
WhileInternal(const Shape & shape,const XlaComputation & condition,const XlaComputation & body,XlaOp init)278 StatusOr<XlaOp> MlirHloBuilder::WhileInternal(const Shape& shape,
279                                               const XlaComputation& condition,
280                                               const XlaComputation& body,
281                                               XlaOp init) {
282   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
283                                          shape, builder_));
284   auto op = builder_.create<mlir::mhlo::WhileOp>(loc_, ty, GetValue(init));
285   TF_RETURN_IF_ERROR(ImportComputation(condition.proto(), &op.cond()));
286   TF_RETURN_IF_ERROR(ImportComputation(body.proto(), &op.body()));
287   // TODO(jpienaar): Support multi-operand while op.
288   if (op.getNumResults() != 1)
289     return Unimplemented("Only single result MHLO WhileOp's can be import.");
290   return MakeXlaOp(op.getResult(0));
291 }
292 
ReducePrecisionInternal(const Shape & shape,XlaOp operand,const int exponent_bits,const int mantissa_bits)293 StatusOr<XlaOp> MlirHloBuilder::ReducePrecisionInternal(
294     const Shape& shape, XlaOp operand, const int exponent_bits,
295     const int mantissa_bits) {
296   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
297                                          shape, builder_));
298   auto op = builder_.create<mlir::mhlo::ReducePrecisionOp>(
299       loc_, ty, GetValue(operand), builder_.getI32IntegerAttr(exponent_bits),
300       builder_.getI32IntegerAttr(mantissa_bits));
301   return MakeXlaOp(op);
302 }
303 
GatherInternal(const Shape & shape,XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)304 StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
305     const Shape& shape, XlaOp input, XlaOp start_indices,
306     const GatherDimensionNumbers& dimension_numbers,
307     absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
308   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
309                                          shape, builder_));
310   auto op = builder_.create<mlir::mhlo::GatherOp>(
311       loc_, ty, GetValue(input), GetValue(start_indices),
312       ConvertGatherDimensionNumbers(dimension_numbers, &builder_),
313       GetI64ElementsAttr(slice_sizes, &builder_));
314   return MakeXlaOp(op);
315 }
316 
ScatterInternal(const Shape & shape,XlaOp input,XlaOp scatter_indices,XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)317 StatusOr<XlaOp> MlirHloBuilder::ScatterInternal(
318     const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
319     const XlaComputation& update_computation,
320     const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
321     bool unique_indices) {
322   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
323                                          shape, builder_));
324   auto op = builder_.create<mlir::mhlo::ScatterOp>(
325       loc_, ty, GetValue(input), GetValue(scatter_indices), GetValue(updates),
326       ConvertScatterDimensionNumbers(dimension_numbers, &builder_),
327       builder_.getBoolAttr(indices_are_sorted),
328       builder_.getBoolAttr(unique_indices));
329 
330   TF_RETURN_IF_ERROR(
331       ImportComputation(update_computation.proto(), &op.update_computation()));
332   return MakeXlaOp(op);
333 }
334 
SetDimensionSizeInternal(const Shape & shape,XlaOp operand,XlaOp val,int64_t dimension)335 StatusOr<XlaOp> MlirHloBuilder::SetDimensionSizeInternal(const Shape& shape,
336                                                          XlaOp operand,
337                                                          XlaOp val,
338                                                          int64_t dimension) {
339   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
340                                          shape, builder_));
341   auto op = builder_.create<mlir::mhlo::SetDimensionSizeOp>(
342       loc_, ty, GetValue(operand), GetValue(val),
343       builder_.getI64IntegerAttr(dimension));
344   return MakeXlaOp(op);
345 }
346 
RngOpInternal(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)347 StatusOr<XlaOp> MlirHloBuilder::RngOpInternal(
348     RandomDistribution distribution, absl::Span<const XlaOp> parameters,
349     const Shape& shape) {
350   // TODO(hinsu): Introduce RngOp in the HLO dialect in MLIR and then RngUniform
351   // and RngNormal can be mapped to the new op.
352   std::string op_name;
353   if (distribution == xla::RandomDistribution::RNG_UNIFORM) {
354     op_name = "mhlo.rng_uniform";
355   } else {
356     TF_RET_CHECK(distribution == xla::RandomDistribution::RNG_NORMAL)
357         << "Unexpected distribution: " << distribution;
358     op_name = "mhlo.rng_normal";
359   }
360 
361   if (shape.is_dynamic())
362     return Unimplemented("RngOp with dynamic dims not supported");
363   llvm::SmallVector<XlaOp, 3> operands;
364   operands.append(parameters.begin(), parameters.end());
365   operands.push_back(
366       ConstantLiteral(LiteralUtil::CreateR1<int64>(shape.dimensions())));
367   return CreateOp(op_name, shape, operands);
368 }
369 
RngBitGeneratorInternal(const Shape & full_result_shape,RandomAlgorithm algorithm,XlaOp initial_state)370 StatusOr<XlaOp> MlirHloBuilder::RngBitGeneratorInternal(
371     const Shape& full_result_shape, RandomAlgorithm algorithm,
372     XlaOp initial_state) {
373   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
374                                          full_result_shape, builder_));
375   auto op = builder_.create<mlir::mhlo::RngBitGeneratorOp>(
376       loc_, ty, builder_.getI32IntegerAttr(algorithm), GetValue(initial_state));
377   return MakeXlaOp(op);
378 }
379 
ReshapeInternal(const Shape & shape,XlaOp operand,int64_t inferred_dimension)380 StatusOr<XlaOp> MlirHloBuilder::ReshapeInternal(const Shape& shape,
381                                                 XlaOp operand,
382                                                 int64_t inferred_dimension) {
383   TF_RETURN_IF_ERROR(first_error());
384 
385   if (inferred_dimension != -1)
386     return Unimplemented("inferred_dimension not yet supported for Reshape op");
387   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
388                                          shape, builder_));
389   mlir::Value value = GetValue(operand);
390   auto op = builder_.create<mlir::mhlo::ReshapeOp>(loc_, ty, value);
391   return MakeXlaOp(op.getResult());
392 }
393 
DotGeneralInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_number,const PrecisionConfig * precision_config)394 StatusOr<XlaOp> MlirHloBuilder::DotGeneralInternal(
395     const Shape& shape, XlaOp lhs, XlaOp rhs,
396     const DotDimensionNumbers& dimension_number,
397     const PrecisionConfig* precision_config) {
398   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
399                                          shape, builder_));
400   auto op = builder_.create<mlir::mhlo::DotGeneralOp>(
401       loc_, ty, GetValue(lhs), GetValue(rhs),
402       ConvertDotDimensionNumbers(dimension_number, &builder_),
403       ConvertPrecisionConfig(precision_config, &builder_));
404   return MakeXlaOp(op.getResult());
405 }
406 
InDimBroadcast(const Shape & shape,XlaOp operand,absl::Span<const int64> broadcast_dimensions)407 StatusOr<XlaOp> MlirHloBuilder::InDimBroadcast(
408     const Shape& shape, XlaOp operand,
409     absl::Span<const int64> broadcast_dimensions) {
410   TF_RETURN_IF_ERROR(first_error());
411   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
412                                          shape, builder_));
413   mlir::Value value = GetValue(operand);
414   auto op = builder_.create<mlir::mhlo::BroadcastInDimOp>(
415       loc_, ty, value, GetI64ElementsAttr(broadcast_dimensions, &builder_));
416   return MakeXlaOp(op.getResult());
417 }
418 
AddInstruction(HloInstructionProto && instr,HloOpcode opcode,absl::Span<const XlaOp> operands)419 StatusOr<XlaOp> MlirHloBuilder::AddInstruction(
420     HloInstructionProto&& instr, HloOpcode opcode,
421     absl::Span<const XlaOp> operands) {
422   return Unimplemented("MlirHloBuilder does not support op %s",
423                        HloOpcodeString(opcode));
424 }
425 
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction,Comparison::Type type)426 StatusOr<XlaOp> MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs,
427                                         XlaOp rhs,
428                                         ComparisonDirection direction,
429                                         Comparison::Type type) {
430   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
431                                          shape, builder_));
432   auto op = builder_.create<mlir::mhlo::CompareOp>(
433       loc_, ty, GetValue(lhs), GetValue(rhs),
434       builder_.getStringAttr(ComparisonDirectionToString(direction)),
435       builder_.getStringAttr(ComparisonTypeToString(type)));
436   return MakeXlaOp(op.getResult());
437 }
438 
BinaryOpNoBroadcast(HloOpcode binop,const Shape & shape,XlaOp lhs,XlaOp rhs)439 XlaOp MlirHloBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
440                                           XlaOp lhs, XlaOp rhs) {
441   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
442     return CreateOp(GetMlirOpName(binop), shape, {lhs, rhs});
443   });
444 }
445 
AddOpWithShape(HloOpcode opcode,const Shape & shape,absl::Span<const XlaOp> operands)446 StatusOr<XlaOp> MlirHloBuilder::AddOpWithShape(
447     HloOpcode opcode, const Shape& shape, absl::Span<const XlaOp> operands) {
448   return CreateOp(GetMlirOpName(opcode), shape,
449                   llvm::makeArrayRef<XlaOp>(operands.data(), operands.size()));
450 }
451 
CreateToken()452 XlaOp MlirHloBuilder::CreateToken() {
453   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
454     return MakeXlaOp(builder_.create<mlir::mhlo::CreateTokenOp>(
455         loc_, mlir::mhlo::TokenType::get(builder_.getContext())));
456   });
457 }
458 
TriangularSolveInternal(const Shape & shape,XlaOp a,XlaOp b,TriangularSolveOptions options)459 StatusOr<XlaOp> MlirHloBuilder::TriangularSolveInternal(
460     const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) {
461   TF_ASSIGN_OR_RETURN(
462       mlir::Type result_ty,
463       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
464   auto op = builder_.create<mlir::mhlo::TriangularSolveOp>(
465       loc_, result_ty, GetValue(a), GetValue(b),
466       builder_.getBoolAttr(options.left_side()),
467       builder_.getBoolAttr(options.lower()),
468       builder_.getBoolAttr(options.unit_diagonal()),
469       builder_.getStringAttr(
470           TriangularSolveOptions::Transpose_Name(options.transpose_a())));
471   return MakeXlaOp(op);
472 }
473 
CholeskyInternal(const Shape & shape,XlaOp a,bool lower)474 StatusOr<XlaOp> MlirHloBuilder::CholeskyInternal(const Shape& shape, XlaOp a,
475                                                  bool lower) {
476   TF_ASSIGN_OR_RETURN(
477       mlir::Type result_ty,
478       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
479   auto op = builder_.create<mlir::mhlo::CholeskyOp>(
480       loc_, result_ty, GetValue(a), builder_.getBoolAttr(lower));
481   return MakeXlaOp(op);
482 }
483 
InfeedWithTokenInternal(const Shape & infeed_instruction_shape,XlaOp token,const string & config)484 StatusOr<XlaOp> MlirHloBuilder::InfeedWithTokenInternal(
485     const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
486   TF_ASSIGN_OR_RETURN(mlir::Type result_type,
487                       ConvertShapeToType<mlir::RankedTensorType>(
488                           infeed_instruction_shape, builder_));
489   mlir::ArrayAttr layout;
490   return MakeXlaOp(
491       builder_.create<mlir::mhlo::InfeedOp>(loc_, result_type, GetValue(token),
492                                             /*infeed_config=*/config,
493                                             /*layout=*/layout));
494 }
495 
OutfeedWithTokenInternal(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)496 StatusOr<XlaOp> MlirHloBuilder::OutfeedWithTokenInternal(
497     XlaOp operand, XlaOp token, const Shape& shape_with_layout,
498     const string& outfeed_config) {
499   auto token_type = mlir::mhlo::TokenType::get(builder_.getContext());
500   return MakeXlaOp(builder_.create<mlir::mhlo::OutfeedOp>(
501       loc_, token_type, GetValue(operand), GetValue(token), outfeed_config));
502 }
503 
ConcatInDimInternal(const Shape & shape,absl::Span<const XlaOp> operands,int64_t dimension)504 StatusOr<XlaOp> MlirHloBuilder::ConcatInDimInternal(
505     const Shape& shape, absl::Span<const XlaOp> operands, int64_t dimension) {
506   TF_ASSIGN_OR_RETURN(
507       mlir::Type result_type,
508       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
509   auto mlir_operands = GetValues(operands);
510   return MakeXlaOp(builder_.create<mlir::mhlo::ConcatenateOp>(
511       loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension)));
512 }
513 
GetTupleElementInternal(const Shape & shape,XlaOp tuple_data,int64_t index)514 StatusOr<XlaOp> MlirHloBuilder::GetTupleElementInternal(const Shape& shape,
515                                                         XlaOp tuple_data,
516                                                         int64_t index) {
517   TF_ASSIGN_OR_RETURN(
518       mlir::Type result_type,
519       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
520   return MakeXlaOp(builder_.create<mlir::mhlo::GetTupleElementOp>(
521       loc_, result_type, GetValue(tuple_data),
522       builder_.getI32IntegerAttr(index)));
523 }
524 
SliceInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)525 StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
526     const Shape& shape, XlaOp operand, absl::Span<const int64> start_indices,
527     absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
528   return MakeXlaOp(builder_.create<mlir::mhlo::SliceOp>(
529       loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_),
530       GetI64ElementsAttr(limit_indices, &builder_),
531       GetI64ElementsAttr(strides, &builder_)));
532 }
533 
DynamicSliceInternal(const Shape & shape,XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)534 StatusOr<XlaOp> MlirHloBuilder::DynamicSliceInternal(
535     const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
536     absl::Span<const int64> slice_sizes) {
537   TF_ASSIGN_OR_RETURN(
538       mlir::Type result_ty,
539       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
540   return MakeXlaOp(builder_.create<mlir::mhlo::DynamicSliceOp>(
541       loc_, result_ty, GetValue(operand), GetValues(start_indices),
542       GetI64ElementsAttr(slice_sizes, &builder_)));
543 }
544 
DynamicUpdateSliceInternal(const Shape & shape,XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)545 StatusOr<XlaOp> MlirHloBuilder::DynamicUpdateSliceInternal(
546     const Shape& shape, XlaOp operand, XlaOp update,
547     absl::Span<const XlaOp> start_indices) {
548   TF_ASSIGN_OR_RETURN(
549       mlir::Type result_ty,
550       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
551   return MakeXlaOp(builder_.create<mlir::mhlo::DynamicUpdateSliceOp>(
552       loc_, result_ty, GetValue(operand), GetValue(update),
553       GetValues(start_indices)));
554 }
555 
PadInternal(const Shape & shape,XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)556 StatusOr<XlaOp> MlirHloBuilder::PadInternal(
557     const Shape& shape, XlaOp operand, XlaOp padding_value,
558     const PaddingConfig& padding_config) {
559   TF_ASSIGN_OR_RETURN(
560       mlir::Type result_type,
561       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
562   std::vector<int64> low;
563   std::vector<int64> high;
564   std::vector<int64> internal;
565   for (auto& dimension : padding_config.dimensions()) {
566     low.push_back(dimension.edge_padding_low());
567     high.push_back(dimension.edge_padding_high());
568     internal.push_back(dimension.interior_padding());
569   }
570   return MakeXlaOp(builder_.create<mlir::mhlo::PadOp>(
571       loc_, result_type, GetValue(operand), GetValue(padding_value),
572       GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_),
573       GetI64ElementsAttr(internal, &builder_)));
574 }
575 
TupleInternal(const Shape & shape,absl::Span<const XlaOp> elements)576 StatusOr<XlaOp> MlirHloBuilder::TupleInternal(
577     const Shape& shape, absl::Span<const XlaOp> elements) {
578   mlir::SmallVector<mlir::Value, 4> operands;
579   for (auto& element : elements) {
580     operands.push_back(GetValue(element));
581   }
582   return MakeXlaOp(builder_.create<mlir::mhlo::TupleOp>(loc_, operands));
583 }
584 
CreateOp(const std::string & op_name,const Shape & shape,llvm::ArrayRef<XlaOp> operands,llvm::ArrayRef<mlir::NamedAttribute> attributes)585 StatusOr<XlaOp> MlirHloBuilder::CreateOp(
586     const std::string& op_name, const Shape& shape,
587     llvm::ArrayRef<XlaOp> operands,
588     llvm::ArrayRef<mlir::NamedAttribute> attributes) {
589   llvm::SmallVector<mlir::Value, 4> operand_values;
590   operand_values.reserve(operands.size());
591   for (XlaOp xla_op : operands) {
592     operand_values.push_back(GetValue(xla_op));
593   }
594   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
595                                          shape, builder_));
596   mlir::OperationState state(loc_, op_name, operand_values, {ty}, attributes);
597   mlir::Operation* op = builder_.createOperation(state);
598   return MakeXlaOp(op->getResult(0));
599 }
600 
ImportComputation(const HloModuleProto & computation,mlir::Region * region)601 Status MlirHloBuilder::ImportComputation(const HloModuleProto& computation,
602                                          mlir::Region* region) {
603   TF_ASSIGN_OR_RETURN(auto module_config,
604                       xla::HloModule::CreateModuleConfigFromProto(
605                           computation, xla::DebugOptions()));
606   TF_ASSIGN_OR_RETURN(auto hlo_module, xla::HloModule::CreateFromProto(
607                                            computation, module_config));
608 
609   return HloFunctionImporter::ImportAsRegion(*hlo_module->entry_computation(),
610                                              region, &builder_);
611 }
612 
GetShapePtr(XlaOp op) const613 StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const {
614   TF_RETURN_IF_ERROR(first_error());
615   TF_RETURN_IF_ERROR(CheckOpBuilder(op));
616   auto it = handle_to_shape_.find(op.handle());
617   if (it == handle_to_shape_.end()) {
618     return InvalidArgument("No XlaOp with handle %d", op.handle());
619   }
620   return it->second.get();
621 }
622 
623 }  // namespace xla
624