1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/MemoryBuffer.h"
28 #include "llvm/Support/SMLoc.h"
29 #include "llvm/Support/SourceMgr.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
32 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
33 #include "mlir/IR/Attributes.h" // from @llvm-project
34 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
35 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
36 #include "mlir/IR/Location.h" // from @llvm-project
37 #include "mlir/IR/MLIRContext.h" // from @llvm-project
38 #include "mlir/IR/Matchers.h" // from @llvm-project
39 #include "mlir/IR/Operation.h" // from @llvm-project
40 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
41 #include "mlir/IR/UseDefLists.h" // from @llvm-project
42 #include "mlir/Pass/Pass.h" // from @llvm-project
43 #include "mlir/Pass/PassManager.h" // from @llvm-project
44 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
46 #include "tensorflow/compiler/mlir/utils/name_utils.h"
47 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
48 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
49 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
50 #include "tensorflow/compiler/tf2xla/shape_util.h"
51 #include "tensorflow/compiler/xla/client/lib/matrix.h"
52 #include "tensorflow/compiler/xla/client/lib/quantize.h"
53 #include "tensorflow/compiler/xla/client/lib/slicing.h"
54 #include "tensorflow/compiler/xla/client/xla_builder.h"
55 #include "tensorflow/compiler/xla/comparison_util.h"
56 #include "tensorflow/compiler/xla/literal_util.h"
57 #include "tensorflow/compiler/xla/service/hlo_module.h"
58 #include "tensorflow/compiler/xla/shape_util.h"
59 #include "tensorflow/compiler/xla/status_macros.h"
60 #include "tensorflow/compiler/xla/xla_data.pb.h"
61 #include "tensorflow/core/framework/tensor_shape.h"
62 #include "tensorflow/core/framework/types.pb.h"
63 #include "tensorflow/core/platform/errors.h"
64 #include "tensorflow/stream_executor/lib/statusor.h"
65
66 using ::stream_executor::port::StatusOr;
67 using ::tensorflow::int16;
68 using ::tensorflow::int32;
69 using ::tensorflow::int64;
70 using ::tensorflow::int8;
71 using ::tensorflow::uint16;
72 using ::tensorflow::uint32;
73 using ::tensorflow::uint64;
74 using ::tensorflow::uint8;
75
76 constexpr char kPaddingMapAttr[] = "mhlo.padding_map";
77 constexpr char kShapeIndicesAttr[] = "shape_indices";
78 constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
79 constexpr char kShardingAttr[] = "mhlo.sharding";
80 constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
81 constexpr char kRepicationAttr[] = "mhlo.is_same_data_across_replicas";
82
83 // Array attribute. Same shape as infeed result, but contains a
84 // minor_to_major array for every tensor.
85 constexpr char kLayoutAttr[] = "layout";
86
87 // Passes through everything except for unique_ptr, on which it calls get().
88 // This exists to allow the generated code to call XLA functions that take a raw
89 // pointer. In particular, PrecisionConfig is passed to xla::Dot and xla::Conv
90 // as a pointer and there is otherwise no way to avoid a memory leak.
91 template <typename T>
Unwrap(T t)92 T Unwrap(T t) {
93 return t;
94 }
95
96 template <typename T>
Unwrap(const std::unique_ptr<T> & t)97 T* Unwrap(const std::unique_ptr<T>& t) {
98 return t.get();
99 }
100
GetXlaOp(mlir::Value val,const llvm::DenseMap<mlir::Value,xla::XlaOp> & val_map,xla::XlaOp * result,mlir::Operation * op)101 static mlir::LogicalResult GetXlaOp(
102 mlir::Value val, const llvm::DenseMap<mlir::Value, xla::XlaOp>& val_map,
103 xla::XlaOp* result, mlir::Operation* op) {
104 auto iter = val_map.find(val);
105 if (iter == val_map.end()) {
106 return op->emitOpError(
107 "requires all operands to be defined in the parent region for export");
108 }
109 *result = iter->second;
110 return mlir::success();
111 }
112
113 // Convert APInt into an int.
114 // TODO(hpucha): This should be consolidated into a general place.
ConvertAPInt(llvm::APInt i)115 static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); }
116
Convertuint32_t(uint32_t i)117 static uint32_t Convertuint32_t(uint32_t i) { return i; }
Convertuint64_t(uint64_t i)118 static uint64_t Convertuint64_t(uint64_t i) { return i; }
119
120 // Convert APFloat to double.
ConvertAPFloat(llvm::APFloat value)121 static double ConvertAPFloat(llvm::APFloat value) {
122 const auto& semantics = value.getSemantics();
123 bool losesInfo = false;
124 if (&semantics != &llvm::APFloat::IEEEdouble())
125 value.convert(llvm::APFloat::IEEEdouble(),
126 llvm::APFloat::rmNearestTiesToEven, &losesInfo);
127 return value.convertToDouble();
128 }
129
Convertbool(bool value)130 static inline bool Convertbool(bool value) { return value; }
131
ConvertStringRef(mlir::StringRef value)132 static absl::string_view ConvertStringRef(mlir::StringRef value) {
133 return {value.data(), value.size()};
134 }
135
ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr)136 static std::vector<int64> ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) {
137 auto values = attr.getValues<int64>();
138 return {values.begin(), values.end()};
139 }
140
ConvertDenseIntAttr(llvm::Optional<mlir::DenseIntElementsAttr> attr)141 static std::vector<int64> ConvertDenseIntAttr(
142 llvm::Optional<mlir::DenseIntElementsAttr> attr) {
143 if (!attr) return {};
144 return ConvertDenseIntAttr(*attr);
145 }
146
147 // Converts the broadcast_dimensions attribute into a vector of dimension
148 // numbers (empty if the attribute is absent).
Convert_broadcast_dimensions(llvm::Optional<mlir::DenseIntElementsAttr> broadcast_dimensions)149 static std::vector<int64> Convert_broadcast_dimensions(
150 llvm::Optional<mlir::DenseIntElementsAttr> broadcast_dimensions) {
151 if (!broadcast_dimensions.hasValue()) return {};
152
153 return ConvertDenseIntAttr(*broadcast_dimensions);
154 }
155
156 // Converts StringRef to xla FftType enum
Convert_fft_type(llvm::StringRef fft_type_str)157 static xla::FftType Convert_fft_type(llvm::StringRef fft_type_str) {
158 xla::FftType fft_type_enum;
159 // Illegal fft_type string would be caught by the verifier, so 'FftType_Parse'
160 // call below should never return false.
161 if (!FftType_Parse(std::string(fft_type_str), &fft_type_enum))
162 return xla::FftType::FFT;
163 return fft_type_enum;
164 }
165
Convert_padding(llvm::Optional<mlir::DenseIntElementsAttr> padding)166 static std::vector<std::pair<int64, int64>> Convert_padding(
167 llvm::Optional<mlir::DenseIntElementsAttr> padding) {
168 return xla::ConvertNx2Attribute(padding).ValueOrDie();
169 }
170
Convert_source_target_pairs(llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs)171 static std::vector<std::pair<int64, int64>> Convert_source_target_pairs(
172 llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs) {
173 return xla::ConvertNx2Attribute(source_target_pairs).ValueOrDie();
174 }
175
Convert_replica_groups(mlir::DenseIntElementsAttr groups)176 static std::vector<xla::ReplicaGroup> Convert_replica_groups(
177 mlir::DenseIntElementsAttr groups) {
178 return xla::ConvertReplicaGroups(groups).ValueOrDie();
179 }
180
181 // Converts StringRef to xla Transpose enum.
Convert_transpose_a(llvm::StringRef transpose_str)182 static xla::TriangularSolveOptions::Transpose Convert_transpose_a(
183 llvm::StringRef transpose_str) {
184 return xla::ConvertTranspose(transpose_str).ValueOrDie();
185 }
186
187 #define I64_ELEMENTS_ATTR_TO_VECTOR(attribute) \
188 static std::vector<int64> Convert_##attribute( \
189 llvm::Optional<mlir::DenseIntElementsAttr> attribute) { \
190 return ConvertDenseIntAttr(attribute); \
191 }
192
193 I64_ELEMENTS_ATTR_TO_VECTOR(broadcast_sizes);
194 I64_ELEMENTS_ATTR_TO_VECTOR(permutation);
195 I64_ELEMENTS_ATTR_TO_VECTOR(start_indices);
196 I64_ELEMENTS_ATTR_TO_VECTOR(limit_indices);
197 I64_ELEMENTS_ATTR_TO_VECTOR(strides);
198 I64_ELEMENTS_ATTR_TO_VECTOR(slice_sizes);
199 I64_ELEMENTS_ATTR_TO_VECTOR(fft_length);
200 I64_ELEMENTS_ATTR_TO_VECTOR(dimensions);
201 I64_ELEMENTS_ATTR_TO_VECTOR(window_strides);
202 I64_ELEMENTS_ATTR_TO_VECTOR(lhs_dilation);
203 I64_ELEMENTS_ATTR_TO_VECTOR(rhs_dilation);
204
205 #undef I64_ELEMENTS_ATTR_TO_VECTOR
206
Convert_ArrayRef(llvm::ArrayRef<int64_t> values)207 static std::vector<int64> Convert_ArrayRef(llvm::ArrayRef<int64_t> values) {
208 return {values.begin(), values.end()};
209 }
210
211 // Converts the precision config array of strings attribute into the
212 // corresponding XLA proto. All the strings are assumed to be valid names of the
213 // Precision enum. This should have been checked in the op verify method.
Convert_precision_config(llvm::Optional<mlir::ArrayAttr> optional_precision_config_attr)214 static std::unique_ptr<xla::PrecisionConfig> Convert_precision_config(
215 llvm::Optional<mlir::ArrayAttr> optional_precision_config_attr) {
216 if (!optional_precision_config_attr.hasValue()) return nullptr;
217
218 auto precision_config = absl::make_unique<xla::PrecisionConfig>();
219 for (auto attr : optional_precision_config_attr.getValue()) {
220 xla::PrecisionConfig::Precision p;
221 auto operand_precision = attr.cast<mlir::StringAttr>().getValue().str();
222 // TODO(jpienaar): Update this to ensure this is captured by verify.
223 if (xla::PrecisionConfig::Precision_Parse(operand_precision, &p)) {
224 precision_config->add_operand_precision(p);
225 } else {
226 auto* context = attr.getContext();
227 mlir::emitError(mlir::UnknownLoc::get(context))
228 << "unexpected operand precision " << operand_precision;
229 return nullptr;
230 }
231 }
232
233 return precision_config;
234 }
235
Convert_dot_dimension_numbers(mlir::mhlo::DotDimensionNumbers dot_dimension_numbers_attr)236 static xla::DotDimensionNumbers Convert_dot_dimension_numbers(
237 mlir::mhlo::DotDimensionNumbers dot_dimension_numbers_attr) {
238 xla::DotDimensionNumbers dot_dimension_numbers;
239
240 auto rhs_contracting_dimensions =
241 dot_dimension_numbers_attr.rhs_contracting_dimensions()
242 .cast<mlir::DenseIntElementsAttr>();
243 auto lhs_contracting_dimensions =
244 dot_dimension_numbers_attr.lhs_contracting_dimensions()
245 .cast<mlir::DenseIntElementsAttr>();
246 auto rhs_batch_dimensions =
247 dot_dimension_numbers_attr.rhs_batching_dimensions()
248 .cast<mlir::DenseIntElementsAttr>();
249 auto lhs_batch_dimensions =
250 dot_dimension_numbers_attr.lhs_batching_dimensions()
251 .cast<mlir::DenseIntElementsAttr>();
252
253 for (const auto& val : rhs_contracting_dimensions) {
254 dot_dimension_numbers.add_rhs_contracting_dimensions(val.getSExtValue());
255 }
256 for (const auto& val : lhs_contracting_dimensions) {
257 dot_dimension_numbers.add_lhs_contracting_dimensions(val.getSExtValue());
258 }
259
260 for (const auto& val : rhs_batch_dimensions) {
261 dot_dimension_numbers.add_rhs_batch_dimensions(val.getSExtValue());
262 }
263
264 for (const auto& val : lhs_batch_dimensions) {
265 dot_dimension_numbers.add_lhs_batch_dimensions(val.getSExtValue());
266 }
267
268 return dot_dimension_numbers;
269 }
270
Convert_dimension_numbers(mlir::mhlo::ConvDimensionNumbers input)271 static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
272 mlir::mhlo::ConvDimensionNumbers input) {
273 return xla::ConvertConvDimensionNumbers(input);
274 }
275
Convert_channel_handle(mlir::mhlo::ChannelHandle attr)276 xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandle attr) {
277 xla::ChannelHandle channel_handle;
278 channel_handle.set_handle(ConvertAPInt(attr.handle().getValue()));
279 channel_handle.set_type(static_cast<xla::ChannelHandle::ChannelType>(
280 ConvertAPInt(attr.type().getValue())));
281 return channel_handle;
282 }
283
284 // Converts the comparison_direction string attribute into the XLA enum. The
285 // string is assumed to correspond to exactly one of the allowed strings
286 // representing the enum. This should have been checked in the op verify method.
Convert_comparison_direction(llvm::StringRef comparison_direction_string)287 static xla::ComparisonDirection Convert_comparison_direction(
288 llvm::StringRef comparison_direction_string) {
289 return xla::StringToComparisonDirection(comparison_direction_string.str())
290 .ValueOrDie();
291 }
292
Convert_dimension_numbers(mlir::mhlo::GatherDimensionNumbers input)293 static xla::GatherDimensionNumbers Convert_dimension_numbers(
294 mlir::mhlo::GatherDimensionNumbers input) {
295 xla::GatherDimensionNumbers output;
296
297 auto offset_dims = ConvertDenseIntAttr(input.offset_dims());
298 std::copy(offset_dims.begin(), offset_dims.end(),
299 tensorflow::protobuf::RepeatedFieldBackInserter(
300 output.mutable_offset_dims()));
301
302 auto collapsed_slice_dims = ConvertDenseIntAttr(input.collapsed_slice_dims());
303 std::copy(collapsed_slice_dims.begin(), collapsed_slice_dims.end(),
304 tensorflow::protobuf::RepeatedFieldBackInserter(
305 output.mutable_collapsed_slice_dims()));
306
307 auto start_index_map = ConvertDenseIntAttr(input.start_index_map());
308 std::copy(start_index_map.begin(), start_index_map.end(),
309 tensorflow::protobuf::RepeatedFieldBackInserter(
310 output.mutable_start_index_map()));
311
312 output.set_index_vector_dim(
313 ConvertAPInt(input.index_vector_dim().getValue()));
314 return output;
315 }
316
Convert_scatter_dimension_numbers(mlir::mhlo::ScatterDimensionNumbers input)317 static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
318 mlir::mhlo::ScatterDimensionNumbers input) {
319 xla::ScatterDimensionNumbers output;
320
321 auto update_window_dims = ConvertDenseIntAttr(input.update_window_dims());
322 std::copy(update_window_dims.begin(), update_window_dims.end(),
323 tensorflow::protobuf::RepeatedFieldBackInserter(
324 output.mutable_update_window_dims()));
325
326 auto inserted_window_dims = ConvertDenseIntAttr(input.inserted_window_dims());
327 std::copy(inserted_window_dims.begin(), inserted_window_dims.end(),
328 tensorflow::protobuf::RepeatedFieldBackInserter(
329 output.mutable_inserted_window_dims()));
330
331 auto scatter_dims_to_operand_dims =
332 ConvertDenseIntAttr(input.scatter_dims_to_operand_dims());
333 std::copy(scatter_dims_to_operand_dims.begin(),
334 scatter_dims_to_operand_dims.end(),
335 tensorflow::protobuf::RepeatedFieldBackInserter(
336 output.mutable_scatter_dims_to_operand_dims()));
337
338 output.set_index_vector_dim(
339 ConvertAPInt(input.index_vector_dim().getValue()));
340 return output;
341 }
342
343 // Extracts sharding from attribute string.
CreateOpShardingFromStringRef(llvm::StringRef sharding)344 static absl::optional<xla::OpSharding> CreateOpShardingFromStringRef(
345 llvm::StringRef sharding) {
346 xla::OpSharding sharding_proto;
347 if (!sharding_proto.ParseFromString(sharding.str())) return absl::nullopt;
348 return sharding_proto;
349 }
350
351 // Returns an OpSharding proto from the "sharding" attribute of the op. If the
352 // op doesn't have a sharding attribute or the sharding attribute is invalid,
353 // returns absl::nullopt.
CreateOpShardingFromAttribute(mlir::Operation * op)354 static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
355 mlir::Operation* op) {
356 auto sharding = op->getAttrOfType<mlir::StringAttr>(kShardingAttr);
357 if (!sharding) return absl::nullopt;
358 return CreateOpShardingFromStringRef(sharding.getValue());
359 }
360
361 // Returns a FrontendAttributes proto from the "frontend_attributes" attribute
362 // of the op. An empty FrontendAttributes proto is returned if an op does not
363 // have frontend attributes.
CreateOpFrontendAttributesFromAttribute(mlir::Operation * op)364 static xla::FrontendAttributes CreateOpFrontendAttributesFromAttribute(
365 mlir::Operation* op) {
366 xla::FrontendAttributes frontend_attributes;
367 auto frontend_attributes_dict =
368 op->getAttrOfType<mlir::DictionaryAttr>(kFrontendAttributesAttr);
369
370 if (!frontend_attributes_dict) return frontend_attributes;
371
372 for (const auto& attr : frontend_attributes_dict)
373 if (auto value_str_attr = attr.second.dyn_cast<mlir::StringAttr>())
374 frontend_attributes.mutable_map()->insert(
375 {attr.first.str(), value_str_attr.getValue().str()});
376
377 return frontend_attributes;
378 }
379
380 // Returns a OpMetadata proto based on the location of the op. If the location
381 // is unknown, an empty proto is returned. `op_name` are populated with the op
382 // location (converted). FileLineColLoc locations are populated by taking the
383 // file name and line number, and populating `source_file` and `source_line`
384 // respectively.
CreateOpMetadataFromLocation(mlir::Operation * op)385 static xla::OpMetadata CreateOpMetadataFromLocation(mlir::Operation* op) {
386 xla::OpMetadata metadata;
387 if (op->getLoc().isa<mlir::UnknownLoc>()) return metadata;
388
389 std::string name = mlir::GetNameFromLoc(op->getLoc());
390 mlir::LegalizeNodeName(name);
391 metadata.set_op_name(name);
392
393 if (auto file_line_col_loc = op->getLoc().dyn_cast<mlir::FileLineColLoc>()) {
394 metadata.set_source_file(file_line_col_loc.getFilename().str());
395 metadata.set_source_line(file_line_col_loc.getLine());
396 }
397
398 return metadata;
399 }
400
401 // Checks if all shardings are set.
AllOptionalShardingsAreSet(llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings)402 static bool AllOptionalShardingsAreSet(
403 llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings) {
404 return llvm::all_of(shardings,
405 [](const absl::optional<xla::OpSharding>& sharding) {
406 return sharding.has_value();
407 });
408 }
409
410 // Extracts argument and result shardings from function.
ExtractShardingsFromFunction(mlir::FuncOp function,llvm::SmallVectorImpl<absl::optional<xla::OpSharding>> * arg_shardings,llvm::SmallVectorImpl<absl::optional<xla::OpSharding>> * ret_shardings)411 static void ExtractShardingsFromFunction(
412 mlir::FuncOp function,
413 llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* arg_shardings,
414 llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* ret_shardings) {
415 arg_shardings->resize(function.getNumArguments(),
416 absl::optional<xla::OpSharding>());
417 for (int i = 0, end = function.getNumArguments(); i < end; ++i)
418 if (auto sharding =
419 function.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr))
420 (*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
421
422 ret_shardings->resize(function.getNumResults(),
423 absl::optional<xla::OpSharding>());
424 for (int i = 0, end = function.getNumResults(); i < end; ++i)
425 if (auto sharding =
426 function.getResultAttrOfType<mlir::StringAttr>(i, kShardingAttr))
427 (*ret_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
428 }
429
430 namespace mlir {
431 namespace {
432 class ConvertToHloModule {
433 public:
434 using ValueLoweringMap = llvm::DenseMap<Value, xla::XlaOp>;
435 using FunctionLoweringMap = llvm::DenseMap<mlir::FuncOp, xla::XlaComputation>;
436
437 // If use_tuple_args is true, then the entry function's arguments are
438 // converted to a tuple and passed as a single parameter.
439 // Similarly, if return tuple is true, then the entry function's return values
440 // are converted to a tuple even when there is only a single return value.
441 // Multiple return values are always converted to a tuple and returned as a
442 // single value.
ConvertToHloModule(mlir::ModuleOp module,xla::XlaBuilder & module_builder,bool use_tuple_args,bool return_tuple,tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,MlirToHloConversionOptions options)443 explicit ConvertToHloModule(
444 mlir::ModuleOp module, xla::XlaBuilder& module_builder,
445 bool use_tuple_args, bool return_tuple,
446 tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
447 MlirToHloConversionOptions options)
448 : module_(module),
449 module_builder_(module_builder),
450 use_tuple_args_(use_tuple_args),
451 return_tuple_(return_tuple),
452 shape_representation_fn_(shape_representation_fn),
453 options_(options) {
454 if (!shape_representation_fn_)
455 shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn();
456 }
457
458 // Perform the lowering to XLA. This function returns failure if an error was
459 // encountered.
460 //
461 // TODO(hinsu): Check for dynamic shapes and exit instead of crashing.
Run()462 LogicalResult Run() {
463 auto main = module_.lookupSymbol<mlir::FuncOp>("main");
464 if (!main)
465 return module_.emitError(
466 "conversion requires module with `main` function");
467
468 for (auto func : module_.getOps<FuncOp>()) {
469 if (func.empty()) continue;
470 if (failed(RunOnFunction(func))) return failure();
471 }
472 return success();
473 }
474
475 // Lower a specific function to HLO.
476 LogicalResult RunOnFunction(mlir::FuncOp f);
477
478 // Lower a `mlir::Region` to a `XlaComputation`
479 LogicalResult LowerRegionAsComputation(mlir::Region* region,
480 xla::XlaComputation* func);
481
482 // Lower a single `Block` to a `XlaComputation`
483 LogicalResult LowerBasicBlockAsFunction(
484 Block* block, xla::XlaBuilder* builder, bool is_entry_function,
485 const std::vector<bool>& entry_args_same_across_replicas,
486 llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
487 llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
488 xla::XlaComputation* result);
489
ConsumeMainProto()490 ::xla::HloModuleProto ConsumeMainProto() {
491 auto main = module_.lookupSymbol<mlir::FuncOp>("main");
492 // This is an invariant check as Run returns failure if there is no main
493 // function and so the main proto shouldn't be consumed in that case.
494 CHECK(main) << "requires module to have main function"; // Crash Ok.
495 return lowered_computation_[main].proto();
496 }
497
498 // Lower function call to HLO call instruction
499 LogicalResult LowerFunctionCall(
500 mlir::CallOp call_op, xla::XlaBuilder* builder,
501 ConvertToHloModule::ValueLoweringMap* value_lowering);
502
503 LogicalResult Lower(
504 mlir::Operation* inst, bool is_entry_function,
505 llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
506 xla::XlaBuilder* builder,
507 ConvertToHloModule::ValueLoweringMap* value_lowering,
508 xla::XlaOp* return_value);
509
510 private:
511 LogicalResult SetEntryTupleShapesAndLeafReplication(
512 Block* block, const std::vector<bool>& entry_args_same_across_replicas,
513 llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
514 std::vector<bool>* leaf_replication);
515
516 LogicalResult SetEntryTupleShardings(
517 Block* block, xla::XlaBuilder* builder,
518 llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
519 llvm::SmallVectorImpl<xla::Shape>* arg_shapes);
520
521 // The module being lowered.
522 mlir::ModuleOp module_;
523
524 // The top-level XlaBuilder.
525 xla::XlaBuilder& module_builder_;
526
527 // Map between function and lowered computation.
528 FunctionLoweringMap lowered_computation_;
529
530 // Whether the entry function should take a single tuple as input.
531 bool use_tuple_args_;
532
533 // Whether to always return a tuple.
534 bool return_tuple_;
535
536 // Shape representation function to determine entry function argument and
537 // result shapes.
538 tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_;
539
540 // Unique suffix to give to the name of the next lowered region.
541 size_t region_id_ = 0;
542
543 MlirToHloConversionOptions options_;
544 };
545
546 } // namespace
547 } // namespace mlir
548
549 namespace {
550
551 struct OpLoweringContext {
552 llvm::DenseMap<mlir::Value, xla::XlaOp>* values;
553 mlir::ConvertToHloModule* converter;
554 xla::XlaBuilder* builder;
555 };
556
GetTuple(mlir::Operation::operand_range values,OpLoweringContext ctx)557 llvm::SmallVector<xla::XlaOp, 4> GetTuple(mlir::Operation::operand_range values,
558 OpLoweringContext ctx) {
559 llvm::SmallVector<xla::XlaOp, 4> ops;
560 for (mlir::Value value : values) {
561 ops.push_back((*ctx.values)[value]);
562 }
563 return ops;
564 }
565
566 } // namespace
567
568 namespace mlir {
569 namespace mhlo {
570 namespace {
571
ExportXlaOp(AllReduceOp op,OpLoweringContext ctx)572 LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) {
573 auto& value_map = *ctx.values;
574 xla::XlaComputation computation;
575 if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
576 &computation))) {
577 return failure();
578 }
579 auto replica_groups = Convert_replica_groups(op.replica_groups());
580 xla::XlaOp operand;
581 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
582
583 if (!op.channel_id().hasValue()) {
584 value_map[op] = xla::AllReduce(operand, computation, replica_groups,
585 /*channel_id=*/absl::nullopt);
586 return success();
587 }
588 auto channel_id = Convert_channel_handle(op.channel_id().getValue());
589 value_map[op] =
590 xla::AllReduce(operand, computation, replica_groups, channel_id);
591 return success();
592 }
593
ExportXlaOp(BitcastConvertOp op,OpLoweringContext ctx)594 LogicalResult ExportXlaOp(BitcastConvertOp op, OpLoweringContext ctx) {
595 auto& value_map = *ctx.values;
596 xla::XlaOp operand;
597 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
598
599 value_map[op] = xla::BitcastConvertType(
600 operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())));
601 return success();
602 }
603
ExportXlaOp(BroadcastInDimOp op,OpLoweringContext ctx)604 LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) {
605 auto type = op.getType().dyn_cast<RankedTensorType>();
606 if (!type) return failure();
607 auto& value_map = *ctx.values;
608 xla::XlaOp operand;
609 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
610
611 value_map[op] =
612 BroadcastInDim(operand, Convert_ArrayRef(type.getShape()),
613 Convert_broadcast_dimensions(op.broadcast_dimensions()));
614 return success();
615 }
616
ExportXlaOp(DynamicBroadcastInDimOp op,OpLoweringContext ctx)617 LogicalResult ExportXlaOp(DynamicBroadcastInDimOp op, OpLoweringContext ctx) {
618 // This op has no expression in the legacy export format.
619 return failure();
620 }
621
ExportXlaOp(DynamicIotaOp op,OpLoweringContext ctx)622 LogicalResult ExportXlaOp(DynamicIotaOp op, OpLoweringContext ctx) {
623 // This op has no expression in the legacy export format.
624 return failure();
625 }
626
ExportXlaOp(DynamicReshapeOp op,OpLoweringContext ctx)627 LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) {
628 // This op has no expression in the legacy export format.
629 return failure();
630 }
631
ExportXlaOp(IfOp op,OpLoweringContext ctx)632 LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) {
633 xla::XlaComputation true_branch;
634 xla::XlaComputation false_branch;
635 auto& value_map = *ctx.values;
636 if (failed(ctx.converter->LowerRegionAsComputation(&op.true_branch(),
637 &true_branch)) ||
638 failed(ctx.converter->LowerRegionAsComputation(&op.false_branch(),
639 &false_branch))) {
640 return failure();
641 }
642 xla::XlaOp pred, true_arg, false_arg;
643 if (failed(GetXlaOp(op.pred(), value_map, &pred, op))) return failure();
644 if (failed(GetXlaOp(op.true_arg(), value_map, &true_arg, op)))
645 return failure();
646 if (failed(GetXlaOp(op.false_arg(), value_map, &false_arg, op)))
647 return failure();
648
649 value_map[op] =
650 xla::Conditional(pred, true_arg, true_branch, false_arg, false_branch);
651 return success();
652 }
653
ExportXlaOp(CaseOp op,OpLoweringContext ctx)654 LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) {
655 llvm::DenseMap<mlir::Value, xla::XlaOp>& value_map = *ctx.values;
656 OperandRange operands = op.branch_operands();
657 MutableArrayRef<Region> branches = op.branches();
658 llvm::SmallVector<xla::XlaOp, 4> branch_operands(branches.size());
659 std::vector<xla::XlaComputation> computations(branches.size());
660 std::vector<xla::XlaComputation*> computations_p(branches.size());
661
662 for (unsigned i = 0; i < branches.size(); ++i) {
663 xla::XlaOp operand;
664 if (failed(GetXlaOp(operands[i], value_map, &operand, op)))
665 return failure();
666 branch_operands[i] = operand;
667 computations_p[i] = &computations[i];
668 if (failed(ctx.converter->LowerRegionAsComputation(&branches[i],
669 computations_p[i])))
670 return failure();
671 }
672 xla::XlaOp index;
673 if (failed(GetXlaOp(op.index(), value_map, &index, op))) return failure();
674
675 xla::XlaOp result = xla::Conditional(index, computations_p, branch_operands);
676 if (op.getNumResults() == 1) {
677 value_map[op.getResult(0)] = result;
678 } else {
679 for (auto item : llvm::enumerate(op.getResults())) {
680 value_map[item.value()] = xla::GetTupleElement(result, item.index());
681 }
682 }
683 return success();
684 }
685
686 // Specialize CompareOp export to set broadcast_dimensions argument.
ExportXlaOp(mlir::mhlo::CompareOp op,OpLoweringContext ctx)687 mlir::LogicalResult ExportXlaOp(mlir::mhlo::CompareOp op,
688 OpLoweringContext ctx) {
689 auto& value_map = *ctx.values;
690 xla::XlaOp lhs, rhs;
691 if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
692 if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
693 auto dir = Convert_comparison_direction(op.comparison_direction());
694 auto type_attr = op.compare_typeAttr();
695
696 xla::XlaOp xla_result;
697 if (type_attr) {
698 auto type =
699 xla::StringToComparisonType(type_attr.getValue().str()).ValueOrDie();
700 xla_result = xla::Compare(lhs, rhs, /*broadcast_dimensions=*/{}, dir, type);
701 } else {
702 xla_result = xla::Compare(lhs, rhs, dir);
703 }
704 value_map[op] = xla_result;
705 return mlir::success();
706 }
707
ExportXlaOp(ConstOp op,OpLoweringContext ctx)708 LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) {
709 return failure();
710 }
711
ExportXlaOp(mlir::mhlo::ConvOp op,OpLoweringContext ctx)712 LogicalResult ExportXlaOp(mlir::mhlo::ConvOp op, OpLoweringContext ctx) {
713 // XLA client builder API does not support generating convolution instructions
714 // with window reversal.
715 if (op.hasWindowReversal()) return failure();
716 auto& value_map = *ctx.values;
717 xla::XlaOp lhs, rhs;
718 if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
719 if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
720 xla::XlaOp xla_result = xla::ConvGeneralDilated(
721 lhs, rhs, Convert_window_strides(op.window_strides()),
722 Convert_padding(op.padding()), Convert_lhs_dilation(op.lhs_dilation()),
723 Convert_rhs_dilation(op.rhs_dilation()),
724 Convert_dimension_numbers(op.dimension_numbers()),
725 Convertuint64_t(op.feature_group_count()),
726 Convertuint64_t(op.batch_group_count()),
727 Unwrap(Convert_precision_config(op.precision_config())));
728 value_map[op] = xla_result;
729 return mlir::success();
730 }
731
ExportXlaOp(ConvertOp op,OpLoweringContext ctx)732 LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) {
733 auto& value_map = *ctx.values;
734 xla::XlaOp operand;
735 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
736
737 value_map[op] = xla::ConvertElementType(
738 operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())));
739 return success();
740 }
741
ExportXlaOp(CustomCallOp op,OpLoweringContext ctx)742 LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) {
743 // XLA client builder API does not support generating custom call instructions
744 // with side effect.
745 if (op.has_side_effect() || op.getNumResults() != 1) return failure();
746 Value result = op.getResult(0);
747 auto& value_map = *ctx.values;
748 value_map[result] = xla::CustomCall(
749 ctx.builder, std::string(op.call_target_name()), GetTuple(op.args(), ctx),
750 xla::TypeToShape(result.getType()), std::string(op.backend_config()));
751 return success();
752 }
753
ExportXlaOp(DequantizeOp op,OpLoweringContext ctx)754 LogicalResult ExportXlaOp(DequantizeOp op, OpLoweringContext ctx) {
755 xla::QuantizedRange range(ConvertAPFloat(op.min_range()),
756 ConvertAPFloat(op.max_range()));
757 auto& value_map = *ctx.values;
758 xla::XlaOp input;
759 if (failed(GetXlaOp(op.input(), value_map, &input, op))) return failure();
760
761 auto casted = xla::ConvertElementType(input, xla::U32);
762 if (op.is_16bits()) {
763 value_map[op] = xla::Dequantize<uint16>(
764 casted, range, ConvertStringRef(op.mode()), op.transpose_output());
765 } else {
766 value_map[op] = xla::Dequantize<uint8>(
767 casted, range, ConvertStringRef(op.mode()), op.transpose_output());
768 }
769 return success();
770 }
771
ExportXlaOp(InfeedOp op,OpLoweringContext ctx)772 LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) {
773 auto& value_map = *ctx.values;
774 xla::XlaOp token;
775 if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
776
777 // The shape argument expected by the xla client API is the type of the first
778 // element in the result tuple.
779 auto result_type = op.getType().cast<mlir::TupleType>().getType(0);
780 value_map[op] = xla::InfeedWithToken(token, xla::TypeToShape(result_type),
781 std::string(op.infeed_config()));
782 return success();
783 }
784
ExportXlaOp(IotaOp op,OpLoweringContext ctx)785 LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
786 auto& value_map = *ctx.values;
787 value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()),
788 op.iota_dimension());
789 return success();
790 }
791
ExportXlaOp(MapOp op,OpLoweringContext ctx)792 LogicalResult ExportXlaOp(MapOp op, OpLoweringContext ctx) {
793 auto& value_map = *ctx.values;
794 xla::XlaComputation computation;
795 if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
796 &computation))) {
797 return failure();
798 }
799 value_map[op] = xla::Map(ctx.builder, GetTuple(op.operands(), ctx),
800 computation, Convert_dimensions(op.dimensions()));
801 return success();
802 }
803
ExportXlaOp(OutfeedOp op,OpLoweringContext ctx)804 LogicalResult ExportXlaOp(OutfeedOp op, OpLoweringContext ctx) {
805 auto& value_map = *ctx.values;
806 xla::XlaOp operand, token;
807 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
808 if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
809
810 value_map[op] = xla::OutfeedWithToken(
811 operand, token, xla::TypeToShape(op.operand().getType()),
812 std::string(op.outfeed_config()));
813 return success();
814 }
815
ExportXlaOp(PadOp op,OpLoweringContext ctx)816 LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) {
817 auto& value_map = *ctx.values;
818 xla::PaddingConfig padding_config;
819 auto edge_padding_low = ConvertDenseIntAttr(op.edge_padding_low());
820 auto edge_padding_high = ConvertDenseIntAttr(op.edge_padding_high());
821 auto interior_padding = ConvertDenseIntAttr(op.interior_padding());
822 for (xla::int64 i = 0, end = edge_padding_low.size(); i < end; ++i) {
823 auto* dims = padding_config.add_dimensions();
824 dims->set_edge_padding_low(edge_padding_low[i]);
825 dims->set_edge_padding_high(edge_padding_high[i]);
826 dims->set_interior_padding(interior_padding[i]);
827 }
828 xla::XlaOp operand, padding_value;
829 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
830 if (failed(GetXlaOp(op.padding_value(), value_map, &padding_value, op)))
831 return failure();
832
833 value_map[op] = xla::Pad(operand, padding_value, padding_config);
834 return success();
835 }
836
ExportXlaOp(RecvOp op,OpLoweringContext ctx)837 LogicalResult ExportXlaOp(RecvOp op, OpLoweringContext ctx) {
838 auto& value_map = *ctx.values;
839 auto result_type = op.getType().cast<mlir::TupleType>().getType(0);
840 xla::XlaOp token;
841 if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
842
843 if (op.is_host_transfer()) {
844 value_map[op] = xla::RecvFromHost(token, xla::TypeToShape(result_type),
845 Convert_channel_handle(op.channel_id()));
846 return success();
847 }
848 value_map[op] = xla::RecvWithToken(token, xla::TypeToShape(result_type),
849 Convert_channel_handle(op.channel_id()));
850 return success();
851 }
852
ExportXlaOp(ReduceOp op,OpLoweringContext ctx)853 LogicalResult ExportXlaOp(ReduceOp op, OpLoweringContext ctx) {
854 auto& value_map = *ctx.values;
855 xla::XlaComputation body;
856 if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) {
857 return failure();
858 }
859 xla::XlaOp result =
860 xla::Reduce(ctx.builder, GetTuple(op.operands(), ctx),
861 GetTuple(op.init_values(), ctx), body,
862 Convert_broadcast_dimensions(op.dimensions()));
863 if (op.getNumResults() == 1) {
864 value_map[op.getResult(0)] = result;
865 } else {
866 for (auto item : llvm::enumerate(op.getResults())) {
867 value_map[item.value()] = xla::GetTupleElement(result, item.index());
868 }
869 }
870 return success();
871 }
872
ExportXlaOp(ReduceWindowOp op,OpLoweringContext ctx)873 LogicalResult ExportXlaOp(ReduceWindowOp op, OpLoweringContext ctx) {
874 auto& value_map = *ctx.values;
875 xla::XlaComputation body;
876 if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) {
877 return failure();
878 }
879 xla::XlaOp operand, init_value;
880 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
881 if (failed(GetXlaOp(op.init_value(), value_map, &init_value, op)))
882 return failure();
883
884 value_map[op] = xla::ReduceWindowWithGeneralPadding(
885 operand, init_value, body, ConvertDenseIntAttr(op.window_dimensions()),
886 ConvertDenseIntAttr(op.window_strides()),
887 ConvertDenseIntAttr(op.base_dilations()),
888 ConvertDenseIntAttr(op.window_dilations()),
889 Convert_padding(op.padding()));
890 return success();
891 }
892
ExportXlaOp(ReshapeOp op,OpLoweringContext ctx)893 LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) {
894 auto& value_map = *ctx.values;
895 xla::XlaOp operand;
896 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
897
898 value_map[op] =
899 xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions());
900 return success();
901 }
902
ExportXlaOp(ReturnOp op,OpLoweringContext ctx)903 LogicalResult ExportXlaOp(ReturnOp op, OpLoweringContext ctx) {
904 // Failure on purpose because `mhlo::ReturnOp` will be handled by
905 // special purpose logic in `ConvertToHloModule::Lower`.
906 return failure();
907 }
908
ExportXlaOp(RngBitGeneratorOp op,OpLoweringContext ctx)909 LogicalResult ExportXlaOp(RngBitGeneratorOp op, OpLoweringContext ctx) {
910 auto& value_map = *ctx.values;
911 auto result = op.getResult();
912 auto xla_arg_1 = value_map[*op.getODSOperands(0).begin()];
913 auto xla_result = xla::RngBitGenerator(
914 static_cast<xla::RandomAlgorithm>(op.rng_algorithm()), Unwrap(xla_arg_1),
915 xla::TypeToShape(result.getType()).tuple_shapes(1));
916 value_map[result] = xla_result;
917 return mlir::success();
918 }
919
ExportXlaOp(RngNormalOp op,OpLoweringContext ctx)920 LogicalResult ExportXlaOp(RngNormalOp op, OpLoweringContext ctx) {
921 auto& value_map = *ctx.values;
922 xla::XlaOp mu, sigma;
923 if (failed(GetXlaOp(op.mu(), value_map, &mu, op))) return failure();
924 if (failed(GetXlaOp(op.sigma(), value_map, &sigma, op))) return failure();
925
926 value_map[op] = xla::RngNormal(mu, sigma, xla::TypeToShape(op.getType()));
927 return success();
928 }
929
ExportXlaOp(RngUniformOp op,OpLoweringContext ctx)930 LogicalResult ExportXlaOp(RngUniformOp op, OpLoweringContext ctx) {
931 auto& value_map = *ctx.values;
932 xla::XlaOp a, b;
933 if (failed(GetXlaOp(op.a(), value_map, &a, op))) return failure();
934 if (failed(GetXlaOp(op.b(), value_map, &b, op))) return failure();
935
936 value_map[op] = xla::RngUniform(a, b, xla::TypeToShape(op.getType()));
937 return success();
938 }
939
ExportXlaOp(ScatterOp op,OpLoweringContext ctx)940 LogicalResult ExportXlaOp(ScatterOp op, OpLoweringContext ctx) {
941 auto& value_map = *ctx.values;
942 xla::XlaComputation update_computation;
943 if (failed(ctx.converter->LowerRegionAsComputation(&op.update_computation(),
944 &update_computation))) {
945 return failure();
946 }
947 xla::ScatterDimensionNumbers dimension_numbers =
948 Convert_scatter_dimension_numbers(op.scatter_dimension_numbers());
949 xla::XlaOp operand, scatter_indices, updates;
950 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
951 if (failed(GetXlaOp(op.scatter_indices(), value_map, &scatter_indices, op)))
952 return failure();
953 if (failed(GetXlaOp(op.updates(), value_map, &updates, op))) return failure();
954
955 value_map[op] = xla::Scatter(operand, scatter_indices, updates,
956 update_computation, dimension_numbers,
957 op.indices_are_sorted(), op.unique_indices());
958 return success();
959 }
960
ExportXlaOp(SelectAndScatterOp op,OpLoweringContext ctx)961 LogicalResult ExportXlaOp(SelectAndScatterOp op, OpLoweringContext ctx) {
962 auto& value_map = *ctx.values;
963 xla::XlaComputation select;
964 xla::XlaComputation scatter;
965 if (failed(ctx.converter->LowerRegionAsComputation(&op.select(), &select)) ||
966 failed(
967 ctx.converter->LowerRegionAsComputation(&op.scatter(), &scatter))) {
968 return failure();
969 }
970 xla::XlaOp operand, source, init_value;
971 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
972 if (failed(GetXlaOp(op.source(), value_map, &source, op))) return failure();
973 if (failed(GetXlaOp(op.init_value(), value_map, &init_value, op)))
974 return failure();
975
976 value_map[op] = xla::SelectAndScatterWithGeneralPadding(
977 operand, select, ConvertDenseIntAttr(op.window_dimensions()),
978 ConvertDenseIntAttr(op.window_strides()), Convert_padding(op.padding()),
979 source, init_value, scatter);
980 return success();
981 }
982
ExportXlaOp(SendOp op,OpLoweringContext ctx)983 LogicalResult ExportXlaOp(SendOp op, OpLoweringContext ctx) {
984 auto& value_map = *ctx.values;
985 xla::XlaOp operand, token;
986 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
987 if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
988
989 if (op.is_host_transfer()) {
990 value_map[op] = xla::SendToHost(operand, token,
991 xla::TypeToShape(op.operand().getType()),
992 Convert_channel_handle(op.channel_id()));
993 return success();
994 }
995 value_map[op] = xla::SendWithToken(operand, token,
996 Convert_channel_handle(op.channel_id()));
997 return success();
998 }
999
ExportXlaOp(SliceOp op,OpLoweringContext ctx)1000 LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) {
1001 return failure();
1002 }
1003
ExportXlaOp(SortOp op,OpLoweringContext ctx)1004 LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
1005 xla::XlaComputation comparator;
1006 if (failed(ctx.converter->LowerRegionAsComputation(&op.comparator(),
1007 &comparator)))
1008 return failure();
1009
1010 auto sorted = xla::Sort(GetTuple(op.operands(), ctx), comparator,
1011 op.dimension(), op.is_stable());
1012
1013 auto& value_map = *ctx.values;
1014 auto shape_or = sorted.builder()->GetShape(sorted);
1015 if (!shape_or.ok()) {
1016 return op.emitError(shape_or.status().ToString());
1017 }
1018
1019 xla::Shape& shape = shape_or.ValueOrDie();
1020 if (!shape.IsTuple()) {
1021 value_map[op.getResult(0)] = sorted;
1022 return success();
1023 }
1024
1025 // MLIR's sort supports multiple returns, untuple all the results of XLA's.
1026 for (auto it : llvm::enumerate(op.getResults())) {
1027 value_map[it.value()] = xla::GetTupleElement(sorted, it.index());
1028 }
1029 return success();
1030 }
1031
ExportXlaOp(TraceOp op,OpLoweringContext ctx)1032 LogicalResult ExportXlaOp(TraceOp op, OpLoweringContext ctx) {
1033 auto& value_map = *ctx.values;
1034 xla::XlaOp operand;
1035 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1036 xla::Trace(std::string(op.tag()), operand);
1037 return success();
1038 }
1039
ExportXlaOp(UnaryEinsumOp op,OpLoweringContext ctx)1040 LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) {
1041 // Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two
1042 // operands.
1043 return failure();
1044 }
1045
ExportXlaOp(WhileOp op,OpLoweringContext ctx)1046 LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
1047 xla::XlaComputation condition;
1048 xla::XlaComputation body;
1049 auto& value_map = *ctx.values;
1050 if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body)) ||
1051 failed(ctx.converter->LowerRegionAsComputation(&op.cond(), &condition))) {
1052 return failure();
1053 }
1054
1055 xla::XlaOp operand;
1056 if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op)))
1057 return failure();
1058 value_map[op] = xla::While(condition, body, operand);
1059 return success();
1060 }
1061
ExportXlaOp(FusionOp op,OpLoweringContext ctx)1062 LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
1063 if (!op.fusion_kind()) {
1064 op.emitOpError() << "requires fusion kind for HLO translation";
1065 return failure();
1066 }
1067
1068 xla::XlaComputation fused_computation;
1069 if (failed(ctx.converter->LowerRegionAsComputation(&op.fused_computation(),
1070 &fused_computation)))
1071 return failure();
1072
1073 auto& values = *ctx.values;
1074 llvm::SmallVector<xla::XlaOp, 4> operands;
1075 for (auto operand : op.operands()) operands.push_back(values[operand]);
1076
1077 xla::XlaOp fusion = xla::internal::XlaBuilderFriend::BuildFusion(
1078 ctx.builder, operands,
1079 absl::string_view(op.fusion_kind()->data(), op.fusion_kind()->size()),
1080 fused_computation);
1081 if (op.getNumResults() == 1) {
1082 values[op.getResult(0)] = fusion;
1083 } else {
1084 for (auto item : llvm::enumerate(op.getResults())) {
1085 values[item.value()] = xla::GetTupleElement(fusion, item.index());
1086 }
1087 }
1088 return success();
1089 }
1090
ExportXlaOp(BitcastOp op,OpLoweringContext ctx)1091 LogicalResult ExportXlaOp(BitcastOp op, OpLoweringContext ctx) {
1092 auto& value_map = *ctx.values;
1093 xla::XlaOp operand;
1094 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1095 value_map[op] = xla::internal::XlaBuilderFriend::BuildBitcast(
1096 ctx.builder, operand, xla::TypeToShape(op.getType()));
1097 return success();
1098 }
1099
1100 } // namespace
1101 } // namespace mhlo
1102 } // namespace mlir
1103
1104 #include "tensorflow/compiler/mlir/xla/operator_writers.inc"
1105
1106 namespace mlir {
1107 namespace {
1108
CreateArrayLiteralFromAttr(ElementsAttr attr,xla::Layout layout)1109 StatusOr<xla::Literal> CreateArrayLiteralFromAttr(ElementsAttr attr,
1110 xla::Layout layout) {
1111 if (attr.isa<OpaqueElementsAttr>())
1112 return tensorflow::errors::Unimplemented(
1113 "Opaque elements attr not supported");
1114
1115 xla::Shape shape = xla::TypeToShape(attr.getType());
1116
1117 #define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \
1118 case xla_type: { \
1119 xla::Array<cpp_type> source_data(shape.dimensions()); \
1120 source_data.SetValues(attr.getValues<cpp_type>()); \
1121 return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); \
1122 }
1123
1124 switch (shape.element_type()) {
1125 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::PRED, bool)
1126 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F32, float)
1127 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F64, double)
1128 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S8, int8)
1129 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S16, int16)
1130 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S32, int32)
1131 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S64, int64)
1132 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U8, uint8)
1133 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16)
1134 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32)
1135 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U64, uint64)
1136 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C64, std::complex<float>)
1137 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C128, std::complex<double>)
1138 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F16, Eigen::half)
1139 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::BF16, Eigen::bfloat16)
1140 default:
1141 return tensorflow::errors::Internal(absl::StrCat(
1142 "Unsupported type: ", xla::PrimitiveType_Name(shape.element_type())));
1143 }
1144 #undef ELEMENTS_ATTR_TO_LITERAL
1145 }
1146
ExtractLayout(mlir::Operation * op,int rank)1147 xla::Layout ExtractLayout(mlir::Operation* op, int rank) {
1148 if (auto attr = GetLayoutFromMlirHlo(op)) {
1149 llvm::SmallVector<int64, 4> minor_to_major;
1150 DCHECK_EQ(rank, attr.size());
1151 minor_to_major.reserve(attr.size());
1152 for (const llvm::APInt& i : attr) {
1153 minor_to_major.push_back(i.getZExtValue());
1154 }
1155 return xla::LayoutUtil::MakeLayout(minor_to_major);
1156 }
1157 return xla::LayoutUtil::MakeDescendingLayout(rank);
1158 }
1159
ConvertLayout(mlir::Operation * op,const mlir::ArrayAttr & layout,xla::ShapeProto * shape)1160 LogicalResult ConvertLayout(mlir::Operation* op, const mlir::ArrayAttr& layout,
1161 xla::ShapeProto* shape) {
1162 // In the case of tuples, ShapeProtos can be nested, and so can the mlir
1163 // attribute describing the layout. So recurse into the subshapes in both data
1164 // structures in parallel.
1165 if (shape->element_type() == xla::TUPLE) {
1166 auto subshapes = shape->mutable_tuple_shapes();
1167 if (layout.size() != subshapes->size()) {
1168 op->emitOpError() << "Expected layout of size " << layout.size()
1169 << ", but found " << subshapes->size();
1170 return failure();
1171 }
1172 for (int i = 0; i < subshapes->size(); i++) {
1173 mlir::Attribute child = layout[i];
1174 if (child.isa<mlir::UnitAttr>()) {
1175 // ignore unit attributes, they are used only for tokens.
1176 continue;
1177 }
1178 mlir::ArrayAttr c = child.dyn_cast<mlir::ArrayAttr>();
1179 if (!c) {
1180 op->emitOpError() << "Type Error: Expected layout array attribute";
1181 return failure();
1182 }
1183 if (failed(ConvertLayout(op, c, subshapes->Mutable(i)))) {
1184 return failure();
1185 }
1186 }
1187 } else {
1188 int rank = shape->dimensions().size();
1189 if (rank) {
1190 if (layout.size() != rank) {
1191 return failure(); // pass error down
1192 }
1193 std::vector<int64> array(rank);
1194 for (int i = 0; i < rank; i++) {
1195 mlir::IntegerAttr attr = layout[i].dyn_cast<mlir::IntegerAttr>();
1196 if (!attr) {
1197 op->emitOpError() << "Type Error: Expected layout integer attribute";
1198 return failure();
1199 }
1200 array[i] = attr.getInt();
1201 }
1202 *shape->mutable_layout() = xla::LayoutUtil::MakeLayout(array).ToProto();
1203 }
1204 }
1205 return success();
1206 }
1207
Lower(mlir::Operation * inst,bool is_entry_function,llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,xla::XlaBuilder * builder,ConvertToHloModule::ValueLoweringMap * value_lowering,xla::XlaOp * return_value)1208 LogicalResult ConvertToHloModule::Lower(
1209 mlir::Operation* inst, bool is_entry_function,
1210 llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
1211 xla::XlaBuilder* builder,
1212 ConvertToHloModule::ValueLoweringMap* value_lowering,
1213 xla::XlaOp* return_value) {
1214 *return_value = xla::XlaOp();
1215
1216 // See MlirToHloConversionOptions for more about layouts.
1217 auto propagate_layouts = [this](mlir::Operation* inst,
1218 xla::XlaOp xla_op) -> mlir::LogicalResult {
1219 if (options_.propagate_layouts) {
1220 auto* shape = xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
1221 ->mutable_shape();
1222 if (shape->tuple_shapes().empty())
1223 // TODO(kramm): merge this with ConvertLayout.
1224 *shape->mutable_layout() =
1225 ExtractLayout(inst, shape->dimensions().size()).ToProto();
1226 }
1227
1228 // For infeed ops stemming back to InfeedDequeueTuple, respect the layout
1229 // attribute, and create the corresponding layout in hlo.
1230 if (isa<mhlo::InfeedOp>(inst)) {
1231 mlir::ArrayAttr layout =
1232 inst->getAttrOfType<mlir::ArrayAttr>(kLayoutAttr);
1233 if (layout) {
1234 xla::ShapeProto* shape =
1235 xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
1236 ->mutable_shape();
1237
1238 if (failed(ConvertLayout(inst, layout, shape))) {
1239 return failure();
1240 }
1241 }
1242 }
1243 return success();
1244 };
1245
1246 if (succeeded(ExportXlaOperator(inst, {value_lowering, this, builder}))) {
1247 if (inst->getNumResults() == 1) {
1248 auto iter = value_lowering->find(inst->getResult(0));
1249 if (iter == value_lowering->end()) {
1250 inst->emitOpError(
1251 "inst has a result, but it's not found in value_lowering");
1252 return failure();
1253 }
1254 if (failed(propagate_layouts(inst, iter->second))) {
1255 return failure();
1256 }
1257 }
1258 return success();
1259 }
1260
1261 auto& value_map = *value_lowering;
1262 ElementsAttr const_attr;
1263
1264 if (auto call_op = dyn_cast<mlir::CallOp>(inst)) {
1265 return LowerFunctionCall(call_op, builder, &value_map);
1266 }
1267
1268 if (auto op = dyn_cast<mlir::tensor::CastOp>(inst)) {
1269 Value operand = op.getOperand();
1270 auto ty = operand.getType().dyn_cast<ShapedType>();
1271 // If this was a cast from a static shaped tensors, then it is a noop for
1272 // export to HLO and we can use the operand.
1273 if (!ty || !ty.hasStaticShape()) {
1274 inst->emitOpError()
1275 << "requires static shaped operand for HLO translation";
1276 return failure();
1277 }
1278
1279 xla::XlaOp xla_operand;
1280 if (failed(GetXlaOp(operand, value_map, &xla_operand, op)))
1281 return failure();
1282 value_map[op.getResult()] = xla_operand;
1283 if (failed(propagate_layouts(inst, xla_operand))) {
1284 return failure();
1285 }
1286 return success();
1287 }
1288
1289 if (matchPattern(inst, m_Constant(&const_attr))) {
1290 xla::Layout layout;
1291 layout = ExtractLayout(inst, const_attr.getType().getRank());
1292 auto literal_or = CreateArrayLiteralFromAttr(const_attr, layout);
1293 if (!literal_or.ok())
1294 return inst->emitError(literal_or.status().ToString());
1295 auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie());
1296 value_map[inst->getResult(0)] = constant;
1297
1298 return success();
1299 }
1300
1301 if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
1302 // Construct the return value for the function. If there is a single value
1303 // returned, then return it directly, else create a tuple and return.
1304 unsigned num_return_values = inst->getNumOperands();
1305 if ((return_tuple_ && is_entry_function) || num_return_values != 1) {
1306 const bool has_ret_shardings =
1307 !ret_shardings.empty() && AllOptionalShardingsAreSet(ret_shardings);
1308
1309 std::vector<xla::XlaOp> returns(num_return_values);
1310 for (OpOperand& ret : inst->getOpOperands()) {
1311 unsigned index = ret.getOperandNumber();
1312 xla::XlaOp operand;
1313 if (failed(GetXlaOp(ret.get(), value_map, &operand, inst)))
1314 return failure();
1315
1316 returns[index] = operand;
1317 if (!is_entry_function || !has_ret_shardings) continue;
1318
1319 xla::Shape return_shape = xla::TypeToShape(ret.get().getType());
1320 StatusOr<xla::XlaOp> reshape =
1321 tensorflow::ReshapeWithCorrectRepresentationAndSharding(
1322 builder, returns[index], return_shape, shape_representation_fn_,
1323 ret_shardings[index], /*fast_mem=*/false);
1324 if (!reshape.ok())
1325 return inst->emitError() << reshape.status().error_message();
1326
1327 returns[index] = reshape.ValueOrDie();
1328 }
1329
1330 if (has_ret_shardings) {
1331 xla::OpSharding sharding;
1332 sharding.set_type(xla::OpSharding::TUPLE);
1333 for (auto& ret_sharding : ret_shardings)
1334 *sharding.add_tuple_shardings() = *ret_sharding;
1335
1336 builder->SetSharding(sharding);
1337 }
1338
1339 *return_value = xla::Tuple(builder, returns);
1340 builder->ClearSharding();
1341 } else if (num_return_values == 1) {
1342 xla::XlaOp operand;
1343 if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst)))
1344 return failure();
1345
1346 *return_value = operand;
1347 }
1348
1349 return success();
1350 }
1351
1352 inst->emitOpError() << "can't be translated to XLA HLO";
1353 return failure();
1354 }
1355
LowerFunctionCall(mlir::CallOp call_op,xla::XlaBuilder * builder,ConvertToHloModule::ValueLoweringMap * value_lowering)1356 LogicalResult ConvertToHloModule::LowerFunctionCall(
1357 mlir::CallOp call_op, xla::XlaBuilder* builder,
1358 ConvertToHloModule::ValueLoweringMap* value_lowering) {
1359 auto& value_map = *value_lowering;
1360 mlir::FuncOp callee = module_.lookupSymbol<mlir::FuncOp>(call_op.callee());
1361 if (failed(RunOnFunction(callee))) return failure();
1362 std::vector<xla::XlaOp> operands;
1363 for (auto operand : call_op.getOperands()) {
1364 xla::XlaOp xla_operand;
1365 if (failed(GetXlaOp(operand, value_map, &xla_operand, call_op)))
1366 return failure();
1367 operands.push_back(xla_operand);
1368 }
1369 // Each call to xla::Call would insert a copy of the computation to
1370 // the HLO. Thus each callsite would have a unique callee in the
1371 // exported HLO. HLO syntactically does not require all calls to have unique
1372 // callees, but eventually before lowering call graph is "flattened" to
1373 // make that true. This is done before lowering because buffer assignment
1374 // needs this invariant.
1375 xla::XlaOp call_result =
1376 xla::Call(builder, lowered_computation_[callee], operands);
1377 // Use GetTupleElement for multiple outputs
1378 unsigned num_results = call_op.getNumResults();
1379 if (num_results > 1) {
1380 for (unsigned i = 0; i != num_results; ++i) {
1381 value_map[call_op.getResult(i)] = xla::GetTupleElement(call_result, i);
1382 }
1383 } else if (num_results == 1) {
1384 value_map[call_op.getResult(0)] = call_result;
1385 }
1386 return success();
1387 }
1388
RunOnFunction(mlir::FuncOp f)1389 LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
1390 if (lowered_computation_.count(f)) return success();
1391 if (!llvm::hasSingleElement(f)) {
1392 return f.emitError("only single block Function supported");
1393 }
1394
1395 // Create a sub-builder if this is not the main function.
1396 std::unique_ptr<xla::XlaBuilder> builder_up;
1397 bool entry_function = f.getName() == "main";
1398 if (!entry_function)
1399 builder_up = module_builder_.CreateSubBuilder(f.getName().str());
1400 auto& builder = entry_function ? module_builder_ : *builder_up;
1401
1402 xla::XlaComputation computation;
1403 std::vector<bool> entry_args_same_across_replicas;
1404 llvm::SmallVector<absl::optional<xla::OpSharding>, 4> arg_shardings;
1405 llvm::SmallVector<absl::optional<xla::OpSharding>, 4> ret_shardings;
1406 if (entry_function) {
1407 bool any_arg_replicated = false;
1408 entry_args_same_across_replicas.reserve(f.getNumArguments());
1409 for (int64_t i = 0; i < f.getNumArguments(); ++i) {
1410 auto attr = f.getArgAttrOfType<mlir::UnitAttr>(i, kRepicationAttr);
1411 entry_args_same_across_replicas.push_back(attr != nullptr);
1412 any_arg_replicated |= entry_args_same_across_replicas.back();
1413 // Pass the alias info to the builder so that it will build the alias info
1414 // into the resulting HloModule.
1415 auto aliasing_output =
1416 f.getArgAttrOfType<mlir::IntegerAttr>(i, "tf.aliasing_output");
1417 if (!aliasing_output) continue;
1418 if (use_tuple_args_) {
1419 builder.SetUpAlias(/*output_index=*/{aliasing_output.getInt()},
1420 /*param_number=*/0, /*param_index=*/{i});
1421 } else {
1422 builder.SetUpAlias(/*output_index=*/{aliasing_output.getInt()},
1423 /*param_number=*/i, /*param_index=*/{});
1424 }
1425 }
1426 // Do not populate this field when nothing is replicated, since empty field
1427 // means no replication. This avoids the need for unrelated tests to handle
1428 // this field.
1429 if (!any_arg_replicated) entry_args_same_across_replicas.clear();
1430
1431 ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings);
1432 }
1433 if (failed(LowerBasicBlockAsFunction(
1434 &f.front(), &builder, entry_function, entry_args_same_across_replicas,
1435 arg_shardings, ret_shardings, &computation))) {
1436 return failure();
1437 }
1438 lowered_computation_[f] = std::move(computation);
1439 return success();
1440 }
1441
SetEntryTupleShapesAndLeafReplication(Block * block,const std::vector<bool> & entry_args_same_across_replicas,llvm::SmallVectorImpl<xla::Shape> * arg_shapes,std::vector<bool> * leaf_replication)1442 LogicalResult ConvertToHloModule::SetEntryTupleShapesAndLeafReplication(
1443 Block* block, const std::vector<bool>& entry_args_same_across_replicas,
1444 llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
1445 std::vector<bool>* leaf_replication) {
1446 arg_shapes->reserve(block->getNumArguments());
1447 leaf_replication->reserve(block->getNumArguments());
1448 for (BlockArgument& arg : block->getArguments()) {
1449 arg_shapes->push_back(xla::TypeToShape(arg.getType()));
1450 xla::Shape& arg_shape = arg_shapes->back();
1451 tensorflow::TensorShape arg_tensor_shape;
1452 auto status =
1453 tensorflow::XLAShapeToTensorShape(arg_shape, &arg_tensor_shape);
1454 if (!status.ok())
1455 return block->getParentOp()->emitError() << status.error_message();
1456
1457 tensorflow::DataType dtype;
1458 status = tensorflow::ConvertToDataType(arg.getType(), &dtype);
1459 if (!status.ok())
1460 return block->getParentOp()->emitError() << status.error_message();
1461
1462 auto arg_shape_status = shape_representation_fn_(arg_tensor_shape, dtype,
1463 /*use_fast_memory=*/false);
1464 if (!arg_shape_status.ok())
1465 return block->getParentOp()->emitError()
1466 << arg_shape_status.status().error_message();
1467
1468 arg_shape = std::move(arg_shape_status.ValueOrDie());
1469
1470 if (entry_args_same_across_replicas.empty()) continue;
1471 for (int i = 0, e = xla::ShapeUtil::GetLeafCount(arg_shape); i < e; ++i)
1472 leaf_replication->push_back(
1473 entry_args_same_across_replicas[arg.getArgNumber()]);
1474 }
1475
1476 return success();
1477 }
1478
SetEntryTupleShardings(Block * block,xla::XlaBuilder * builder,llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,llvm::SmallVectorImpl<xla::Shape> * arg_shapes)1479 LogicalResult ConvertToHloModule::SetEntryTupleShardings(
1480 Block* block, xla::XlaBuilder* builder,
1481 llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
1482 llvm::SmallVectorImpl<xla::Shape>* arg_shapes) {
1483 if (!arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings)) {
1484 xla::OpSharding sharding;
1485 sharding.set_type(xla::OpSharding::TUPLE);
1486 for (auto arg_sharding : llvm::enumerate(arg_shardings)) {
1487 auto hlo_sharding = xla::HloSharding::FromProto(*arg_sharding.value());
1488 if (!hlo_sharding.ok())
1489 return block->getParentOp()->emitError()
1490 << hlo_sharding.status().error_message();
1491
1492 auto status = tensorflow::RewriteLayoutWithShardedShape(
1493 hlo_sharding.ValueOrDie(), /*use_fast_memory=*/false,
1494 shape_representation_fn_, &(*arg_shapes)[arg_sharding.index()]);
1495 if (!status.ok())
1496 return block->getParentOp()->emitError() << status.error_message();
1497
1498 *sharding.add_tuple_shardings() = *arg_sharding.value();
1499 }
1500
1501 builder->SetSharding(sharding);
1502 }
1503
1504 return success();
1505 }
1506
LowerBasicBlockAsFunction(Block * block,xla::XlaBuilder * builder,bool is_entry_function,const std::vector<bool> & entry_args_same_across_replicas,llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,xla::XlaComputation * result)1507 LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
1508 Block* block, xla::XlaBuilder* builder, bool is_entry_function,
1509 const std::vector<bool>& entry_args_same_across_replicas,
1510 llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
1511 llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
1512 xla::XlaComputation* result) {
1513 // Mapping from the Value to lowered XlaOp.
1514 ValueLoweringMap lowering;
1515
1516 // If using tuples as input, then there is only one input parameter that is a
1517 // tuple.
1518 if (is_entry_function && use_tuple_args_) {
1519 llvm::SmallVector<xla::Shape, 4> arg_shapes;
1520 std::vector<bool> leaf_replication;
1521 if (failed(SetEntryTupleShapesAndLeafReplication(
1522 block, entry_args_same_across_replicas, &arg_shapes,
1523 &leaf_replication)))
1524 return failure();
1525
1526 if (failed(
1527 SetEntryTupleShardings(block, builder, arg_shardings, &arg_shapes)))
1528 return failure();
1529
1530 xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
1531 auto tuple =
1532 xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication);
1533 builder->ClearSharding();
1534
1535 bool set_tuple_element_sharding =
1536 !arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings);
1537 for (BlockArgument& arg : block->getArguments()) {
1538 if (set_tuple_element_sharding)
1539 builder->SetSharding(*arg_shardings[arg.getArgNumber()]);
1540 lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber());
1541 }
1542 builder->ClearSharding();
1543 } else {
1544 for (BlockArgument& arg : block->getArguments()) {
1545 auto num = arg.getArgNumber();
1546 xla::Shape shape = xla::TypeToShape(arg.getType());
1547 if (entry_args_same_across_replicas.empty()) {
1548 lowering[arg] =
1549 xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num));
1550 } else {
1551 lowering[arg] = xla::Parameter(
1552 builder, num, shape, absl::StrCat("Arg_", num),
1553 std::vector<bool>(entry_args_same_across_replicas[num],
1554 xla::ShapeUtil::GetLeafCount(shape)));
1555 }
1556 }
1557 }
1558
1559 xla::XlaOp return_value;
1560 for (auto& inst : *block)
1561 if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
1562 &lowering, &return_value)))
1563 return failure();
1564
1565 // Build the XlaComputation and check for failures.
1566 auto computation_or =
1567 return_value.valid() ? builder->Build(return_value) : builder->Build();
1568 if (!computation_or.ok()) {
1569 block->back().emitError(
1570 llvm::Twine(computation_or.status().error_message()));
1571 return failure();
1572 }
1573 *result = std::move(computation_or.ValueOrDie());
1574 return success();
1575 }
1576
LowerRegionAsComputation(mlir::Region * region,xla::XlaComputation * func)1577 LogicalResult ConvertToHloModule::LowerRegionAsComputation(
1578 mlir::Region* region, xla::XlaComputation* func) {
1579 std::unique_ptr<xla::XlaBuilder> builder =
1580 module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++));
1581 return LowerBasicBlockAsFunction(
1582 ®ion->front(), builder.get(),
1583 /*is_entry_function=*/false, /*entry_args_same_across_replicas=*/{},
1584 /*arg_shardings=*/{}, /*ret_shardings=*/{}, func);
1585 }
1586
PaddingMapBadArrayAttrMsg(llvm::StringRef attr_name,int index)1587 std::string PaddingMapBadArrayAttrMsg(llvm::StringRef attr_name, int index) {
1588 return llvm::formatv(
1589 "requires '{0}' array attribute in '{1}' dict at arg {2}",
1590 attr_name, kPaddingMapAttr, index)
1591 .str();
1592 }
1593
PaddingMapMismatchedArraySizeMsg(int arg_index,int shape_indices_size,int padding_arg_indices_size)1594 std::string PaddingMapMismatchedArraySizeMsg(int arg_index,
1595 int shape_indices_size,
1596 int padding_arg_indices_size) {
1597 return llvm::formatv(
1598 "requires '{0}' and '{1}' array attributes in '{2}' dic at arg "
1599 "{3} to be of the same size, got sizes {4} and {5}",
1600 kShapeIndicesAttr, kPaddingArgIndicesAttr, kPaddingMapAttr,
1601 arg_index, shape_indices_size, padding_arg_indices_size)
1602 .str();
1603 }
1604
PaddingMapBadIntAttrMsg(llvm::StringRef attr_name,int arg_index,int element_index)1605 std::string PaddingMapBadIntAttrMsg(llvm::StringRef attr_name, int arg_index,
1606 int element_index) {
1607 return llvm::formatv(
1608 "requires element {0} in '{1}' array of '{2}' dict at arg {3} "
1609 "to be an int attribute",
1610 element_index, attr_name, kPaddingMapAttr, arg_index)
1611 .str();
1612 }
1613
PaddingMapBadIndexMsg(llvm::StringRef attr_name,int arg_index,int element_index,int max,int32_t value)1614 std::string PaddingMapBadIndexMsg(llvm::StringRef attr_name, int arg_index,
1615 int element_index, int max, int32_t value) {
1616 return llvm::formatv(
1617 "requires element {0} in '{1}' array of '{2}' dict at arg {3} "
1618 "to be in range [0, {4}), got {5}",
1619 element_index, attr_name, kPaddingMapAttr, arg_index, max, value)
1620 .str();
1621 }
1622
PaddingMapNegativeShapeIndexMsg(int arg_index,int element_index,int32_t value)1623 std::string PaddingMapNegativeShapeIndexMsg(int arg_index, int element_index,
1624 int32_t value) {
1625 return llvm::formatv(
1626 "requires element {0} in '{1}' array of '{2}' dict at arg {3} to "
1627 "be non-negative, got {4}",
1628 element_index, kShapeIndicesAttr, kPaddingMapAttr, arg_index,
1629 value)
1630 .str();
1631 }
1632
PaddingMapUniqueShapeIndexMsg(int arg_index,int element_index,int32_t value)1633 std::string PaddingMapUniqueShapeIndexMsg(int arg_index, int element_index,
1634 int32_t value) {
1635 return llvm::formatv(
1636 "requires elements in '{0}' array of '{1}' dict at arg {2} to be "
1637 "unique, got duplicate element {3} at index {4}",
1638 kShapeIndicesAttr, kPaddingMapAttr, arg_index, value,
1639 element_index)
1640 .str();
1641 }
1642
AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto * binding,int arg_index,int32_t shape_index,int32_t padding_arg_index,bool use_tuple_args)1643 void AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto* binding,
1644 int arg_index, int32_t shape_index,
1645 int32_t padding_arg_index,
1646 bool use_tuple_args) {
1647 auto* entry = binding->add_entries();
1648 entry->set_target_param_dim_num(shape_index);
1649 if (use_tuple_args) {
1650 entry->set_target_param_num(0);
1651 entry->add_target_param_index(arg_index);
1652 entry->set_dynamic_param_num(0);
1653 entry->add_dynamic_param_index(padding_arg_index);
1654 } else {
1655 entry->set_target_param_num(arg_index);
1656 entry->set_dynamic_param_num(padding_arg_index);
1657 }
1658 }
1659
1660 // Validates and populates dynamic parameter bindings from a module's entry
1661 // function `mhlo.padding_map` argument attributes to a `xla::HloModuleProto`
1662 // `DynamicParameterBindingProto`.
AddDynamicParameterBindings(mlir::ModuleOp module,xla::HloModuleProto * hlo_module_proto,bool use_tuple_args)1663 LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
1664 xla::HloModuleProto* hlo_module_proto,
1665 bool use_tuple_args) {
1666 auto entry_func = module.lookupSymbol<mlir::FuncOp>("main");
1667 if (!entry_func) return success();
1668
1669 auto* dynamic_parameter_binding =
1670 hlo_module_proto->mutable_dynamic_parameter_binding();
1671 for (int i = 0, e = entry_func.getNumArguments(); i < e; ++i) {
1672 auto padding_map_attr = entry_func.getArgAttr(i, kPaddingMapAttr);
1673 if (!padding_map_attr) continue;
1674 auto padding_map = padding_map_attr.dyn_cast<DictionaryAttr>();
1675 if (!padding_map)
1676 return entry_func.emitError() << "requires '" << kPaddingMapAttr
1677 << "' dict attribute at arg " << i;
1678
1679 auto shape_indices =
1680 padding_map.get(kShapeIndicesAttr).dyn_cast_or_null<ArrayAttr>();
1681 if (!shape_indices)
1682 return entry_func.emitError(
1683 PaddingMapBadArrayAttrMsg(kShapeIndicesAttr, i));
1684
1685 auto padding_arg_indices =
1686 padding_map.get(kPaddingArgIndicesAttr).dyn_cast_or_null<ArrayAttr>();
1687 if (!padding_arg_indices)
1688 return entry_func.emitError(
1689 PaddingMapBadArrayAttrMsg(kPaddingArgIndicesAttr, i));
1690
1691 if (shape_indices.size() != padding_arg_indices.size())
1692 return entry_func.emitError(PaddingMapMismatchedArraySizeMsg(
1693 i, shape_indices.size(), padding_arg_indices.size()));
1694
1695 llvm::SmallDenseSet<int32_t, 4> used_shape_indices;
1696 auto arg_type =
1697 entry_func.getArgument(i).getType().dyn_cast<RankedTensorType>();
1698 for (auto shape_and_padding : llvm::enumerate(llvm::zip(
1699 shape_indices.getValue(), padding_arg_indices.getValue()))) {
1700 const int element_index = shape_and_padding.index();
1701 auto shape_index_attr =
1702 std::get<0>(shape_and_padding.value()).dyn_cast<IntegerAttr>();
1703 if (!shape_index_attr)
1704 return entry_func.emitError(
1705 PaddingMapBadIntAttrMsg(kShapeIndicesAttr, i, element_index));
1706
1707 auto padding_arg_index_attr =
1708 std::get<1>(shape_and_padding.value()).dyn_cast<IntegerAttr>();
1709 if (!padding_arg_index_attr)
1710 return entry_func.emitError(
1711 PaddingMapBadIntAttrMsg(kPaddingArgIndicesAttr, i, element_index));
1712
1713 const int32_t shape_index = shape_index_attr.getInt();
1714 if (arg_type && (shape_index < 0 || shape_index >= arg_type.getRank()))
1715 return entry_func.emitError(
1716 PaddingMapBadIndexMsg(kShapeIndicesAttr, i, element_index,
1717 arg_type.getRank(), shape_index));
1718 else if (shape_index < 0)
1719 return entry_func.emitError(
1720 PaddingMapNegativeShapeIndexMsg(i, element_index, shape_index));
1721
1722 if (!used_shape_indices.insert(shape_index).second)
1723 return entry_func.emitError(
1724 PaddingMapUniqueShapeIndexMsg(i, element_index, shape_index));
1725
1726 const int32_t padding_arg_index = padding_arg_index_attr.getInt();
1727 if (padding_arg_index < 0 || padding_arg_index >= e)
1728 return entry_func.emitError(PaddingMapBadIndexMsg(
1729 kPaddingArgIndicesAttr, i, element_index, e, padding_arg_index));
1730
1731 Type padding_arg_type =
1732 entry_func.getArgument(padding_arg_index).getType();
1733 if (auto tensor_type = padding_arg_type.dyn_cast<RankedTensorType>())
1734 if (tensor_type.getRank() != 0)
1735 return entry_func.emitError()
1736 << "requires arg " << padding_arg_index
1737 << " to be a scalar for use as a dynamic parameter";
1738
1739 if (!mlir::getElementTypeOrSelf(padding_arg_type).isSignlessInteger())
1740 return entry_func.emitError()
1741 << "requires arg " << padding_arg_index
1742 << " to be of an int type for use as a dynamic parameter";
1743
1744 AddDynamicParameterBindingEntry(dynamic_parameter_binding, i, shape_index,
1745 padding_arg_index, use_tuple_args);
1746 }
1747 }
1748
1749 return success();
1750 }
1751
1752 } // namespace
1753
ConvertRegionToComputation(mlir::Region * region,xla::XlaComputation * func,MlirToHloConversionOptions options)1754 Status ConvertRegionToComputation(mlir::Region* region,
1755 xla::XlaComputation* func,
1756 MlirToHloConversionOptions options) {
1757 mlir::ModuleOp module;
1758 xla::XlaBuilder module_builder("main");
1759 ConvertToHloModule converter(module, module_builder, true, true, {}, options);
1760 if (failed(converter.LowerRegionAsComputation(region, func)))
1761 return tensorflow::errors::Internal(
1762 "failed to convert region to computation");
1763 return Status::OK();
1764 }
1765
ConvertMlirHloToHlo(mlir::ModuleOp module,xla::HloProto * hlo_proto,bool use_tuple_args,bool return_tuple,const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,MlirToHloConversionOptions options)1766 Status ConvertMlirHloToHlo(
1767 mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args,
1768 bool return_tuple,
1769 const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
1770 MlirToHloConversionOptions options) {
1771 // Prepare for export to XLA HLO.
1772 mlir::PassManager pm(module.getContext());
1773 pm.addNestedPass<mlir::FuncOp>(mhlo::CreatePrepareForExport());
1774 if (failed(pm.run(module)))
1775 return tensorflow::errors::Internal("Unable to optimize for XLA export");
1776 mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
1777 xla::XlaBuilder module_builder("main");
1778 ConvertToHloModule converter(module, module_builder, use_tuple_args,
1779 return_tuple, shape_representation_fn, options);
1780 if (failed(converter.Run())) return diag_handler.ConsumeStatus();
1781 auto hlo_module = converter.ConsumeMainProto();
1782 hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
1783 if (failed(AddDynamicParameterBindings(
1784 module, hlo_proto->mutable_hlo_module(), use_tuple_args)))
1785 return diag_handler.ConsumeStatus();
1786 return Status::OK();
1787 }
1788
BuildHloFromMlirHlo(mlir::Block & block,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,MlirToHloConversionOptions options)1789 Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
1790 llvm::ArrayRef<xla::XlaOp> xla_params,
1791 std::vector<xla::XlaOp>& returns,
1792 MlirToHloConversionOptions options) {
1793 auto module = block.getParentOp()->getParentOfType<mlir::ModuleOp>();
1794 ConvertToHloModule converter(module, builder,
1795 /*use_tuple_args=*/false, /*return_tuple=*/false,
1796 /*shape_representation_fn=*/nullptr, options);
1797
1798 ConvertToHloModule::ValueLoweringMap lowering;
1799 if (xla_params.size() != block.getArguments().size())
1800 return tensorflow::errors::Internal(
1801 "xla_params size != block arguments size");
1802 for (BlockArgument& arg : block.getArguments()) {
1803 auto num = arg.getArgNumber();
1804 lowering[arg] = xla_params[num];
1805 }
1806
1807 mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
1808 for (auto& inst : block) {
1809 if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
1810 returns.resize(inst.getNumOperands());
1811 for (OpOperand& ret : inst.getOpOperands()) {
1812 unsigned index = ret.getOperandNumber();
1813 xla::XlaOp operand;
1814 if (failed(GetXlaOp(ret.get(), lowering, &operand, &inst)))
1815 return diag_handler.ConsumeStatus();
1816 returns[index] = operand;
1817 }
1818 } else {
1819 xla::XlaOp return_value;
1820 if (failed(converter.Lower(&inst, /*is_entry_function=*/true,
1821 /*ret_shardings=*/{}, &builder, &lowering,
1822 &return_value)))
1823 return diag_handler.ConsumeStatus();
1824 }
1825 }
1826
1827 return Status::OK();
1828 }
1829
GetLayoutFromMlirHlo(mlir::Operation * op)1830 DenseIntElementsAttr GetLayoutFromMlirHlo(mlir::Operation* op) {
1831 return op->getAttrOfType<mlir::DenseIntElementsAttr>("minor_to_major");
1832 }
1833
1834 } // namespace mlir
1835