1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/MemoryBuffer.h"
28 #include "llvm/Support/SMLoc.h"
29 #include "llvm/Support/SourceMgr.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
32 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
33 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
34 #include "mlir/IR/Attributes.h" // from @llvm-project
35 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
36 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
37 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
38 #include "mlir/IR/Location.h" // from @llvm-project
39 #include "mlir/IR/MLIRContext.h" // from @llvm-project
40 #include "mlir/IR/Matchers.h" // from @llvm-project
41 #include "mlir/IR/Operation.h" // from @llvm-project
42 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
43 #include "mlir/IR/UseDefLists.h" // from @llvm-project
44 #include "mlir/Pass/Pass.h" // from @llvm-project
45 #include "mlir/Pass/PassManager.h" // from @llvm-project
46 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
48 #include "tensorflow/compiler/mlir/utils/name_utils.h"
49 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
50 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
51 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
52 #include "tensorflow/compiler/tf2xla/shape_util.h"
53 #include "tensorflow/compiler/xla/client/lib/matrix.h"
54 #include "tensorflow/compiler/xla/client/lib/quantize.h"
55 #include "tensorflow/compiler/xla/client/lib/slicing.h"
56 #include "tensorflow/compiler/xla/client/xla_builder.h"
57 #include "tensorflow/compiler/xla/comparison_util.h"
58 #include "tensorflow/compiler/xla/literal_util.h"
59 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
60 #include "tensorflow/compiler/xla/service/hlo_module.h"
61 #include "tensorflow/compiler/xla/shape_util.h"
62 #include "tensorflow/compiler/xla/status_macros.h"
63 #include "tensorflow/compiler/xla/xla_data.pb.h"
64 #include "tensorflow/core/framework/tensor_shape.h"
65 #include "tensorflow/core/framework/types.pb.h"
66 #include "tensorflow/core/platform/errors.h"
67 #include "tensorflow/stream_executor/lib/statusor.h"
68
69 using ::stream_executor::port::StatusOr;
70 using ::tensorflow::int16;
71 using ::tensorflow::int32;
72 using ::tensorflow::int64;
73 using ::tensorflow::int8;
74 using ::tensorflow::uint16;
75 using ::tensorflow::uint32;
76 using ::tensorflow::uint64;
77 using ::tensorflow::uint8;
78
79 constexpr char kShapeIndicesAttr[] = "shape_indices";
80 constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
81 constexpr char kShardingAttr[] = "mhlo.sharding";
82 constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
83 constexpr char kReplicationAttr[] = "mhlo.is_same_data_across_replicas";
84
85 // Array attribute. Same shape as infeed result, but contains a
86 // minor_to_major array for every tensor.
87 constexpr char kLayoutAttr[] = "layout";
88 constexpr char kDefaultLayoutAttrName[] = "minor_to_major";
89
90 // Passes through everything except for unique_ptr, on which it calls get().
91 // This exists to allow the generated code to call XLA functions that take a raw
92 // pointer. In particular, PrecisionConfig is passed to xla::Dot and xla::Conv
93 // as a pointer and there is otherwise no way to avoid a memory leak.
94 template <typename T>
Unwrap(T t)95 T Unwrap(T t) {
96 return t;
97 }
98
99 template <typename T>
Unwrap(const std::unique_ptr<T> & t)100 T* Unwrap(const std::unique_ptr<T>& t) {
101 return t.get();
102 }
103
GetXlaOp(mlir::Value val,const llvm::DenseMap<mlir::Value,xla::XlaOp> & val_map,xla::XlaOp * result,mlir::Operation * op)104 static mlir::LogicalResult GetXlaOp(
105 mlir::Value val, const llvm::DenseMap<mlir::Value, xla::XlaOp>& val_map,
106 xla::XlaOp* result, mlir::Operation* op) {
107 auto iter = val_map.find(val);
108 if (iter == val_map.end()) {
109 return op->emitOpError(
110 "requires all operands to be defined in the parent region for export");
111 }
112 *result = iter->second;
113 return mlir::success();
114 }
115
116 // Convert APInt into an int.
117 // TODO(hpucha): This should be consolidated into a general place.
ConvertAPInt(llvm::APInt i)118 static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); }
119
Convertuint32_t(uint32_t i)120 static uint32_t Convertuint32_t(uint32_t i) { return i; }
Convertuint64_t(uint64_t i)121 static uint64_t Convertuint64_t(uint64_t i) { return i; }
122
123 // Convert APFloat to double.
ConvertAPFloat(llvm::APFloat value)124 static double ConvertAPFloat(llvm::APFloat value) {
125 const auto& semantics = value.getSemantics();
126 bool losesInfo = false;
127 if (&semantics != &llvm::APFloat::IEEEdouble())
128 value.convert(llvm::APFloat::IEEEdouble(),
129 llvm::APFloat::rmNearestTiesToEven, &losesInfo);
130 return value.convertToDouble();
131 }
132
Convertbool(bool value)133 static inline bool Convertbool(bool value) { return value; }
134
ConvertStringRef(mlir::StringRef value)135 static absl::string_view ConvertStringRef(mlir::StringRef value) {
136 return {value.data(), value.size()};
137 }
138
ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr)139 static std::vector<int64> ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) {
140 auto values = attr.getValues<int64>();
141 return {values.begin(), values.end()};
142 }
143
ConvertDenseIntAttr(llvm::Optional<mlir::DenseIntElementsAttr> attr)144 static std::vector<int64> ConvertDenseIntAttr(
145 llvm::Optional<mlir::DenseIntElementsAttr> attr) {
146 if (!attr) return {};
147 return ConvertDenseIntAttr(*attr);
148 }
149
150 // Converts the broadcast_dimensions attribute into a vector of dimension
151 // numbers (empty if the attribute is absent).
Convert_broadcast_dimensions(llvm::Optional<mlir::DenseIntElementsAttr> broadcast_dimensions)152 static std::vector<int64> Convert_broadcast_dimensions(
153 llvm::Optional<mlir::DenseIntElementsAttr> broadcast_dimensions) {
154 if (!broadcast_dimensions.hasValue()) return {};
155
156 return ConvertDenseIntAttr(*broadcast_dimensions);
157 }
158
159 // Converts StringRef to xla FftType enum
Convert_fft_type(llvm::StringRef fft_type_str)160 static xla::FftType Convert_fft_type(llvm::StringRef fft_type_str) {
161 xla::FftType fft_type_enum;
162 // Illegal fft_type string would be caught by the verifier, so 'FftType_Parse'
163 // call below should never return false.
164 if (!FftType_Parse(std::string(fft_type_str), &fft_type_enum))
165 return xla::FftType::FFT;
166 return fft_type_enum;
167 }
168
Convert_padding(llvm::Optional<mlir::DenseIntElementsAttr> padding)169 static std::vector<std::pair<int64, int64>> Convert_padding(
170 llvm::Optional<mlir::DenseIntElementsAttr> padding) {
171 return xla::ConvertNx2Attribute(padding).ValueOrDie();
172 }
173
Convert_source_target_pairs(llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs)174 static std::vector<std::pair<int64, int64>> Convert_source_target_pairs(
175 llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs) {
176 return xla::ConvertNx2Attribute(source_target_pairs).ValueOrDie();
177 }
178
Convert_replica_groups(mlir::DenseIntElementsAttr groups)179 static std::vector<xla::ReplicaGroup> Convert_replica_groups(
180 mlir::DenseIntElementsAttr groups) {
181 return xla::ConvertReplicaGroups(groups).ValueOrDie();
182 }
183
184 // Converts StringRef to xla Transpose enum.
Convert_transpose_a(llvm::StringRef transpose_str)185 static xla::TriangularSolveOptions::Transpose Convert_transpose_a(
186 llvm::StringRef transpose_str) {
187 return xla::ConvertTranspose(transpose_str).ValueOrDie();
188 }
189
ExtractLayout(mlir::Operation * op,int rank,llvm::StringRef attr_name=kDefaultLayoutAttrName)190 static xla::Layout ExtractLayout(
191 mlir::Operation* op, int rank,
192 llvm::StringRef attr_name = kDefaultLayoutAttrName) {
193 if (auto attr = op->getAttrOfType<mlir::DenseIntElementsAttr>(attr_name)) {
194 llvm::SmallVector<int64, 4> minor_to_major;
195 DCHECK_EQ(rank, attr.size());
196 minor_to_major.reserve(attr.size());
197 for (const llvm::APInt& i : attr) {
198 minor_to_major.push_back(i.getZExtValue());
199 }
200 return xla::LayoutUtil::MakeLayout(minor_to_major);
201 }
202 return xla::LayoutUtil::MakeDescendingLayout(rank);
203 }
204
205 #define I64_ELEMENTS_ATTR_TO_VECTOR(attribute) \
206 static std::vector<int64> Convert_##attribute( \
207 llvm::Optional<mlir::DenseIntElementsAttr> attribute) { \
208 return ConvertDenseIntAttr(attribute); \
209 }
210
211 I64_ELEMENTS_ATTR_TO_VECTOR(broadcast_sizes);
212 I64_ELEMENTS_ATTR_TO_VECTOR(permutation);
213 I64_ELEMENTS_ATTR_TO_VECTOR(start_indices);
214 I64_ELEMENTS_ATTR_TO_VECTOR(limit_indices);
215 I64_ELEMENTS_ATTR_TO_VECTOR(strides);
216 I64_ELEMENTS_ATTR_TO_VECTOR(slice_sizes);
217 I64_ELEMENTS_ATTR_TO_VECTOR(fft_length);
218 I64_ELEMENTS_ATTR_TO_VECTOR(dimensions);
219 I64_ELEMENTS_ATTR_TO_VECTOR(window_strides);
220 I64_ELEMENTS_ATTR_TO_VECTOR(lhs_dilation);
221 I64_ELEMENTS_ATTR_TO_VECTOR(rhs_dilation);
222
223 #undef I64_ELEMENTS_ATTR_TO_VECTOR
224
Convert_ArrayRef(llvm::ArrayRef<int64_t> values)225 static std::vector<int64> Convert_ArrayRef(llvm::ArrayRef<int64_t> values) {
226 return {values.begin(), values.end()};
227 }
228
229 // Converts the precision config array of strings attribute into the
230 // corresponding XLA proto. All the strings are assumed to be valid names of the
231 // Precision enum. This should have been checked in the op verify method.
Convert_precision_config(llvm::Optional<mlir::ArrayAttr> optional_precision_config_attr)232 static std::unique_ptr<xla::PrecisionConfig> Convert_precision_config(
233 llvm::Optional<mlir::ArrayAttr> optional_precision_config_attr) {
234 if (!optional_precision_config_attr.hasValue()) return nullptr;
235
236 auto precision_config = absl::make_unique<xla::PrecisionConfig>();
237 for (auto attr : optional_precision_config_attr.getValue()) {
238 xla::PrecisionConfig::Precision p;
239 auto operand_precision = attr.cast<mlir::StringAttr>().getValue().str();
240 // TODO(jpienaar): Update this to ensure this is captured by verify.
241 if (xla::PrecisionConfig::Precision_Parse(operand_precision, &p)) {
242 precision_config->add_operand_precision(p);
243 } else {
244 auto* context = attr.getContext();
245 mlir::emitError(mlir::UnknownLoc::get(context))
246 << "unexpected operand precision " << operand_precision;
247 return nullptr;
248 }
249 }
250
251 return precision_config;
252 }
253
Convert_dot_dimension_numbers(mlir::mhlo::DotDimensionNumbers dot_dimension_numbers_attr)254 static xla::DotDimensionNumbers Convert_dot_dimension_numbers(
255 mlir::mhlo::DotDimensionNumbers dot_dimension_numbers_attr) {
256 xla::DotDimensionNumbers dot_dimension_numbers;
257
258 auto rhs_contracting_dimensions =
259 dot_dimension_numbers_attr.rhs_contracting_dimensions()
260 .cast<mlir::DenseIntElementsAttr>();
261 auto lhs_contracting_dimensions =
262 dot_dimension_numbers_attr.lhs_contracting_dimensions()
263 .cast<mlir::DenseIntElementsAttr>();
264 auto rhs_batch_dimensions =
265 dot_dimension_numbers_attr.rhs_batching_dimensions()
266 .cast<mlir::DenseIntElementsAttr>();
267 auto lhs_batch_dimensions =
268 dot_dimension_numbers_attr.lhs_batching_dimensions()
269 .cast<mlir::DenseIntElementsAttr>();
270
271 for (const auto& val : rhs_contracting_dimensions) {
272 dot_dimension_numbers.add_rhs_contracting_dimensions(val.getSExtValue());
273 }
274 for (const auto& val : lhs_contracting_dimensions) {
275 dot_dimension_numbers.add_lhs_contracting_dimensions(val.getSExtValue());
276 }
277
278 for (const auto& val : rhs_batch_dimensions) {
279 dot_dimension_numbers.add_rhs_batch_dimensions(val.getSExtValue());
280 }
281
282 for (const auto& val : lhs_batch_dimensions) {
283 dot_dimension_numbers.add_lhs_batch_dimensions(val.getSExtValue());
284 }
285
286 return dot_dimension_numbers;
287 }
288
Convert_dimension_numbers(mlir::mhlo::ConvDimensionNumbers input)289 static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
290 mlir::mhlo::ConvDimensionNumbers input) {
291 return xla::ConvertConvDimensionNumbers(input);
292 }
293
Convert_channel_handle(mlir::mhlo::ChannelHandle attr)294 xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandle attr) {
295 xla::ChannelHandle channel_handle;
296 channel_handle.set_handle(ConvertAPInt(attr.handle().getValue()));
297 channel_handle.set_type(static_cast<xla::ChannelHandle::ChannelType>(
298 ConvertAPInt(attr.type().getValue())));
299 return channel_handle;
300 }
301
Convert_channel_handle(llvm::Optional<mlir::mhlo::ChannelHandle> attr)302 absl::optional<xla::ChannelHandle> Convert_channel_handle(
303 llvm::Optional<mlir::mhlo::ChannelHandle> attr) {
304 if (!attr.hasValue()) return absl::nullopt;
305 return Convert_channel_handle(attr.getValue());
306 }
307
308 // Converts the comparison_direction string attribute into the XLA enum. The
309 // string is assumed to correspond to exactly one of the allowed strings
310 // representing the enum. This should have been checked in the op verify method.
Convert_comparison_direction(llvm::StringRef comparison_direction_string)311 static xla::ComparisonDirection Convert_comparison_direction(
312 llvm::StringRef comparison_direction_string) {
313 return xla::StringToComparisonDirection(comparison_direction_string.str())
314 .ValueOrDie();
315 }
316
Convert_dimension_numbers(mlir::mhlo::GatherDimensionNumbers input)317 static xla::GatherDimensionNumbers Convert_dimension_numbers(
318 mlir::mhlo::GatherDimensionNumbers input) {
319 xla::GatherDimensionNumbers output;
320
321 auto offset_dims = ConvertDenseIntAttr(input.offset_dims());
322 std::copy(offset_dims.begin(), offset_dims.end(),
323 tensorflow::protobuf::RepeatedFieldBackInserter(
324 output.mutable_offset_dims()));
325
326 auto collapsed_slice_dims = ConvertDenseIntAttr(input.collapsed_slice_dims());
327 std::copy(collapsed_slice_dims.begin(), collapsed_slice_dims.end(),
328 tensorflow::protobuf::RepeatedFieldBackInserter(
329 output.mutable_collapsed_slice_dims()));
330
331 auto start_index_map = ConvertDenseIntAttr(input.start_index_map());
332 std::copy(start_index_map.begin(), start_index_map.end(),
333 tensorflow::protobuf::RepeatedFieldBackInserter(
334 output.mutable_start_index_map()));
335
336 output.set_index_vector_dim(
337 ConvertAPInt(input.index_vector_dim().getValue()));
338 return output;
339 }
340
Convert_scatter_dimension_numbers(mlir::mhlo::ScatterDimensionNumbers input)341 static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
342 mlir::mhlo::ScatterDimensionNumbers input) {
343 xla::ScatterDimensionNumbers output;
344
345 auto update_window_dims = ConvertDenseIntAttr(input.update_window_dims());
346 std::copy(update_window_dims.begin(), update_window_dims.end(),
347 tensorflow::protobuf::RepeatedFieldBackInserter(
348 output.mutable_update_window_dims()));
349
350 auto inserted_window_dims = ConvertDenseIntAttr(input.inserted_window_dims());
351 std::copy(inserted_window_dims.begin(), inserted_window_dims.end(),
352 tensorflow::protobuf::RepeatedFieldBackInserter(
353 output.mutable_inserted_window_dims()));
354
355 auto scatter_dims_to_operand_dims =
356 ConvertDenseIntAttr(input.scatter_dims_to_operand_dims());
357 std::copy(scatter_dims_to_operand_dims.begin(),
358 scatter_dims_to_operand_dims.end(),
359 tensorflow::protobuf::RepeatedFieldBackInserter(
360 output.mutable_scatter_dims_to_operand_dims()));
361
362 output.set_index_vector_dim(
363 ConvertAPInt(input.index_vector_dim().getValue()));
364 return output;
365 }
366
367 // Extracts sharding from attribute string.
CreateOpShardingFromStringRef(llvm::StringRef sharding)368 static absl::optional<xla::OpSharding> CreateOpShardingFromStringRef(
369 llvm::StringRef sharding) {
370 xla::OpSharding sharding_proto;
371 if (!sharding_proto.ParseFromString(sharding.str())) return absl::nullopt;
372 return sharding_proto;
373 }
374
375 // Returns an OpSharding proto from the "sharding" attribute of the op. If the
376 // op doesn't have a sharding attribute or the sharding attribute is invalid,
377 // returns absl::nullopt.
CreateOpShardingFromAttribute(mlir::Operation * op)378 static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
379 mlir::Operation* op) {
380 auto sharding = op->getAttrOfType<mlir::StringAttr>(kShardingAttr);
381 if (!sharding) return absl::nullopt;
382 return CreateOpShardingFromStringRef(sharding.getValue());
383 }
384
385 // Returns a FrontendAttributes proto from the "frontend_attributes" attribute
386 // of the op. An empty FrontendAttributes proto is returned if an op does not
387 // have frontend attributes.
CreateOpFrontendAttributesFromAttribute(mlir::Operation * op)388 static xla::FrontendAttributes CreateOpFrontendAttributesFromAttribute(
389 mlir::Operation* op) {
390 xla::FrontendAttributes frontend_attributes;
391 auto frontend_attributes_dict =
392 op->getAttrOfType<mlir::DictionaryAttr>(kFrontendAttributesAttr);
393
394 if (!frontend_attributes_dict) return frontend_attributes;
395
396 for (const auto& attr : frontend_attributes_dict)
397 if (auto value_str_attr = attr.second.dyn_cast<mlir::StringAttr>())
398 frontend_attributes.mutable_map()->insert(
399 {attr.first.str(), value_str_attr.getValue().str()});
400
401 return frontend_attributes;
402 }
403
404 // Returns a OpMetadata proto based on the location of the op. If the location
405 // is unknown, an empty proto is returned. `op_name` are populated with the op
406 // location (converted). FileLineColLoc locations are populated by taking the
407 // file name and line number, and populating `source_file` and `source_line`
408 // respectively.
CreateOpMetadataFromLocation(mlir::Operation * op)409 static xla::OpMetadata CreateOpMetadataFromLocation(mlir::Operation* op) {
410 xla::OpMetadata metadata;
411 if (op->getLoc().isa<mlir::UnknownLoc>()) return metadata;
412
413 std::string name = mlir::GetNameFromLoc(op->getLoc());
414 mlir::LegalizeNodeName(name);
415 metadata.set_op_name(name);
416
417 if (auto file_line_col_loc = op->getLoc().dyn_cast<mlir::FileLineColLoc>()) {
418 metadata.set_source_file(file_line_col_loc.getFilename().str());
419 metadata.set_source_line(file_line_col_loc.getLine());
420 }
421
422 return metadata;
423 }
424
425 // Checks if all shardings are set.
AllOptionalShardingsAreSet(llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings)426 static bool AllOptionalShardingsAreSet(
427 llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings) {
428 return llvm::all_of(shardings,
429 [](const absl::optional<xla::OpSharding>& sharding) {
430 return sharding.has_value();
431 });
432 }
433
434 // Extracts argument and result shardings from function.
ExtractShardingsFromFunction(mlir::FuncOp function,llvm::SmallVectorImpl<absl::optional<xla::OpSharding>> * arg_shardings,llvm::SmallVectorImpl<absl::optional<xla::OpSharding>> * ret_shardings)435 static void ExtractShardingsFromFunction(
436 mlir::FuncOp function,
437 llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* arg_shardings,
438 llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* ret_shardings) {
439 arg_shardings->resize(function.getNumArguments(),
440 absl::optional<xla::OpSharding>());
441 for (int i = 0, end = function.getNumArguments(); i < end; ++i)
442 if (auto sharding =
443 function.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr))
444 (*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
445
446 ret_shardings->resize(function.getNumResults(),
447 absl::optional<xla::OpSharding>());
448 for (int i = 0, end = function.getNumResults(); i < end; ++i)
449 if (auto sharding =
450 function.getResultAttrOfType<mlir::StringAttr>(i, kShardingAttr))
451 (*ret_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
452 }
453
454 namespace mlir {
455 namespace {
456 class ConvertToHloModule {
457 public:
458 using ValueLoweringMap = llvm::DenseMap<Value, xla::XlaOp>;
459 using FunctionLoweringMap = llvm::DenseMap<mlir::FuncOp, xla::XlaComputation>;
460
461 // If use_tuple_args is true, then the entry function's arguments are
462 // converted to a tuple and passed as a single parameter.
463 // Similarly, if return tuple is true, then the entry function's return values
464 // are converted to a tuple even when there is only a single return value.
465 // Multiple return values are always converted to a tuple and returned as a
466 // single value.
ConvertToHloModule(mlir::ModuleOp module,xla::XlaBuilder & module_builder,bool use_tuple_args,bool return_tuple,tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,MlirToHloConversionOptions options)467 explicit ConvertToHloModule(
468 mlir::ModuleOp module, xla::XlaBuilder& module_builder,
469 bool use_tuple_args, bool return_tuple,
470 tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
471 MlirToHloConversionOptions options)
472 : module_(module),
473 module_builder_(module_builder),
474 use_tuple_args_(use_tuple_args),
475 return_tuple_(return_tuple),
476 shape_representation_fn_(shape_representation_fn),
477 options_(options) {
478 if (!shape_representation_fn_)
479 shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn();
480 }
481
482 // Perform the lowering to XLA. This function returns failure if an error was
483 // encountered.
484 //
485 // TODO(hinsu): Check for dynamic shapes and exit instead of crashing.
Run()486 LogicalResult Run() {
487 auto main = module_.lookupSymbol<mlir::FuncOp>("main");
488 if (!main)
489 return module_.emitError(
490 "conversion requires module with `main` function");
491
492 for (auto func : module_.getOps<FuncOp>()) {
493 if (func.empty()) continue;
494 if (failed(RunOnFunction(func))) return failure();
495 }
496 return success();
497 }
498
499 // Lower a specific function to HLO.
500 LogicalResult RunOnFunction(mlir::FuncOp f);
501
502 // Lower a `mlir::Region` to a `XlaComputation`
503 LogicalResult LowerRegionAsComputation(mlir::Region* region,
504 xla::XlaComputation* func);
505
506 // Lower a single `Block` to a `XlaComputation`
507 LogicalResult LowerBasicBlockAsFunction(
508 Block* block, xla::XlaBuilder* builder, bool is_entry_function,
509 const std::vector<bool>& entry_args_same_across_replicas,
510 llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
511 llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
512 xla::XlaComputation* result);
513
ConsumeMainProto()514 ::xla::HloModuleProto ConsumeMainProto() {
515 auto main = module_.lookupSymbol<mlir::FuncOp>("main");
516 // This is an invariant check as Run returns failure if there is no main
517 // function and so the main proto shouldn't be consumed in that case.
518 CHECK(main) << "requires module to have main function"; // Crash Ok.
519 return lowered_computation_[main].proto();
520 }
521
522 // Lower function call to HLO call instruction
523 LogicalResult LowerFunctionCall(
524 mlir::CallOp call_op, xla::XlaBuilder* builder,
525 ConvertToHloModule::ValueLoweringMap* value_lowering);
526
527 LogicalResult Lower(
528 mlir::Operation* inst, bool is_entry_function,
529 llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
530 xla::XlaBuilder* builder,
531 ConvertToHloModule::ValueLoweringMap* value_lowering,
532 xla::XlaOp* return_value);
533
GetOptions() const534 const MlirToHloConversionOptions& GetOptions() const { return options_; }
535
536 private:
537 LogicalResult SetEntryTupleShapesAndLeafReplication(
538 Block* block, const std::vector<bool>& entry_args_same_across_replicas,
539 llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
540 std::vector<bool>* leaf_replication);
541
542 LogicalResult SetEntryTupleShardings(
543 Block* block, xla::XlaBuilder* builder,
544 llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
545 llvm::SmallVectorImpl<xla::Shape>* arg_shapes);
546
547 // The module being lowered.
548 mlir::ModuleOp module_;
549
550 // The top-level XlaBuilder.
551 xla::XlaBuilder& module_builder_;
552
553 // Map between function and lowered computation.
554 FunctionLoweringMap lowered_computation_;
555
556 // Whether the entry function should take a single tuple as input.
557 bool use_tuple_args_;
558
559 // Whether to always return a tuple.
560 bool return_tuple_;
561
562 // Shape representation function to determine entry function argument and
563 // result shapes.
564 tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_;
565
566 // Unique suffix to give to the name of the next lowered region.
567 size_t region_id_ = 0;
568
569 MlirToHloConversionOptions options_;
570 };
571
572 } // namespace
573 } // namespace mlir
574
575 namespace {
576
577 struct OpLoweringContext {
578 llvm::DenseMap<mlir::Value, xla::XlaOp>* values;
579 mlir::ConvertToHloModule* converter;
580 xla::XlaBuilder* builder;
581 };
582
GetTuple(mlir::Operation * op,mlir::Operation::operand_range values,OpLoweringContext ctx,llvm::SmallVectorImpl<xla::XlaOp> & results)583 mlir::LogicalResult GetTuple(mlir::Operation* op,
584 mlir::Operation::operand_range values,
585 OpLoweringContext ctx,
586 llvm::SmallVectorImpl<xla::XlaOp>& results) {
587 results.reserve(values.size());
588 for (mlir::Value value : values) {
589 if (failed(GetXlaOp(value, *ctx.values, &results.emplace_back(), op)))
590 return mlir::failure();
591 }
592 return mlir::success();
593 }
594
595 } // namespace
596
597 namespace mlir {
598 namespace mhlo {
599 namespace {
600
ExportXlaOp(ComputeReshapeShapeOp,OpLoweringContext)601 LogicalResult ExportXlaOp(ComputeReshapeShapeOp, OpLoweringContext) {
602 // This op has no expression in the legacy export format. It can be expanded
603 // to a sequence of operations if needed in the future, but would feed into
604 // ops creating unsupported dynamic shapes.
605 return failure();
606 }
607
ExportXlaOp(CstrReshapableOp,OpLoweringContext)608 LogicalResult ExportXlaOp(CstrReshapableOp, OpLoweringContext) {
609 // This op has no expression in the legacy export format.
610 return failure();
611 }
612
ExportXlaOp(AllGatherOp op,OpLoweringContext ctx)613 LogicalResult ExportXlaOp(AllGatherOp op, OpLoweringContext ctx) {
614 auto& valueMap = *ctx.values;
615 xla::XlaOp operand;
616 if (failed(GetXlaOp(op.operand(), valueMap, &operand, op))) return failure();
617 TensorType operandType = op.operand().getType().cast<TensorType>();
618 TensorType resultType = op.getType();
619 if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
620 return failure();
621 auto allGatherDim = op.all_gather_dim();
622 int64_t shardCount = resultType.getDimSize(allGatherDim) /
623 operandType.getDimSize(allGatherDim);
624 valueMap[op] = xla::AllGather(operand, allGatherDim, shardCount,
625 Convert_replica_groups(op.replica_groups()),
626 Convert_channel_handle(op.channel_handle()));
627 return success();
628 }
629
ExportXlaOp(AllReduceOp op,OpLoweringContext ctx)630 LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) {
631 auto& value_map = *ctx.values;
632 xla::XlaComputation computation;
633 if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
634 &computation))) {
635 return failure();
636 }
637
638 xla::XlaOp operand;
639 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
640
641 value_map[op] = xla::AllReduce(operand, computation,
642 Convert_replica_groups(op.replica_groups()),
643 Convert_channel_handle(op.channel_handle()));
644 return success();
645 }
646
ExportXlaOp(ReduceScatterOp op,OpLoweringContext ctx)647 LogicalResult ExportXlaOp(ReduceScatterOp op, OpLoweringContext ctx) {
648 auto& valueMap = *ctx.values;
649 xla::XlaOp operand;
650 if (failed(GetXlaOp(op.operand(), valueMap, &operand, op))) return failure();
651 TensorType operandType = op.operand().getType().cast<TensorType>();
652 TensorType resultType = op.getType();
653 if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
654 return failure();
655 auto scatterDim = op.scatter_dimension();
656 int64_t shardCount =
657 operandType.getDimSize(scatterDim) / resultType.getDimSize(scatterDim);
658
659 xla::XlaComputation computation;
660 if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
661 &computation))) {
662 return failure();
663 }
664
665 valueMap[op] =
666 xla::ReduceScatter(operand, computation, scatterDim, shardCount,
667 Convert_replica_groups(op.replica_groups()),
668 Convert_channel_handle(op.channel_handle()));
669 return success();
670 }
671
ExportXlaOp(BitcastConvertOp op,OpLoweringContext ctx)672 LogicalResult ExportXlaOp(BitcastConvertOp op, OpLoweringContext ctx) {
673 auto& value_map = *ctx.values;
674 xla::XlaOp operand;
675 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
676
677 value_map[op] = xla::BitcastConvertType(
678 operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())));
679 return success();
680 }
681
ExportXlaOp(BroadcastInDimOp op,OpLoweringContext ctx)682 LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) {
683 auto type = op.getType().dyn_cast<RankedTensorType>();
684 if (!type) return failure();
685 auto& value_map = *ctx.values;
686 xla::XlaOp operand;
687 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
688
689 value_map[op] =
690 BroadcastInDim(operand, Convert_ArrayRef(type.getShape()),
691 Convert_broadcast_dimensions(op.broadcast_dimensions()));
692 return success();
693 }
694
ExportXlaOp(DotGeneralOp op,OpLoweringContext ctx)695 LogicalResult ExportXlaOp(DotGeneralOp op, OpLoweringContext ctx) {
696 auto& value_map = *ctx.values;
697 xla::XlaOp lhs, rhs;
698 if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
699 if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
700 xla::PrimitiveType preferred_element_type =
701 xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()));
702 value_map[op] = xla::DotGeneral(
703 lhs, rhs, Convert_dot_dimension_numbers(op.dot_dimension_numbers()),
704 Unwrap(Convert_precision_config(op.precision_config())),
705 preferred_element_type);
706 return mlir::success();
707 }
708
ExportXlaOp(DynamicBroadcastInDimOp op,OpLoweringContext ctx)709 LogicalResult ExportXlaOp(DynamicBroadcastInDimOp op, OpLoweringContext ctx) {
710 // This op has no expression in the legacy export format.
711 return failure();
712 }
713
ExportXlaOp(DynamicIotaOp op,OpLoweringContext ctx)714 LogicalResult ExportXlaOp(DynamicIotaOp op, OpLoweringContext ctx) {
715 // This op has no expression in the legacy export format.
716 return failure();
717 }
718
ExportXlaOp(DynamicReshapeOp op,OpLoweringContext ctx)719 LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) {
720 // This op has no expression in the legacy export format.
721 return failure();
722 }
723
ExportXlaOp(IfOp op,OpLoweringContext ctx)724 LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) {
725 xla::XlaComputation true_branch;
726 xla::XlaComputation false_branch;
727 auto& value_map = *ctx.values;
728 if (failed(ctx.converter->LowerRegionAsComputation(&op.true_branch(),
729 &true_branch)) ||
730 failed(ctx.converter->LowerRegionAsComputation(&op.false_branch(),
731 &false_branch))) {
732 return failure();
733 }
734 xla::XlaOp pred, true_arg, false_arg;
735 if (failed(GetXlaOp(op.pred(), value_map, &pred, op))) return failure();
736 if (failed(GetXlaOp(op.true_arg(), value_map, &true_arg, op)))
737 return failure();
738 if (failed(GetXlaOp(op.false_arg(), value_map, &false_arg, op)))
739 return failure();
740
741 value_map[op] =
742 xla::Conditional(pred, true_arg, true_branch, false_arg, false_branch);
743 return success();
744 }
745
ExportXlaOp(CaseOp op,OpLoweringContext ctx)746 LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) {
747 llvm::DenseMap<mlir::Value, xla::XlaOp>& value_map = *ctx.values;
748 OperandRange operands = op.branch_operands();
749 MutableArrayRef<Region> branches = op.branches();
750 llvm::SmallVector<xla::XlaOp, 4> branch_operands(branches.size());
751 std::vector<xla::XlaComputation> computations(branches.size());
752 std::vector<xla::XlaComputation*> computations_p(branches.size());
753
754 for (unsigned i = 0; i < branches.size(); ++i) {
755 xla::XlaOp operand;
756 if (failed(GetXlaOp(operands[i], value_map, &operand, op)))
757 return failure();
758 branch_operands[i] = operand;
759 computations_p[i] = &computations[i];
760 if (failed(ctx.converter->LowerRegionAsComputation(&branches[i],
761 computations_p[i])))
762 return failure();
763 }
764 xla::XlaOp index;
765 if (failed(GetXlaOp(op.index(), value_map, &index, op))) return failure();
766
767 xla::XlaOp result = xla::Conditional(index, computations_p, branch_operands);
768 if (op.getNumResults() == 1) {
769 value_map[op.getResult(0)] = result;
770 } else {
771 for (auto item : llvm::enumerate(op.getResults())) {
772 value_map[item.value()] = xla::GetTupleElement(result, item.index());
773 }
774 }
775 return success();
776 }
777
778 // Specialize CompareOp export to set broadcast_dimensions argument.
ExportXlaOp(mlir::mhlo::CompareOp op,OpLoweringContext ctx)779 mlir::LogicalResult ExportXlaOp(mlir::mhlo::CompareOp op,
780 OpLoweringContext ctx) {
781 auto& value_map = *ctx.values;
782 xla::XlaOp lhs, rhs;
783 if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
784 if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
785 auto dir = Convert_comparison_direction(op.comparison_direction());
786 auto type_attr = op.compare_typeAttr();
787
788 xla::XlaOp xla_result;
789 if (type_attr) {
790 auto type =
791 xla::StringToComparisonType(type_attr.getValue().str()).ValueOrDie();
792 xla_result = xla::Compare(lhs, rhs, /*broadcast_dimensions=*/{}, dir, type);
793 } else {
794 xla_result = xla::Compare(lhs, rhs, dir);
795 }
796 value_map[op] = xla_result;
797 return mlir::success();
798 }
799
ExportXlaOp(ConstOp op,OpLoweringContext ctx)800 LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) {
801 return failure();
802 }
803
ExportXlaOp(mlir::mhlo::ConvOp op,OpLoweringContext ctx)804 LogicalResult ExportXlaOp(mlir::mhlo::ConvOp op, OpLoweringContext ctx) {
805 // XLA client builder API does not support generating convolution instructions
806 // with window reversal.
807 if (op.hasWindowReversal()) return failure();
808 auto& value_map = *ctx.values;
809 xla::XlaOp lhs, rhs;
810 if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
811 if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
812 xla::PrimitiveType preferred_element_type =
813 xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()));
814 xla::XlaOp xla_result = xla::ConvGeneralDilated(
815 lhs, rhs, Convert_window_strides(op.window_strides()),
816 Convert_padding(op.padding()), Convert_lhs_dilation(op.lhs_dilation()),
817 Convert_rhs_dilation(op.rhs_dilation()),
818 Convert_dimension_numbers(op.dimension_numbers()),
819 Convertuint64_t(op.feature_group_count()),
820 Convertuint64_t(op.batch_group_count()),
821 Unwrap(Convert_precision_config(op.precision_config())),
822 preferred_element_type);
823 value_map[op] = xla_result;
824 return mlir::success();
825 }
826
ExportXlaOp(ConvertOp op,OpLoweringContext ctx)827 LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) {
828 auto& value_map = *ctx.values;
829 xla::XlaOp operand;
830 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
831
832 value_map[op] = xla::ConvertElementType(
833 operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())));
834 return success();
835 }
836
ExportXlaOp(CustomCallOp op,OpLoweringContext ctx)837 LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) {
838 // XLA client builder API does not support generating custom call instructions
839 // with side effect.
840 if (op.has_side_effect() || op.getNumResults() != 1) return failure();
841 Value result = op.getResult(0);
842 llvm::SmallVector<xla::XlaOp> args;
843 if (failed(GetTuple(op, op.args(), ctx, args))) return failure();
844 auto xla_api_version = xla::ConvertCustomCallApiVersion(op.api_version());
845 if (!xla_api_version.ok()) return failure();
846 auto& value_map = *ctx.values;
847 value_map[result] = xla::CustomCall(
848 ctx.builder, std::string(op.call_target_name()), args,
849 xla::TypeToShape(result.getType()), std::string(op.backend_config()),
850 /*has_side_effect=*/false, /*output_operand_aliasing=*/{},
851 /*literal=*/nullptr, /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
852 /*api_version=*/*xla_api_version);
853 return success();
854 }
855
ExportXlaOp(DequantizeOp op,OpLoweringContext ctx)856 LogicalResult ExportXlaOp(DequantizeOp op, OpLoweringContext ctx) {
857 xla::QuantizedRange range(ConvertAPFloat(op.min_range()),
858 ConvertAPFloat(op.max_range()));
859 auto& value_map = *ctx.values;
860 xla::XlaOp input;
861 if (failed(GetXlaOp(op.input(), value_map, &input, op))) return failure();
862
863 auto casted = xla::ConvertElementType(input, xla::U32);
864 if (op.is_16bits()) {
865 value_map[op] = xla::Dequantize<uint16>(
866 casted, range, ConvertStringRef(op.mode()), op.transpose_output());
867 } else {
868 value_map[op] = xla::Dequantize<uint8>(
869 casted, range, ConvertStringRef(op.mode()), op.transpose_output());
870 }
871 return success();
872 }
873
ExportXlaOp(InfeedOp op,OpLoweringContext ctx)874 LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) {
875 auto& value_map = *ctx.values;
876 xla::XlaOp token;
877 if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
878
879 // The shape argument expected by the xla client API is the type of the first
880 // element in the result tuple.
881 auto result_type = op.getType().cast<mlir::TupleType>().getType(0);
882 value_map[op] = xla::InfeedWithToken(token, xla::TypeToShape(result_type),
883 std::string(op.infeed_config()));
884 return success();
885 }
886
ExportXlaOp(IotaOp op,OpLoweringContext ctx)887 LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
888 auto& value_map = *ctx.values;
889 value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()),
890 op.iota_dimension());
891 return success();
892 }
893
ExportXlaOp(MapOp op,OpLoweringContext ctx)894 LogicalResult ExportXlaOp(MapOp op, OpLoweringContext ctx) {
895 auto& value_map = *ctx.values;
896 xla::XlaComputation computation;
897 if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
898 &computation))) {
899 return failure();
900 }
901 llvm::SmallVector<xla::XlaOp> operands;
902 if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure();
903 value_map[op] = xla::Map(ctx.builder, operands, computation,
904 Convert_dimensions(op.dimensions()));
905 return success();
906 }
907
ExportXlaOp(OutfeedOp op,OpLoweringContext ctx)908 LogicalResult ExportXlaOp(OutfeedOp op, OpLoweringContext ctx) {
909 auto& value_map = *ctx.values;
910 xla::XlaOp operand, token;
911 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
912 if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
913
914 value_map[op] = xla::OutfeedWithToken(
915 operand, token, xla::TypeToShape(op.operand().getType()),
916 std::string(op.outfeed_config()));
917 return success();
918 }
919
ExportXlaOp(PadOp op,OpLoweringContext ctx)920 LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) {
921 auto& value_map = *ctx.values;
922 xla::PaddingConfig padding_config;
923 auto edge_padding_low = ConvertDenseIntAttr(op.edge_padding_low());
924 auto edge_padding_high = ConvertDenseIntAttr(op.edge_padding_high());
925 auto interior_padding = ConvertDenseIntAttr(op.interior_padding());
926 for (xla::int64 i = 0, end = edge_padding_low.size(); i < end; ++i) {
927 auto* dims = padding_config.add_dimensions();
928 dims->set_edge_padding_low(edge_padding_low[i]);
929 dims->set_edge_padding_high(edge_padding_high[i]);
930 dims->set_interior_padding(interior_padding[i]);
931 }
932 xla::XlaOp operand, padding_value;
933 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
934 if (failed(GetXlaOp(op.padding_value(), value_map, &padding_value, op)))
935 return failure();
936
937 value_map[op] = xla::Pad(operand, padding_value, padding_config);
938 return success();
939 }
940
ExportXlaOp(RecvOp op,OpLoweringContext ctx)941 LogicalResult ExportXlaOp(RecvOp op, OpLoweringContext ctx) {
942 auto& value_map = *ctx.values;
943 auto result_type = op.getType().cast<mlir::TupleType>().getType(0);
944 xla::XlaOp token;
945 if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
946
947 if (op.is_host_transfer()) {
948 value_map[op] =
949 xla::RecvFromHost(token, xla::TypeToShape(result_type),
950 Convert_channel_handle(op.channel_handle()));
951 return success();
952 }
953 value_map[op] =
954 xla::RecvWithToken(token, xla::TypeToShape(result_type),
955 Convert_channel_handle(op.channel_handle()));
956 return success();
957 }
958
ExportXlaOp(ReduceOp op,OpLoweringContext ctx)959 LogicalResult ExportXlaOp(ReduceOp op, OpLoweringContext ctx) {
960 auto& value_map = *ctx.values;
961 xla::XlaComputation body;
962 if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) {
963 return failure();
964 }
965 llvm::SmallVector<xla::XlaOp> operands, init_values;
966 if (failed(GetTuple(op, op.inputs(), ctx, operands)) ||
967 failed(GetTuple(op, op.init_values(), ctx, init_values))) {
968 return failure();
969 }
970 xla::XlaOp result =
971 xla::Reduce(ctx.builder, operands, init_values, body,
972 Convert_broadcast_dimensions(op.dimensions()));
973 if (op.getNumResults() == 1) {
974 value_map[op.getResult(0)] = result;
975 } else {
976 for (auto item : llvm::enumerate(op.getResults())) {
977 value_map[item.value()] = xla::GetTupleElement(result, item.index());
978 }
979 }
980 return success();
981 }
982
ExportXlaOp(ReduceWindowOp op,OpLoweringContext ctx)983 LogicalResult ExportXlaOp(ReduceWindowOp op, OpLoweringContext ctx) {
984 auto& value_map = *ctx.values;
985 xla::XlaComputation body;
986 if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) {
987 return failure();
988 }
989 llvm::SmallVector<xla::XlaOp> operands, init_values;
990 if (failed(GetTuple(op, op.inputs(), ctx, operands)) ||
991 failed(GetTuple(op, op.init_values(), ctx, init_values))) {
992 return failure();
993 }
994
995 xla::XlaOp result = xla::ReduceWindowWithGeneralPadding(
996 operands, init_values, body, ConvertDenseIntAttr(op.window_dimensions()),
997 ConvertDenseIntAttr(op.window_strides()),
998 ConvertDenseIntAttr(op.base_dilations()),
999 ConvertDenseIntAttr(op.window_dilations()),
1000 Convert_padding(op.padding()));
1001
1002 if (op.getNumResults() == 1) {
1003 value_map[op.getResult(0)] = result;
1004 } else {
1005 for (auto item : llvm::enumerate(op.getResults())) {
1006 value_map[item.value()] = xla::GetTupleElement(result, item.index());
1007 }
1008 }
1009 return success();
1010 }
1011
ExportXlaOp(ReshapeOp op,OpLoweringContext ctx)1012 LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) {
1013 auto& value_map = *ctx.values;
1014 xla::XlaOp operand;
1015 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1016
1017 value_map[op] =
1018 xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions());
1019 return success();
1020 }
1021
ExportXlaOp(ReturnOp op,OpLoweringContext ctx)1022 LogicalResult ExportXlaOp(ReturnOp op, OpLoweringContext ctx) {
1023 // Failure on purpose because `mhlo::ReturnOp` will be handled by
1024 // special purpose logic in `ConvertToHloModule::Lower`.
1025 return failure();
1026 }
1027
ExportXlaOp(RngBitGeneratorOp op,OpLoweringContext ctx)1028 LogicalResult ExportXlaOp(RngBitGeneratorOp op, OpLoweringContext ctx) {
1029 auto& value_map = *ctx.values;
1030 auto result = op.getResult();
1031 auto xla_arg_1 = value_map[*op.getODSOperands(0).begin()];
1032 auto xla_result = xla::RngBitGenerator(
1033 static_cast<xla::RandomAlgorithm>(op.rng_algorithm()), Unwrap(xla_arg_1),
1034 xla::TypeToShape(result.getType()).tuple_shapes(1));
1035 value_map[result] = xla_result;
1036 return mlir::success();
1037 }
1038
ExportXlaOp(RngNormalOp op,OpLoweringContext ctx)1039 LogicalResult ExportXlaOp(RngNormalOp op, OpLoweringContext ctx) {
1040 auto& value_map = *ctx.values;
1041 xla::XlaOp mu, sigma;
1042 if (failed(GetXlaOp(op.mu(), value_map, &mu, op))) return failure();
1043 if (failed(GetXlaOp(op.sigma(), value_map, &sigma, op))) return failure();
1044
1045 value_map[op] = xla::RngNormal(mu, sigma, xla::TypeToShape(op.getType()));
1046 return success();
1047 }
1048
ExportXlaOp(RngUniformOp op,OpLoweringContext ctx)1049 LogicalResult ExportXlaOp(RngUniformOp op, OpLoweringContext ctx) {
1050 auto& value_map = *ctx.values;
1051 xla::XlaOp a, b;
1052 if (failed(GetXlaOp(op.a(), value_map, &a, op))) return failure();
1053 if (failed(GetXlaOp(op.b(), value_map, &b, op))) return failure();
1054
1055 value_map[op] = xla::RngUniform(a, b, xla::TypeToShape(op.getType()));
1056 return success();
1057 }
1058
ExportXlaOp(ScatterOp op,OpLoweringContext ctx)1059 LogicalResult ExportXlaOp(ScatterOp op, OpLoweringContext ctx) {
1060 auto& value_map = *ctx.values;
1061 xla::XlaComputation update_computation;
1062 if (failed(ctx.converter->LowerRegionAsComputation(&op.update_computation(),
1063 &update_computation))) {
1064 return failure();
1065 }
1066 xla::ScatterDimensionNumbers dimension_numbers =
1067 Convert_scatter_dimension_numbers(op.scatter_dimension_numbers());
1068 xla::XlaOp operand, scatter_indices, updates;
1069 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1070 if (failed(GetXlaOp(op.scatter_indices(), value_map, &scatter_indices, op)))
1071 return failure();
1072 if (failed(GetXlaOp(op.updates(), value_map, &updates, op))) return failure();
1073
1074 value_map[op] = xla::Scatter(operand, scatter_indices, updates,
1075 update_computation, dimension_numbers,
1076 op.indices_are_sorted(), op.unique_indices());
1077 return success();
1078 }
1079
ExportXlaOp(SelectAndScatterOp op,OpLoweringContext ctx)1080 LogicalResult ExportXlaOp(SelectAndScatterOp op, OpLoweringContext ctx) {
1081 auto& value_map = *ctx.values;
1082 xla::XlaComputation select;
1083 xla::XlaComputation scatter;
1084 if (failed(ctx.converter->LowerRegionAsComputation(&op.select(), &select)) ||
1085 failed(
1086 ctx.converter->LowerRegionAsComputation(&op.scatter(), &scatter))) {
1087 return failure();
1088 }
1089 xla::XlaOp operand, source, init_value;
1090 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1091 if (failed(GetXlaOp(op.source(), value_map, &source, op))) return failure();
1092 if (failed(GetXlaOp(op.init_value(), value_map, &init_value, op)))
1093 return failure();
1094
1095 value_map[op] = xla::SelectAndScatterWithGeneralPadding(
1096 operand, select, ConvertDenseIntAttr(op.window_dimensions()),
1097 ConvertDenseIntAttr(op.window_strides()), Convert_padding(op.padding()),
1098 source, init_value, scatter);
1099 return success();
1100 }
1101
ExportXlaOp(SendOp op,OpLoweringContext ctx)1102 LogicalResult ExportXlaOp(SendOp op, OpLoweringContext ctx) {
1103 auto& value_map = *ctx.values;
1104 xla::XlaOp operand, token;
1105 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1106 if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
1107
1108 if (op.is_host_transfer()) {
1109 value_map[op] = xla::SendToHost(
1110 operand, token, xla::TypeToShape(op.operand().getType()),
1111 Convert_channel_handle(op.channel_handle()));
1112 return success();
1113 }
1114 value_map[op] = xla::SendWithToken(
1115 operand, token, Convert_channel_handle(op.channel_handle()));
1116 return success();
1117 }
1118
ExportXlaOp(SliceOp op,OpLoweringContext ctx)1119 LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) {
1120 return failure();
1121 }
1122
ExportXlaOp(SortOp op,OpLoweringContext ctx)1123 LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
1124 xla::XlaComputation comparator;
1125 if (failed(ctx.converter->LowerRegionAsComputation(&op.comparator(),
1126 &comparator)))
1127 return failure();
1128
1129 llvm::SmallVector<xla::XlaOp> operands;
1130 if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure();
1131 auto sorted = xla::Sort(operands, comparator, op.dimension(), op.is_stable());
1132
1133 auto& value_map = *ctx.values;
1134 auto shape_or = sorted.builder()->GetShape(sorted);
1135 if (!shape_or.ok()) {
1136 return op.emitError(shape_or.status().ToString());
1137 }
1138
1139 xla::Shape& shape = shape_or.ValueOrDie();
1140 if (!shape.IsTuple()) {
1141 value_map[op.getResult(0)] = sorted;
1142 return success();
1143 }
1144
1145 // MLIR's sort supports multiple returns, untuple all the results of XLA's.
1146 for (auto it : llvm::enumerate(op.getResults())) {
1147 value_map[it.value()] = xla::GetTupleElement(sorted, it.index());
1148 }
1149 return success();
1150 }
1151
ExportXlaOp(TraceOp op,OpLoweringContext ctx)1152 LogicalResult ExportXlaOp(TraceOp op, OpLoweringContext ctx) {
1153 auto& value_map = *ctx.values;
1154 xla::XlaOp operand;
1155 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1156 xla::Trace(std::string(op.tag()), operand);
1157 return success();
1158 }
1159
ExportXlaOp(UnaryEinsumOp op,OpLoweringContext ctx)1160 LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) {
1161 // Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two
1162 // operands.
1163 return failure();
1164 }
1165
ExportXlaOp(WhileOp op,OpLoweringContext ctx)1166 LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
1167 // TODO(jpienaar): Support multi-operand while op.
1168 if (op.getNumResults() != 1)
1169 return op.emitError("nyi: unable to export multi-result While op");
1170 xla::XlaComputation condition;
1171 xla::XlaComputation body;
1172 auto& value_map = *ctx.values;
1173 if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body)) ||
1174 failed(ctx.converter->LowerRegionAsComputation(&op.cond(), &condition))) {
1175 return failure();
1176 }
1177
1178 xla::XlaOp operand;
1179 // TODO(jpienaar): Support multi-operand while op.
1180 Value val = op->getResult(0);
1181 if (failed(GetXlaOp(op->getOperand(0), value_map, &operand, op)))
1182 return failure();
1183 value_map[val] = xla::While(condition, body, operand);
1184 return success();
1185 }
1186
ExportXlaOp(FusionOp op,OpLoweringContext ctx)1187 LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
1188 if (!op.fusion_kind()) {
1189 op.emitOpError() << "requires fusion kind for HLO translation";
1190 return failure();
1191 }
1192
1193 xla::XlaComputation fused_computation;
1194 if (failed(ctx.converter->LowerRegionAsComputation(&op.fused_computation(),
1195 &fused_computation)))
1196 return failure();
1197
1198 auto& values = *ctx.values;
1199 llvm::SmallVector<xla::XlaOp, 4> operands;
1200 for (auto operand : op.operands()) operands.push_back(values[operand]);
1201
1202 xla::XlaOp fusion = xla::internal::XlaBuilderFriend::BuildFusion(
1203 ctx.builder, operands,
1204 absl::string_view(op.fusion_kind()->data(), op.fusion_kind()->size()),
1205 fused_computation);
1206 if (op.getNumResults() == 1) {
1207 values[op.getResult(0)] = fusion;
1208 } else {
1209 for (auto item : llvm::enumerate(op.getResults())) {
1210 values[item.value()] = xla::GetTupleElement(fusion, item.index());
1211 }
1212 }
1213 return success();
1214 }
1215
ExportXlaOp(BitcastOp op,OpLoweringContext ctx)1216 LogicalResult ExportXlaOp(BitcastOp op, OpLoweringContext ctx) {
1217 auto& value_map = *ctx.values;
1218 xla::XlaOp operand;
1219 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1220 xla::XlaOp bitcast = xla::internal::XlaBuilderFriend::BuildBitcast(
1221 ctx.builder, operand, xla::TypeToShape(op.getType()));
1222 value_map[op] = bitcast;
1223 if (ctx.converter->GetOptions().propagate_bitcast_layouts_to_backend_config) {
1224 // Encode the source and result layout of the bitcast into the XLA HLO
1225 // backend config as a protobuf. Note that this is a temporary solution
1226 // which will go away once XLA:GPU stops falling back to XLA HLO Elemental
1227 // IR emitters.
1228 xla::HloInstructionProto* bitcast_proto =
1229 xla::internal::XlaBuilderFriend::GetInstruction(bitcast);
1230 xla::HloInstructionProto* operand_proto =
1231 xla::internal::XlaBuilderFriend::GetInstruction(operand);
1232 xla::LayoutProto result_layout =
1233 ExtractLayout(op, bitcast_proto->shape().dimensions_size(),
1234 "result_layout")
1235 .ToProto();
1236 xla::LayoutProto source_layout =
1237 ExtractLayout(op, operand_proto->shape().dimensions_size(),
1238 "source_layout")
1239 .ToProto();
1240 xla::gpu::BitcastBackendConfig bitcast_config;
1241 *bitcast_config.mutable_source_layout() = source_layout;
1242 *bitcast_config.mutable_result_layout() = result_layout;
1243 *bitcast_proto->mutable_backend_config() =
1244 bitcast_config.SerializeAsString();
1245 }
1246 return success();
1247 }
1248
ExportXlaOp(RealDynamicSliceOp op,OpLoweringContext ctx)1249 LogicalResult ExportXlaOp(RealDynamicSliceOp op, OpLoweringContext ctx) {
1250 return failure();
1251 }
1252
ExportXlaOp(DynamicPadOp op,OpLoweringContext ctx)1253 LogicalResult ExportXlaOp(DynamicPadOp op, OpLoweringContext ctx) {
1254 return failure();
1255 }
1256
ExportXlaOp(DynamicGatherOp op,OpLoweringContext ctx)1257 LogicalResult ExportXlaOp(DynamicGatherOp op, OpLoweringContext ctx) {
1258 return failure();
1259 }
1260
ExportXlaOp(DynamicConvOp op,OpLoweringContext ctx)1261 LogicalResult ExportXlaOp(DynamicConvOp op, OpLoweringContext ctx) {
1262 return failure();
1263 }
1264
ExportXlaOp(PrintOp op,OpLoweringContext ctx)1265 LogicalResult ExportXlaOp(PrintOp op, OpLoweringContext ctx) {
1266 return failure();
1267 }
1268
1269 } // namespace
1270 } // namespace mhlo
1271 } // namespace mlir
1272
1273 #include "tensorflow/compiler/mlir/xla/operator_writers.inc"
1274
1275 namespace mlir {
1276 namespace {
1277
CreateArrayLiteralFromAttr(ElementsAttr attr,xla::Layout layout)1278 StatusOr<xla::Literal> CreateArrayLiteralFromAttr(ElementsAttr attr,
1279 xla::Layout layout) {
1280 if (attr.isa<OpaqueElementsAttr>())
1281 return tensorflow::errors::Unimplemented(
1282 "Opaque elements attr not supported");
1283
1284 xla::Shape shape = xla::TypeToShape(attr.getType());
1285
1286 #define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \
1287 case xla_type: { \
1288 xla::Array<cpp_type> source_data(shape.dimensions()); \
1289 source_data.SetValues(attr.getValues<cpp_type>()); \
1290 return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); \
1291 }
1292
1293 switch (shape.element_type()) {
1294 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::PRED, bool)
1295 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F32, float)
1296 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F64, double)
1297 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S8, int8)
1298 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S16, int16)
1299 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S32, int32)
1300 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S64, int64)
1301 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U8, uint8)
1302 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16)
1303 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32)
1304 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U64, uint64)
1305 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C64, std::complex<float>)
1306 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C128, std::complex<double>)
1307 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F16, Eigen::half)
1308 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::BF16, Eigen::bfloat16)
1309 default:
1310 return tensorflow::errors::Internal(absl::StrCat(
1311 "Unsupported type: ", xla::PrimitiveType_Name(shape.element_type())));
1312 }
1313 #undef ELEMENTS_ATTR_TO_LITERAL
1314 }
1315
ConvertLayout(mlir::Operation * op,const mlir::ArrayAttr & layout,xla::ShapeProto * shape)1316 LogicalResult ConvertLayout(mlir::Operation* op, const mlir::ArrayAttr& layout,
1317 xla::ShapeProto* shape) {
1318 // In the case of tuples, ShapeProtos can be nested, and so can the mlir
1319 // attribute describing the layout. So recurse into the subshapes in both data
1320 // structures in parallel.
1321 if (shape->element_type() == xla::TUPLE) {
1322 auto subshapes = shape->mutable_tuple_shapes();
1323 if (layout.size() != subshapes->size()) {
1324 op->emitOpError() << "Expected layout of size " << layout.size()
1325 << ", but found " << subshapes->size();
1326 return failure();
1327 }
1328 for (int i = 0; i < subshapes->size(); i++) {
1329 mlir::Attribute child = layout[i];
1330 if (child.isa<mlir::UnitAttr>()) {
1331 // ignore unit attributes, they are used only for tokens.
1332 continue;
1333 }
1334 mlir::ArrayAttr c = child.dyn_cast<mlir::ArrayAttr>();
1335 if (!c) {
1336 op->emitOpError() << "Type Error: Expected layout array attribute";
1337 return failure();
1338 }
1339 if (failed(ConvertLayout(op, c, subshapes->Mutable(i)))) {
1340 return failure();
1341 }
1342 }
1343 } else {
1344 int rank = shape->dimensions().size();
1345 if (rank) {
1346 if (layout.size() != rank) {
1347 return failure(); // pass error down
1348 }
1349 std::vector<int64> array(rank);
1350 for (int i = 0; i < rank; i++) {
1351 mlir::IntegerAttr attr = layout[i].dyn_cast<mlir::IntegerAttr>();
1352 if (!attr) {
1353 op->emitOpError() << "Type Error: Expected layout integer attribute";
1354 return failure();
1355 }
1356 array[i] = attr.getInt();
1357 }
1358 *shape->mutable_layout() = xla::LayoutUtil::MakeLayout(array).ToProto();
1359 }
1360 }
1361 return success();
1362 }
1363
1364 // MHLO and XLA HLO disagree on the meaning of addition of `pred` / `i1`, so
1365 // there has to be a special case somewhere to account for the difference. To
1366 // get the expected behavior of an `AddOp` on `i1`, we have to use `xor`. Since
1367 // the majority of the conversion is generated code, we just sidestep it here
1368 // for this single case, and inline the code to emit an `xor`.
ExportXlaOperatorWrapped(mlir::Operation * inst,OpLoweringContext ctx)1369 LogicalResult ExportXlaOperatorWrapped(mlir::Operation* inst,
1370 OpLoweringContext ctx) {
1371 auto op = dyn_cast<mlir::mhlo::AddOp>(inst);
1372 if (op && op.getResult()
1373 .getType()
1374 .cast<mlir::TensorType>()
1375 .getElementType()
1376 .isSignlessInteger(1)) {
1377 auto& value_map = *ctx.values;
1378 auto result = op.getResult();
1379 xla::XlaOp xla_arg_0;
1380 if (failed(GetXlaOp(op.lhs(), value_map, &xla_arg_0, op)))
1381 return mlir::failure();
1382 xla::XlaOp xla_arg_1;
1383 if (failed(GetXlaOp(op.rhs(), value_map, &xla_arg_1, op)))
1384 return mlir::failure();
1385 auto xla_result = xla::Xor(Unwrap(xla_arg_0), Unwrap(xla_arg_1));
1386 value_map[result] = xla_result;
1387 return mlir::success();
1388 }
1389
1390 return ExportXlaOperator(inst, ctx);
1391 }
1392
Lower(mlir::Operation * inst,bool is_entry_function,llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,xla::XlaBuilder * builder,ConvertToHloModule::ValueLoweringMap * value_lowering,xla::XlaOp * return_value)1393 LogicalResult ConvertToHloModule::Lower(
1394 mlir::Operation* inst, bool is_entry_function,
1395 llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
1396 xla::XlaBuilder* builder,
1397 ConvertToHloModule::ValueLoweringMap* value_lowering,
1398 xla::XlaOp* return_value) {
1399 // Explicitly fail for ops that are not supported for export.
1400 if (inst->getDialect() !=
1401 inst->getContext()->getLoadedDialect<mlir::mhlo::MhloDialect>() &&
1402 !mlir::isa<mlir::ConstantOp, mlir::CallOp, mlir::tensor::CastOp,
1403 mlir::ReturnOp>(inst)) {
1404 inst->emitOpError("unsupported op for export to XLA");
1405 return failure();
1406 }
1407
1408 *return_value = xla::XlaOp();
1409
1410 // See MlirToHloConversionOptions for more about layouts.
1411 auto propagate_layouts = [this](mlir::Operation* inst,
1412 xla::XlaOp xla_op) -> mlir::LogicalResult {
1413 if (options_.propagate_layouts) {
1414 auto* shape = xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
1415 ->mutable_shape();
1416 if (shape->tuple_shapes().empty())
1417 // TODO(kramm): merge this with ConvertLayout.
1418 *shape->mutable_layout() =
1419 ExtractLayout(inst, shape->dimensions().size()).ToProto();
1420 }
1421
1422 // For infeed ops stemming back to InfeedDequeueTuple, respect the layout
1423 // attribute, and create the corresponding layout in hlo.
1424 if (isa<mhlo::InfeedOp>(inst)) {
1425 mlir::ArrayAttr layout =
1426 inst->getAttrOfType<mlir::ArrayAttr>(kLayoutAttr);
1427 if (layout) {
1428 xla::ShapeProto* shape =
1429 xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
1430 ->mutable_shape();
1431
1432 if (failed(ConvertLayout(inst, layout, shape))) {
1433 return failure();
1434 }
1435 }
1436 }
1437 return success();
1438 };
1439
1440 if (succeeded(
1441 ExportXlaOperatorWrapped(inst, {value_lowering, this, builder}))) {
1442 if (inst->getNumResults() == 1) {
1443 auto iter = value_lowering->find(inst->getResult(0));
1444 if (iter == value_lowering->end()) {
1445 inst->emitOpError(
1446 "inst has a result, but it's not found in value_lowering");
1447 return failure();
1448 }
1449 if (failed(propagate_layouts(inst, iter->second))) {
1450 return failure();
1451 }
1452 }
1453 return success();
1454 }
1455
1456 auto& value_map = *value_lowering;
1457 ElementsAttr const_attr;
1458
1459 if (auto call_op = dyn_cast<mlir::CallOp>(inst)) {
1460 return LowerFunctionCall(call_op, builder, &value_map);
1461 }
1462
1463 if (auto op = dyn_cast<mlir::tensor::CastOp>(inst)) {
1464 Value operand = op.getOperand();
1465 auto ty = operand.getType().dyn_cast<ShapedType>();
1466 // If this was a cast from a static shaped tensors, then it is a noop for
1467 // export to HLO and we can use the operand.
1468 if (!ty || !ty.hasStaticShape()) {
1469 inst->emitOpError()
1470 << "requires static shaped operand for HLO translation";
1471 return failure();
1472 }
1473
1474 xla::XlaOp xla_operand;
1475 if (failed(GetXlaOp(operand, value_map, &xla_operand, op)))
1476 return failure();
1477 value_map[op.getResult()] = xla_operand;
1478 if (failed(propagate_layouts(inst, xla_operand))) {
1479 return failure();
1480 }
1481 return success();
1482 }
1483
1484 if (matchPattern(inst, m_Constant(&const_attr))) {
1485 xla::Layout layout;
1486 layout = ExtractLayout(inst, const_attr.getType().getRank());
1487 auto literal_or = CreateArrayLiteralFromAttr(const_attr, layout);
1488 if (!literal_or.ok())
1489 return inst->emitError(literal_or.status().ToString());
1490 auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie());
1491 value_map[inst->getResult(0)] = constant;
1492
1493 return success();
1494 }
1495
1496 if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
1497 // Construct the return value for the function. If there is a single value
1498 // returned, then return it directly, else create a tuple and return.
1499 unsigned num_return_values = inst->getNumOperands();
1500 if ((return_tuple_ && is_entry_function) || num_return_values != 1) {
1501 const bool has_ret_shardings =
1502 !ret_shardings.empty() && AllOptionalShardingsAreSet(ret_shardings);
1503
1504 std::vector<xla::XlaOp> returns(num_return_values);
1505 for (OpOperand& ret : inst->getOpOperands()) {
1506 unsigned index = ret.getOperandNumber();
1507 xla::XlaOp operand;
1508 if (failed(GetXlaOp(ret.get(), value_map, &operand, inst)))
1509 return failure();
1510
1511 returns[index] = operand;
1512 if (!is_entry_function || !has_ret_shardings) continue;
1513
1514 xla::Shape return_shape = xla::TypeToShape(ret.get().getType());
1515 StatusOr<xla::XlaOp> reshape =
1516 tensorflow::ReshapeWithCorrectRepresentationAndSharding(
1517 builder, returns[index], return_shape, shape_representation_fn_,
1518 ret_shardings[index], /*fast_mem=*/false);
1519 if (!reshape.ok())
1520 return inst->emitError() << reshape.status().error_message();
1521
1522 returns[index] = reshape.ValueOrDie();
1523 }
1524
1525 if (has_ret_shardings) {
1526 xla::OpSharding sharding;
1527 sharding.set_type(xla::OpSharding::TUPLE);
1528 for (auto& ret_sharding : ret_shardings)
1529 *sharding.add_tuple_shardings() = *ret_sharding;
1530
1531 builder->SetSharding(sharding);
1532 }
1533
1534 *return_value = xla::Tuple(builder, returns);
1535 builder->ClearSharding();
1536 } else if (num_return_values == 1) {
1537 xla::XlaOp operand;
1538 if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst)))
1539 return failure();
1540
1541 *return_value = operand;
1542 }
1543
1544 return success();
1545 }
1546
1547 inst->emitOpError() << "can't be translated to XLA HLO";
1548 return failure();
1549 }
1550
LowerFunctionCall(mlir::CallOp call_op,xla::XlaBuilder * builder,ConvertToHloModule::ValueLoweringMap * value_lowering)1551 LogicalResult ConvertToHloModule::LowerFunctionCall(
1552 mlir::CallOp call_op, xla::XlaBuilder* builder,
1553 ConvertToHloModule::ValueLoweringMap* value_lowering) {
1554 auto& value_map = *value_lowering;
1555 mlir::FuncOp callee = module_.lookupSymbol<mlir::FuncOp>(call_op.callee());
1556 if (failed(RunOnFunction(callee))) return failure();
1557 std::vector<xla::XlaOp> operands;
1558 for (auto operand : call_op.getOperands()) {
1559 xla::XlaOp xla_operand;
1560 if (failed(GetXlaOp(operand, value_map, &xla_operand, call_op)))
1561 return failure();
1562 operands.push_back(xla_operand);
1563 }
1564 // Each call to xla::Call would insert a copy of the computation to
1565 // the HLO. Thus each callsite would have a unique callee in the
1566 // exported HLO. HLO syntactically does not require all calls to have unique
1567 // callees, but eventually before lowering call graph is "flattened" to
1568 // make that true. This is done before lowering because buffer assignment
1569 // needs this invariant.
1570 xla::XlaOp call_result =
1571 xla::Call(builder, lowered_computation_[callee], operands);
1572 // Use GetTupleElement for multiple outputs
1573 unsigned num_results = call_op.getNumResults();
1574 if (num_results > 1) {
1575 for (unsigned i = 0; i != num_results; ++i) {
1576 value_map[call_op.getResult(i)] = xla::GetTupleElement(call_result, i);
1577 }
1578 } else if (num_results == 1) {
1579 value_map[call_op.getResult(0)] = call_result;
1580 }
1581 return success();
1582 }
1583
RunOnFunction(mlir::FuncOp f)1584 LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
1585 if (lowered_computation_.count(f)) return success();
1586 if (!llvm::hasSingleElement(f)) {
1587 return f.emitError("only single block Function supported");
1588 }
1589
1590 // Create a sub-builder if this is not the main function.
1591 std::unique_ptr<xla::XlaBuilder> builder_up;
1592 bool entry_function = f.getName() == "main";
1593 if (!entry_function)
1594 builder_up = module_builder_.CreateSubBuilder(f.getName().str());
1595 auto& builder = entry_function ? module_builder_ : *builder_up;
1596
1597 xla::XlaComputation computation;
1598 std::vector<bool> entry_args_same_across_replicas;
1599 llvm::SmallVector<absl::optional<xla::OpSharding>, 4> arg_shardings;
1600 llvm::SmallVector<absl::optional<xla::OpSharding>, 4> ret_shardings;
1601 if (entry_function) {
1602 bool any_arg_replicated = false;
1603 entry_args_same_across_replicas.reserve(f.getNumArguments());
1604 for (int64_t i = 0; i < f.getNumArguments(); ++i) {
1605 auto attr = f.getArgAttrOfType<mlir::UnitAttr>(i, kReplicationAttr);
1606 entry_args_same_across_replicas.push_back(attr != nullptr);
1607 any_arg_replicated |= entry_args_same_across_replicas.back();
1608 // Pass the alias info to the builder so that it will build the alias info
1609 // into the resulting HloModule.
1610 auto aliasing_output =
1611 f.getArgAttrOfType<mlir::IntegerAttr>(i, "tf.aliasing_output");
1612 if (!aliasing_output) continue;
1613 if (use_tuple_args_) {
1614 builder.SetUpAlias(/*output_index=*/{aliasing_output.getInt()},
1615 /*param_number=*/0, /*param_index=*/{i});
1616 } else {
1617 builder.SetUpAlias(/*output_index=*/{aliasing_output.getInt()},
1618 /*param_number=*/i, /*param_index=*/{});
1619 }
1620 }
1621 // Do not populate this field when nothing is replicated, since empty field
1622 // means no replication. This avoids the need for unrelated tests to handle
1623 // this field.
1624 if (!any_arg_replicated) entry_args_same_across_replicas.clear();
1625
1626 ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings);
1627 }
1628 if (failed(LowerBasicBlockAsFunction(
1629 &f.front(), &builder, entry_function, entry_args_same_across_replicas,
1630 arg_shardings, ret_shardings, &computation))) {
1631 return failure();
1632 }
1633 lowered_computation_[f] = std::move(computation);
1634 return success();
1635 }
1636
SetEntryTupleShapesAndLeafReplication(Block * block,const std::vector<bool> & entry_args_same_across_replicas,llvm::SmallVectorImpl<xla::Shape> * arg_shapes,std::vector<bool> * leaf_replication)1637 LogicalResult ConvertToHloModule::SetEntryTupleShapesAndLeafReplication(
1638 Block* block, const std::vector<bool>& entry_args_same_across_replicas,
1639 llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
1640 std::vector<bool>* leaf_replication) {
1641 arg_shapes->reserve(block->getNumArguments());
1642 leaf_replication->reserve(block->getNumArguments());
1643 for (BlockArgument& arg : block->getArguments()) {
1644 arg_shapes->push_back(xla::TypeToShape(arg.getType()));
1645 xla::Shape& arg_shape = arg_shapes->back();
1646 tensorflow::TensorShape arg_tensor_shape;
1647 auto status =
1648 tensorflow::XLAShapeToTensorShape(arg_shape, &arg_tensor_shape);
1649 if (!status.ok())
1650 return block->getParentOp()->emitError() << status.error_message();
1651
1652 tensorflow::DataType dtype;
1653 status = tensorflow::ConvertToDataType(arg.getType(), &dtype);
1654 if (!status.ok())
1655 return block->getParentOp()->emitError() << status.error_message();
1656
1657 auto arg_shape_status = shape_representation_fn_(arg_tensor_shape, dtype,
1658 /*use_fast_memory=*/false);
1659 if (!arg_shape_status.ok())
1660 return block->getParentOp()->emitError()
1661 << arg_shape_status.status().error_message();
1662
1663 arg_shape = std::move(arg_shape_status.ValueOrDie());
1664
1665 if (entry_args_same_across_replicas.empty()) continue;
1666 for (int i = 0, e = xla::ShapeUtil::GetLeafCount(arg_shape); i < e; ++i)
1667 leaf_replication->push_back(
1668 entry_args_same_across_replicas[arg.getArgNumber()]);
1669 }
1670
1671 return success();
1672 }
1673
SetEntryTupleShardings(Block * block,xla::XlaBuilder * builder,llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,llvm::SmallVectorImpl<xla::Shape> * arg_shapes)1674 LogicalResult ConvertToHloModule::SetEntryTupleShardings(
1675 Block* block, xla::XlaBuilder* builder,
1676 llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
1677 llvm::SmallVectorImpl<xla::Shape>* arg_shapes) {
1678 if (!arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings)) {
1679 xla::OpSharding sharding;
1680 sharding.set_type(xla::OpSharding::TUPLE);
1681 for (auto arg_sharding : llvm::enumerate(arg_shardings)) {
1682 auto hlo_sharding = xla::HloSharding::FromProto(*arg_sharding.value());
1683 if (!hlo_sharding.ok())
1684 return block->getParentOp()->emitError()
1685 << hlo_sharding.status().error_message();
1686
1687 auto status = tensorflow::RewriteLayoutWithShardedShape(
1688 hlo_sharding.ValueOrDie(), /*use_fast_memory=*/false,
1689 shape_representation_fn_, &(*arg_shapes)[arg_sharding.index()]);
1690 if (!status.ok())
1691 return block->getParentOp()->emitError() << status.error_message();
1692
1693 *sharding.add_tuple_shardings() = *arg_sharding.value();
1694 }
1695
1696 builder->SetSharding(sharding);
1697 }
1698
1699 return success();
1700 }
1701
LowerBasicBlockAsFunction(Block * block,xla::XlaBuilder * builder,bool is_entry_function,const std::vector<bool> & entry_args_same_across_replicas,llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,xla::XlaComputation * result)1702 LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
1703 Block* block, xla::XlaBuilder* builder, bool is_entry_function,
1704 const std::vector<bool>& entry_args_same_across_replicas,
1705 llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
1706 llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
1707 xla::XlaComputation* result) {
1708 // Mapping from the Value to lowered XlaOp.
1709 ValueLoweringMap lowering;
1710
1711 // If using tuples as input, then there is only one input parameter that is a
1712 // tuple.
1713 if (is_entry_function && use_tuple_args_) {
1714 llvm::SmallVector<xla::Shape, 4> arg_shapes;
1715 std::vector<bool> leaf_replication;
1716 if (failed(SetEntryTupleShapesAndLeafReplication(
1717 block, entry_args_same_across_replicas, &arg_shapes,
1718 &leaf_replication)))
1719 return failure();
1720
1721 if (failed(
1722 SetEntryTupleShardings(block, builder, arg_shardings, &arg_shapes)))
1723 return failure();
1724
1725 xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
1726 auto tuple =
1727 xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication);
1728 builder->ClearSharding();
1729
1730 bool set_tuple_element_sharding =
1731 !arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings);
1732 for (BlockArgument& arg : block->getArguments()) {
1733 if (set_tuple_element_sharding)
1734 builder->SetSharding(*arg_shardings[arg.getArgNumber()]);
1735 lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber());
1736 }
1737 builder->ClearSharding();
1738 } else {
1739 for (BlockArgument& arg : block->getArguments()) {
1740 auto num = arg.getArgNumber();
1741 xla::Shape shape = xla::TypeToShape(arg.getType());
1742 if (entry_args_same_across_replicas.empty()) {
1743 lowering[arg] =
1744 xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num));
1745 } else {
1746 lowering[arg] = xla::Parameter(
1747 builder, num, shape, absl::StrCat("Arg_", num),
1748 std::vector<bool>(entry_args_same_across_replicas[num],
1749 xla::ShapeUtil::GetLeafCount(shape)));
1750 }
1751 }
1752 }
1753
1754 xla::XlaOp return_value;
1755 for (auto& inst : *block)
1756 if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
1757 &lowering, &return_value)))
1758 return failure();
1759
1760 // Build the XlaComputation and check for failures.
1761 auto computation_or =
1762 return_value.valid() ? builder->Build(return_value) : builder->Build();
1763 if (!computation_or.ok()) {
1764 block->back().emitError(
1765 llvm::Twine(computation_or.status().error_message()));
1766 return failure();
1767 }
1768 *result = std::move(computation_or.ValueOrDie());
1769 return success();
1770 }
1771
LowerRegionAsComputation(mlir::Region * region,xla::XlaComputation * func)1772 LogicalResult ConvertToHloModule::LowerRegionAsComputation(
1773 mlir::Region* region, xla::XlaComputation* func) {
1774 std::unique_ptr<xla::XlaBuilder> builder =
1775 module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++));
1776 return LowerBasicBlockAsFunction(
1777 ®ion->front(), builder.get(),
1778 /*is_entry_function=*/false, /*entry_args_same_across_replicas=*/{},
1779 /*arg_shardings=*/{}, /*ret_shardings=*/{}, func);
1780 }
1781
AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto * binding,int arg_index,int32_t shape_index,int32_t padding_arg_index,bool use_tuple_args)1782 void AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto* binding,
1783 int arg_index, int32_t shape_index,
1784 int32_t padding_arg_index,
1785 bool use_tuple_args) {
1786 auto* entry = binding->add_entries();
1787 entry->set_target_param_dim_num(shape_index);
1788 if (use_tuple_args) {
1789 entry->set_target_param_num(0);
1790 entry->add_target_param_index(arg_index);
1791 entry->set_dynamic_param_num(0);
1792 entry->add_dynamic_param_index(padding_arg_index);
1793 } else {
1794 entry->set_target_param_num(arg_index);
1795 entry->set_dynamic_param_num(padding_arg_index);
1796 }
1797 }
1798
1799 // Runs the PrepareForExport pass on the ModuleOp.
PrepareForExport(mlir::ModuleOp module)1800 Status PrepareForExport(mlir::ModuleOp module) {
1801 // Prepare for export to XLA HLO.
1802 mlir::PassManager pm(module.getContext());
1803 pm.addNestedPass<mlir::FuncOp>(mhlo::CreatePrepareForExport());
1804 if (failed(pm.run(module)))
1805 return tensorflow::errors::Internal("Unable to optimize for XLA export");
1806 return Status::OK();
1807 }
1808
1809 } // namespace
1810
ConvertRegionToComputation(mlir::Region * region,xla::XlaComputation * func,MlirToHloConversionOptions options)1811 Status ConvertRegionToComputation(mlir::Region* region,
1812 xla::XlaComputation* func,
1813 MlirToHloConversionOptions options) {
1814 mlir::ModuleOp module;
1815 xla::XlaBuilder module_builder("main");
1816 ConvertToHloModule converter(module, module_builder, true, true, {}, options);
1817 if (failed(converter.LowerRegionAsComputation(region, func)))
1818 return tensorflow::errors::Internal(
1819 "failed to convert region to computation");
1820 return Status::OK();
1821 }
1822
ConvertMlirHloToHlo(mlir::ModuleOp module,xla::HloProto * hlo_proto,bool use_tuple_args,bool return_tuple,const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,MlirToHloConversionOptions options)1823 Status ConvertMlirHloToHlo(
1824 mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args,
1825 bool return_tuple,
1826 const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
1827 MlirToHloConversionOptions options) {
1828 TF_RETURN_IF_ERROR(PrepareForExport(module));
1829 mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
1830 xla::XlaBuilder module_builder("main");
1831 ConvertToHloModule converter(module, module_builder, use_tuple_args,
1832 return_tuple, shape_representation_fn, options);
1833 if (failed(converter.Run())) return diag_handler.ConsumeStatus();
1834 auto hlo_module = converter.ConsumeMainProto();
1835 hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
1836 return Status::OK();
1837 }
1838
BuildHloFromMlirHlo(mlir::Block & block,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,MlirToHloConversionOptions options)1839 Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
1840 llvm::ArrayRef<xla::XlaOp> xla_params,
1841 std::vector<xla::XlaOp>& returns,
1842 MlirToHloConversionOptions options) {
1843 auto module = block.getParentOp()->getParentOfType<mlir::ModuleOp>();
1844 TF_RETURN_IF_ERROR(PrepareForExport(module));
1845 ConvertToHloModule converter(module, builder,
1846 /*use_tuple_args=*/false, /*return_tuple=*/false,
1847 /*shape_representation_fn=*/nullptr, options);
1848
1849 ConvertToHloModule::ValueLoweringMap lowering;
1850 if (xla_params.size() != block.getArguments().size())
1851 return tensorflow::errors::Internal(
1852 "xla_params size != block arguments size");
1853 for (BlockArgument& arg : block.getArguments()) {
1854 auto num = arg.getArgNumber();
1855 lowering[arg] = xla_params[num];
1856 }
1857
1858 mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
1859 for (auto& inst : block) {
1860 if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
1861 returns.resize(inst.getNumOperands());
1862 for (OpOperand& ret : inst.getOpOperands()) {
1863 unsigned index = ret.getOperandNumber();
1864 xla::XlaOp operand;
1865 if (failed(GetXlaOp(ret.get(), lowering, &operand, &inst)))
1866 return diag_handler.ConsumeStatus();
1867 returns[index] = operand;
1868 }
1869 } else {
1870 xla::XlaOp return_value;
1871 if (failed(converter.Lower(&inst, /*is_entry_function=*/true,
1872 /*ret_shardings=*/{}, &builder, &lowering,
1873 &return_value)))
1874 return diag_handler.ConsumeStatus();
1875 }
1876 }
1877
1878 return Status::OK();
1879 }
1880
1881 } // namespace mlir
1882