• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/MemoryBuffer.h"
28 #include "llvm/Support/SMLoc.h"
29 #include "llvm/Support/SourceMgr.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
32 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
33 #include "mlir/IR/Attributes.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
35 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
36 #include "mlir/IR/Location.h"  // from @llvm-project
37 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
38 #include "mlir/IR/Matchers.h"  // from @llvm-project
39 #include "mlir/IR/Operation.h"  // from @llvm-project
40 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
41 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
42 #include "mlir/Pass/Pass.h"  // from @llvm-project
43 #include "mlir/Pass/PassManager.h"  // from @llvm-project
44 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
46 #include "tensorflow/compiler/mlir/utils/name_utils.h"
47 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
48 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
49 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
50 #include "tensorflow/compiler/tf2xla/shape_util.h"
51 #include "tensorflow/compiler/xla/client/lib/matrix.h"
52 #include "tensorflow/compiler/xla/client/lib/quantize.h"
53 #include "tensorflow/compiler/xla/client/lib/slicing.h"
54 #include "tensorflow/compiler/xla/client/xla_builder.h"
55 #include "tensorflow/compiler/xla/comparison_util.h"
56 #include "tensorflow/compiler/xla/literal_util.h"
57 #include "tensorflow/compiler/xla/service/hlo_module.h"
58 #include "tensorflow/compiler/xla/shape_util.h"
59 #include "tensorflow/compiler/xla/status_macros.h"
60 #include "tensorflow/compiler/xla/xla_data.pb.h"
61 #include "tensorflow/core/framework/tensor_shape.h"
62 #include "tensorflow/core/framework/types.pb.h"
63 #include "tensorflow/core/platform/errors.h"
64 #include "tensorflow/stream_executor/lib/statusor.h"
65 
66 using ::stream_executor::port::StatusOr;
67 using ::tensorflow::int16;
68 using ::tensorflow::int32;
69 using ::tensorflow::int64;
70 using ::tensorflow::int8;
71 using ::tensorflow::uint16;
72 using ::tensorflow::uint32;
73 using ::tensorflow::uint64;
74 using ::tensorflow::uint8;
75 
76 constexpr char kPaddingMapAttr[] = "mhlo.padding_map";
77 constexpr char kShapeIndicesAttr[] = "shape_indices";
78 constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
79 constexpr char kShardingAttr[] = "mhlo.sharding";
80 constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
81 constexpr char kRepicationAttr[] = "mhlo.is_same_data_across_replicas";
82 
83 // Array attribute. Same shape as infeed result, but contains a
84 // minor_to_major array for every tensor.
85 constexpr char kLayoutAttr[] = "layout";
86 
87 // Passes through everything except for unique_ptr, on which it calls get().
88 // This exists to allow the generated code to call XLA functions that take a raw
89 // pointer. In particular, PrecisionConfig is passed to xla::Dot and xla::Conv
90 // as a pointer and there is otherwise no way to avoid a memory leak.
91 template <typename T>
Unwrap(T t)92 T Unwrap(T t) {
93   return t;
94 }
95 
96 template <typename T>
Unwrap(const std::unique_ptr<T> & t)97 T* Unwrap(const std::unique_ptr<T>& t) {
98   return t.get();
99 }
100 
GetXlaOp(mlir::Value val,const llvm::DenseMap<mlir::Value,xla::XlaOp> & val_map,xla::XlaOp * result,mlir::Operation * op)101 static mlir::LogicalResult GetXlaOp(
102     mlir::Value val, const llvm::DenseMap<mlir::Value, xla::XlaOp>& val_map,
103     xla::XlaOp* result, mlir::Operation* op) {
104   auto iter = val_map.find(val);
105   if (iter == val_map.end()) {
106     return op->emitOpError(
107         "requires all operands to be defined in the parent region for export");
108   }
109   *result = iter->second;
110   return mlir::success();
111 }
112 
113 // Convert APInt into an int.
114 // TODO(hpucha): This should be consolidated into a general place.
ConvertAPInt(llvm::APInt i)115 static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); }
116 
Convertuint32_t(uint32_t i)117 static uint32_t Convertuint32_t(uint32_t i) { return i; }
Convertuint64_t(uint64_t i)118 static uint64_t Convertuint64_t(uint64_t i) { return i; }
119 
120 // Convert APFloat to double.
ConvertAPFloat(llvm::APFloat value)121 static double ConvertAPFloat(llvm::APFloat value) {
122   const auto& semantics = value.getSemantics();
123   bool losesInfo = false;
124   if (&semantics != &llvm::APFloat::IEEEdouble())
125     value.convert(llvm::APFloat::IEEEdouble(),
126                   llvm::APFloat::rmNearestTiesToEven, &losesInfo);
127   return value.convertToDouble();
128 }
129 
Convertbool(bool value)130 static inline bool Convertbool(bool value) { return value; }
131 
ConvertStringRef(mlir::StringRef value)132 static absl::string_view ConvertStringRef(mlir::StringRef value) {
133   return {value.data(), value.size()};
134 }
135 
ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr)136 static std::vector<int64> ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) {
137   auto values = attr.getValues<int64>();
138   return {values.begin(), values.end()};
139 }
140 
ConvertDenseIntAttr(llvm::Optional<mlir::DenseIntElementsAttr> attr)141 static std::vector<int64> ConvertDenseIntAttr(
142     llvm::Optional<mlir::DenseIntElementsAttr> attr) {
143   if (!attr) return {};
144   return ConvertDenseIntAttr(*attr);
145 }
146 
147 // Converts the broadcast_dimensions attribute into a vector of dimension
148 // numbers (empty if the attribute is absent).
Convert_broadcast_dimensions(llvm::Optional<mlir::DenseIntElementsAttr> broadcast_dimensions)149 static std::vector<int64> Convert_broadcast_dimensions(
150     llvm::Optional<mlir::DenseIntElementsAttr> broadcast_dimensions) {
151   if (!broadcast_dimensions.hasValue()) return {};
152 
153   return ConvertDenseIntAttr(*broadcast_dimensions);
154 }
155 
156 // Converts StringRef to xla FftType enum
Convert_fft_type(llvm::StringRef fft_type_str)157 static xla::FftType Convert_fft_type(llvm::StringRef fft_type_str) {
158   xla::FftType fft_type_enum;
159   // Illegal fft_type string would be caught by the verifier, so 'FftType_Parse'
160   // call below should never return false.
161   if (!FftType_Parse(std::string(fft_type_str), &fft_type_enum))
162     return xla::FftType::FFT;
163   return fft_type_enum;
164 }
165 
Convert_padding(llvm::Optional<mlir::DenseIntElementsAttr> padding)166 static std::vector<std::pair<int64, int64>> Convert_padding(
167     llvm::Optional<mlir::DenseIntElementsAttr> padding) {
168   return xla::ConvertNx2Attribute(padding).ValueOrDie();
169 }
170 
Convert_source_target_pairs(llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs)171 static std::vector<std::pair<int64, int64>> Convert_source_target_pairs(
172     llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs) {
173   return xla::ConvertNx2Attribute(source_target_pairs).ValueOrDie();
174 }
175 
Convert_replica_groups(mlir::DenseIntElementsAttr groups)176 static std::vector<xla::ReplicaGroup> Convert_replica_groups(
177     mlir::DenseIntElementsAttr groups) {
178   return xla::ConvertReplicaGroups(groups).ValueOrDie();
179 }
180 
181 // Converts StringRef to xla Transpose enum.
Convert_transpose_a(llvm::StringRef transpose_str)182 static xla::TriangularSolveOptions::Transpose Convert_transpose_a(
183     llvm::StringRef transpose_str) {
184   return xla::ConvertTranspose(transpose_str).ValueOrDie();
185 }
186 
187 #define I64_ELEMENTS_ATTR_TO_VECTOR(attribute)                \
188   static std::vector<int64> Convert_##attribute(              \
189       llvm::Optional<mlir::DenseIntElementsAttr> attribute) { \
190     return ConvertDenseIntAttr(attribute);                    \
191   }
192 
193 I64_ELEMENTS_ATTR_TO_VECTOR(broadcast_sizes);
194 I64_ELEMENTS_ATTR_TO_VECTOR(permutation);
195 I64_ELEMENTS_ATTR_TO_VECTOR(start_indices);
196 I64_ELEMENTS_ATTR_TO_VECTOR(limit_indices);
197 I64_ELEMENTS_ATTR_TO_VECTOR(strides);
198 I64_ELEMENTS_ATTR_TO_VECTOR(slice_sizes);
199 I64_ELEMENTS_ATTR_TO_VECTOR(fft_length);
200 I64_ELEMENTS_ATTR_TO_VECTOR(dimensions);
201 I64_ELEMENTS_ATTR_TO_VECTOR(window_strides);
202 I64_ELEMENTS_ATTR_TO_VECTOR(lhs_dilation);
203 I64_ELEMENTS_ATTR_TO_VECTOR(rhs_dilation);
204 
205 #undef I64_ELEMENTS_ATTR_TO_VECTOR
206 
Convert_ArrayRef(llvm::ArrayRef<int64_t> values)207 static std::vector<int64> Convert_ArrayRef(llvm::ArrayRef<int64_t> values) {
208   return {values.begin(), values.end()};
209 }
210 
211 // Converts the precision config array of strings attribute into the
212 // corresponding XLA proto. All the strings are assumed to be valid names of the
213 // Precision enum. This should have been checked in the op verify method.
Convert_precision_config(llvm::Optional<mlir::ArrayAttr> optional_precision_config_attr)214 static std::unique_ptr<xla::PrecisionConfig> Convert_precision_config(
215     llvm::Optional<mlir::ArrayAttr> optional_precision_config_attr) {
216   if (!optional_precision_config_attr.hasValue()) return nullptr;
217 
218   auto precision_config = absl::make_unique<xla::PrecisionConfig>();
219   for (auto attr : optional_precision_config_attr.getValue()) {
220     xla::PrecisionConfig::Precision p;
221     auto operand_precision = attr.cast<mlir::StringAttr>().getValue().str();
222     // TODO(jpienaar): Update this to ensure this is captured by verify.
223     if (xla::PrecisionConfig::Precision_Parse(operand_precision, &p)) {
224       precision_config->add_operand_precision(p);
225     } else {
226       auto* context = attr.getContext();
227       mlir::emitError(mlir::UnknownLoc::get(context))
228           << "unexpected operand precision " << operand_precision;
229       return nullptr;
230     }
231   }
232 
233   return precision_config;
234 }
235 
Convert_dot_dimension_numbers(mlir::mhlo::DotDimensionNumbers dot_dimension_numbers_attr)236 static xla::DotDimensionNumbers Convert_dot_dimension_numbers(
237     mlir::mhlo::DotDimensionNumbers dot_dimension_numbers_attr) {
238   xla::DotDimensionNumbers dot_dimension_numbers;
239 
240   auto rhs_contracting_dimensions =
241       dot_dimension_numbers_attr.rhs_contracting_dimensions()
242           .cast<mlir::DenseIntElementsAttr>();
243   auto lhs_contracting_dimensions =
244       dot_dimension_numbers_attr.lhs_contracting_dimensions()
245           .cast<mlir::DenseIntElementsAttr>();
246   auto rhs_batch_dimensions =
247       dot_dimension_numbers_attr.rhs_batching_dimensions()
248           .cast<mlir::DenseIntElementsAttr>();
249   auto lhs_batch_dimensions =
250       dot_dimension_numbers_attr.lhs_batching_dimensions()
251           .cast<mlir::DenseIntElementsAttr>();
252 
253   for (const auto& val : rhs_contracting_dimensions) {
254     dot_dimension_numbers.add_rhs_contracting_dimensions(val.getSExtValue());
255   }
256   for (const auto& val : lhs_contracting_dimensions) {
257     dot_dimension_numbers.add_lhs_contracting_dimensions(val.getSExtValue());
258   }
259 
260   for (const auto& val : rhs_batch_dimensions) {
261     dot_dimension_numbers.add_rhs_batch_dimensions(val.getSExtValue());
262   }
263 
264   for (const auto& val : lhs_batch_dimensions) {
265     dot_dimension_numbers.add_lhs_batch_dimensions(val.getSExtValue());
266   }
267 
268   return dot_dimension_numbers;
269 }
270 
Convert_dimension_numbers(mlir::mhlo::ConvDimensionNumbers input)271 static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
272     mlir::mhlo::ConvDimensionNumbers input) {
273   return xla::ConvertConvDimensionNumbers(input);
274 }
275 
Convert_channel_handle(mlir::mhlo::ChannelHandle attr)276 xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandle attr) {
277   xla::ChannelHandle channel_handle;
278   channel_handle.set_handle(ConvertAPInt(attr.handle().getValue()));
279   channel_handle.set_type(static_cast<xla::ChannelHandle::ChannelType>(
280       ConvertAPInt(attr.type().getValue())));
281   return channel_handle;
282 }
283 
284 // Converts the comparison_direction string attribute into the XLA enum. The
285 // string is assumed to correspond to exactly one of the allowed strings
286 // representing the enum. This should have been checked in the op verify method.
Convert_comparison_direction(llvm::StringRef comparison_direction_string)287 static xla::ComparisonDirection Convert_comparison_direction(
288     llvm::StringRef comparison_direction_string) {
289   return xla::StringToComparisonDirection(comparison_direction_string.str())
290       .ValueOrDie();
291 }
292 
Convert_dimension_numbers(mlir::mhlo::GatherDimensionNumbers input)293 static xla::GatherDimensionNumbers Convert_dimension_numbers(
294     mlir::mhlo::GatherDimensionNumbers input) {
295   xla::GatherDimensionNumbers output;
296 
297   auto offset_dims = ConvertDenseIntAttr(input.offset_dims());
298   std::copy(offset_dims.begin(), offset_dims.end(),
299             tensorflow::protobuf::RepeatedFieldBackInserter(
300                 output.mutable_offset_dims()));
301 
302   auto collapsed_slice_dims = ConvertDenseIntAttr(input.collapsed_slice_dims());
303   std::copy(collapsed_slice_dims.begin(), collapsed_slice_dims.end(),
304             tensorflow::protobuf::RepeatedFieldBackInserter(
305                 output.mutable_collapsed_slice_dims()));
306 
307   auto start_index_map = ConvertDenseIntAttr(input.start_index_map());
308   std::copy(start_index_map.begin(), start_index_map.end(),
309             tensorflow::protobuf::RepeatedFieldBackInserter(
310                 output.mutable_start_index_map()));
311 
312   output.set_index_vector_dim(
313       ConvertAPInt(input.index_vector_dim().getValue()));
314   return output;
315 }
316 
Convert_scatter_dimension_numbers(mlir::mhlo::ScatterDimensionNumbers input)317 static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
318     mlir::mhlo::ScatterDimensionNumbers input) {
319   xla::ScatterDimensionNumbers output;
320 
321   auto update_window_dims = ConvertDenseIntAttr(input.update_window_dims());
322   std::copy(update_window_dims.begin(), update_window_dims.end(),
323             tensorflow::protobuf::RepeatedFieldBackInserter(
324                 output.mutable_update_window_dims()));
325 
326   auto inserted_window_dims = ConvertDenseIntAttr(input.inserted_window_dims());
327   std::copy(inserted_window_dims.begin(), inserted_window_dims.end(),
328             tensorflow::protobuf::RepeatedFieldBackInserter(
329                 output.mutable_inserted_window_dims()));
330 
331   auto scatter_dims_to_operand_dims =
332       ConvertDenseIntAttr(input.scatter_dims_to_operand_dims());
333   std::copy(scatter_dims_to_operand_dims.begin(),
334             scatter_dims_to_operand_dims.end(),
335             tensorflow::protobuf::RepeatedFieldBackInserter(
336                 output.mutable_scatter_dims_to_operand_dims()));
337 
338   output.set_index_vector_dim(
339       ConvertAPInt(input.index_vector_dim().getValue()));
340   return output;
341 }
342 
343 // Extracts sharding from attribute string.
CreateOpShardingFromStringRef(llvm::StringRef sharding)344 static absl::optional<xla::OpSharding> CreateOpShardingFromStringRef(
345     llvm::StringRef sharding) {
346   xla::OpSharding sharding_proto;
347   if (!sharding_proto.ParseFromString(sharding.str())) return absl::nullopt;
348   return sharding_proto;
349 }
350 
351 // Returns an OpSharding proto from the "sharding" attribute of the op. If the
352 // op doesn't have a sharding attribute or the sharding attribute is invalid,
353 // returns absl::nullopt.
CreateOpShardingFromAttribute(mlir::Operation * op)354 static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
355     mlir::Operation* op) {
356   auto sharding = op->getAttrOfType<mlir::StringAttr>(kShardingAttr);
357   if (!sharding) return absl::nullopt;
358   return CreateOpShardingFromStringRef(sharding.getValue());
359 }
360 
361 // Returns a FrontendAttributes proto from the "frontend_attributes" attribute
362 // of the op. An empty FrontendAttributes proto is returned if an op does not
363 // have frontend attributes.
CreateOpFrontendAttributesFromAttribute(mlir::Operation * op)364 static xla::FrontendAttributes CreateOpFrontendAttributesFromAttribute(
365     mlir::Operation* op) {
366   xla::FrontendAttributes frontend_attributes;
367   auto frontend_attributes_dict =
368       op->getAttrOfType<mlir::DictionaryAttr>(kFrontendAttributesAttr);
369 
370   if (!frontend_attributes_dict) return frontend_attributes;
371 
372   for (const auto& attr : frontend_attributes_dict)
373     if (auto value_str_attr = attr.second.dyn_cast<mlir::StringAttr>())
374       frontend_attributes.mutable_map()->insert(
375           {attr.first.str(), value_str_attr.getValue().str()});
376 
377   return frontend_attributes;
378 }
379 
380 // Returns a OpMetadata proto based on the location of the op. If the location
381 // is unknown, an empty proto is returned. `op_name` are populated with the op
382 // location (converted). FileLineColLoc locations are populated by taking the
383 // file name and line number, and populating `source_file` and `source_line`
384 // respectively.
CreateOpMetadataFromLocation(mlir::Operation * op)385 static xla::OpMetadata CreateOpMetadataFromLocation(mlir::Operation* op) {
386   xla::OpMetadata metadata;
387   if (op->getLoc().isa<mlir::UnknownLoc>()) return metadata;
388 
389   std::string name = mlir::GetNameFromLoc(op->getLoc());
390   mlir::LegalizeNodeName(name);
391   metadata.set_op_name(name);
392 
393   if (auto file_line_col_loc = op->getLoc().dyn_cast<mlir::FileLineColLoc>()) {
394     metadata.set_source_file(file_line_col_loc.getFilename().str());
395     metadata.set_source_line(file_line_col_loc.getLine());
396   }
397 
398   return metadata;
399 }
400 
401 // Checks if all shardings are set.
AllOptionalShardingsAreSet(llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings)402 static bool AllOptionalShardingsAreSet(
403     llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings) {
404   return llvm::all_of(shardings,
405                       [](const absl::optional<xla::OpSharding>& sharding) {
406                         return sharding.has_value();
407                       });
408 }
409 
410 // Extracts argument and result shardings from function.
ExtractShardingsFromFunction(mlir::FuncOp function,llvm::SmallVectorImpl<absl::optional<xla::OpSharding>> * arg_shardings,llvm::SmallVectorImpl<absl::optional<xla::OpSharding>> * ret_shardings)411 static void ExtractShardingsFromFunction(
412     mlir::FuncOp function,
413     llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* arg_shardings,
414     llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* ret_shardings) {
415   arg_shardings->resize(function.getNumArguments(),
416                         absl::optional<xla::OpSharding>());
417   for (int i = 0, end = function.getNumArguments(); i < end; ++i)
418     if (auto sharding =
419             function.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr))
420       (*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
421 
422   ret_shardings->resize(function.getNumResults(),
423                         absl::optional<xla::OpSharding>());
424   for (int i = 0, end = function.getNumResults(); i < end; ++i)
425     if (auto sharding =
426             function.getResultAttrOfType<mlir::StringAttr>(i, kShardingAttr))
427       (*ret_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
428 }
429 
430 namespace mlir {
431 namespace {
432 class ConvertToHloModule {
433  public:
434   using ValueLoweringMap = llvm::DenseMap<Value, xla::XlaOp>;
435   using FunctionLoweringMap = llvm::DenseMap<mlir::FuncOp, xla::XlaComputation>;
436 
437   // If use_tuple_args is true, then the entry function's arguments are
438   // converted to a tuple and passed as a single parameter.
439   // Similarly, if return tuple is true, then the entry function's return values
440   // are converted to a tuple even when there is only a single return value.
441   // Multiple return values are always converted to a tuple and returned as a
442   // single value.
ConvertToHloModule(mlir::ModuleOp module,xla::XlaBuilder & module_builder,bool use_tuple_args,bool return_tuple,tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,MlirToHloConversionOptions options)443   explicit ConvertToHloModule(
444       mlir::ModuleOp module, xla::XlaBuilder& module_builder,
445       bool use_tuple_args, bool return_tuple,
446       tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
447       MlirToHloConversionOptions options)
448       : module_(module),
449         module_builder_(module_builder),
450         use_tuple_args_(use_tuple_args),
451         return_tuple_(return_tuple),
452         shape_representation_fn_(shape_representation_fn),
453         options_(options) {
454     if (!shape_representation_fn_)
455       shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn();
456   }
457 
458   // Perform the lowering to XLA. This function returns failure if an error was
459   // encountered.
460   //
461   // TODO(hinsu): Check for dynamic shapes and exit instead of crashing.
Run()462   LogicalResult Run() {
463     auto main = module_.lookupSymbol<mlir::FuncOp>("main");
464     if (!main)
465       return module_.emitError(
466           "conversion requires module with `main` function");
467 
468     for (auto func : module_.getOps<FuncOp>()) {
469       if (func.empty()) continue;
470       if (failed(RunOnFunction(func))) return failure();
471     }
472     return success();
473   }
474 
475   // Lower a specific function to HLO.
476   LogicalResult RunOnFunction(mlir::FuncOp f);
477 
478   // Lower a `mlir::Region` to a `XlaComputation`
479   LogicalResult LowerRegionAsComputation(mlir::Region* region,
480                                          xla::XlaComputation* func);
481 
482   // Lower a single `Block` to a `XlaComputation`
483   LogicalResult LowerBasicBlockAsFunction(
484       Block* block, xla::XlaBuilder* builder, bool is_entry_function,
485       const std::vector<bool>& entry_args_same_across_replicas,
486       llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
487       llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
488       xla::XlaComputation* result);
489 
ConsumeMainProto()490   ::xla::HloModuleProto ConsumeMainProto() {
491     auto main = module_.lookupSymbol<mlir::FuncOp>("main");
492     // This is an invariant check as Run returns failure if there is no main
493     // function and so the main proto shouldn't be consumed in that case.
494     CHECK(main) << "requires module to have main function";  // Crash Ok.
495     return lowered_computation_[main].proto();
496   }
497 
498   // Lower function call to HLO call instruction
499   LogicalResult LowerFunctionCall(
500       mlir::CallOp call_op, xla::XlaBuilder* builder,
501       ConvertToHloModule::ValueLoweringMap* value_lowering);
502 
503   LogicalResult Lower(
504       mlir::Operation* inst, bool is_entry_function,
505       llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
506       xla::XlaBuilder* builder,
507       ConvertToHloModule::ValueLoweringMap* value_lowering,
508       xla::XlaOp* return_value);
509 
510  private:
511   LogicalResult SetEntryTupleShapesAndLeafReplication(
512       Block* block, const std::vector<bool>& entry_args_same_across_replicas,
513       llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
514       std::vector<bool>* leaf_replication);
515 
516   LogicalResult SetEntryTupleShardings(
517       Block* block, xla::XlaBuilder* builder,
518       llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
519       llvm::SmallVectorImpl<xla::Shape>* arg_shapes);
520 
521   // The module being lowered.
522   mlir::ModuleOp module_;
523 
524   // The top-level XlaBuilder.
525   xla::XlaBuilder& module_builder_;
526 
527   // Map between function and lowered computation.
528   FunctionLoweringMap lowered_computation_;
529 
530   // Whether the entry function should take a single tuple as input.
531   bool use_tuple_args_;
532 
533   // Whether to always return a tuple.
534   bool return_tuple_;
535 
536   // Shape representation function to determine entry function argument and
537   // result shapes.
538   tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_;
539 
540   // Unique suffix to give to the name of the next lowered region.
541   size_t region_id_ = 0;
542 
543   MlirToHloConversionOptions options_;
544 };
545 
546 }  // namespace
547 }  // namespace mlir
548 
549 namespace {
550 
551 struct OpLoweringContext {
552   llvm::DenseMap<mlir::Value, xla::XlaOp>* values;
553   mlir::ConvertToHloModule* converter;
554   xla::XlaBuilder* builder;
555 };
556 
GetTuple(mlir::Operation::operand_range values,OpLoweringContext ctx)557 llvm::SmallVector<xla::XlaOp, 4> GetTuple(mlir::Operation::operand_range values,
558                                           OpLoweringContext ctx) {
559   llvm::SmallVector<xla::XlaOp, 4> ops;
560   for (mlir::Value value : values) {
561     ops.push_back((*ctx.values)[value]);
562   }
563   return ops;
564 }
565 
566 }  // namespace
567 
568 namespace mlir {
569 namespace mhlo {
570 namespace {
571 
ExportXlaOp(AllReduceOp op,OpLoweringContext ctx)572 LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) {
573   auto& value_map = *ctx.values;
574   xla::XlaComputation computation;
575   if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
576                                                      &computation))) {
577     return failure();
578   }
579   auto replica_groups = Convert_replica_groups(op.replica_groups());
580   xla::XlaOp operand;
581   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
582 
583   if (!op.channel_id().hasValue()) {
584     value_map[op] = xla::AllReduce(operand, computation, replica_groups,
585                                    /*channel_id=*/absl::nullopt);
586     return success();
587   }
588   auto channel_id = Convert_channel_handle(op.channel_id().getValue());
589   value_map[op] =
590       xla::AllReduce(operand, computation, replica_groups, channel_id);
591   return success();
592 }
593 
ExportXlaOp(BitcastConvertOp op,OpLoweringContext ctx)594 LogicalResult ExportXlaOp(BitcastConvertOp op, OpLoweringContext ctx) {
595   auto& value_map = *ctx.values;
596   xla::XlaOp operand;
597   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
598 
599   value_map[op] = xla::BitcastConvertType(
600       operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())));
601   return success();
602 }
603 
ExportXlaOp(BroadcastInDimOp op,OpLoweringContext ctx)604 LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) {
605   auto type = op.getType().dyn_cast<RankedTensorType>();
606   if (!type) return failure();
607   auto& value_map = *ctx.values;
608   xla::XlaOp operand;
609   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
610 
611   value_map[op] =
612       BroadcastInDim(operand, Convert_ArrayRef(type.getShape()),
613                      Convert_broadcast_dimensions(op.broadcast_dimensions()));
614   return success();
615 }
616 
ExportXlaOp(DynamicBroadcastInDimOp op,OpLoweringContext ctx)617 LogicalResult ExportXlaOp(DynamicBroadcastInDimOp op, OpLoweringContext ctx) {
618   // This op has no expression in the legacy export format.
619   return failure();
620 }
621 
ExportXlaOp(DynamicIotaOp op,OpLoweringContext ctx)622 LogicalResult ExportXlaOp(DynamicIotaOp op, OpLoweringContext ctx) {
623   // This op has no expression in the legacy export format.
624   return failure();
625 }
626 
ExportXlaOp(DynamicReshapeOp op,OpLoweringContext ctx)627 LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) {
628   // This op has no expression in the legacy export format.
629   return failure();
630 }
631 
ExportXlaOp(IfOp op,OpLoweringContext ctx)632 LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) {
633   xla::XlaComputation true_branch;
634   xla::XlaComputation false_branch;
635   auto& value_map = *ctx.values;
636   if (failed(ctx.converter->LowerRegionAsComputation(&op.true_branch(),
637                                                      &true_branch)) ||
638       failed(ctx.converter->LowerRegionAsComputation(&op.false_branch(),
639                                                      &false_branch))) {
640     return failure();
641   }
642   xla::XlaOp pred, true_arg, false_arg;
643   if (failed(GetXlaOp(op.pred(), value_map, &pred, op))) return failure();
644   if (failed(GetXlaOp(op.true_arg(), value_map, &true_arg, op)))
645     return failure();
646   if (failed(GetXlaOp(op.false_arg(), value_map, &false_arg, op)))
647     return failure();
648 
649   value_map[op] =
650       xla::Conditional(pred, true_arg, true_branch, false_arg, false_branch);
651   return success();
652 }
653 
ExportXlaOp(CaseOp op,OpLoweringContext ctx)654 LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) {
655   llvm::DenseMap<mlir::Value, xla::XlaOp>& value_map = *ctx.values;
656   OperandRange operands = op.branch_operands();
657   MutableArrayRef<Region> branches = op.branches();
658   llvm::SmallVector<xla::XlaOp, 4> branch_operands(branches.size());
659   std::vector<xla::XlaComputation> computations(branches.size());
660   std::vector<xla::XlaComputation*> computations_p(branches.size());
661 
662   for (unsigned i = 0; i < branches.size(); ++i) {
663     xla::XlaOp operand;
664     if (failed(GetXlaOp(operands[i], value_map, &operand, op)))
665       return failure();
666     branch_operands[i] = operand;
667     computations_p[i] = &computations[i];
668     if (failed(ctx.converter->LowerRegionAsComputation(&branches[i],
669                                                        computations_p[i])))
670       return failure();
671   }
672   xla::XlaOp index;
673   if (failed(GetXlaOp(op.index(), value_map, &index, op))) return failure();
674 
675   xla::XlaOp result = xla::Conditional(index, computations_p, branch_operands);
676   if (op.getNumResults() == 1) {
677     value_map[op.getResult(0)] = result;
678   } else {
679     for (auto item : llvm::enumerate(op.getResults())) {
680       value_map[item.value()] = xla::GetTupleElement(result, item.index());
681     }
682   }
683   return success();
684 }
685 
686 // Specialize CompareOp export to set broadcast_dimensions argument.
ExportXlaOp(mlir::mhlo::CompareOp op,OpLoweringContext ctx)687 mlir::LogicalResult ExportXlaOp(mlir::mhlo::CompareOp op,
688                                 OpLoweringContext ctx) {
689   auto& value_map = *ctx.values;
690   xla::XlaOp lhs, rhs;
691   if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
692   if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
693   auto dir = Convert_comparison_direction(op.comparison_direction());
694   auto type_attr = op.compare_typeAttr();
695 
696   xla::XlaOp xla_result;
697   if (type_attr) {
698     auto type =
699         xla::StringToComparisonType(type_attr.getValue().str()).ValueOrDie();
700     xla_result = xla::Compare(lhs, rhs, /*broadcast_dimensions=*/{}, dir, type);
701   } else {
702     xla_result = xla::Compare(lhs, rhs, dir);
703   }
704   value_map[op] = xla_result;
705   return mlir::success();
706 }
707 
ExportXlaOp(ConstOp op,OpLoweringContext ctx)708 LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) {
709   return failure();
710 }
711 
ExportXlaOp(mlir::mhlo::ConvOp op,OpLoweringContext ctx)712 LogicalResult ExportXlaOp(mlir::mhlo::ConvOp op, OpLoweringContext ctx) {
713   // XLA client builder API does not support generating convolution instructions
714   // with window reversal.
715   if (op.hasWindowReversal()) return failure();
716   auto& value_map = *ctx.values;
717   xla::XlaOp lhs, rhs;
718   if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
719   if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
720   xla::XlaOp xla_result = xla::ConvGeneralDilated(
721       lhs, rhs, Convert_window_strides(op.window_strides()),
722       Convert_padding(op.padding()), Convert_lhs_dilation(op.lhs_dilation()),
723       Convert_rhs_dilation(op.rhs_dilation()),
724       Convert_dimension_numbers(op.dimension_numbers()),
725       Convertuint64_t(op.feature_group_count()),
726       Convertuint64_t(op.batch_group_count()),
727       Unwrap(Convert_precision_config(op.precision_config())));
728   value_map[op] = xla_result;
729   return mlir::success();
730 }
731 
ExportXlaOp(ConvertOp op,OpLoweringContext ctx)732 LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) {
733   auto& value_map = *ctx.values;
734   xla::XlaOp operand;
735   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
736 
737   value_map[op] = xla::ConvertElementType(
738       operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())));
739   return success();
740 }
741 
ExportXlaOp(CustomCallOp op,OpLoweringContext ctx)742 LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) {
743   // XLA client builder API does not support generating custom call instructions
744   // with side effect.
745   if (op.has_side_effect() || op.getNumResults() != 1) return failure();
746   Value result = op.getResult(0);
747   auto& value_map = *ctx.values;
748   value_map[result] = xla::CustomCall(
749       ctx.builder, std::string(op.call_target_name()), GetTuple(op.args(), ctx),
750       xla::TypeToShape(result.getType()), std::string(op.backend_config()));
751   return success();
752 }
753 
ExportXlaOp(DequantizeOp op,OpLoweringContext ctx)754 LogicalResult ExportXlaOp(DequantizeOp op, OpLoweringContext ctx) {
755   xla::QuantizedRange range(ConvertAPFloat(op.min_range()),
756                             ConvertAPFloat(op.max_range()));
757   auto& value_map = *ctx.values;
758   xla::XlaOp input;
759   if (failed(GetXlaOp(op.input(), value_map, &input, op))) return failure();
760 
761   auto casted = xla::ConvertElementType(input, xla::U32);
762   if (op.is_16bits()) {
763     value_map[op] = xla::Dequantize<uint16>(
764         casted, range, ConvertStringRef(op.mode()), op.transpose_output());
765   } else {
766     value_map[op] = xla::Dequantize<uint8>(
767         casted, range, ConvertStringRef(op.mode()), op.transpose_output());
768   }
769   return success();
770 }
771 
ExportXlaOp(InfeedOp op,OpLoweringContext ctx)772 LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) {
773   auto& value_map = *ctx.values;
774   xla::XlaOp token;
775   if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
776 
777   // The shape argument expected by the xla client API is the type of the first
778   // element in the result tuple.
779   auto result_type = op.getType().cast<mlir::TupleType>().getType(0);
780   value_map[op] = xla::InfeedWithToken(token, xla::TypeToShape(result_type),
781                                        std::string(op.infeed_config()));
782   return success();
783 }
784 
ExportXlaOp(IotaOp op,OpLoweringContext ctx)785 LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
786   auto& value_map = *ctx.values;
787   value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()),
788                             op.iota_dimension());
789   return success();
790 }
791 
ExportXlaOp(MapOp op,OpLoweringContext ctx)792 LogicalResult ExportXlaOp(MapOp op, OpLoweringContext ctx) {
793   auto& value_map = *ctx.values;
794   xla::XlaComputation computation;
795   if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
796                                                      &computation))) {
797     return failure();
798   }
799   value_map[op] = xla::Map(ctx.builder, GetTuple(op.operands(), ctx),
800                            computation, Convert_dimensions(op.dimensions()));
801   return success();
802 }
803 
ExportXlaOp(OutfeedOp op,OpLoweringContext ctx)804 LogicalResult ExportXlaOp(OutfeedOp op, OpLoweringContext ctx) {
805   auto& value_map = *ctx.values;
806   xla::XlaOp operand, token;
807   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
808   if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
809 
810   value_map[op] = xla::OutfeedWithToken(
811       operand, token, xla::TypeToShape(op.operand().getType()),
812       std::string(op.outfeed_config()));
813   return success();
814 }
815 
ExportXlaOp(PadOp op,OpLoweringContext ctx)816 LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) {
817   auto& value_map = *ctx.values;
818   xla::PaddingConfig padding_config;
819   auto edge_padding_low = ConvertDenseIntAttr(op.edge_padding_low());
820   auto edge_padding_high = ConvertDenseIntAttr(op.edge_padding_high());
821   auto interior_padding = ConvertDenseIntAttr(op.interior_padding());
822   for (xla::int64 i = 0, end = edge_padding_low.size(); i < end; ++i) {
823     auto* dims = padding_config.add_dimensions();
824     dims->set_edge_padding_low(edge_padding_low[i]);
825     dims->set_edge_padding_high(edge_padding_high[i]);
826     dims->set_interior_padding(interior_padding[i]);
827   }
828   xla::XlaOp operand, padding_value;
829   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
830   if (failed(GetXlaOp(op.padding_value(), value_map, &padding_value, op)))
831     return failure();
832 
833   value_map[op] = xla::Pad(operand, padding_value, padding_config);
834   return success();
835 }
836 
ExportXlaOp(RecvOp op,OpLoweringContext ctx)837 LogicalResult ExportXlaOp(RecvOp op, OpLoweringContext ctx) {
838   auto& value_map = *ctx.values;
839   auto result_type = op.getType().cast<mlir::TupleType>().getType(0);
840   xla::XlaOp token;
841   if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
842 
843   if (op.is_host_transfer()) {
844     value_map[op] = xla::RecvFromHost(token, xla::TypeToShape(result_type),
845                                       Convert_channel_handle(op.channel_id()));
846     return success();
847   }
848   value_map[op] = xla::RecvWithToken(token, xla::TypeToShape(result_type),
849                                      Convert_channel_handle(op.channel_id()));
850   return success();
851 }
852 
ExportXlaOp(ReduceOp op,OpLoweringContext ctx)853 LogicalResult ExportXlaOp(ReduceOp op, OpLoweringContext ctx) {
854   auto& value_map = *ctx.values;
855   xla::XlaComputation body;
856   if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) {
857     return failure();
858   }
859   xla::XlaOp result =
860       xla::Reduce(ctx.builder, GetTuple(op.operands(), ctx),
861                   GetTuple(op.init_values(), ctx), body,
862                   Convert_broadcast_dimensions(op.dimensions()));
863   if (op.getNumResults() == 1) {
864     value_map[op.getResult(0)] = result;
865   } else {
866     for (auto item : llvm::enumerate(op.getResults())) {
867       value_map[item.value()] = xla::GetTupleElement(result, item.index());
868     }
869   }
870   return success();
871 }
872 
ExportXlaOp(ReduceWindowOp op,OpLoweringContext ctx)873 LogicalResult ExportXlaOp(ReduceWindowOp op, OpLoweringContext ctx) {
874   auto& value_map = *ctx.values;
875   xla::XlaComputation body;
876   if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) {
877     return failure();
878   }
879   xla::XlaOp operand, init_value;
880   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
881   if (failed(GetXlaOp(op.init_value(), value_map, &init_value, op)))
882     return failure();
883 
884   value_map[op] = xla::ReduceWindowWithGeneralPadding(
885       operand, init_value, body, ConvertDenseIntAttr(op.window_dimensions()),
886       ConvertDenseIntAttr(op.window_strides()),
887       ConvertDenseIntAttr(op.base_dilations()),
888       ConvertDenseIntAttr(op.window_dilations()),
889       Convert_padding(op.padding()));
890   return success();
891 }
892 
ExportXlaOp(ReshapeOp op,OpLoweringContext ctx)893 LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) {
894   auto& value_map = *ctx.values;
895   xla::XlaOp operand;
896   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
897 
898   value_map[op] =
899       xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions());
900   return success();
901 }
902 
ExportXlaOp(ReturnOp op,OpLoweringContext ctx)903 LogicalResult ExportXlaOp(ReturnOp op, OpLoweringContext ctx) {
904   // Failure on purpose because `mhlo::ReturnOp` will be handled by
905   // special purpose logic in `ConvertToHloModule::Lower`.
906   return failure();
907 }
908 
ExportXlaOp(RngBitGeneratorOp op,OpLoweringContext ctx)909 LogicalResult ExportXlaOp(RngBitGeneratorOp op, OpLoweringContext ctx) {
910   auto& value_map = *ctx.values;
911   auto result = op.getResult();
912   auto xla_arg_1 = value_map[*op.getODSOperands(0).begin()];
913   auto xla_result = xla::RngBitGenerator(
914       static_cast<xla::RandomAlgorithm>(op.rng_algorithm()), Unwrap(xla_arg_1),
915       xla::TypeToShape(result.getType()).tuple_shapes(1));
916   value_map[result] = xla_result;
917   return mlir::success();
918 }
919 
ExportXlaOp(RngNormalOp op,OpLoweringContext ctx)920 LogicalResult ExportXlaOp(RngNormalOp op, OpLoweringContext ctx) {
921   auto& value_map = *ctx.values;
922   xla::XlaOp mu, sigma;
923   if (failed(GetXlaOp(op.mu(), value_map, &mu, op))) return failure();
924   if (failed(GetXlaOp(op.sigma(), value_map, &sigma, op))) return failure();
925 
926   value_map[op] = xla::RngNormal(mu, sigma, xla::TypeToShape(op.getType()));
927   return success();
928 }
929 
ExportXlaOp(RngUniformOp op,OpLoweringContext ctx)930 LogicalResult ExportXlaOp(RngUniformOp op, OpLoweringContext ctx) {
931   auto& value_map = *ctx.values;
932   xla::XlaOp a, b;
933   if (failed(GetXlaOp(op.a(), value_map, &a, op))) return failure();
934   if (failed(GetXlaOp(op.b(), value_map, &b, op))) return failure();
935 
936   value_map[op] = xla::RngUniform(a, b, xla::TypeToShape(op.getType()));
937   return success();
938 }
939 
ExportXlaOp(ScatterOp op,OpLoweringContext ctx)940 LogicalResult ExportXlaOp(ScatterOp op, OpLoweringContext ctx) {
941   auto& value_map = *ctx.values;
942   xla::XlaComputation update_computation;
943   if (failed(ctx.converter->LowerRegionAsComputation(&op.update_computation(),
944                                                      &update_computation))) {
945     return failure();
946   }
947   xla::ScatterDimensionNumbers dimension_numbers =
948       Convert_scatter_dimension_numbers(op.scatter_dimension_numbers());
949   xla::XlaOp operand, scatter_indices, updates;
950   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
951   if (failed(GetXlaOp(op.scatter_indices(), value_map, &scatter_indices, op)))
952     return failure();
953   if (failed(GetXlaOp(op.updates(), value_map, &updates, op))) return failure();
954 
955   value_map[op] = xla::Scatter(operand, scatter_indices, updates,
956                                update_computation, dimension_numbers,
957                                op.indices_are_sorted(), op.unique_indices());
958   return success();
959 }
960 
ExportXlaOp(SelectAndScatterOp op,OpLoweringContext ctx)961 LogicalResult ExportXlaOp(SelectAndScatterOp op, OpLoweringContext ctx) {
962   auto& value_map = *ctx.values;
963   xla::XlaComputation select;
964   xla::XlaComputation scatter;
965   if (failed(ctx.converter->LowerRegionAsComputation(&op.select(), &select)) ||
966       failed(
967           ctx.converter->LowerRegionAsComputation(&op.scatter(), &scatter))) {
968     return failure();
969   }
970   xla::XlaOp operand, source, init_value;
971   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
972   if (failed(GetXlaOp(op.source(), value_map, &source, op))) return failure();
973   if (failed(GetXlaOp(op.init_value(), value_map, &init_value, op)))
974     return failure();
975 
976   value_map[op] = xla::SelectAndScatterWithGeneralPadding(
977       operand, select, ConvertDenseIntAttr(op.window_dimensions()),
978       ConvertDenseIntAttr(op.window_strides()), Convert_padding(op.padding()),
979       source, init_value, scatter);
980   return success();
981 }
982 
ExportXlaOp(SendOp op,OpLoweringContext ctx)983 LogicalResult ExportXlaOp(SendOp op, OpLoweringContext ctx) {
984   auto& value_map = *ctx.values;
985   xla::XlaOp operand, token;
986   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
987   if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
988 
989   if (op.is_host_transfer()) {
990     value_map[op] = xla::SendToHost(operand, token,
991                                     xla::TypeToShape(op.operand().getType()),
992                                     Convert_channel_handle(op.channel_id()));
993     return success();
994   }
995   value_map[op] = xla::SendWithToken(operand, token,
996                                      Convert_channel_handle(op.channel_id()));
997   return success();
998 }
999 
ExportXlaOp(SliceOp op,OpLoweringContext ctx)1000 LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) {
1001   return failure();
1002 }
1003 
ExportXlaOp(SortOp op,OpLoweringContext ctx)1004 LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
1005   xla::XlaComputation comparator;
1006   if (failed(ctx.converter->LowerRegionAsComputation(&op.comparator(),
1007                                                      &comparator)))
1008     return failure();
1009 
1010   auto sorted = xla::Sort(GetTuple(op.operands(), ctx), comparator,
1011                           op.dimension(), op.is_stable());
1012 
1013   auto& value_map = *ctx.values;
1014   auto shape_or = sorted.builder()->GetShape(sorted);
1015   if (!shape_or.ok()) {
1016     return op.emitError(shape_or.status().ToString());
1017   }
1018 
1019   xla::Shape& shape = shape_or.ValueOrDie();
1020   if (!shape.IsTuple()) {
1021     value_map[op.getResult(0)] = sorted;
1022     return success();
1023   }
1024 
1025   // MLIR's sort supports multiple returns, untuple all the results of XLA's.
1026   for (auto it : llvm::enumerate(op.getResults())) {
1027     value_map[it.value()] = xla::GetTupleElement(sorted, it.index());
1028   }
1029   return success();
1030 }
1031 
ExportXlaOp(TraceOp op,OpLoweringContext ctx)1032 LogicalResult ExportXlaOp(TraceOp op, OpLoweringContext ctx) {
1033   auto& value_map = *ctx.values;
1034   xla::XlaOp operand;
1035   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1036   xla::Trace(std::string(op.tag()), operand);
1037   return success();
1038 }
1039 
ExportXlaOp(UnaryEinsumOp op,OpLoweringContext ctx)1040 LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) {
1041   // Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two
1042   // operands.
1043   return failure();
1044 }
1045 
ExportXlaOp(WhileOp op,OpLoweringContext ctx)1046 LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
1047   xla::XlaComputation condition;
1048   xla::XlaComputation body;
1049   auto& value_map = *ctx.values;
1050   if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body)) ||
1051       failed(ctx.converter->LowerRegionAsComputation(&op.cond(), &condition))) {
1052     return failure();
1053   }
1054 
1055   xla::XlaOp operand;
1056   if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op)))
1057     return failure();
1058   value_map[op] = xla::While(condition, body, operand);
1059   return success();
1060 }
1061 
ExportXlaOp(FusionOp op,OpLoweringContext ctx)1062 LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
1063   if (!op.fusion_kind()) {
1064     op.emitOpError() << "requires fusion kind for HLO translation";
1065     return failure();
1066   }
1067 
1068   xla::XlaComputation fused_computation;
1069   if (failed(ctx.converter->LowerRegionAsComputation(&op.fused_computation(),
1070                                                      &fused_computation)))
1071     return failure();
1072 
1073   auto& values = *ctx.values;
1074   llvm::SmallVector<xla::XlaOp, 4> operands;
1075   for (auto operand : op.operands()) operands.push_back(values[operand]);
1076 
1077   xla::XlaOp fusion = xla::internal::XlaBuilderFriend::BuildFusion(
1078       ctx.builder, operands,
1079       absl::string_view(op.fusion_kind()->data(), op.fusion_kind()->size()),
1080       fused_computation);
1081   if (op.getNumResults() == 1) {
1082     values[op.getResult(0)] = fusion;
1083   } else {
1084     for (auto item : llvm::enumerate(op.getResults())) {
1085       values[item.value()] = xla::GetTupleElement(fusion, item.index());
1086     }
1087   }
1088   return success();
1089 }
1090 
ExportXlaOp(BitcastOp op,OpLoweringContext ctx)1091 LogicalResult ExportXlaOp(BitcastOp op, OpLoweringContext ctx) {
1092   auto& value_map = *ctx.values;
1093   xla::XlaOp operand;
1094   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1095   value_map[op] = xla::internal::XlaBuilderFriend::BuildBitcast(
1096       ctx.builder, operand, xla::TypeToShape(op.getType()));
1097   return success();
1098 }
1099 
1100 }  // namespace
1101 }  // namespace mhlo
1102 }  // namespace mlir
1103 
1104 #include "tensorflow/compiler/mlir/xla/operator_writers.inc"
1105 
1106 namespace mlir {
1107 namespace {
1108 
CreateArrayLiteralFromAttr(ElementsAttr attr,xla::Layout layout)1109 StatusOr<xla::Literal> CreateArrayLiteralFromAttr(ElementsAttr attr,
1110                                                   xla::Layout layout) {
1111   if (attr.isa<OpaqueElementsAttr>())
1112     return tensorflow::errors::Unimplemented(
1113         "Opaque elements attr not supported");
1114 
1115   xla::Shape shape = xla::TypeToShape(attr.getType());
1116 
1117 #define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type)                         \
1118   case xla_type: {                                                           \
1119     xla::Array<cpp_type> source_data(shape.dimensions());                    \
1120     source_data.SetValues(attr.getValues<cpp_type>());                       \
1121     return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); \
1122   }
1123 
1124   switch (shape.element_type()) {
1125     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::PRED, bool)
1126     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F32, float)
1127     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F64, double)
1128     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S8, int8)
1129     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S16, int16)
1130     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S32, int32)
1131     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S64, int64)
1132     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U8, uint8)
1133     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16)
1134     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32)
1135     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U64, uint64)
1136     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C64, std::complex<float>)
1137     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C128, std::complex<double>)
1138     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F16, Eigen::half)
1139     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::BF16, Eigen::bfloat16)
1140     default:
1141       return tensorflow::errors::Internal(absl::StrCat(
1142           "Unsupported type: ", xla::PrimitiveType_Name(shape.element_type())));
1143   }
1144 #undef ELEMENTS_ATTR_TO_LITERAL
1145 }
1146 
ExtractLayout(mlir::Operation * op,int rank)1147 xla::Layout ExtractLayout(mlir::Operation* op, int rank) {
1148   if (auto attr = GetLayoutFromMlirHlo(op)) {
1149     llvm::SmallVector<int64, 4> minor_to_major;
1150     DCHECK_EQ(rank, attr.size());
1151     minor_to_major.reserve(attr.size());
1152     for (const llvm::APInt& i : attr) {
1153       minor_to_major.push_back(i.getZExtValue());
1154     }
1155     return xla::LayoutUtil::MakeLayout(minor_to_major);
1156   }
1157   return xla::LayoutUtil::MakeDescendingLayout(rank);
1158 }
1159 
ConvertLayout(mlir::Operation * op,const mlir::ArrayAttr & layout,xla::ShapeProto * shape)1160 LogicalResult ConvertLayout(mlir::Operation* op, const mlir::ArrayAttr& layout,
1161                             xla::ShapeProto* shape) {
1162   // In the case of tuples, ShapeProtos can be nested, and so can the mlir
1163   // attribute describing the layout. So recurse into the subshapes in both data
1164   // structures in parallel.
1165   if (shape->element_type() == xla::TUPLE) {
1166     auto subshapes = shape->mutable_tuple_shapes();
1167     if (layout.size() != subshapes->size()) {
1168       op->emitOpError() << "Expected layout of size " << layout.size()
1169                         << ", but found " << subshapes->size();
1170       return failure();
1171     }
1172     for (int i = 0; i < subshapes->size(); i++) {
1173       mlir::Attribute child = layout[i];
1174       if (child.isa<mlir::UnitAttr>()) {
1175         // ignore unit attributes, they are used only for tokens.
1176         continue;
1177       }
1178       mlir::ArrayAttr c = child.dyn_cast<mlir::ArrayAttr>();
1179       if (!c) {
1180         op->emitOpError() << "Type Error: Expected layout array attribute";
1181         return failure();
1182       }
1183       if (failed(ConvertLayout(op, c, subshapes->Mutable(i)))) {
1184         return failure();
1185       }
1186     }
1187   } else {
1188     int rank = shape->dimensions().size();
1189     if (rank) {
1190       if (layout.size() != rank) {
1191         return failure();  // pass error down
1192       }
1193       std::vector<int64> array(rank);
1194       for (int i = 0; i < rank; i++) {
1195         mlir::IntegerAttr attr = layout[i].dyn_cast<mlir::IntegerAttr>();
1196         if (!attr) {
1197           op->emitOpError() << "Type Error: Expected layout integer attribute";
1198           return failure();
1199         }
1200         array[i] = attr.getInt();
1201       }
1202       *shape->mutable_layout() = xla::LayoutUtil::MakeLayout(array).ToProto();
1203     }
1204   }
1205   return success();
1206 }
1207 
Lower(mlir::Operation * inst,bool is_entry_function,llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,xla::XlaBuilder * builder,ConvertToHloModule::ValueLoweringMap * value_lowering,xla::XlaOp * return_value)1208 LogicalResult ConvertToHloModule::Lower(
1209     mlir::Operation* inst, bool is_entry_function,
1210     llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
1211     xla::XlaBuilder* builder,
1212     ConvertToHloModule::ValueLoweringMap* value_lowering,
1213     xla::XlaOp* return_value) {
1214   *return_value = xla::XlaOp();
1215 
1216   // See MlirToHloConversionOptions for more about layouts.
1217   auto propagate_layouts = [this](mlir::Operation* inst,
1218                                   xla::XlaOp xla_op) -> mlir::LogicalResult {
1219     if (options_.propagate_layouts) {
1220       auto* shape = xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
1221                         ->mutable_shape();
1222       if (shape->tuple_shapes().empty())
1223         // TODO(kramm): merge this with ConvertLayout.
1224         *shape->mutable_layout() =
1225             ExtractLayout(inst, shape->dimensions().size()).ToProto();
1226     }
1227 
1228     // For infeed ops stemming back to InfeedDequeueTuple, respect the layout
1229     // attribute, and create the corresponding layout in hlo.
1230     if (isa<mhlo::InfeedOp>(inst)) {
1231       mlir::ArrayAttr layout =
1232           inst->getAttrOfType<mlir::ArrayAttr>(kLayoutAttr);
1233       if (layout) {
1234         xla::ShapeProto* shape =
1235             xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
1236                 ->mutable_shape();
1237 
1238         if (failed(ConvertLayout(inst, layout, shape))) {
1239           return failure();
1240         }
1241       }
1242     }
1243     return success();
1244   };
1245 
1246   if (succeeded(ExportXlaOperator(inst, {value_lowering, this, builder}))) {
1247     if (inst->getNumResults() == 1) {
1248       auto iter = value_lowering->find(inst->getResult(0));
1249       if (iter == value_lowering->end()) {
1250         inst->emitOpError(
1251             "inst has a result, but it's not found in value_lowering");
1252         return failure();
1253       }
1254       if (failed(propagate_layouts(inst, iter->second))) {
1255         return failure();
1256       }
1257     }
1258     return success();
1259   }
1260 
1261   auto& value_map = *value_lowering;
1262   ElementsAttr const_attr;
1263 
1264   if (auto call_op = dyn_cast<mlir::CallOp>(inst)) {
1265     return LowerFunctionCall(call_op, builder, &value_map);
1266   }
1267 
1268   if (auto op = dyn_cast<mlir::tensor::CastOp>(inst)) {
1269     Value operand = op.getOperand();
1270     auto ty = operand.getType().dyn_cast<ShapedType>();
1271     // If this was a cast from a static shaped tensors, then it is a noop for
1272     // export to HLO and we can use the operand.
1273     if (!ty || !ty.hasStaticShape()) {
1274       inst->emitOpError()
1275           << "requires static shaped operand for HLO translation";
1276       return failure();
1277     }
1278 
1279     xla::XlaOp xla_operand;
1280     if (failed(GetXlaOp(operand, value_map, &xla_operand, op)))
1281       return failure();
1282     value_map[op.getResult()] = xla_operand;
1283     if (failed(propagate_layouts(inst, xla_operand))) {
1284       return failure();
1285     }
1286     return success();
1287   }
1288 
1289   if (matchPattern(inst, m_Constant(&const_attr))) {
1290     xla::Layout layout;
1291     layout = ExtractLayout(inst, const_attr.getType().getRank());
1292     auto literal_or = CreateArrayLiteralFromAttr(const_attr, layout);
1293     if (!literal_or.ok())
1294       return inst->emitError(literal_or.status().ToString());
1295     auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie());
1296     value_map[inst->getResult(0)] = constant;
1297 
1298     return success();
1299   }
1300 
1301   if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
1302     // Construct the return value for the function. If there is a single value
1303     // returned, then return it directly, else create a tuple and return.
1304     unsigned num_return_values = inst->getNumOperands();
1305     if ((return_tuple_ && is_entry_function) || num_return_values != 1) {
1306       const bool has_ret_shardings =
1307           !ret_shardings.empty() && AllOptionalShardingsAreSet(ret_shardings);
1308 
1309       std::vector<xla::XlaOp> returns(num_return_values);
1310       for (OpOperand& ret : inst->getOpOperands()) {
1311         unsigned index = ret.getOperandNumber();
1312         xla::XlaOp operand;
1313         if (failed(GetXlaOp(ret.get(), value_map, &operand, inst)))
1314           return failure();
1315 
1316         returns[index] = operand;
1317         if (!is_entry_function || !has_ret_shardings) continue;
1318 
1319         xla::Shape return_shape = xla::TypeToShape(ret.get().getType());
1320         StatusOr<xla::XlaOp> reshape =
1321             tensorflow::ReshapeWithCorrectRepresentationAndSharding(
1322                 builder, returns[index], return_shape, shape_representation_fn_,
1323                 ret_shardings[index], /*fast_mem=*/false);
1324         if (!reshape.ok())
1325           return inst->emitError() << reshape.status().error_message();
1326 
1327         returns[index] = reshape.ValueOrDie();
1328       }
1329 
1330       if (has_ret_shardings) {
1331         xla::OpSharding sharding;
1332         sharding.set_type(xla::OpSharding::TUPLE);
1333         for (auto& ret_sharding : ret_shardings)
1334           *sharding.add_tuple_shardings() = *ret_sharding;
1335 
1336         builder->SetSharding(sharding);
1337       }
1338 
1339       *return_value = xla::Tuple(builder, returns);
1340       builder->ClearSharding();
1341     } else if (num_return_values == 1) {
1342       xla::XlaOp operand;
1343       if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst)))
1344         return failure();
1345 
1346       *return_value = operand;
1347     }
1348 
1349     return success();
1350   }
1351 
1352   inst->emitOpError() << "can't be translated to XLA HLO";
1353   return failure();
1354 }
1355 
LowerFunctionCall(mlir::CallOp call_op,xla::XlaBuilder * builder,ConvertToHloModule::ValueLoweringMap * value_lowering)1356 LogicalResult ConvertToHloModule::LowerFunctionCall(
1357     mlir::CallOp call_op, xla::XlaBuilder* builder,
1358     ConvertToHloModule::ValueLoweringMap* value_lowering) {
1359   auto& value_map = *value_lowering;
1360   mlir::FuncOp callee = module_.lookupSymbol<mlir::FuncOp>(call_op.callee());
1361   if (failed(RunOnFunction(callee))) return failure();
1362   std::vector<xla::XlaOp> operands;
1363   for (auto operand : call_op.getOperands()) {
1364     xla::XlaOp xla_operand;
1365     if (failed(GetXlaOp(operand, value_map, &xla_operand, call_op)))
1366       return failure();
1367     operands.push_back(xla_operand);
1368   }
1369   // Each call to xla::Call would insert a copy of the computation to
1370   // the HLO. Thus each callsite would have a unique callee in the
1371   // exported HLO. HLO syntactically does not require all calls to have unique
1372   // callees, but eventually before lowering call graph is "flattened" to
1373   // make that true. This is done before lowering because buffer assignment
1374   // needs this invariant.
1375   xla::XlaOp call_result =
1376       xla::Call(builder, lowered_computation_[callee], operands);
1377   // Use GetTupleElement for multiple outputs
1378   unsigned num_results = call_op.getNumResults();
1379   if (num_results > 1) {
1380     for (unsigned i = 0; i != num_results; ++i) {
1381       value_map[call_op.getResult(i)] = xla::GetTupleElement(call_result, i);
1382     }
1383   } else if (num_results == 1) {
1384     value_map[call_op.getResult(0)] = call_result;
1385   }
1386   return success();
1387 }
1388 
RunOnFunction(mlir::FuncOp f)1389 LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
1390   if (lowered_computation_.count(f)) return success();
1391   if (!llvm::hasSingleElement(f)) {
1392     return f.emitError("only single block Function supported");
1393   }
1394 
1395   // Create a sub-builder if this is not the main function.
1396   std::unique_ptr<xla::XlaBuilder> builder_up;
1397   bool entry_function = f.getName() == "main";
1398   if (!entry_function)
1399     builder_up = module_builder_.CreateSubBuilder(f.getName().str());
1400   auto& builder = entry_function ? module_builder_ : *builder_up;
1401 
1402   xla::XlaComputation computation;
1403   std::vector<bool> entry_args_same_across_replicas;
1404   llvm::SmallVector<absl::optional<xla::OpSharding>, 4> arg_shardings;
1405   llvm::SmallVector<absl::optional<xla::OpSharding>, 4> ret_shardings;
1406   if (entry_function) {
1407     bool any_arg_replicated = false;
1408     entry_args_same_across_replicas.reserve(f.getNumArguments());
1409     for (int64_t i = 0; i < f.getNumArguments(); ++i) {
1410       auto attr = f.getArgAttrOfType<mlir::UnitAttr>(i, kRepicationAttr);
1411       entry_args_same_across_replicas.push_back(attr != nullptr);
1412       any_arg_replicated |= entry_args_same_across_replicas.back();
1413       // Pass the alias info to the builder so that it will build the alias info
1414       // into the resulting HloModule.
1415       auto aliasing_output =
1416           f.getArgAttrOfType<mlir::IntegerAttr>(i, "tf.aliasing_output");
1417       if (!aliasing_output) continue;
1418       if (use_tuple_args_) {
1419         builder.SetUpAlias(/*output_index=*/{aliasing_output.getInt()},
1420                            /*param_number=*/0, /*param_index=*/{i});
1421       } else {
1422         builder.SetUpAlias(/*output_index=*/{aliasing_output.getInt()},
1423                            /*param_number=*/i, /*param_index=*/{});
1424       }
1425     }
1426     // Do not populate this field when nothing is replicated, since empty field
1427     // means no replication. This avoids the need for unrelated tests to handle
1428     // this field.
1429     if (!any_arg_replicated) entry_args_same_across_replicas.clear();
1430 
1431     ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings);
1432   }
1433   if (failed(LowerBasicBlockAsFunction(
1434           &f.front(), &builder, entry_function, entry_args_same_across_replicas,
1435           arg_shardings, ret_shardings, &computation))) {
1436     return failure();
1437   }
1438   lowered_computation_[f] = std::move(computation);
1439   return success();
1440 }
1441 
SetEntryTupleShapesAndLeafReplication(Block * block,const std::vector<bool> & entry_args_same_across_replicas,llvm::SmallVectorImpl<xla::Shape> * arg_shapes,std::vector<bool> * leaf_replication)1442 LogicalResult ConvertToHloModule::SetEntryTupleShapesAndLeafReplication(
1443     Block* block, const std::vector<bool>& entry_args_same_across_replicas,
1444     llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
1445     std::vector<bool>* leaf_replication) {
1446   arg_shapes->reserve(block->getNumArguments());
1447   leaf_replication->reserve(block->getNumArguments());
1448   for (BlockArgument& arg : block->getArguments()) {
1449     arg_shapes->push_back(xla::TypeToShape(arg.getType()));
1450     xla::Shape& arg_shape = arg_shapes->back();
1451     tensorflow::TensorShape arg_tensor_shape;
1452     auto status =
1453         tensorflow::XLAShapeToTensorShape(arg_shape, &arg_tensor_shape);
1454     if (!status.ok())
1455       return block->getParentOp()->emitError() << status.error_message();
1456 
1457     tensorflow::DataType dtype;
1458     status = tensorflow::ConvertToDataType(arg.getType(), &dtype);
1459     if (!status.ok())
1460       return block->getParentOp()->emitError() << status.error_message();
1461 
1462     auto arg_shape_status = shape_representation_fn_(arg_tensor_shape, dtype,
1463                                                      /*use_fast_memory=*/false);
1464     if (!arg_shape_status.ok())
1465       return block->getParentOp()->emitError()
1466              << arg_shape_status.status().error_message();
1467 
1468     arg_shape = std::move(arg_shape_status.ValueOrDie());
1469 
1470     if (entry_args_same_across_replicas.empty()) continue;
1471     for (int i = 0, e = xla::ShapeUtil::GetLeafCount(arg_shape); i < e; ++i)
1472       leaf_replication->push_back(
1473           entry_args_same_across_replicas[arg.getArgNumber()]);
1474   }
1475 
1476   return success();
1477 }
1478 
SetEntryTupleShardings(Block * block,xla::XlaBuilder * builder,llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,llvm::SmallVectorImpl<xla::Shape> * arg_shapes)1479 LogicalResult ConvertToHloModule::SetEntryTupleShardings(
1480     Block* block, xla::XlaBuilder* builder,
1481     llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
1482     llvm::SmallVectorImpl<xla::Shape>* arg_shapes) {
1483   if (!arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings)) {
1484     xla::OpSharding sharding;
1485     sharding.set_type(xla::OpSharding::TUPLE);
1486     for (auto arg_sharding : llvm::enumerate(arg_shardings)) {
1487       auto hlo_sharding = xla::HloSharding::FromProto(*arg_sharding.value());
1488       if (!hlo_sharding.ok())
1489         return block->getParentOp()->emitError()
1490                << hlo_sharding.status().error_message();
1491 
1492       auto status = tensorflow::RewriteLayoutWithShardedShape(
1493           hlo_sharding.ValueOrDie(), /*use_fast_memory=*/false,
1494           shape_representation_fn_, &(*arg_shapes)[arg_sharding.index()]);
1495       if (!status.ok())
1496         return block->getParentOp()->emitError() << status.error_message();
1497 
1498       *sharding.add_tuple_shardings() = *arg_sharding.value();
1499     }
1500 
1501     builder->SetSharding(sharding);
1502   }
1503 
1504   return success();
1505 }
1506 
LowerBasicBlockAsFunction(Block * block,xla::XlaBuilder * builder,bool is_entry_function,const std::vector<bool> & entry_args_same_across_replicas,llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,xla::XlaComputation * result)1507 LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
1508     Block* block, xla::XlaBuilder* builder, bool is_entry_function,
1509     const std::vector<bool>& entry_args_same_across_replicas,
1510     llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
1511     llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
1512     xla::XlaComputation* result) {
1513   // Mapping from the Value to lowered XlaOp.
1514   ValueLoweringMap lowering;
1515 
1516   // If using tuples as input, then there is only one input parameter that is a
1517   // tuple.
1518   if (is_entry_function && use_tuple_args_) {
1519     llvm::SmallVector<xla::Shape, 4> arg_shapes;
1520     std::vector<bool> leaf_replication;
1521     if (failed(SetEntryTupleShapesAndLeafReplication(
1522             block, entry_args_same_across_replicas, &arg_shapes,
1523             &leaf_replication)))
1524       return failure();
1525 
1526     if (failed(
1527             SetEntryTupleShardings(block, builder, arg_shardings, &arg_shapes)))
1528       return failure();
1529 
1530     xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
1531     auto tuple =
1532         xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication);
1533     builder->ClearSharding();
1534 
1535     bool set_tuple_element_sharding =
1536         !arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings);
1537     for (BlockArgument& arg : block->getArguments()) {
1538       if (set_tuple_element_sharding)
1539         builder->SetSharding(*arg_shardings[arg.getArgNumber()]);
1540       lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber());
1541     }
1542     builder->ClearSharding();
1543   } else {
1544     for (BlockArgument& arg : block->getArguments()) {
1545       auto num = arg.getArgNumber();
1546       xla::Shape shape = xla::TypeToShape(arg.getType());
1547       if (entry_args_same_across_replicas.empty()) {
1548         lowering[arg] =
1549             xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num));
1550       } else {
1551         lowering[arg] = xla::Parameter(
1552             builder, num, shape, absl::StrCat("Arg_", num),
1553             std::vector<bool>(entry_args_same_across_replicas[num],
1554                               xla::ShapeUtil::GetLeafCount(shape)));
1555       }
1556     }
1557   }
1558 
1559   xla::XlaOp return_value;
1560   for (auto& inst : *block)
1561     if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
1562                      &lowering, &return_value)))
1563       return failure();
1564 
1565   // Build the XlaComputation and check for failures.
1566   auto computation_or =
1567       return_value.valid() ? builder->Build(return_value) : builder->Build();
1568   if (!computation_or.ok()) {
1569     block->back().emitError(
1570         llvm::Twine(computation_or.status().error_message()));
1571     return failure();
1572   }
1573   *result = std::move(computation_or.ValueOrDie());
1574   return success();
1575 }
1576 
LowerRegionAsComputation(mlir::Region * region,xla::XlaComputation * func)1577 LogicalResult ConvertToHloModule::LowerRegionAsComputation(
1578     mlir::Region* region, xla::XlaComputation* func) {
1579   std::unique_ptr<xla::XlaBuilder> builder =
1580       module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++));
1581   return LowerBasicBlockAsFunction(
1582       &region->front(), builder.get(),
1583       /*is_entry_function=*/false, /*entry_args_same_across_replicas=*/{},
1584       /*arg_shardings=*/{}, /*ret_shardings=*/{}, func);
1585 }
1586 
PaddingMapBadArrayAttrMsg(llvm::StringRef attr_name,int index)1587 std::string PaddingMapBadArrayAttrMsg(llvm::StringRef attr_name, int index) {
1588   return llvm::formatv(
1589              "requires '{0}' array attribute in '{1}' dict at arg {2}",
1590              attr_name, kPaddingMapAttr, index)
1591       .str();
1592 }
1593 
PaddingMapMismatchedArraySizeMsg(int arg_index,int shape_indices_size,int padding_arg_indices_size)1594 std::string PaddingMapMismatchedArraySizeMsg(int arg_index,
1595                                              int shape_indices_size,
1596                                              int padding_arg_indices_size) {
1597   return llvm::formatv(
1598              "requires '{0}' and '{1}' array attributes in '{2}' dic at arg "
1599              "{3} to be of the same size, got sizes {4} and {5}",
1600              kShapeIndicesAttr, kPaddingArgIndicesAttr, kPaddingMapAttr,
1601              arg_index, shape_indices_size, padding_arg_indices_size)
1602       .str();
1603 }
1604 
PaddingMapBadIntAttrMsg(llvm::StringRef attr_name,int arg_index,int element_index)1605 std::string PaddingMapBadIntAttrMsg(llvm::StringRef attr_name, int arg_index,
1606                                     int element_index) {
1607   return llvm::formatv(
1608              "requires element {0} in '{1}' array of '{2}' dict at arg {3} "
1609              "to be an int attribute",
1610              element_index, attr_name, kPaddingMapAttr, arg_index)
1611       .str();
1612 }
1613 
PaddingMapBadIndexMsg(llvm::StringRef attr_name,int arg_index,int element_index,int max,int32_t value)1614 std::string PaddingMapBadIndexMsg(llvm::StringRef attr_name, int arg_index,
1615                                   int element_index, int max, int32_t value) {
1616   return llvm::formatv(
1617              "requires element {0} in '{1}' array of '{2}' dict at arg {3} "
1618              "to be in range [0, {4}), got {5}",
1619              element_index, attr_name, kPaddingMapAttr, arg_index, max, value)
1620       .str();
1621 }
1622 
PaddingMapNegativeShapeIndexMsg(int arg_index,int element_index,int32_t value)1623 std::string PaddingMapNegativeShapeIndexMsg(int arg_index, int element_index,
1624                                             int32_t value) {
1625   return llvm::formatv(
1626              "requires element {0} in '{1}' array of '{2}' dict at arg {3} to "
1627              "be non-negative, got {4}",
1628              element_index, kShapeIndicesAttr, kPaddingMapAttr, arg_index,
1629              value)
1630       .str();
1631 }
1632 
PaddingMapUniqueShapeIndexMsg(int arg_index,int element_index,int32_t value)1633 std::string PaddingMapUniqueShapeIndexMsg(int arg_index, int element_index,
1634                                           int32_t value) {
1635   return llvm::formatv(
1636              "requires elements in '{0}' array of '{1}' dict at arg {2} to be "
1637              "unique, got duplicate element {3} at index {4}",
1638              kShapeIndicesAttr, kPaddingMapAttr, arg_index, value,
1639              element_index)
1640       .str();
1641 }
1642 
AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto * binding,int arg_index,int32_t shape_index,int32_t padding_arg_index,bool use_tuple_args)1643 void AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto* binding,
1644                                      int arg_index, int32_t shape_index,
1645                                      int32_t padding_arg_index,
1646                                      bool use_tuple_args) {
1647   auto* entry = binding->add_entries();
1648   entry->set_target_param_dim_num(shape_index);
1649   if (use_tuple_args) {
1650     entry->set_target_param_num(0);
1651     entry->add_target_param_index(arg_index);
1652     entry->set_dynamic_param_num(0);
1653     entry->add_dynamic_param_index(padding_arg_index);
1654   } else {
1655     entry->set_target_param_num(arg_index);
1656     entry->set_dynamic_param_num(padding_arg_index);
1657   }
1658 }
1659 
1660 // Validates and populates dynamic parameter bindings from a module's entry
1661 // function `mhlo.padding_map` argument attributes to a `xla::HloModuleProto`
1662 // `DynamicParameterBindingProto`.
AddDynamicParameterBindings(mlir::ModuleOp module,xla::HloModuleProto * hlo_module_proto,bool use_tuple_args)1663 LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
1664                                           xla::HloModuleProto* hlo_module_proto,
1665                                           bool use_tuple_args) {
1666   auto entry_func = module.lookupSymbol<mlir::FuncOp>("main");
1667   if (!entry_func) return success();
1668 
1669   auto* dynamic_parameter_binding =
1670       hlo_module_proto->mutable_dynamic_parameter_binding();
1671   for (int i = 0, e = entry_func.getNumArguments(); i < e; ++i) {
1672     auto padding_map_attr = entry_func.getArgAttr(i, kPaddingMapAttr);
1673     if (!padding_map_attr) continue;
1674     auto padding_map = padding_map_attr.dyn_cast<DictionaryAttr>();
1675     if (!padding_map)
1676       return entry_func.emitError() << "requires '" << kPaddingMapAttr
1677                                     << "' dict attribute at arg " << i;
1678 
1679     auto shape_indices =
1680         padding_map.get(kShapeIndicesAttr).dyn_cast_or_null<ArrayAttr>();
1681     if (!shape_indices)
1682       return entry_func.emitError(
1683           PaddingMapBadArrayAttrMsg(kShapeIndicesAttr, i));
1684 
1685     auto padding_arg_indices =
1686         padding_map.get(kPaddingArgIndicesAttr).dyn_cast_or_null<ArrayAttr>();
1687     if (!padding_arg_indices)
1688       return entry_func.emitError(
1689           PaddingMapBadArrayAttrMsg(kPaddingArgIndicesAttr, i));
1690 
1691     if (shape_indices.size() != padding_arg_indices.size())
1692       return entry_func.emitError(PaddingMapMismatchedArraySizeMsg(
1693           i, shape_indices.size(), padding_arg_indices.size()));
1694 
1695     llvm::SmallDenseSet<int32_t, 4> used_shape_indices;
1696     auto arg_type =
1697         entry_func.getArgument(i).getType().dyn_cast<RankedTensorType>();
1698     for (auto shape_and_padding : llvm::enumerate(llvm::zip(
1699              shape_indices.getValue(), padding_arg_indices.getValue()))) {
1700       const int element_index = shape_and_padding.index();
1701       auto shape_index_attr =
1702           std::get<0>(shape_and_padding.value()).dyn_cast<IntegerAttr>();
1703       if (!shape_index_attr)
1704         return entry_func.emitError(
1705             PaddingMapBadIntAttrMsg(kShapeIndicesAttr, i, element_index));
1706 
1707       auto padding_arg_index_attr =
1708           std::get<1>(shape_and_padding.value()).dyn_cast<IntegerAttr>();
1709       if (!padding_arg_index_attr)
1710         return entry_func.emitError(
1711             PaddingMapBadIntAttrMsg(kPaddingArgIndicesAttr, i, element_index));
1712 
1713       const int32_t shape_index = shape_index_attr.getInt();
1714       if (arg_type && (shape_index < 0 || shape_index >= arg_type.getRank()))
1715         return entry_func.emitError(
1716             PaddingMapBadIndexMsg(kShapeIndicesAttr, i, element_index,
1717                                   arg_type.getRank(), shape_index));
1718       else if (shape_index < 0)
1719         return entry_func.emitError(
1720             PaddingMapNegativeShapeIndexMsg(i, element_index, shape_index));
1721 
1722       if (!used_shape_indices.insert(shape_index).second)
1723         return entry_func.emitError(
1724             PaddingMapUniqueShapeIndexMsg(i, element_index, shape_index));
1725 
1726       const int32_t padding_arg_index = padding_arg_index_attr.getInt();
1727       if (padding_arg_index < 0 || padding_arg_index >= e)
1728         return entry_func.emitError(PaddingMapBadIndexMsg(
1729             kPaddingArgIndicesAttr, i, element_index, e, padding_arg_index));
1730 
1731       Type padding_arg_type =
1732           entry_func.getArgument(padding_arg_index).getType();
1733       if (auto tensor_type = padding_arg_type.dyn_cast<RankedTensorType>())
1734         if (tensor_type.getRank() != 0)
1735           return entry_func.emitError()
1736                  << "requires arg " << padding_arg_index
1737                  << " to be a scalar for use as a dynamic parameter";
1738 
1739       if (!mlir::getElementTypeOrSelf(padding_arg_type).isSignlessInteger())
1740         return entry_func.emitError()
1741                << "requires arg " << padding_arg_index
1742                << " to be of an int type for use as a dynamic parameter";
1743 
1744       AddDynamicParameterBindingEntry(dynamic_parameter_binding, i, shape_index,
1745                                       padding_arg_index, use_tuple_args);
1746     }
1747   }
1748 
1749   return success();
1750 }
1751 
1752 }  // namespace
1753 
ConvertRegionToComputation(mlir::Region * region,xla::XlaComputation * func,MlirToHloConversionOptions options)1754 Status ConvertRegionToComputation(mlir::Region* region,
1755                                   xla::XlaComputation* func,
1756                                   MlirToHloConversionOptions options) {
1757   mlir::ModuleOp module;
1758   xla::XlaBuilder module_builder("main");
1759   ConvertToHloModule converter(module, module_builder, true, true, {}, options);
1760   if (failed(converter.LowerRegionAsComputation(region, func)))
1761     return tensorflow::errors::Internal(
1762         "failed to convert region to computation");
1763   return Status::OK();
1764 }
1765 
ConvertMlirHloToHlo(mlir::ModuleOp module,xla::HloProto * hlo_proto,bool use_tuple_args,bool return_tuple,const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,MlirToHloConversionOptions options)1766 Status ConvertMlirHloToHlo(
1767     mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args,
1768     bool return_tuple,
1769     const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
1770     MlirToHloConversionOptions options) {
1771   // Prepare for export to XLA HLO.
1772   mlir::PassManager pm(module.getContext());
1773   pm.addNestedPass<mlir::FuncOp>(mhlo::CreatePrepareForExport());
1774   if (failed(pm.run(module)))
1775     return tensorflow::errors::Internal("Unable to optimize for XLA export");
1776   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
1777   xla::XlaBuilder module_builder("main");
1778   ConvertToHloModule converter(module, module_builder, use_tuple_args,
1779                                return_tuple, shape_representation_fn, options);
1780   if (failed(converter.Run())) return diag_handler.ConsumeStatus();
1781   auto hlo_module = converter.ConsumeMainProto();
1782   hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
1783   if (failed(AddDynamicParameterBindings(
1784           module, hlo_proto->mutable_hlo_module(), use_tuple_args)))
1785     return diag_handler.ConsumeStatus();
1786   return Status::OK();
1787 }
1788 
BuildHloFromMlirHlo(mlir::Block & block,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,MlirToHloConversionOptions options)1789 Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
1790                            llvm::ArrayRef<xla::XlaOp> xla_params,
1791                            std::vector<xla::XlaOp>& returns,
1792                            MlirToHloConversionOptions options) {
1793   auto module = block.getParentOp()->getParentOfType<mlir::ModuleOp>();
1794   ConvertToHloModule converter(module, builder,
1795                                /*use_tuple_args=*/false, /*return_tuple=*/false,
1796                                /*shape_representation_fn=*/nullptr, options);
1797 
1798   ConvertToHloModule::ValueLoweringMap lowering;
1799   if (xla_params.size() != block.getArguments().size())
1800     return tensorflow::errors::Internal(
1801         "xla_params size != block arguments size");
1802   for (BlockArgument& arg : block.getArguments()) {
1803     auto num = arg.getArgNumber();
1804     lowering[arg] = xla_params[num];
1805   }
1806 
1807   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
1808   for (auto& inst : block) {
1809     if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
1810       returns.resize(inst.getNumOperands());
1811       for (OpOperand& ret : inst.getOpOperands()) {
1812         unsigned index = ret.getOperandNumber();
1813         xla::XlaOp operand;
1814         if (failed(GetXlaOp(ret.get(), lowering, &operand, &inst)))
1815           return diag_handler.ConsumeStatus();
1816         returns[index] = operand;
1817       }
1818     } else {
1819       xla::XlaOp return_value;
1820       if (failed(converter.Lower(&inst, /*is_entry_function=*/true,
1821                                  /*ret_shardings=*/{}, &builder, &lowering,
1822                                  &return_value)))
1823         return diag_handler.ConsumeStatus();
1824     }
1825   }
1826 
1827   return Status::OK();
1828 }
1829 
GetLayoutFromMlirHlo(mlir::Operation * op)1830 DenseIntElementsAttr GetLayoutFromMlirHlo(mlir::Operation* op) {
1831   return op->getAttrOfType<mlir::DenseIntElementsAttr>("minor_to_major");
1832 }
1833 
1834 }  // namespace mlir
1835