• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
16 
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/Support/raw_ostream.h"
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
21 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
22 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
23 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
24 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
25 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
26 #include "tensorflow/compiler/xla/comparison_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/shape_inference.h"
29 #include "tensorflow/compiler/xla/util.h"
30 
31 namespace xla {
32 
GetMlirOpName(HloOpcode opcode)33 static std::string GetMlirOpName(HloOpcode opcode) {
34   std::string op_name = HloOpcodeString(opcode);
35   absl::c_replace(op_name, '-', '_');
36   return mlir::mhlo::MhloDialect::getDialectNamespace().str() + "." + op_name;
37 }
38 
ToString(mlir::Type ty)39 static std::string ToString(mlir::Type ty) {
40   std::string str;
41   llvm::raw_string_ostream sstream(str);
42   ty.print(sstream);
43   sstream.flush();
44   return str;
45 }
46 
47 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(absl::Span<const int64> values,mlir::Builder * builder)48 static mlir::DenseIntElementsAttr GetI64ElementsAttr(
49     absl::Span<const int64> values, mlir::Builder* builder) {
50   auto ty = mlir::RankedTensorType::get({static_cast<int64_t>(values.size())},
51                                         builder->getIntegerType(64));
52   return mlir::DenseIntElementsAttr::get(
53       ty, llvm::makeArrayRef(values.data(), values.size()));
54 }
55 
ConvertPadding(absl::Span<const std::pair<int64_t,int64_t>> padding,mlir::Builder * builder)56 static mlir::DenseIntElementsAttr ConvertPadding(
57     absl::Span<const std::pair<int64_t, int64_t>> padding,
58     mlir::Builder* builder) {
59   llvm::SmallVector<int64_t, 8> elements;
60   elements.reserve(padding.size() * 2);
61   for (const auto& vals : padding) {
62     elements.push_back(vals.first);
63     elements.push_back(vals.second);
64   }
65   auto ty = mlir::RankedTensorType::get(
66       {static_cast<int64_t>(padding.size()), 2}, builder->getIntegerType(64));
67   return mlir::DenseIntElementsAttr::get(ty, elements);
68 }
69 
70 MlirHloBuilder::~MlirHloBuilder() = default;
71 
MakeXlaOp(mlir::Value val)72 StatusOr<XlaOp> MlirHloBuilder::MakeXlaOp(mlir::Value val) {
73   mlir::Type ty = val.getType();
74   auto shape = std::make_unique<Shape>(TypeToShape(ty));
75   if (shape->element_type() == PrimitiveType::PRIMITIVE_TYPE_INVALID) {
76     return InvalidArgument("unsupported type: %s", ToString(ty).c_str());
77   }
78 
79   int64_t handle = reinterpret_cast<int64_t>(val.getAsOpaquePointer());
80   handle_to_shape_[handle] = std::move(shape);
81   return XlaOp(handle, this);
82 }
83 
ConstantLiteral(const LiteralSlice & literal)84 XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) {
85   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
86     TF_ASSIGN_OR_RETURN(mlir::DenseElementsAttr attr,
87                         CreateDenseElementsAttrFromLiteral(literal, builder_));
88     auto op = builder_.create<mlir::mhlo::ConstOp>(loc_, attr);
89     return MakeXlaOp(op);
90   });
91 }
92 
ConvGeneralDilatedInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const Window & window,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)93 StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
94     const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
95     absl::Span<const int64> window_strides,
96     absl::Span<const std::pair<int64, int64>> padding,
97     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
98     const ConvolutionDimensionNumbers& dimension_numbers,
99     int64 feature_group_count, int64 batch_group_count,
100     const PrecisionConfig* precision_config) {
101   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
102                                          shape, builder_));
103   mlir::ArrayAttr config_attr;
104   if (precision_config)
105     config_attr = ConvertPrecisionConfig(precision_config, &builder_);
106   auto op = builder_.create<mlir::mhlo::ConvOp>(
107       loc_, ty, GetValue(lhs), GetValue(rhs),
108       GetI64ElementsAttr(window_strides, &builder_),
109       ConvertPadding(padding, &builder_),
110       GetI64ElementsAttr(lhs_dilation, &builder_),
111       GetI64ElementsAttr(rhs_dilation, &builder_),
112       /*window_reversal=*/nullptr,
113       ConvertConvDimensionNumbers(dimension_numbers, &builder_),
114       builder_.getI64IntegerAttr(feature_group_count),
115       builder_.getI64IntegerAttr(batch_group_count), config_attr);
116   return MakeXlaOp(op);
117 }
118 
FftInternal(const Shape & shape,XlaOp operand,FftType fft_type,absl::Span<const int64> fft_length)119 StatusOr<XlaOp> MlirHloBuilder::FftInternal(
120     const Shape& shape, XlaOp operand, FftType fft_type,
121     absl::Span<const int64> fft_length) {
122   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
123                                          shape, builder_));
124   auto op = builder_.create<mlir::mhlo::FftOp>(
125       loc_, ty, GetValue(operand),
126       builder_.getStringAttr(FftType_Name(fft_type)),
127       GetI64ElementsAttr(fft_length, &builder_));
128   return MakeXlaOp(op);
129 }
130 
CustomCallInternal(const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,bool has_side_effect,absl::Span<const std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> output_operand_aliasing,const Literal * literal)131 StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
132     const string& call_target_name, absl::Span<const XlaOp> operands,
133     const Shape& shape, const string& opaque,
134     absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
135     bool has_side_effect,
136     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
137         output_operand_aliasing,
138     const Literal* literal) {
139   if (operand_shapes_with_layout.has_value())
140     return Unimplemented(
141         "CustomCall doesn't support operands shapes with layout");
142   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
143                                          shape, builder_));
144   TF_RET_CHECK(output_operand_aliasing.empty())
145       << "MLIR CustomCallOp does not support output_operand_aliasing yet";
146   TF_RET_CHECK(literal == nullptr)
147       << "MLIR CustomCallOp does not support literal yet";
148   auto op = builder_.create<mlir::mhlo::CustomCallOp>(
149       loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
150       /*has_side_effect=*/builder_.getBoolAttr(has_side_effect),
151       builder_.getStringAttr(opaque));
152   return MakeXlaOp(op.getResult(0));
153 }
154 
ReduceInternal(const Shape & shape,absl::Span<const XlaOp> all_operands,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)155 StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
156     const Shape& shape, absl::Span<const XlaOp> all_operands,
157     const XlaComputation& computation,
158     absl::Span<const int64> dimensions_to_reduce) {
159   // Reduce takes two set of variadic operands inputs and init_values.
160   // all_operands contains both of these so split operands into two parts.
161   int64_t num_args = all_operands.size() / 2;
162   auto op = builder_.create<mlir::mhlo::ReduceOp>(
163       loc_, GetValues(all_operands.first(num_args)),
164       GetValues(all_operands.subspan(num_args)),
165       GetI64ElementsAttr(dimensions_to_reduce, &builder_));
166   TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
167   if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0));
168   auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
169   return MakeXlaOp(tuple);
170 }
171 
ReduceWindowInternal(const Shape & shape,XlaOp operand,XlaOp init_value,const XlaComputation & computation,Window window)172 StatusOr<XlaOp> MlirHloBuilder::ReduceWindowInternal(
173     const Shape& shape, XlaOp operand, XlaOp init_value,
174     const XlaComputation& computation, Window window) {
175   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
176                                          shape, builder_));
177   llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
178   llvm::SmallVector<int64, 8> padding;
179   for (const auto& dim : window.dimensions()) {
180     sizes.push_back(dim.size());
181     strides.push_back(dim.stride());
182     base_dilations.push_back(dim.base_dilation());
183     win_dilations.push_back(dim.window_dilation());
184     padding.push_back(dim.padding_low());
185     padding.push_back(dim.padding_high());
186   }
187   auto padding_ty =
188       mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
189                                   builder_.getIntegerType(64));
190   auto op = builder_.create<mlir::mhlo::ReduceWindowOp>(
191       loc_, ty, GetValue(operand), GetValue(init_value),
192       GetI64ElementsAttr(sizes, &builder_),
193       GetI64ElementsAttr(strides, &builder_),
194       GetI64ElementsAttr(base_dilations, &builder_),
195       GetI64ElementsAttr(win_dilations, &builder_),
196       mlir::DenseIntElementsAttr::get(padding_ty, padding));
197   TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
198   return MakeXlaOp(op);
199 }
200 
Iota(const Shape & shape,int64 iota_dimension)201 XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) {
202   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
203     TF_ASSIGN_OR_RETURN(
204         mlir::Type ty,
205         ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
206     auto op = builder_.create<mlir::mhlo::IotaOp>(
207         loc_, ty,
208         builder_.getIntegerAttr(builder_.getI64Type(), iota_dimension));
209     return MakeXlaOp(op);
210   });
211 }
212 
BitcastConvertTypeInternal(const Shape & shape,XlaOp operand)213 StatusOr<XlaOp> MlirHloBuilder::BitcastConvertTypeInternal(const Shape& shape,
214                                                            XlaOp operand) {
215   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
216                                          shape, builder_));
217   auto op = builder_.create<mlir::mhlo::BitcastConvertOp>(loc_, ty,
218                                                           GetValue(operand));
219   return MakeXlaOp(op);
220 }
221 
TransposeInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> permutation)222 StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
223     const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
224   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
225                                          shape, builder_));
226   auto op = builder_.create<mlir::mhlo::TransposeOp>(
227       loc_, ty, GetValue(operand), GetI64ElementsAttr(permutation, &builder_));
228   return MakeXlaOp(op);
229 }
230 
RevInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> dimensions)231 StatusOr<XlaOp> MlirHloBuilder::RevInternal(
232     const Shape& shape, XlaOp operand, absl::Span<const int64> dimensions) {
233   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
234                                          shape, builder_));
235   auto op = builder_.create<mlir::mhlo::ReverseOp>(
236       loc_, ty, GetValue(operand), GetI64ElementsAttr(dimensions, &builder_));
237   return MakeXlaOp(op);
238 }
239 
SortInternal(const Shape & shape,absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64 dimension,bool is_stable)240 StatusOr<XlaOp> MlirHloBuilder::SortInternal(const Shape& shape,
241                                              absl::Span<const XlaOp> operands,
242                                              const XlaComputation& comparator,
243                                              int64 dimension, bool is_stable) {
244   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
245                                          shape, builder_));
246   llvm::SmallVector<mlir::Type, 4> sort_types = {ty};
247   if (auto tuple_ty = ty.dyn_cast<mlir::TupleType>()) {
248     sort_types = llvm::to_vector<6>(tuple_ty.getTypes());
249   }
250 
251   auto op = builder_.create<mlir::mhlo::SortOp>(
252       loc_, sort_types, GetValues(operands),
253       builder_.getI64IntegerAttr(dimension), builder_.getBoolAttr(is_stable));
254   TF_RETURN_IF_ERROR(ImportComputation(comparator.proto(), &op.comparator()));
255 
256   if (ty.isa<mlir::TupleType>()) {
257     auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
258     return MakeXlaOp(tuple);
259   }
260 
261   return MakeXlaOp(op.getResult(0));
262 }
263 
WhileInternal(const Shape & shape,const XlaComputation & condition,const XlaComputation & body,XlaOp init)264 StatusOr<XlaOp> MlirHloBuilder::WhileInternal(const Shape& shape,
265                                               const XlaComputation& condition,
266                                               const XlaComputation& body,
267                                               XlaOp init) {
268   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
269                                          shape, builder_));
270   auto op = builder_.create<mlir::mhlo::WhileOp>(loc_, ty, GetValue(init));
271   TF_RETURN_IF_ERROR(ImportComputation(condition.proto(), &op.cond()));
272   TF_RETURN_IF_ERROR(ImportComputation(body.proto(), &op.body()));
273   return MakeXlaOp(op);
274 }
275 
ReducePrecisionInternal(const Shape & shape,XlaOp operand,const int exponent_bits,const int mantissa_bits)276 StatusOr<XlaOp> MlirHloBuilder::ReducePrecisionInternal(
277     const Shape& shape, XlaOp operand, const int exponent_bits,
278     const int mantissa_bits) {
279   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
280                                          shape, builder_));
281   auto op = builder_.create<mlir::mhlo::ReducePrecisionOp>(
282       loc_, ty, GetValue(operand), builder_.getI32IntegerAttr(exponent_bits),
283       builder_.getI32IntegerAttr(mantissa_bits));
284   return MakeXlaOp(op);
285 }
286 
GatherInternal(const Shape & shape,XlaOp input,XlaOp start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)287 StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
288     const Shape& shape, XlaOp input, XlaOp start_indices,
289     const GatherDimensionNumbers& dimension_numbers,
290     absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
291   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
292                                          shape, builder_));
293   auto op = builder_.create<mlir::mhlo::GatherOp>(
294       loc_, ty, GetValue(input), GetValue(start_indices),
295       ConvertGatherDimensionNumbers(dimension_numbers, &builder_),
296       GetI64ElementsAttr(slice_sizes, &builder_));
297   return MakeXlaOp(op);
298 }
299 
ScatterInternal(const Shape & shape,XlaOp input,XlaOp scatter_indices,XlaOp updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers,bool indices_are_sorted,bool unique_indices)300 StatusOr<XlaOp> MlirHloBuilder::ScatterInternal(
301     const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
302     const XlaComputation& update_computation,
303     const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
304     bool unique_indices) {
305   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
306                                          shape, builder_));
307   auto op = builder_.create<mlir::mhlo::ScatterOp>(
308       loc_, ty, GetValue(input), GetValue(scatter_indices), GetValue(updates),
309       ConvertScatterDimensionNumbers(dimension_numbers, &builder_),
310       builder_.getBoolAttr(indices_are_sorted),
311       builder_.getBoolAttr(unique_indices));
312 
313   TF_RETURN_IF_ERROR(
314       ImportComputation(update_computation.proto(), &op.update_computation()));
315   return MakeXlaOp(op);
316 }
317 
SetDimensionSizeInternal(const Shape & shape,XlaOp operand,XlaOp val,int64 dimension)318 StatusOr<XlaOp> MlirHloBuilder::SetDimensionSizeInternal(const Shape& shape,
319                                                          XlaOp operand,
320                                                          XlaOp val,
321                                                          int64 dimension) {
322   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
323                                          shape, builder_));
324   auto op = builder_.create<mlir::mhlo::SetDimensionSizeOp>(
325       loc_, ty, GetValue(operand), GetValue(val),
326       builder_.getI64IntegerAttr(dimension));
327   return MakeXlaOp(op);
328 }
329 
RngOpInternal(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)330 StatusOr<XlaOp> MlirHloBuilder::RngOpInternal(
331     RandomDistribution distribution, absl::Span<const XlaOp> parameters,
332     const Shape& shape) {
333   // TODO(hinsu): Introduce RngOp in the HLO dialect in MLIR and then RngUniform
334   // and RngNormal can be mapped to the new op.
335   std::string op_name;
336   if (distribution == xla::RandomDistribution::RNG_UNIFORM) {
337     op_name = "mhlo.rng_uniform";
338   } else {
339     TF_RET_CHECK(distribution == xla::RandomDistribution::RNG_NORMAL)
340         << "Unexpected distribution: " << distribution;
341     op_name = "mhlo.rng_normal";
342   }
343 
344   if (shape.is_dynamic())
345     return Unimplemented("RngOp with dynamic dims not supported");
346   llvm::SmallVector<XlaOp, 3> operands;
347   operands.append(parameters.begin(), parameters.end());
348   operands.push_back(
349       ConstantLiteral(LiteralUtil::CreateR1<int64>(shape.dimensions())));
350   return CreateOp(op_name, shape, operands);
351 }
352 
RngBitGeneratorInternal(const Shape & full_result_shape,RandomAlgorithm algorithm,XlaOp initial_state)353 StatusOr<XlaOp> MlirHloBuilder::RngBitGeneratorInternal(
354     const Shape& full_result_shape, RandomAlgorithm algorithm,
355     XlaOp initial_state) {
356   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
357                                          full_result_shape, builder_));
358   auto op = builder_.create<mlir::mhlo::RngBitGeneratorOp>(
359       loc_, ty, builder_.getI32IntegerAttr(algorithm), GetValue(initial_state));
360   return MakeXlaOp(op);
361 }
362 
ReshapeInternal(const Shape & shape,XlaOp operand,int64 inferred_dimension)363 StatusOr<XlaOp> MlirHloBuilder::ReshapeInternal(const Shape& shape,
364                                                 XlaOp operand,
365                                                 int64 inferred_dimension) {
366   TF_RETURN_IF_ERROR(first_error());
367 
368   if (inferred_dimension != -1)
369     return Unimplemented("inferred_dimension not yet supported for Reshape op");
370   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
371                                          shape, builder_));
372   mlir::Value value = GetValue(operand);
373   auto op = builder_.create<mlir::mhlo::ReshapeOp>(loc_, ty, value);
374   return MakeXlaOp(op.getResult());
375 }
376 
DotGeneralInternal(const Shape & shape,XlaOp lhs,XlaOp rhs,const DotDimensionNumbers & dimension_number,const PrecisionConfig * precision_config)377 StatusOr<XlaOp> MlirHloBuilder::DotGeneralInternal(
378     const Shape& shape, XlaOp lhs, XlaOp rhs,
379     const DotDimensionNumbers& dimension_number,
380     const PrecisionConfig* precision_config) {
381   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
382                                          shape, builder_));
383   auto op = builder_.create<mlir::mhlo::DotGeneralOp>(
384       loc_, ty, GetValue(lhs), GetValue(rhs),
385       ConvertDotDimensionNumbers(dimension_number, &builder_),
386       ConvertPrecisionConfig(precision_config, &builder_));
387   return MakeXlaOp(op.getResult());
388 }
389 
InDimBroadcast(const Shape & shape,XlaOp operand,absl::Span<const int64> broadcast_dimensions)390 StatusOr<XlaOp> MlirHloBuilder::InDimBroadcast(
391     const Shape& shape, XlaOp operand,
392     absl::Span<const int64> broadcast_dimensions) {
393   TF_RETURN_IF_ERROR(first_error());
394   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
395                                          shape, builder_));
396   mlir::Value value = GetValue(operand);
397   auto op = builder_.create<mlir::mhlo::BroadcastInDimOp>(
398       loc_, ty, value, GetI64ElementsAttr(broadcast_dimensions, &builder_));
399   return MakeXlaOp(op.getResult());
400 }
401 
AddInstruction(HloInstructionProto && instr,HloOpcode opcode,absl::Span<const XlaOp> operands)402 StatusOr<XlaOp> MlirHloBuilder::AddInstruction(
403     HloInstructionProto&& instr, HloOpcode opcode,
404     absl::Span<const XlaOp> operands) {
405   return Unimplemented("MlirHloBuilder does not support op %s",
406                        HloOpcodeString(opcode));
407 }
408 
Compare(const Shape & shape,XlaOp lhs,XlaOp rhs,ComparisonDirection direction,Comparison::Type type)409 StatusOr<XlaOp> MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs,
410                                         XlaOp rhs,
411                                         ComparisonDirection direction,
412                                         Comparison::Type type) {
413   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
414                                          shape, builder_));
415   auto op = builder_.create<mlir::mhlo::CompareOp>(
416       loc_, ty, GetValue(lhs), GetValue(rhs),
417       builder_.getStringAttr(ComparisonDirectionToString(direction)),
418       builder_.getStringAttr(ComparisonTypeToString(type)));
419   return MakeXlaOp(op.getResult());
420 }
421 
BinaryOpNoBroadcast(HloOpcode binop,const Shape & shape,XlaOp lhs,XlaOp rhs)422 XlaOp MlirHloBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
423                                           XlaOp lhs, XlaOp rhs) {
424   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
425     return CreateOp(GetMlirOpName(binop), shape, {lhs, rhs});
426   });
427 }
428 
AddOpWithShape(HloOpcode opcode,const Shape & shape,absl::Span<const XlaOp> operands)429 StatusOr<XlaOp> MlirHloBuilder::AddOpWithShape(
430     HloOpcode opcode, const Shape& shape, absl::Span<const XlaOp> operands) {
431   return CreateOp(GetMlirOpName(opcode), shape,
432                   llvm::makeArrayRef<XlaOp>(operands.data(), operands.size()));
433 }
434 
CreateToken()435 XlaOp MlirHloBuilder::CreateToken() {
436   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
437     return MakeXlaOp(builder_.create<mlir::mhlo::CreateTokenOp>(
438         loc_, mlir::mhlo::TokenType::get(builder_.getContext())));
439   });
440 }
441 
TriangularSolveInternal(const Shape & shape,XlaOp a,XlaOp b,TriangularSolveOptions options)442 StatusOr<XlaOp> MlirHloBuilder::TriangularSolveInternal(
443     const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) {
444   TF_ASSIGN_OR_RETURN(
445       mlir::Type result_ty,
446       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
447   auto op = builder_.create<mlir::mhlo::TriangularSolveOp>(
448       loc_, result_ty, GetValue(a), GetValue(b),
449       builder_.getBoolAttr(options.left_side()),
450       builder_.getBoolAttr(options.lower()),
451       builder_.getBoolAttr(options.unit_diagonal()),
452       builder_.getStringAttr(
453           TriangularSolveOptions::Transpose_Name(options.transpose_a())));
454   return MakeXlaOp(op);
455 }
456 
CholeskyInternal(const Shape & shape,XlaOp a,bool lower)457 StatusOr<XlaOp> MlirHloBuilder::CholeskyInternal(const Shape& shape, XlaOp a,
458                                                  bool lower) {
459   TF_ASSIGN_OR_RETURN(
460       mlir::Type result_ty,
461       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
462   auto op = builder_.create<mlir::mhlo::CholeskyOp>(
463       loc_, result_ty, GetValue(a), builder_.getBoolAttr(lower));
464   return MakeXlaOp(op);
465 }
466 
InfeedWithTokenInternal(const Shape & infeed_instruction_shape,XlaOp token,const string & config)467 StatusOr<XlaOp> MlirHloBuilder::InfeedWithTokenInternal(
468     const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
469   TF_ASSIGN_OR_RETURN(mlir::Type result_type,
470                       ConvertShapeToType<mlir::RankedTensorType>(
471                           infeed_instruction_shape, builder_));
472   mlir::ArrayAttr layout;
473   return MakeXlaOp(
474       builder_.create<mlir::mhlo::InfeedOp>(loc_, result_type, GetValue(token),
475                                             /*infeed_config=*/config,
476                                             /*layout=*/layout));
477 }
478 
OutfeedWithTokenInternal(XlaOp operand,XlaOp token,const Shape & shape_with_layout,const string & outfeed_config)479 StatusOr<XlaOp> MlirHloBuilder::OutfeedWithTokenInternal(
480     XlaOp operand, XlaOp token, const Shape& shape_with_layout,
481     const string& outfeed_config) {
482   auto token_type = mlir::mhlo::TokenType::get(builder_.getContext());
483   return MakeXlaOp(builder_.create<mlir::mhlo::OutfeedOp>(
484       loc_, token_type, GetValue(operand), GetValue(token), outfeed_config));
485 }
486 
ConcatInDimInternal(const Shape & shape,absl::Span<const XlaOp> operands,int64 dimension)487 StatusOr<XlaOp> MlirHloBuilder::ConcatInDimInternal(
488     const Shape& shape, absl::Span<const XlaOp> operands, int64 dimension) {
489   TF_ASSIGN_OR_RETURN(
490       mlir::Type result_type,
491       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
492   auto mlir_operands = GetValues(operands);
493   return MakeXlaOp(builder_.create<mlir::mhlo::ConcatenateOp>(
494       loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension)));
495 }
496 
GetTupleElementInternal(const Shape & shape,XlaOp tuple_data,int64 index)497 StatusOr<XlaOp> MlirHloBuilder::GetTupleElementInternal(const Shape& shape,
498                                                         XlaOp tuple_data,
499                                                         int64 index) {
500   TF_ASSIGN_OR_RETURN(
501       mlir::Type result_type,
502       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
503   return MakeXlaOp(builder_.create<mlir::mhlo::GetTupleElementOp>(
504       loc_, result_type, GetValue(tuple_data),
505       builder_.getI32IntegerAttr(index)));
506 }
507 
SliceInternal(const Shape & shape,XlaOp operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)508 StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
509     const Shape& shape, XlaOp operand, absl::Span<const int64> start_indices,
510     absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
511   return MakeXlaOp(builder_.create<mlir::mhlo::SliceOp>(
512       loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_),
513       GetI64ElementsAttr(limit_indices, &builder_),
514       GetI64ElementsAttr(strides, &builder_)));
515 }
516 
DynamicSliceInternal(const Shape & shape,XlaOp operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)517 StatusOr<XlaOp> MlirHloBuilder::DynamicSliceInternal(
518     const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
519     absl::Span<const int64> slice_sizes) {
520   TF_ASSIGN_OR_RETURN(
521       mlir::Type result_ty,
522       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
523   return MakeXlaOp(builder_.create<mlir::mhlo::DynamicSliceOp>(
524       loc_, result_ty, GetValue(operand), GetValues(start_indices),
525       GetI64ElementsAttr(slice_sizes, &builder_)));
526 }
527 
DynamicUpdateSliceInternal(const Shape & shape,XlaOp operand,XlaOp update,absl::Span<const XlaOp> start_indices)528 StatusOr<XlaOp> MlirHloBuilder::DynamicUpdateSliceInternal(
529     const Shape& shape, XlaOp operand, XlaOp update,
530     absl::Span<const XlaOp> start_indices) {
531   TF_ASSIGN_OR_RETURN(
532       mlir::Type result_ty,
533       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
534   return MakeXlaOp(builder_.create<mlir::mhlo::DynamicUpdateSliceOp>(
535       loc_, result_ty, GetValue(operand), GetValue(update),
536       GetValues(start_indices)));
537 }
538 
PadInternal(const Shape & shape,XlaOp operand,XlaOp padding_value,const PaddingConfig & padding_config)539 StatusOr<XlaOp> MlirHloBuilder::PadInternal(
540     const Shape& shape, XlaOp operand, XlaOp padding_value,
541     const PaddingConfig& padding_config) {
542   TF_ASSIGN_OR_RETURN(
543       mlir::Type result_type,
544       ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
545   std::vector<int64> low;
546   std::vector<int64> high;
547   std::vector<int64> internal;
548   for (auto& dimension : padding_config.dimensions()) {
549     low.push_back(dimension.edge_padding_low());
550     high.push_back(dimension.edge_padding_high());
551     internal.push_back(dimension.interior_padding());
552   }
553   return MakeXlaOp(builder_.create<mlir::mhlo::PadOp>(
554       loc_, result_type, GetValue(operand), GetValue(padding_value),
555       GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_),
556       GetI64ElementsAttr(internal, &builder_)));
557 }
558 
TupleInternal(const Shape & shape,absl::Span<const XlaOp> elements)559 StatusOr<XlaOp> MlirHloBuilder::TupleInternal(
560     const Shape& shape, absl::Span<const XlaOp> elements) {
561   mlir::SmallVector<mlir::Value, 4> operands;
562   for (auto& element : elements) {
563     operands.push_back(GetValue(element));
564   }
565   return MakeXlaOp(builder_.create<mlir::mhlo::TupleOp>(loc_, operands));
566 }
567 
CreateOp(const std::string & op_name,const Shape & shape,llvm::ArrayRef<XlaOp> operands,llvm::ArrayRef<mlir::NamedAttribute> attributes)568 StatusOr<XlaOp> MlirHloBuilder::CreateOp(
569     const std::string& op_name, const Shape& shape,
570     llvm::ArrayRef<XlaOp> operands,
571     llvm::ArrayRef<mlir::NamedAttribute> attributes) {
572   llvm::SmallVector<mlir::Value, 4> operand_values;
573   operand_values.reserve(operands.size());
574   for (XlaOp xla_op : operands) {
575     operand_values.push_back(GetValue(xla_op));
576   }
577   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
578                                          shape, builder_));
579   mlir::OperationState state(loc_, op_name, operand_values, {ty}, attributes);
580   mlir::Operation* op = builder_.createOperation(state);
581   return MakeXlaOp(op->getResult(0));
582 }
583 
ImportComputation(const HloModuleProto & computation,mlir::Region * region)584 Status MlirHloBuilder::ImportComputation(const HloModuleProto& computation,
585                                          mlir::Region* region) {
586   TF_ASSIGN_OR_RETURN(auto module_config,
587                       xla::HloModule::CreateModuleConfigFromProto(
588                           computation, xla::DebugOptions()));
589   TF_ASSIGN_OR_RETURN(auto hlo_module, xla::HloModule::CreateFromProto(
590                                            computation, module_config));
591 
592   return HloFunctionImporter::ImportAsRegion(*hlo_module->entry_computation(),
593                                              region, &builder_);
594 }
595 
GetShapePtr(XlaOp op) const596 StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const {
597   TF_RETURN_IF_ERROR(first_error());
598   TF_RETURN_IF_ERROR(CheckOpBuilder(op));
599   auto it = handle_to_shape_.find(op.handle());
600   if (it == handle_to_shape_.end()) {
601     return InvalidArgument("No XlaOp with handle %d", op.handle());
602   }
603   return it->second.get();
604 }
605 
606 }  // namespace xla
607