• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/MemoryBuffer.h"
28 #include "llvm/Support/SMLoc.h"
29 #include "llvm/Support/SourceMgr.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
32 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
33 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
34 #include "mlir/IR/Attributes.h"  // from @llvm-project
35 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
36 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
37 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
38 #include "mlir/IR/Location.h"  // from @llvm-project
39 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
40 #include "mlir/IR/Matchers.h"  // from @llvm-project
41 #include "mlir/IR/Operation.h"  // from @llvm-project
42 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
43 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
44 #include "mlir/Pass/Pass.h"  // from @llvm-project
45 #include "mlir/Pass/PassManager.h"  // from @llvm-project
46 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
48 #include "tensorflow/compiler/mlir/utils/name_utils.h"
49 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
50 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
51 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
52 #include "tensorflow/compiler/tf2xla/shape_util.h"
53 #include "tensorflow/compiler/xla/client/lib/matrix.h"
54 #include "tensorflow/compiler/xla/client/lib/quantize.h"
55 #include "tensorflow/compiler/xla/client/lib/slicing.h"
56 #include "tensorflow/compiler/xla/client/xla_builder.h"
57 #include "tensorflow/compiler/xla/comparison_util.h"
58 #include "tensorflow/compiler/xla/literal_util.h"
59 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
60 #include "tensorflow/compiler/xla/service/hlo_module.h"
61 #include "tensorflow/compiler/xla/shape_util.h"
62 #include "tensorflow/compiler/xla/status_macros.h"
63 #include "tensorflow/compiler/xla/xla_data.pb.h"
64 #include "tensorflow/core/framework/tensor_shape.h"
65 #include "tensorflow/core/framework/types.pb.h"
66 #include "tensorflow/core/platform/errors.h"
67 #include "tensorflow/stream_executor/lib/statusor.h"
68 
69 using ::stream_executor::port::StatusOr;
70 using ::tensorflow::int16;
71 using ::tensorflow::int32;
72 using ::tensorflow::int64;
73 using ::tensorflow::int8;
74 using ::tensorflow::uint16;
75 using ::tensorflow::uint32;
76 using ::tensorflow::uint64;
77 using ::tensorflow::uint8;
78 
79 constexpr char kShapeIndicesAttr[] = "shape_indices";
80 constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
81 constexpr char kShardingAttr[] = "mhlo.sharding";
82 constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
83 constexpr char kReplicationAttr[] = "mhlo.is_same_data_across_replicas";
84 
85 // Array attribute. Same shape as infeed result, but contains a
86 // minor_to_major array for every tensor.
87 constexpr char kLayoutAttr[] = "layout";
88 constexpr char kDefaultLayoutAttrName[] = "minor_to_major";
89 
90 // Passes through everything except for unique_ptr, on which it calls get().
91 // This exists to allow the generated code to call XLA functions that take a raw
92 // pointer. In particular, PrecisionConfig is passed to xla::Dot and xla::Conv
93 // as a pointer and there is otherwise no way to avoid a memory leak.
94 template <typename T>
Unwrap(T t)95 T Unwrap(T t) {
96   return t;
97 }
98 
99 template <typename T>
Unwrap(const std::unique_ptr<T> & t)100 T* Unwrap(const std::unique_ptr<T>& t) {
101   return t.get();
102 }
103 
GetXlaOp(mlir::Value val,const llvm::DenseMap<mlir::Value,xla::XlaOp> & val_map,xla::XlaOp * result,mlir::Operation * op)104 static mlir::LogicalResult GetXlaOp(
105     mlir::Value val, const llvm::DenseMap<mlir::Value, xla::XlaOp>& val_map,
106     xla::XlaOp* result, mlir::Operation* op) {
107   auto iter = val_map.find(val);
108   if (iter == val_map.end()) {
109     return op->emitOpError(
110         "requires all operands to be defined in the parent region for export");
111   }
112   *result = iter->second;
113   return mlir::success();
114 }
115 
116 // Convert APInt into an int.
117 // TODO(hpucha): This should be consolidated into a general place.
ConvertAPInt(llvm::APInt i)118 static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); }
119 
Convertuint32_t(uint32_t i)120 static uint32_t Convertuint32_t(uint32_t i) { return i; }
Convertuint64_t(uint64_t i)121 static uint64_t Convertuint64_t(uint64_t i) { return i; }
122 
123 // Convert APFloat to double.
ConvertAPFloat(llvm::APFloat value)124 static double ConvertAPFloat(llvm::APFloat value) {
125   const auto& semantics = value.getSemantics();
126   bool losesInfo = false;
127   if (&semantics != &llvm::APFloat::IEEEdouble())
128     value.convert(llvm::APFloat::IEEEdouble(),
129                   llvm::APFloat::rmNearestTiesToEven, &losesInfo);
130   return value.convertToDouble();
131 }
132 
Convertbool(bool value)133 static inline bool Convertbool(bool value) { return value; }
134 
ConvertStringRef(mlir::StringRef value)135 static absl::string_view ConvertStringRef(mlir::StringRef value) {
136   return {value.data(), value.size()};
137 }
138 
ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr)139 static std::vector<int64> ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) {
140   auto values = attr.getValues<int64>();
141   return {values.begin(), values.end()};
142 }
143 
ConvertDenseIntAttr(llvm::Optional<mlir::DenseIntElementsAttr> attr)144 static std::vector<int64> ConvertDenseIntAttr(
145     llvm::Optional<mlir::DenseIntElementsAttr> attr) {
146   if (!attr) return {};
147   return ConvertDenseIntAttr(*attr);
148 }
149 
150 // Converts the broadcast_dimensions attribute into a vector of dimension
151 // numbers (empty if the attribute is absent).
Convert_broadcast_dimensions(llvm::Optional<mlir::DenseIntElementsAttr> broadcast_dimensions)152 static std::vector<int64> Convert_broadcast_dimensions(
153     llvm::Optional<mlir::DenseIntElementsAttr> broadcast_dimensions) {
154   if (!broadcast_dimensions.hasValue()) return {};
155 
156   return ConvertDenseIntAttr(*broadcast_dimensions);
157 }
158 
159 // Converts StringRef to xla FftType enum
Convert_fft_type(llvm::StringRef fft_type_str)160 static xla::FftType Convert_fft_type(llvm::StringRef fft_type_str) {
161   xla::FftType fft_type_enum;
162   // Illegal fft_type string would be caught by the verifier, so 'FftType_Parse'
163   // call below should never return false.
164   if (!FftType_Parse(std::string(fft_type_str), &fft_type_enum))
165     return xla::FftType::FFT;
166   return fft_type_enum;
167 }
168 
Convert_padding(llvm::Optional<mlir::DenseIntElementsAttr> padding)169 static std::vector<std::pair<int64, int64>> Convert_padding(
170     llvm::Optional<mlir::DenseIntElementsAttr> padding) {
171   return xla::ConvertNx2Attribute(padding).ValueOrDie();
172 }
173 
Convert_source_target_pairs(llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs)174 static std::vector<std::pair<int64, int64>> Convert_source_target_pairs(
175     llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs) {
176   return xla::ConvertNx2Attribute(source_target_pairs).ValueOrDie();
177 }
178 
Convert_replica_groups(mlir::DenseIntElementsAttr groups)179 static std::vector<xla::ReplicaGroup> Convert_replica_groups(
180     mlir::DenseIntElementsAttr groups) {
181   return xla::ConvertReplicaGroups(groups).ValueOrDie();
182 }
183 
184 // Converts StringRef to xla Transpose enum.
Convert_transpose_a(llvm::StringRef transpose_str)185 static xla::TriangularSolveOptions::Transpose Convert_transpose_a(
186     llvm::StringRef transpose_str) {
187   return xla::ConvertTranspose(transpose_str).ValueOrDie();
188 }
189 
ExtractLayout(mlir::Operation * op,int rank,llvm::StringRef attr_name=kDefaultLayoutAttrName)190 static xla::Layout ExtractLayout(
191     mlir::Operation* op, int rank,
192     llvm::StringRef attr_name = kDefaultLayoutAttrName) {
193   if (auto attr = op->getAttrOfType<mlir::DenseIntElementsAttr>(attr_name)) {
194     llvm::SmallVector<int64, 4> minor_to_major;
195     DCHECK_EQ(rank, attr.size());
196     minor_to_major.reserve(attr.size());
197     for (const llvm::APInt& i : attr) {
198       minor_to_major.push_back(i.getZExtValue());
199     }
200     return xla::LayoutUtil::MakeLayout(minor_to_major);
201   }
202   return xla::LayoutUtil::MakeDescendingLayout(rank);
203 }
204 
205 #define I64_ELEMENTS_ATTR_TO_VECTOR(attribute)                \
206   static std::vector<int64> Convert_##attribute(              \
207       llvm::Optional<mlir::DenseIntElementsAttr> attribute) { \
208     return ConvertDenseIntAttr(attribute);                    \
209   }
210 
211 I64_ELEMENTS_ATTR_TO_VECTOR(broadcast_sizes);
212 I64_ELEMENTS_ATTR_TO_VECTOR(permutation);
213 I64_ELEMENTS_ATTR_TO_VECTOR(start_indices);
214 I64_ELEMENTS_ATTR_TO_VECTOR(limit_indices);
215 I64_ELEMENTS_ATTR_TO_VECTOR(strides);
216 I64_ELEMENTS_ATTR_TO_VECTOR(slice_sizes);
217 I64_ELEMENTS_ATTR_TO_VECTOR(fft_length);
218 I64_ELEMENTS_ATTR_TO_VECTOR(dimensions);
219 I64_ELEMENTS_ATTR_TO_VECTOR(window_strides);
220 I64_ELEMENTS_ATTR_TO_VECTOR(lhs_dilation);
221 I64_ELEMENTS_ATTR_TO_VECTOR(rhs_dilation);
222 
223 #undef I64_ELEMENTS_ATTR_TO_VECTOR
224 
Convert_ArrayRef(llvm::ArrayRef<int64_t> values)225 static std::vector<int64> Convert_ArrayRef(llvm::ArrayRef<int64_t> values) {
226   return {values.begin(), values.end()};
227 }
228 
229 // Converts the precision config array of strings attribute into the
230 // corresponding XLA proto. All the strings are assumed to be valid names of the
231 // Precision enum. This should have been checked in the op verify method.
Convert_precision_config(llvm::Optional<mlir::ArrayAttr> optional_precision_config_attr)232 static std::unique_ptr<xla::PrecisionConfig> Convert_precision_config(
233     llvm::Optional<mlir::ArrayAttr> optional_precision_config_attr) {
234   if (!optional_precision_config_attr.hasValue()) return nullptr;
235 
236   auto precision_config = absl::make_unique<xla::PrecisionConfig>();
237   for (auto attr : optional_precision_config_attr.getValue()) {
238     xla::PrecisionConfig::Precision p;
239     auto operand_precision = attr.cast<mlir::StringAttr>().getValue().str();
240     // TODO(jpienaar): Update this to ensure this is captured by verify.
241     if (xla::PrecisionConfig::Precision_Parse(operand_precision, &p)) {
242       precision_config->add_operand_precision(p);
243     } else {
244       auto* context = attr.getContext();
245       mlir::emitError(mlir::UnknownLoc::get(context))
246           << "unexpected operand precision " << operand_precision;
247       return nullptr;
248     }
249   }
250 
251   return precision_config;
252 }
253 
Convert_dot_dimension_numbers(mlir::mhlo::DotDimensionNumbers dot_dimension_numbers_attr)254 static xla::DotDimensionNumbers Convert_dot_dimension_numbers(
255     mlir::mhlo::DotDimensionNumbers dot_dimension_numbers_attr) {
256   xla::DotDimensionNumbers dot_dimension_numbers;
257 
258   auto rhs_contracting_dimensions =
259       dot_dimension_numbers_attr.rhs_contracting_dimensions()
260           .cast<mlir::DenseIntElementsAttr>();
261   auto lhs_contracting_dimensions =
262       dot_dimension_numbers_attr.lhs_contracting_dimensions()
263           .cast<mlir::DenseIntElementsAttr>();
264   auto rhs_batch_dimensions =
265       dot_dimension_numbers_attr.rhs_batching_dimensions()
266           .cast<mlir::DenseIntElementsAttr>();
267   auto lhs_batch_dimensions =
268       dot_dimension_numbers_attr.lhs_batching_dimensions()
269           .cast<mlir::DenseIntElementsAttr>();
270 
271   for (const auto& val : rhs_contracting_dimensions) {
272     dot_dimension_numbers.add_rhs_contracting_dimensions(val.getSExtValue());
273   }
274   for (const auto& val : lhs_contracting_dimensions) {
275     dot_dimension_numbers.add_lhs_contracting_dimensions(val.getSExtValue());
276   }
277 
278   for (const auto& val : rhs_batch_dimensions) {
279     dot_dimension_numbers.add_rhs_batch_dimensions(val.getSExtValue());
280   }
281 
282   for (const auto& val : lhs_batch_dimensions) {
283     dot_dimension_numbers.add_lhs_batch_dimensions(val.getSExtValue());
284   }
285 
286   return dot_dimension_numbers;
287 }
288 
Convert_dimension_numbers(mlir::mhlo::ConvDimensionNumbers input)289 static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
290     mlir::mhlo::ConvDimensionNumbers input) {
291   return xla::ConvertConvDimensionNumbers(input);
292 }
293 
Convert_channel_handle(mlir::mhlo::ChannelHandle attr)294 xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandle attr) {
295   xla::ChannelHandle channel_handle;
296   channel_handle.set_handle(ConvertAPInt(attr.handle().getValue()));
297   channel_handle.set_type(static_cast<xla::ChannelHandle::ChannelType>(
298       ConvertAPInt(attr.type().getValue())));
299   return channel_handle;
300 }
301 
Convert_channel_handle(llvm::Optional<mlir::mhlo::ChannelHandle> attr)302 absl::optional<xla::ChannelHandle> Convert_channel_handle(
303     llvm::Optional<mlir::mhlo::ChannelHandle> attr) {
304   if (!attr.hasValue()) return absl::nullopt;
305   return Convert_channel_handle(attr.getValue());
306 }
307 
308 // Converts the comparison_direction string attribute into the XLA enum. The
309 // string is assumed to correspond to exactly one of the allowed strings
310 // representing the enum. This should have been checked in the op verify method.
Convert_comparison_direction(llvm::StringRef comparison_direction_string)311 static xla::ComparisonDirection Convert_comparison_direction(
312     llvm::StringRef comparison_direction_string) {
313   return xla::StringToComparisonDirection(comparison_direction_string.str())
314       .ValueOrDie();
315 }
316 
Convert_dimension_numbers(mlir::mhlo::GatherDimensionNumbers input)317 static xla::GatherDimensionNumbers Convert_dimension_numbers(
318     mlir::mhlo::GatherDimensionNumbers input) {
319   xla::GatherDimensionNumbers output;
320 
321   auto offset_dims = ConvertDenseIntAttr(input.offset_dims());
322   std::copy(offset_dims.begin(), offset_dims.end(),
323             tensorflow::protobuf::RepeatedFieldBackInserter(
324                 output.mutable_offset_dims()));
325 
326   auto collapsed_slice_dims = ConvertDenseIntAttr(input.collapsed_slice_dims());
327   std::copy(collapsed_slice_dims.begin(), collapsed_slice_dims.end(),
328             tensorflow::protobuf::RepeatedFieldBackInserter(
329                 output.mutable_collapsed_slice_dims()));
330 
331   auto start_index_map = ConvertDenseIntAttr(input.start_index_map());
332   std::copy(start_index_map.begin(), start_index_map.end(),
333             tensorflow::protobuf::RepeatedFieldBackInserter(
334                 output.mutable_start_index_map()));
335 
336   output.set_index_vector_dim(
337       ConvertAPInt(input.index_vector_dim().getValue()));
338   return output;
339 }
340 
Convert_scatter_dimension_numbers(mlir::mhlo::ScatterDimensionNumbers input)341 static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
342     mlir::mhlo::ScatterDimensionNumbers input) {
343   xla::ScatterDimensionNumbers output;
344 
345   auto update_window_dims = ConvertDenseIntAttr(input.update_window_dims());
346   std::copy(update_window_dims.begin(), update_window_dims.end(),
347             tensorflow::protobuf::RepeatedFieldBackInserter(
348                 output.mutable_update_window_dims()));
349 
350   auto inserted_window_dims = ConvertDenseIntAttr(input.inserted_window_dims());
351   std::copy(inserted_window_dims.begin(), inserted_window_dims.end(),
352             tensorflow::protobuf::RepeatedFieldBackInserter(
353                 output.mutable_inserted_window_dims()));
354 
355   auto scatter_dims_to_operand_dims =
356       ConvertDenseIntAttr(input.scatter_dims_to_operand_dims());
357   std::copy(scatter_dims_to_operand_dims.begin(),
358             scatter_dims_to_operand_dims.end(),
359             tensorflow::protobuf::RepeatedFieldBackInserter(
360                 output.mutable_scatter_dims_to_operand_dims()));
361 
362   output.set_index_vector_dim(
363       ConvertAPInt(input.index_vector_dim().getValue()));
364   return output;
365 }
366 
367 // Extracts sharding from attribute string.
CreateOpShardingFromStringRef(llvm::StringRef sharding)368 static absl::optional<xla::OpSharding> CreateOpShardingFromStringRef(
369     llvm::StringRef sharding) {
370   xla::OpSharding sharding_proto;
371   if (!sharding_proto.ParseFromString(sharding.str())) return absl::nullopt;
372   return sharding_proto;
373 }
374 
375 // Returns an OpSharding proto from the "sharding" attribute of the op. If the
376 // op doesn't have a sharding attribute or the sharding attribute is invalid,
377 // returns absl::nullopt.
CreateOpShardingFromAttribute(mlir::Operation * op)378 static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
379     mlir::Operation* op) {
380   auto sharding = op->getAttrOfType<mlir::StringAttr>(kShardingAttr);
381   if (!sharding) return absl::nullopt;
382   return CreateOpShardingFromStringRef(sharding.getValue());
383 }
384 
385 // Returns a FrontendAttributes proto from the "frontend_attributes" attribute
386 // of the op. An empty FrontendAttributes proto is returned if an op does not
387 // have frontend attributes.
CreateOpFrontendAttributesFromAttribute(mlir::Operation * op)388 static xla::FrontendAttributes CreateOpFrontendAttributesFromAttribute(
389     mlir::Operation* op) {
390   xla::FrontendAttributes frontend_attributes;
391   auto frontend_attributes_dict =
392       op->getAttrOfType<mlir::DictionaryAttr>(kFrontendAttributesAttr);
393 
394   if (!frontend_attributes_dict) return frontend_attributes;
395 
396   for (const auto& attr : frontend_attributes_dict)
397     if (auto value_str_attr = attr.second.dyn_cast<mlir::StringAttr>())
398       frontend_attributes.mutable_map()->insert(
399           {attr.first.str(), value_str_attr.getValue().str()});
400 
401   return frontend_attributes;
402 }
403 
404 // Returns a OpMetadata proto based on the location of the op. If the location
405 // is unknown, an empty proto is returned. `op_name` are populated with the op
406 // location (converted). FileLineColLoc locations are populated by taking the
407 // file name and line number, and populating `source_file` and `source_line`
408 // respectively.
CreateOpMetadataFromLocation(mlir::Operation * op)409 static xla::OpMetadata CreateOpMetadataFromLocation(mlir::Operation* op) {
410   xla::OpMetadata metadata;
411   if (op->getLoc().isa<mlir::UnknownLoc>()) return metadata;
412 
413   std::string name = mlir::GetNameFromLoc(op->getLoc());
414   mlir::LegalizeNodeName(name);
415   metadata.set_op_name(name);
416 
417   if (auto file_line_col_loc = op->getLoc().dyn_cast<mlir::FileLineColLoc>()) {
418     metadata.set_source_file(file_line_col_loc.getFilename().str());
419     metadata.set_source_line(file_line_col_loc.getLine());
420   }
421 
422   return metadata;
423 }
424 
425 // Checks if all shardings are set.
AllOptionalShardingsAreSet(llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings)426 static bool AllOptionalShardingsAreSet(
427     llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings) {
428   return llvm::all_of(shardings,
429                       [](const absl::optional<xla::OpSharding>& sharding) {
430                         return sharding.has_value();
431                       });
432 }
433 
434 // Extracts argument and result shardings from function.
ExtractShardingsFromFunction(mlir::FuncOp function,llvm::SmallVectorImpl<absl::optional<xla::OpSharding>> * arg_shardings,llvm::SmallVectorImpl<absl::optional<xla::OpSharding>> * ret_shardings)435 static void ExtractShardingsFromFunction(
436     mlir::FuncOp function,
437     llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* arg_shardings,
438     llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* ret_shardings) {
439   arg_shardings->resize(function.getNumArguments(),
440                         absl::optional<xla::OpSharding>());
441   for (int i = 0, end = function.getNumArguments(); i < end; ++i)
442     if (auto sharding =
443             function.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr))
444       (*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
445 
446   ret_shardings->resize(function.getNumResults(),
447                         absl::optional<xla::OpSharding>());
448   for (int i = 0, end = function.getNumResults(); i < end; ++i)
449     if (auto sharding =
450             function.getResultAttrOfType<mlir::StringAttr>(i, kShardingAttr))
451       (*ret_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
452 }
453 
454 namespace mlir {
455 namespace {
456 class ConvertToHloModule {
457  public:
458   using ValueLoweringMap = llvm::DenseMap<Value, xla::XlaOp>;
459   using FunctionLoweringMap = llvm::DenseMap<mlir::FuncOp, xla::XlaComputation>;
460 
461   // If use_tuple_args is true, then the entry function's arguments are
462   // converted to a tuple and passed as a single parameter.
463   // Similarly, if return tuple is true, then the entry function's return values
464   // are converted to a tuple even when there is only a single return value.
465   // Multiple return values are always converted to a tuple and returned as a
466   // single value.
ConvertToHloModule(mlir::ModuleOp module,xla::XlaBuilder & module_builder,bool use_tuple_args,bool return_tuple,tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,MlirToHloConversionOptions options)467   explicit ConvertToHloModule(
468       mlir::ModuleOp module, xla::XlaBuilder& module_builder,
469       bool use_tuple_args, bool return_tuple,
470       tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
471       MlirToHloConversionOptions options)
472       : module_(module),
473         module_builder_(module_builder),
474         use_tuple_args_(use_tuple_args),
475         return_tuple_(return_tuple),
476         shape_representation_fn_(shape_representation_fn),
477         options_(options) {
478     if (!shape_representation_fn_)
479       shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn();
480   }
481 
482   // Perform the lowering to XLA. This function returns failure if an error was
483   // encountered.
484   //
485   // TODO(hinsu): Check for dynamic shapes and exit instead of crashing.
Run()486   LogicalResult Run() {
487     auto main = module_.lookupSymbol<mlir::FuncOp>("main");
488     if (!main)
489       return module_.emitError(
490           "conversion requires module with `main` function");
491 
492     for (auto func : module_.getOps<FuncOp>()) {
493       if (func.empty()) continue;
494       if (failed(RunOnFunction(func))) return failure();
495     }
496     return success();
497   }
498 
499   // Lower a specific function to HLO.
500   LogicalResult RunOnFunction(mlir::FuncOp f);
501 
502   // Lower a `mlir::Region` to a `XlaComputation`
503   LogicalResult LowerRegionAsComputation(mlir::Region* region,
504                                          xla::XlaComputation* func);
505 
506   // Lower a single `Block` to a `XlaComputation`
507   LogicalResult LowerBasicBlockAsFunction(
508       Block* block, xla::XlaBuilder* builder, bool is_entry_function,
509       const std::vector<bool>& entry_args_same_across_replicas,
510       llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
511       llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
512       xla::XlaComputation* result);
513 
ConsumeMainProto()514   ::xla::HloModuleProto ConsumeMainProto() {
515     auto main = module_.lookupSymbol<mlir::FuncOp>("main");
516     // This is an invariant check as Run returns failure if there is no main
517     // function and so the main proto shouldn't be consumed in that case.
518     CHECK(main) << "requires module to have main function";  // Crash Ok.
519     return lowered_computation_[main].proto();
520   }
521 
522   // Lower function call to HLO call instruction
523   LogicalResult LowerFunctionCall(
524       mlir::CallOp call_op, xla::XlaBuilder* builder,
525       ConvertToHloModule::ValueLoweringMap* value_lowering);
526 
527   LogicalResult Lower(
528       mlir::Operation* inst, bool is_entry_function,
529       llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
530       xla::XlaBuilder* builder,
531       ConvertToHloModule::ValueLoweringMap* value_lowering,
532       xla::XlaOp* return_value);
533 
GetOptions() const534   const MlirToHloConversionOptions& GetOptions() const { return options_; }
535 
536  private:
537   LogicalResult SetEntryTupleShapesAndLeafReplication(
538       Block* block, const std::vector<bool>& entry_args_same_across_replicas,
539       llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
540       std::vector<bool>* leaf_replication);
541 
542   LogicalResult SetEntryTupleShardings(
543       Block* block, xla::XlaBuilder* builder,
544       llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
545       llvm::SmallVectorImpl<xla::Shape>* arg_shapes);
546 
547   // The module being lowered.
548   mlir::ModuleOp module_;
549 
550   // The top-level XlaBuilder.
551   xla::XlaBuilder& module_builder_;
552 
553   // Map between function and lowered computation.
554   FunctionLoweringMap lowered_computation_;
555 
556   // Whether the entry function should take a single tuple as input.
557   bool use_tuple_args_;
558 
559   // Whether to always return a tuple.
560   bool return_tuple_;
561 
562   // Shape representation function to determine entry function argument and
563   // result shapes.
564   tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_;
565 
566   // Unique suffix to give to the name of the next lowered region.
567   size_t region_id_ = 0;
568 
569   MlirToHloConversionOptions options_;
570 };
571 
572 }  // namespace
573 }  // namespace mlir
574 
575 namespace {
576 
577 struct OpLoweringContext {
578   llvm::DenseMap<mlir::Value, xla::XlaOp>* values;
579   mlir::ConvertToHloModule* converter;
580   xla::XlaBuilder* builder;
581 };
582 
GetTuple(mlir::Operation * op,mlir::Operation::operand_range values,OpLoweringContext ctx,llvm::SmallVectorImpl<xla::XlaOp> & results)583 mlir::LogicalResult GetTuple(mlir::Operation* op,
584                              mlir::Operation::operand_range values,
585                              OpLoweringContext ctx,
586                              llvm::SmallVectorImpl<xla::XlaOp>& results) {
587   results.reserve(values.size());
588   for (mlir::Value value : values) {
589     if (failed(GetXlaOp(value, *ctx.values, &results.emplace_back(), op)))
590       return mlir::failure();
591   }
592   return mlir::success();
593 }
594 
595 }  // namespace
596 
597 namespace mlir {
598 namespace mhlo {
599 namespace {
600 
ExportXlaOp(ComputeReshapeShapeOp,OpLoweringContext)601 LogicalResult ExportXlaOp(ComputeReshapeShapeOp, OpLoweringContext) {
602   // This op has no expression in the legacy export format. It can be expanded
603   // to a sequence of operations if needed in the future, but would feed into
604   // ops creating unsupported dynamic shapes.
605   return failure();
606 }
607 
ExportXlaOp(CstrReshapableOp,OpLoweringContext)608 LogicalResult ExportXlaOp(CstrReshapableOp, OpLoweringContext) {
609   // This op has no expression in the legacy export format.
610   return failure();
611 }
612 
ExportXlaOp(AllGatherOp op,OpLoweringContext ctx)613 LogicalResult ExportXlaOp(AllGatherOp op, OpLoweringContext ctx) {
614   auto& valueMap = *ctx.values;
615   xla::XlaOp operand;
616   if (failed(GetXlaOp(op.operand(), valueMap, &operand, op))) return failure();
617   TensorType operandType = op.operand().getType().cast<TensorType>();
618   TensorType resultType = op.getType();
619   if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
620     return failure();
621   auto allGatherDim = op.all_gather_dim();
622   int64_t shardCount = resultType.getDimSize(allGatherDim) /
623                        operandType.getDimSize(allGatherDim);
624   valueMap[op] = xla::AllGather(operand, allGatherDim, shardCount,
625                                 Convert_replica_groups(op.replica_groups()),
626                                 Convert_channel_handle(op.channel_handle()));
627   return success();
628 }
629 
ExportXlaOp(AllReduceOp op,OpLoweringContext ctx)630 LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) {
631   auto& value_map = *ctx.values;
632   xla::XlaComputation computation;
633   if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
634                                                      &computation))) {
635     return failure();
636   }
637 
638   xla::XlaOp operand;
639   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
640 
641   value_map[op] = xla::AllReduce(operand, computation,
642                                  Convert_replica_groups(op.replica_groups()),
643                                  Convert_channel_handle(op.channel_handle()));
644   return success();
645 }
646 
ExportXlaOp(ReduceScatterOp op,OpLoweringContext ctx)647 LogicalResult ExportXlaOp(ReduceScatterOp op, OpLoweringContext ctx) {
648   auto& valueMap = *ctx.values;
649   xla::XlaOp operand;
650   if (failed(GetXlaOp(op.operand(), valueMap, &operand, op))) return failure();
651   TensorType operandType = op.operand().getType().cast<TensorType>();
652   TensorType resultType = op.getType();
653   if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
654     return failure();
655   auto scatterDim = op.scatter_dimension();
656   int64_t shardCount =
657       operandType.getDimSize(scatterDim) / resultType.getDimSize(scatterDim);
658 
659   xla::XlaComputation computation;
660   if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
661                                                      &computation))) {
662     return failure();
663   }
664 
665   valueMap[op] =
666       xla::ReduceScatter(operand, computation, scatterDim, shardCount,
667                          Convert_replica_groups(op.replica_groups()),
668                          Convert_channel_handle(op.channel_handle()));
669   return success();
670 }
671 
ExportXlaOp(BitcastConvertOp op,OpLoweringContext ctx)672 LogicalResult ExportXlaOp(BitcastConvertOp op, OpLoweringContext ctx) {
673   auto& value_map = *ctx.values;
674   xla::XlaOp operand;
675   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
676 
677   value_map[op] = xla::BitcastConvertType(
678       operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())));
679   return success();
680 }
681 
ExportXlaOp(BroadcastInDimOp op,OpLoweringContext ctx)682 LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) {
683   auto type = op.getType().dyn_cast<RankedTensorType>();
684   if (!type) return failure();
685   auto& value_map = *ctx.values;
686   xla::XlaOp operand;
687   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
688 
689   value_map[op] =
690       BroadcastInDim(operand, Convert_ArrayRef(type.getShape()),
691                      Convert_broadcast_dimensions(op.broadcast_dimensions()));
692   return success();
693 }
694 
ExportXlaOp(DotGeneralOp op,OpLoweringContext ctx)695 LogicalResult ExportXlaOp(DotGeneralOp op, OpLoweringContext ctx) {
696   auto& value_map = *ctx.values;
697   xla::XlaOp lhs, rhs;
698   if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
699   if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
700   xla::PrimitiveType preferred_element_type =
701       xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()));
702   value_map[op] = xla::DotGeneral(
703       lhs, rhs, Convert_dot_dimension_numbers(op.dot_dimension_numbers()),
704       Unwrap(Convert_precision_config(op.precision_config())),
705       preferred_element_type);
706   return mlir::success();
707 }
708 
ExportXlaOp(DynamicBroadcastInDimOp op,OpLoweringContext ctx)709 LogicalResult ExportXlaOp(DynamicBroadcastInDimOp op, OpLoweringContext ctx) {
710   // This op has no expression in the legacy export format.
711   return failure();
712 }
713 
ExportXlaOp(DynamicIotaOp op,OpLoweringContext ctx)714 LogicalResult ExportXlaOp(DynamicIotaOp op, OpLoweringContext ctx) {
715   // This op has no expression in the legacy export format.
716   return failure();
717 }
718 
ExportXlaOp(DynamicReshapeOp op,OpLoweringContext ctx)719 LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) {
720   // This op has no expression in the legacy export format.
721   return failure();
722 }
723 
ExportXlaOp(IfOp op,OpLoweringContext ctx)724 LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) {
725   xla::XlaComputation true_branch;
726   xla::XlaComputation false_branch;
727   auto& value_map = *ctx.values;
728   if (failed(ctx.converter->LowerRegionAsComputation(&op.true_branch(),
729                                                      &true_branch)) ||
730       failed(ctx.converter->LowerRegionAsComputation(&op.false_branch(),
731                                                      &false_branch))) {
732     return failure();
733   }
734   xla::XlaOp pred, true_arg, false_arg;
735   if (failed(GetXlaOp(op.pred(), value_map, &pred, op))) return failure();
736   if (failed(GetXlaOp(op.true_arg(), value_map, &true_arg, op)))
737     return failure();
738   if (failed(GetXlaOp(op.false_arg(), value_map, &false_arg, op)))
739     return failure();
740 
741   value_map[op] =
742       xla::Conditional(pred, true_arg, true_branch, false_arg, false_branch);
743   return success();
744 }
745 
ExportXlaOp(CaseOp op,OpLoweringContext ctx)746 LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) {
747   llvm::DenseMap<mlir::Value, xla::XlaOp>& value_map = *ctx.values;
748   OperandRange operands = op.branch_operands();
749   MutableArrayRef<Region> branches = op.branches();
750   llvm::SmallVector<xla::XlaOp, 4> branch_operands(branches.size());
751   std::vector<xla::XlaComputation> computations(branches.size());
752   std::vector<xla::XlaComputation*> computations_p(branches.size());
753 
754   for (unsigned i = 0; i < branches.size(); ++i) {
755     xla::XlaOp operand;
756     if (failed(GetXlaOp(operands[i], value_map, &operand, op)))
757       return failure();
758     branch_operands[i] = operand;
759     computations_p[i] = &computations[i];
760     if (failed(ctx.converter->LowerRegionAsComputation(&branches[i],
761                                                        computations_p[i])))
762       return failure();
763   }
764   xla::XlaOp index;
765   if (failed(GetXlaOp(op.index(), value_map, &index, op))) return failure();
766 
767   xla::XlaOp result = xla::Conditional(index, computations_p, branch_operands);
768   if (op.getNumResults() == 1) {
769     value_map[op.getResult(0)] = result;
770   } else {
771     for (auto item : llvm::enumerate(op.getResults())) {
772       value_map[item.value()] = xla::GetTupleElement(result, item.index());
773     }
774   }
775   return success();
776 }
777 
778 // Specialize CompareOp export to set broadcast_dimensions argument.
ExportXlaOp(mlir::mhlo::CompareOp op,OpLoweringContext ctx)779 mlir::LogicalResult ExportXlaOp(mlir::mhlo::CompareOp op,
780                                 OpLoweringContext ctx) {
781   auto& value_map = *ctx.values;
782   xla::XlaOp lhs, rhs;
783   if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
784   if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
785   auto dir = Convert_comparison_direction(op.comparison_direction());
786   auto type_attr = op.compare_typeAttr();
787 
788   xla::XlaOp xla_result;
789   if (type_attr) {
790     auto type =
791         xla::StringToComparisonType(type_attr.getValue().str()).ValueOrDie();
792     xla_result = xla::Compare(lhs, rhs, /*broadcast_dimensions=*/{}, dir, type);
793   } else {
794     xla_result = xla::Compare(lhs, rhs, dir);
795   }
796   value_map[op] = xla_result;
797   return mlir::success();
798 }
799 
ExportXlaOp(ConstOp op,OpLoweringContext ctx)800 LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) {
801   return failure();
802 }
803 
ExportXlaOp(mlir::mhlo::ConvOp op,OpLoweringContext ctx)804 LogicalResult ExportXlaOp(mlir::mhlo::ConvOp op, OpLoweringContext ctx) {
805   // XLA client builder API does not support generating convolution instructions
806   // with window reversal.
807   if (op.hasWindowReversal()) return failure();
808   auto& value_map = *ctx.values;
809   xla::XlaOp lhs, rhs;
810   if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
811   if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
812   xla::PrimitiveType preferred_element_type =
813       xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()));
814   xla::XlaOp xla_result = xla::ConvGeneralDilated(
815       lhs, rhs, Convert_window_strides(op.window_strides()),
816       Convert_padding(op.padding()), Convert_lhs_dilation(op.lhs_dilation()),
817       Convert_rhs_dilation(op.rhs_dilation()),
818       Convert_dimension_numbers(op.dimension_numbers()),
819       Convertuint64_t(op.feature_group_count()),
820       Convertuint64_t(op.batch_group_count()),
821       Unwrap(Convert_precision_config(op.precision_config())),
822       preferred_element_type);
823   value_map[op] = xla_result;
824   return mlir::success();
825 }
826 
ExportXlaOp(ConvertOp op,OpLoweringContext ctx)827 LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) {
828   auto& value_map = *ctx.values;
829   xla::XlaOp operand;
830   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
831 
832   value_map[op] = xla::ConvertElementType(
833       operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())));
834   return success();
835 }
836 
ExportXlaOp(CustomCallOp op,OpLoweringContext ctx)837 LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) {
838   // XLA client builder API does not support generating custom call instructions
839   // with side effect.
840   if (op.has_side_effect() || op.getNumResults() != 1) return failure();
841   Value result = op.getResult(0);
842   llvm::SmallVector<xla::XlaOp> args;
843   if (failed(GetTuple(op, op.args(), ctx, args))) return failure();
844   auto xla_api_version = xla::ConvertCustomCallApiVersion(op.api_version());
845   if (!xla_api_version.ok()) return failure();
846   auto& value_map = *ctx.values;
847   value_map[result] = xla::CustomCall(
848       ctx.builder, std::string(op.call_target_name()), args,
849       xla::TypeToShape(result.getType()), std::string(op.backend_config()),
850       /*has_side_effect=*/false, /*output_operand_aliasing=*/{},
851       /*literal=*/nullptr, /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
852       /*api_version=*/*xla_api_version);
853   return success();
854 }
855 
ExportXlaOp(DequantizeOp op,OpLoweringContext ctx)856 LogicalResult ExportXlaOp(DequantizeOp op, OpLoweringContext ctx) {
857   xla::QuantizedRange range(ConvertAPFloat(op.min_range()),
858                             ConvertAPFloat(op.max_range()));
859   auto& value_map = *ctx.values;
860   xla::XlaOp input;
861   if (failed(GetXlaOp(op.input(), value_map, &input, op))) return failure();
862 
863   auto casted = xla::ConvertElementType(input, xla::U32);
864   if (op.is_16bits()) {
865     value_map[op] = xla::Dequantize<uint16>(
866         casted, range, ConvertStringRef(op.mode()), op.transpose_output());
867   } else {
868     value_map[op] = xla::Dequantize<uint8>(
869         casted, range, ConvertStringRef(op.mode()), op.transpose_output());
870   }
871   return success();
872 }
873 
ExportXlaOp(InfeedOp op,OpLoweringContext ctx)874 LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) {
875   auto& value_map = *ctx.values;
876   xla::XlaOp token;
877   if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
878 
879   // The shape argument expected by the xla client API is the type of the first
880   // element in the result tuple.
881   auto result_type = op.getType().cast<mlir::TupleType>().getType(0);
882   value_map[op] = xla::InfeedWithToken(token, xla::TypeToShape(result_type),
883                                        std::string(op.infeed_config()));
884   return success();
885 }
886 
ExportXlaOp(IotaOp op,OpLoweringContext ctx)887 LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
888   auto& value_map = *ctx.values;
889   value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()),
890                             op.iota_dimension());
891   return success();
892 }
893 
ExportXlaOp(MapOp op,OpLoweringContext ctx)894 LogicalResult ExportXlaOp(MapOp op, OpLoweringContext ctx) {
895   auto& value_map = *ctx.values;
896   xla::XlaComputation computation;
897   if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
898                                                      &computation))) {
899     return failure();
900   }
901   llvm::SmallVector<xla::XlaOp> operands;
902   if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure();
903   value_map[op] = xla::Map(ctx.builder, operands, computation,
904                            Convert_dimensions(op.dimensions()));
905   return success();
906 }
907 
ExportXlaOp(OutfeedOp op,OpLoweringContext ctx)908 LogicalResult ExportXlaOp(OutfeedOp op, OpLoweringContext ctx) {
909   auto& value_map = *ctx.values;
910   xla::XlaOp operand, token;
911   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
912   if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
913 
914   value_map[op] = xla::OutfeedWithToken(
915       operand, token, xla::TypeToShape(op.operand().getType()),
916       std::string(op.outfeed_config()));
917   return success();
918 }
919 
ExportXlaOp(PadOp op,OpLoweringContext ctx)920 LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) {
921   auto& value_map = *ctx.values;
922   xla::PaddingConfig padding_config;
923   auto edge_padding_low = ConvertDenseIntAttr(op.edge_padding_low());
924   auto edge_padding_high = ConvertDenseIntAttr(op.edge_padding_high());
925   auto interior_padding = ConvertDenseIntAttr(op.interior_padding());
926   for (xla::int64 i = 0, end = edge_padding_low.size(); i < end; ++i) {
927     auto* dims = padding_config.add_dimensions();
928     dims->set_edge_padding_low(edge_padding_low[i]);
929     dims->set_edge_padding_high(edge_padding_high[i]);
930     dims->set_interior_padding(interior_padding[i]);
931   }
932   xla::XlaOp operand, padding_value;
933   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
934   if (failed(GetXlaOp(op.padding_value(), value_map, &padding_value, op)))
935     return failure();
936 
937   value_map[op] = xla::Pad(operand, padding_value, padding_config);
938   return success();
939 }
940 
ExportXlaOp(RecvOp op,OpLoweringContext ctx)941 LogicalResult ExportXlaOp(RecvOp op, OpLoweringContext ctx) {
942   auto& value_map = *ctx.values;
943   auto result_type = op.getType().cast<mlir::TupleType>().getType(0);
944   xla::XlaOp token;
945   if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
946 
947   if (op.is_host_transfer()) {
948     value_map[op] =
949         xla::RecvFromHost(token, xla::TypeToShape(result_type),
950                           Convert_channel_handle(op.channel_handle()));
951     return success();
952   }
953   value_map[op] =
954       xla::RecvWithToken(token, xla::TypeToShape(result_type),
955                          Convert_channel_handle(op.channel_handle()));
956   return success();
957 }
958 
ExportXlaOp(ReduceOp op,OpLoweringContext ctx)959 LogicalResult ExportXlaOp(ReduceOp op, OpLoweringContext ctx) {
960   auto& value_map = *ctx.values;
961   xla::XlaComputation body;
962   if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) {
963     return failure();
964   }
965   llvm::SmallVector<xla::XlaOp> operands, init_values;
966   if (failed(GetTuple(op, op.inputs(), ctx, operands)) ||
967       failed(GetTuple(op, op.init_values(), ctx, init_values))) {
968     return failure();
969   }
970   xla::XlaOp result =
971       xla::Reduce(ctx.builder, operands, init_values, body,
972                   Convert_broadcast_dimensions(op.dimensions()));
973   if (op.getNumResults() == 1) {
974     value_map[op.getResult(0)] = result;
975   } else {
976     for (auto item : llvm::enumerate(op.getResults())) {
977       value_map[item.value()] = xla::GetTupleElement(result, item.index());
978     }
979   }
980   return success();
981 }
982 
ExportXlaOp(ReduceWindowOp op,OpLoweringContext ctx)983 LogicalResult ExportXlaOp(ReduceWindowOp op, OpLoweringContext ctx) {
984   auto& value_map = *ctx.values;
985   xla::XlaComputation body;
986   if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) {
987     return failure();
988   }
989   llvm::SmallVector<xla::XlaOp> operands, init_values;
990   if (failed(GetTuple(op, op.inputs(), ctx, operands)) ||
991       failed(GetTuple(op, op.init_values(), ctx, init_values))) {
992     return failure();
993   }
994 
995   xla::XlaOp result = xla::ReduceWindowWithGeneralPadding(
996       operands, init_values, body, ConvertDenseIntAttr(op.window_dimensions()),
997       ConvertDenseIntAttr(op.window_strides()),
998       ConvertDenseIntAttr(op.base_dilations()),
999       ConvertDenseIntAttr(op.window_dilations()),
1000       Convert_padding(op.padding()));
1001 
1002   if (op.getNumResults() == 1) {
1003     value_map[op.getResult(0)] = result;
1004   } else {
1005     for (auto item : llvm::enumerate(op.getResults())) {
1006       value_map[item.value()] = xla::GetTupleElement(result, item.index());
1007     }
1008   }
1009   return success();
1010 }
1011 
ExportXlaOp(ReshapeOp op,OpLoweringContext ctx)1012 LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) {
1013   auto& value_map = *ctx.values;
1014   xla::XlaOp operand;
1015   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1016 
1017   value_map[op] =
1018       xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions());
1019   return success();
1020 }
1021 
ExportXlaOp(ReturnOp op,OpLoweringContext ctx)1022 LogicalResult ExportXlaOp(ReturnOp op, OpLoweringContext ctx) {
1023   // Failure on purpose because `mhlo::ReturnOp` will be handled by
1024   // special purpose logic in `ConvertToHloModule::Lower`.
1025   return failure();
1026 }
1027 
ExportXlaOp(RngBitGeneratorOp op,OpLoweringContext ctx)1028 LogicalResult ExportXlaOp(RngBitGeneratorOp op, OpLoweringContext ctx) {
1029   auto& value_map = *ctx.values;
1030   auto result = op.getResult();
1031   auto xla_arg_1 = value_map[*op.getODSOperands(0).begin()];
1032   auto xla_result = xla::RngBitGenerator(
1033       static_cast<xla::RandomAlgorithm>(op.rng_algorithm()), Unwrap(xla_arg_1),
1034       xla::TypeToShape(result.getType()).tuple_shapes(1));
1035   value_map[result] = xla_result;
1036   return mlir::success();
1037 }
1038 
ExportXlaOp(RngNormalOp op,OpLoweringContext ctx)1039 LogicalResult ExportXlaOp(RngNormalOp op, OpLoweringContext ctx) {
1040   auto& value_map = *ctx.values;
1041   xla::XlaOp mu, sigma;
1042   if (failed(GetXlaOp(op.mu(), value_map, &mu, op))) return failure();
1043   if (failed(GetXlaOp(op.sigma(), value_map, &sigma, op))) return failure();
1044 
1045   value_map[op] = xla::RngNormal(mu, sigma, xla::TypeToShape(op.getType()));
1046   return success();
1047 }
1048 
ExportXlaOp(RngUniformOp op,OpLoweringContext ctx)1049 LogicalResult ExportXlaOp(RngUniformOp op, OpLoweringContext ctx) {
1050   auto& value_map = *ctx.values;
1051   xla::XlaOp a, b;
1052   if (failed(GetXlaOp(op.a(), value_map, &a, op))) return failure();
1053   if (failed(GetXlaOp(op.b(), value_map, &b, op))) return failure();
1054 
1055   value_map[op] = xla::RngUniform(a, b, xla::TypeToShape(op.getType()));
1056   return success();
1057 }
1058 
ExportXlaOp(ScatterOp op,OpLoweringContext ctx)1059 LogicalResult ExportXlaOp(ScatterOp op, OpLoweringContext ctx) {
1060   auto& value_map = *ctx.values;
1061   xla::XlaComputation update_computation;
1062   if (failed(ctx.converter->LowerRegionAsComputation(&op.update_computation(),
1063                                                      &update_computation))) {
1064     return failure();
1065   }
1066   xla::ScatterDimensionNumbers dimension_numbers =
1067       Convert_scatter_dimension_numbers(op.scatter_dimension_numbers());
1068   xla::XlaOp operand, scatter_indices, updates;
1069   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1070   if (failed(GetXlaOp(op.scatter_indices(), value_map, &scatter_indices, op)))
1071     return failure();
1072   if (failed(GetXlaOp(op.updates(), value_map, &updates, op))) return failure();
1073 
1074   value_map[op] = xla::Scatter(operand, scatter_indices, updates,
1075                                update_computation, dimension_numbers,
1076                                op.indices_are_sorted(), op.unique_indices());
1077   return success();
1078 }
1079 
ExportXlaOp(SelectAndScatterOp op,OpLoweringContext ctx)1080 LogicalResult ExportXlaOp(SelectAndScatterOp op, OpLoweringContext ctx) {
1081   auto& value_map = *ctx.values;
1082   xla::XlaComputation select;
1083   xla::XlaComputation scatter;
1084   if (failed(ctx.converter->LowerRegionAsComputation(&op.select(), &select)) ||
1085       failed(
1086           ctx.converter->LowerRegionAsComputation(&op.scatter(), &scatter))) {
1087     return failure();
1088   }
1089   xla::XlaOp operand, source, init_value;
1090   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1091   if (failed(GetXlaOp(op.source(), value_map, &source, op))) return failure();
1092   if (failed(GetXlaOp(op.init_value(), value_map, &init_value, op)))
1093     return failure();
1094 
1095   value_map[op] = xla::SelectAndScatterWithGeneralPadding(
1096       operand, select, ConvertDenseIntAttr(op.window_dimensions()),
1097       ConvertDenseIntAttr(op.window_strides()), Convert_padding(op.padding()),
1098       source, init_value, scatter);
1099   return success();
1100 }
1101 
ExportXlaOp(SendOp op,OpLoweringContext ctx)1102 LogicalResult ExportXlaOp(SendOp op, OpLoweringContext ctx) {
1103   auto& value_map = *ctx.values;
1104   xla::XlaOp operand, token;
1105   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1106   if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
1107 
1108   if (op.is_host_transfer()) {
1109     value_map[op] = xla::SendToHost(
1110         operand, token, xla::TypeToShape(op.operand().getType()),
1111         Convert_channel_handle(op.channel_handle()));
1112     return success();
1113   }
1114   value_map[op] = xla::SendWithToken(
1115       operand, token, Convert_channel_handle(op.channel_handle()));
1116   return success();
1117 }
1118 
ExportXlaOp(SliceOp op,OpLoweringContext ctx)1119 LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) {
1120   return failure();
1121 }
1122 
ExportXlaOp(SortOp op,OpLoweringContext ctx)1123 LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
1124   xla::XlaComputation comparator;
1125   if (failed(ctx.converter->LowerRegionAsComputation(&op.comparator(),
1126                                                      &comparator)))
1127     return failure();
1128 
1129   llvm::SmallVector<xla::XlaOp> operands;
1130   if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure();
1131   auto sorted = xla::Sort(operands, comparator, op.dimension(), op.is_stable());
1132 
1133   auto& value_map = *ctx.values;
1134   auto shape_or = sorted.builder()->GetShape(sorted);
1135   if (!shape_or.ok()) {
1136     return op.emitError(shape_or.status().ToString());
1137   }
1138 
1139   xla::Shape& shape = shape_or.ValueOrDie();
1140   if (!shape.IsTuple()) {
1141     value_map[op.getResult(0)] = sorted;
1142     return success();
1143   }
1144 
1145   // MLIR's sort supports multiple returns, untuple all the results of XLA's.
1146   for (auto it : llvm::enumerate(op.getResults())) {
1147     value_map[it.value()] = xla::GetTupleElement(sorted, it.index());
1148   }
1149   return success();
1150 }
1151 
ExportXlaOp(TraceOp op,OpLoweringContext ctx)1152 LogicalResult ExportXlaOp(TraceOp op, OpLoweringContext ctx) {
1153   auto& value_map = *ctx.values;
1154   xla::XlaOp operand;
1155   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1156   xla::Trace(std::string(op.tag()), operand);
1157   return success();
1158 }
1159 
ExportXlaOp(UnaryEinsumOp op,OpLoweringContext ctx)1160 LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) {
1161   // Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two
1162   // operands.
1163   return failure();
1164 }
1165 
ExportXlaOp(WhileOp op,OpLoweringContext ctx)1166 LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
1167   // TODO(jpienaar): Support multi-operand while op.
1168   if (op.getNumResults() != 1)
1169     return op.emitError("nyi: unable to export multi-result While op");
1170   xla::XlaComputation condition;
1171   xla::XlaComputation body;
1172   auto& value_map = *ctx.values;
1173   if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body)) ||
1174       failed(ctx.converter->LowerRegionAsComputation(&op.cond(), &condition))) {
1175     return failure();
1176   }
1177 
1178   xla::XlaOp operand;
1179   // TODO(jpienaar): Support multi-operand while op.
1180   Value val = op->getResult(0);
1181   if (failed(GetXlaOp(op->getOperand(0), value_map, &operand, op)))
1182     return failure();
1183   value_map[val] = xla::While(condition, body, operand);
1184   return success();
1185 }
1186 
ExportXlaOp(FusionOp op,OpLoweringContext ctx)1187 LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
1188   if (!op.fusion_kind()) {
1189     op.emitOpError() << "requires fusion kind for HLO translation";
1190     return failure();
1191   }
1192 
1193   xla::XlaComputation fused_computation;
1194   if (failed(ctx.converter->LowerRegionAsComputation(&op.fused_computation(),
1195                                                      &fused_computation)))
1196     return failure();
1197 
1198   auto& values = *ctx.values;
1199   llvm::SmallVector<xla::XlaOp, 4> operands;
1200   for (auto operand : op.operands()) operands.push_back(values[operand]);
1201 
1202   xla::XlaOp fusion = xla::internal::XlaBuilderFriend::BuildFusion(
1203       ctx.builder, operands,
1204       absl::string_view(op.fusion_kind()->data(), op.fusion_kind()->size()),
1205       fused_computation);
1206   if (op.getNumResults() == 1) {
1207     values[op.getResult(0)] = fusion;
1208   } else {
1209     for (auto item : llvm::enumerate(op.getResults())) {
1210       values[item.value()] = xla::GetTupleElement(fusion, item.index());
1211     }
1212   }
1213   return success();
1214 }
1215 
ExportXlaOp(BitcastOp op,OpLoweringContext ctx)1216 LogicalResult ExportXlaOp(BitcastOp op, OpLoweringContext ctx) {
1217   auto& value_map = *ctx.values;
1218   xla::XlaOp operand;
1219   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1220   xla::XlaOp bitcast = xla::internal::XlaBuilderFriend::BuildBitcast(
1221       ctx.builder, operand, xla::TypeToShape(op.getType()));
1222   value_map[op] = bitcast;
1223   if (ctx.converter->GetOptions().propagate_bitcast_layouts_to_backend_config) {
1224     // Encode the source and result layout of the bitcast into the XLA HLO
1225     // backend config as a protobuf. Note that this is a temporary solution
1226     // which will go away once XLA:GPU stops falling back to XLA HLO Elemental
1227     // IR emitters.
1228     xla::HloInstructionProto* bitcast_proto =
1229         xla::internal::XlaBuilderFriend::GetInstruction(bitcast);
1230     xla::HloInstructionProto* operand_proto =
1231         xla::internal::XlaBuilderFriend::GetInstruction(operand);
1232     xla::LayoutProto result_layout =
1233         ExtractLayout(op, bitcast_proto->shape().dimensions_size(),
1234                       "result_layout")
1235             .ToProto();
1236     xla::LayoutProto source_layout =
1237         ExtractLayout(op, operand_proto->shape().dimensions_size(),
1238                       "source_layout")
1239             .ToProto();
1240     xla::gpu::BitcastBackendConfig bitcast_config;
1241     *bitcast_config.mutable_source_layout() = source_layout;
1242     *bitcast_config.mutable_result_layout() = result_layout;
1243     *bitcast_proto->mutable_backend_config() =
1244         bitcast_config.SerializeAsString();
1245   }
1246   return success();
1247 }
1248 
ExportXlaOp(RealDynamicSliceOp op,OpLoweringContext ctx)1249 LogicalResult ExportXlaOp(RealDynamicSliceOp op, OpLoweringContext ctx) {
1250   return failure();
1251 }
1252 
ExportXlaOp(DynamicPadOp op,OpLoweringContext ctx)1253 LogicalResult ExportXlaOp(DynamicPadOp op, OpLoweringContext ctx) {
1254   return failure();
1255 }
1256 
ExportXlaOp(DynamicGatherOp op,OpLoweringContext ctx)1257 LogicalResult ExportXlaOp(DynamicGatherOp op, OpLoweringContext ctx) {
1258   return failure();
1259 }
1260 
ExportXlaOp(DynamicConvOp op,OpLoweringContext ctx)1261 LogicalResult ExportXlaOp(DynamicConvOp op, OpLoweringContext ctx) {
1262   return failure();
1263 }
1264 
ExportXlaOp(PrintOp op,OpLoweringContext ctx)1265 LogicalResult ExportXlaOp(PrintOp op, OpLoweringContext ctx) {
1266   return failure();
1267 }
1268 
1269 }  // namespace
1270 }  // namespace mhlo
1271 }  // namespace mlir
1272 
1273 #include "tensorflow/compiler/mlir/xla/operator_writers.inc"
1274 
1275 namespace mlir {
1276 namespace {
1277 
CreateArrayLiteralFromAttr(ElementsAttr attr,xla::Layout layout)1278 StatusOr<xla::Literal> CreateArrayLiteralFromAttr(ElementsAttr attr,
1279                                                   xla::Layout layout) {
1280   if (attr.isa<OpaqueElementsAttr>())
1281     return tensorflow::errors::Unimplemented(
1282         "Opaque elements attr not supported");
1283 
1284   xla::Shape shape = xla::TypeToShape(attr.getType());
1285 
1286 #define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type)                         \
1287   case xla_type: {                                                           \
1288     xla::Array<cpp_type> source_data(shape.dimensions());                    \
1289     source_data.SetValues(attr.getValues<cpp_type>());                       \
1290     return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); \
1291   }
1292 
1293   switch (shape.element_type()) {
1294     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::PRED, bool)
1295     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F32, float)
1296     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F64, double)
1297     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S8, int8)
1298     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S16, int16)
1299     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S32, int32)
1300     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S64, int64)
1301     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U8, uint8)
1302     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16)
1303     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32)
1304     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U64, uint64)
1305     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C64, std::complex<float>)
1306     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C128, std::complex<double>)
1307     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F16, Eigen::half)
1308     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::BF16, Eigen::bfloat16)
1309     default:
1310       return tensorflow::errors::Internal(absl::StrCat(
1311           "Unsupported type: ", xla::PrimitiveType_Name(shape.element_type())));
1312   }
1313 #undef ELEMENTS_ATTR_TO_LITERAL
1314 }
1315 
ConvertLayout(mlir::Operation * op,const mlir::ArrayAttr & layout,xla::ShapeProto * shape)1316 LogicalResult ConvertLayout(mlir::Operation* op, const mlir::ArrayAttr& layout,
1317                             xla::ShapeProto* shape) {
1318   // In the case of tuples, ShapeProtos can be nested, and so can the mlir
1319   // attribute describing the layout. So recurse into the subshapes in both data
1320   // structures in parallel.
1321   if (shape->element_type() == xla::TUPLE) {
1322     auto subshapes = shape->mutable_tuple_shapes();
1323     if (layout.size() != subshapes->size()) {
1324       op->emitOpError() << "Expected layout of size " << layout.size()
1325                         << ", but found " << subshapes->size();
1326       return failure();
1327     }
1328     for (int i = 0; i < subshapes->size(); i++) {
1329       mlir::Attribute child = layout[i];
1330       if (child.isa<mlir::UnitAttr>()) {
1331         // ignore unit attributes, they are used only for tokens.
1332         continue;
1333       }
1334       mlir::ArrayAttr c = child.dyn_cast<mlir::ArrayAttr>();
1335       if (!c) {
1336         op->emitOpError() << "Type Error: Expected layout array attribute";
1337         return failure();
1338       }
1339       if (failed(ConvertLayout(op, c, subshapes->Mutable(i)))) {
1340         return failure();
1341       }
1342     }
1343   } else {
1344     int rank = shape->dimensions().size();
1345     if (rank) {
1346       if (layout.size() != rank) {
1347         return failure();  // pass error down
1348       }
1349       std::vector<int64> array(rank);
1350       for (int i = 0; i < rank; i++) {
1351         mlir::IntegerAttr attr = layout[i].dyn_cast<mlir::IntegerAttr>();
1352         if (!attr) {
1353           op->emitOpError() << "Type Error: Expected layout integer attribute";
1354           return failure();
1355         }
1356         array[i] = attr.getInt();
1357       }
1358       *shape->mutable_layout() = xla::LayoutUtil::MakeLayout(array).ToProto();
1359     }
1360   }
1361   return success();
1362 }
1363 
1364 // MHLO and XLA HLO disagree on the meaning of addition of `pred` / `i1`, so
1365 // there has to be a special case somewhere to account for the difference.  To
1366 // get the expected behavior of an `AddOp` on `i1`, we have to use `xor`.  Since
1367 // the majority of the conversion is generated code, we just sidestep it here
1368 // for this single case, and inline the code to emit an `xor`.
ExportXlaOperatorWrapped(mlir::Operation * inst,OpLoweringContext ctx)1369 LogicalResult ExportXlaOperatorWrapped(mlir::Operation* inst,
1370                                        OpLoweringContext ctx) {
1371   auto op = dyn_cast<mlir::mhlo::AddOp>(inst);
1372   if (op && op.getResult()
1373                 .getType()
1374                 .cast<mlir::TensorType>()
1375                 .getElementType()
1376                 .isSignlessInteger(1)) {
1377     auto& value_map = *ctx.values;
1378     auto result = op.getResult();
1379     xla::XlaOp xla_arg_0;
1380     if (failed(GetXlaOp(op.lhs(), value_map, &xla_arg_0, op)))
1381       return mlir::failure();
1382     xla::XlaOp xla_arg_1;
1383     if (failed(GetXlaOp(op.rhs(), value_map, &xla_arg_1, op)))
1384       return mlir::failure();
1385     auto xla_result = xla::Xor(Unwrap(xla_arg_0), Unwrap(xla_arg_1));
1386     value_map[result] = xla_result;
1387     return mlir::success();
1388   }
1389 
1390   return ExportXlaOperator(inst, ctx);
1391 }
1392 
Lower(mlir::Operation * inst,bool is_entry_function,llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,xla::XlaBuilder * builder,ConvertToHloModule::ValueLoweringMap * value_lowering,xla::XlaOp * return_value)1393 LogicalResult ConvertToHloModule::Lower(
1394     mlir::Operation* inst, bool is_entry_function,
1395     llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
1396     xla::XlaBuilder* builder,
1397     ConvertToHloModule::ValueLoweringMap* value_lowering,
1398     xla::XlaOp* return_value) {
1399   // Explicitly fail for ops that are not supported for export.
1400   if (inst->getDialect() !=
1401           inst->getContext()->getLoadedDialect<mlir::mhlo::MhloDialect>() &&
1402       !mlir::isa<mlir::ConstantOp, mlir::CallOp, mlir::tensor::CastOp,
1403                  mlir::ReturnOp>(inst)) {
1404     inst->emitOpError("unsupported op for export to XLA");
1405     return failure();
1406   }
1407 
1408   *return_value = xla::XlaOp();
1409 
1410   // See MlirToHloConversionOptions for more about layouts.
1411   auto propagate_layouts = [this](mlir::Operation* inst,
1412                                   xla::XlaOp xla_op) -> mlir::LogicalResult {
1413     if (options_.propagate_layouts) {
1414       auto* shape = xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
1415                         ->mutable_shape();
1416       if (shape->tuple_shapes().empty())
1417         // TODO(kramm): merge this with ConvertLayout.
1418         *shape->mutable_layout() =
1419             ExtractLayout(inst, shape->dimensions().size()).ToProto();
1420     }
1421 
1422     // For infeed ops stemming back to InfeedDequeueTuple, respect the layout
1423     // attribute, and create the corresponding layout in hlo.
1424     if (isa<mhlo::InfeedOp>(inst)) {
1425       mlir::ArrayAttr layout =
1426           inst->getAttrOfType<mlir::ArrayAttr>(kLayoutAttr);
1427       if (layout) {
1428         xla::ShapeProto* shape =
1429             xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
1430                 ->mutable_shape();
1431 
1432         if (failed(ConvertLayout(inst, layout, shape))) {
1433           return failure();
1434         }
1435       }
1436     }
1437     return success();
1438   };
1439 
1440   if (succeeded(
1441           ExportXlaOperatorWrapped(inst, {value_lowering, this, builder}))) {
1442     if (inst->getNumResults() == 1) {
1443       auto iter = value_lowering->find(inst->getResult(0));
1444       if (iter == value_lowering->end()) {
1445         inst->emitOpError(
1446             "inst has a result, but it's not found in value_lowering");
1447         return failure();
1448       }
1449       if (failed(propagate_layouts(inst, iter->second))) {
1450         return failure();
1451       }
1452     }
1453     return success();
1454   }
1455 
1456   auto& value_map = *value_lowering;
1457   ElementsAttr const_attr;
1458 
1459   if (auto call_op = dyn_cast<mlir::CallOp>(inst)) {
1460     return LowerFunctionCall(call_op, builder, &value_map);
1461   }
1462 
1463   if (auto op = dyn_cast<mlir::tensor::CastOp>(inst)) {
1464     Value operand = op.getOperand();
1465     auto ty = operand.getType().dyn_cast<ShapedType>();
1466     // If this was a cast from a static shaped tensors, then it is a noop for
1467     // export to HLO and we can use the operand.
1468     if (!ty || !ty.hasStaticShape()) {
1469       inst->emitOpError()
1470           << "requires static shaped operand for HLO translation";
1471       return failure();
1472     }
1473 
1474     xla::XlaOp xla_operand;
1475     if (failed(GetXlaOp(operand, value_map, &xla_operand, op)))
1476       return failure();
1477     value_map[op.getResult()] = xla_operand;
1478     if (failed(propagate_layouts(inst, xla_operand))) {
1479       return failure();
1480     }
1481     return success();
1482   }
1483 
1484   if (matchPattern(inst, m_Constant(&const_attr))) {
1485     xla::Layout layout;
1486     layout = ExtractLayout(inst, const_attr.getType().getRank());
1487     auto literal_or = CreateArrayLiteralFromAttr(const_attr, layout);
1488     if (!literal_or.ok())
1489       return inst->emitError(literal_or.status().ToString());
1490     auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie());
1491     value_map[inst->getResult(0)] = constant;
1492 
1493     return success();
1494   }
1495 
1496   if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
1497     // Construct the return value for the function. If there is a single value
1498     // returned, then return it directly, else create a tuple and return.
1499     unsigned num_return_values = inst->getNumOperands();
1500     if ((return_tuple_ && is_entry_function) || num_return_values != 1) {
1501       const bool has_ret_shardings =
1502           !ret_shardings.empty() && AllOptionalShardingsAreSet(ret_shardings);
1503 
1504       std::vector<xla::XlaOp> returns(num_return_values);
1505       for (OpOperand& ret : inst->getOpOperands()) {
1506         unsigned index = ret.getOperandNumber();
1507         xla::XlaOp operand;
1508         if (failed(GetXlaOp(ret.get(), value_map, &operand, inst)))
1509           return failure();
1510 
1511         returns[index] = operand;
1512         if (!is_entry_function || !has_ret_shardings) continue;
1513 
1514         xla::Shape return_shape = xla::TypeToShape(ret.get().getType());
1515         StatusOr<xla::XlaOp> reshape =
1516             tensorflow::ReshapeWithCorrectRepresentationAndSharding(
1517                 builder, returns[index], return_shape, shape_representation_fn_,
1518                 ret_shardings[index], /*fast_mem=*/false);
1519         if (!reshape.ok())
1520           return inst->emitError() << reshape.status().error_message();
1521 
1522         returns[index] = reshape.ValueOrDie();
1523       }
1524 
1525       if (has_ret_shardings) {
1526         xla::OpSharding sharding;
1527         sharding.set_type(xla::OpSharding::TUPLE);
1528         for (auto& ret_sharding : ret_shardings)
1529           *sharding.add_tuple_shardings() = *ret_sharding;
1530 
1531         builder->SetSharding(sharding);
1532       }
1533 
1534       *return_value = xla::Tuple(builder, returns);
1535       builder->ClearSharding();
1536     } else if (num_return_values == 1) {
1537       xla::XlaOp operand;
1538       if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst)))
1539         return failure();
1540 
1541       *return_value = operand;
1542     }
1543 
1544     return success();
1545   }
1546 
1547   inst->emitOpError() << "can't be translated to XLA HLO";
1548   return failure();
1549 }
1550 
LowerFunctionCall(mlir::CallOp call_op,xla::XlaBuilder * builder,ConvertToHloModule::ValueLoweringMap * value_lowering)1551 LogicalResult ConvertToHloModule::LowerFunctionCall(
1552     mlir::CallOp call_op, xla::XlaBuilder* builder,
1553     ConvertToHloModule::ValueLoweringMap* value_lowering) {
1554   auto& value_map = *value_lowering;
1555   mlir::FuncOp callee = module_.lookupSymbol<mlir::FuncOp>(call_op.callee());
1556   if (failed(RunOnFunction(callee))) return failure();
1557   std::vector<xla::XlaOp> operands;
1558   for (auto operand : call_op.getOperands()) {
1559     xla::XlaOp xla_operand;
1560     if (failed(GetXlaOp(operand, value_map, &xla_operand, call_op)))
1561       return failure();
1562     operands.push_back(xla_operand);
1563   }
1564   // Each call to xla::Call would insert a copy of the computation to
1565   // the HLO. Thus each callsite would have a unique callee in the
1566   // exported HLO. HLO syntactically does not require all calls to have unique
1567   // callees, but eventually before lowering call graph is "flattened" to
1568   // make that true. This is done before lowering because buffer assignment
1569   // needs this invariant.
1570   xla::XlaOp call_result =
1571       xla::Call(builder, lowered_computation_[callee], operands);
1572   // Use GetTupleElement for multiple outputs
1573   unsigned num_results = call_op.getNumResults();
1574   if (num_results > 1) {
1575     for (unsigned i = 0; i != num_results; ++i) {
1576       value_map[call_op.getResult(i)] = xla::GetTupleElement(call_result, i);
1577     }
1578   } else if (num_results == 1) {
1579     value_map[call_op.getResult(0)] = call_result;
1580   }
1581   return success();
1582 }
1583 
RunOnFunction(mlir::FuncOp f)1584 LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
1585   if (lowered_computation_.count(f)) return success();
1586   if (!llvm::hasSingleElement(f)) {
1587     return f.emitError("only single block Function supported");
1588   }
1589 
1590   // Create a sub-builder if this is not the main function.
1591   std::unique_ptr<xla::XlaBuilder> builder_up;
1592   bool entry_function = f.getName() == "main";
1593   if (!entry_function)
1594     builder_up = module_builder_.CreateSubBuilder(f.getName().str());
1595   auto& builder = entry_function ? module_builder_ : *builder_up;
1596 
1597   xla::XlaComputation computation;
1598   std::vector<bool> entry_args_same_across_replicas;
1599   llvm::SmallVector<absl::optional<xla::OpSharding>, 4> arg_shardings;
1600   llvm::SmallVector<absl::optional<xla::OpSharding>, 4> ret_shardings;
1601   if (entry_function) {
1602     bool any_arg_replicated = false;
1603     entry_args_same_across_replicas.reserve(f.getNumArguments());
1604     for (int64_t i = 0; i < f.getNumArguments(); ++i) {
1605       auto attr = f.getArgAttrOfType<mlir::UnitAttr>(i, kReplicationAttr);
1606       entry_args_same_across_replicas.push_back(attr != nullptr);
1607       any_arg_replicated |= entry_args_same_across_replicas.back();
1608       // Pass the alias info to the builder so that it will build the alias info
1609       // into the resulting HloModule.
1610       auto aliasing_output =
1611           f.getArgAttrOfType<mlir::IntegerAttr>(i, "tf.aliasing_output");
1612       if (!aliasing_output) continue;
1613       if (use_tuple_args_) {
1614         builder.SetUpAlias(/*output_index=*/{aliasing_output.getInt()},
1615                            /*param_number=*/0, /*param_index=*/{i});
1616       } else {
1617         builder.SetUpAlias(/*output_index=*/{aliasing_output.getInt()},
1618                            /*param_number=*/i, /*param_index=*/{});
1619       }
1620     }
1621     // Do not populate this field when nothing is replicated, since empty field
1622     // means no replication. This avoids the need for unrelated tests to handle
1623     // this field.
1624     if (!any_arg_replicated) entry_args_same_across_replicas.clear();
1625 
1626     ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings);
1627   }
1628   if (failed(LowerBasicBlockAsFunction(
1629           &f.front(), &builder, entry_function, entry_args_same_across_replicas,
1630           arg_shardings, ret_shardings, &computation))) {
1631     return failure();
1632   }
1633   lowered_computation_[f] = std::move(computation);
1634   return success();
1635 }
1636 
SetEntryTupleShapesAndLeafReplication(Block * block,const std::vector<bool> & entry_args_same_across_replicas,llvm::SmallVectorImpl<xla::Shape> * arg_shapes,std::vector<bool> * leaf_replication)1637 LogicalResult ConvertToHloModule::SetEntryTupleShapesAndLeafReplication(
1638     Block* block, const std::vector<bool>& entry_args_same_across_replicas,
1639     llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
1640     std::vector<bool>* leaf_replication) {
1641   arg_shapes->reserve(block->getNumArguments());
1642   leaf_replication->reserve(block->getNumArguments());
1643   for (BlockArgument& arg : block->getArguments()) {
1644     arg_shapes->push_back(xla::TypeToShape(arg.getType()));
1645     xla::Shape& arg_shape = arg_shapes->back();
1646     tensorflow::TensorShape arg_tensor_shape;
1647     auto status =
1648         tensorflow::XLAShapeToTensorShape(arg_shape, &arg_tensor_shape);
1649     if (!status.ok())
1650       return block->getParentOp()->emitError() << status.error_message();
1651 
1652     tensorflow::DataType dtype;
1653     status = tensorflow::ConvertToDataType(arg.getType(), &dtype);
1654     if (!status.ok())
1655       return block->getParentOp()->emitError() << status.error_message();
1656 
1657     auto arg_shape_status = shape_representation_fn_(arg_tensor_shape, dtype,
1658                                                      /*use_fast_memory=*/false);
1659     if (!arg_shape_status.ok())
1660       return block->getParentOp()->emitError()
1661              << arg_shape_status.status().error_message();
1662 
1663     arg_shape = std::move(arg_shape_status.ValueOrDie());
1664 
1665     if (entry_args_same_across_replicas.empty()) continue;
1666     for (int i = 0, e = xla::ShapeUtil::GetLeafCount(arg_shape); i < e; ++i)
1667       leaf_replication->push_back(
1668           entry_args_same_across_replicas[arg.getArgNumber()]);
1669   }
1670 
1671   return success();
1672 }
1673 
SetEntryTupleShardings(Block * block,xla::XlaBuilder * builder,llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,llvm::SmallVectorImpl<xla::Shape> * arg_shapes)1674 LogicalResult ConvertToHloModule::SetEntryTupleShardings(
1675     Block* block, xla::XlaBuilder* builder,
1676     llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
1677     llvm::SmallVectorImpl<xla::Shape>* arg_shapes) {
1678   if (!arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings)) {
1679     xla::OpSharding sharding;
1680     sharding.set_type(xla::OpSharding::TUPLE);
1681     for (auto arg_sharding : llvm::enumerate(arg_shardings)) {
1682       auto hlo_sharding = xla::HloSharding::FromProto(*arg_sharding.value());
1683       if (!hlo_sharding.ok())
1684         return block->getParentOp()->emitError()
1685                << hlo_sharding.status().error_message();
1686 
1687       auto status = tensorflow::RewriteLayoutWithShardedShape(
1688           hlo_sharding.ValueOrDie(), /*use_fast_memory=*/false,
1689           shape_representation_fn_, &(*arg_shapes)[arg_sharding.index()]);
1690       if (!status.ok())
1691         return block->getParentOp()->emitError() << status.error_message();
1692 
1693       *sharding.add_tuple_shardings() = *arg_sharding.value();
1694     }
1695 
1696     builder->SetSharding(sharding);
1697   }
1698 
1699   return success();
1700 }
1701 
LowerBasicBlockAsFunction(Block * block,xla::XlaBuilder * builder,bool is_entry_function,const std::vector<bool> & entry_args_same_across_replicas,llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,xla::XlaComputation * result)1702 LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
1703     Block* block, xla::XlaBuilder* builder, bool is_entry_function,
1704     const std::vector<bool>& entry_args_same_across_replicas,
1705     llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
1706     llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
1707     xla::XlaComputation* result) {
1708   // Mapping from the Value to lowered XlaOp.
1709   ValueLoweringMap lowering;
1710 
1711   // If using tuples as input, then there is only one input parameter that is a
1712   // tuple.
1713   if (is_entry_function && use_tuple_args_) {
1714     llvm::SmallVector<xla::Shape, 4> arg_shapes;
1715     std::vector<bool> leaf_replication;
1716     if (failed(SetEntryTupleShapesAndLeafReplication(
1717             block, entry_args_same_across_replicas, &arg_shapes,
1718             &leaf_replication)))
1719       return failure();
1720 
1721     if (failed(
1722             SetEntryTupleShardings(block, builder, arg_shardings, &arg_shapes)))
1723       return failure();
1724 
1725     xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
1726     auto tuple =
1727         xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication);
1728     builder->ClearSharding();
1729 
1730     bool set_tuple_element_sharding =
1731         !arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings);
1732     for (BlockArgument& arg : block->getArguments()) {
1733       if (set_tuple_element_sharding)
1734         builder->SetSharding(*arg_shardings[arg.getArgNumber()]);
1735       lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber());
1736     }
1737     builder->ClearSharding();
1738   } else {
1739     for (BlockArgument& arg : block->getArguments()) {
1740       auto num = arg.getArgNumber();
1741       xla::Shape shape = xla::TypeToShape(arg.getType());
1742       if (entry_args_same_across_replicas.empty()) {
1743         lowering[arg] =
1744             xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num));
1745       } else {
1746         lowering[arg] = xla::Parameter(
1747             builder, num, shape, absl::StrCat("Arg_", num),
1748             std::vector<bool>(entry_args_same_across_replicas[num],
1749                               xla::ShapeUtil::GetLeafCount(shape)));
1750       }
1751     }
1752   }
1753 
1754   xla::XlaOp return_value;
1755   for (auto& inst : *block)
1756     if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
1757                      &lowering, &return_value)))
1758       return failure();
1759 
1760   // Build the XlaComputation and check for failures.
1761   auto computation_or =
1762       return_value.valid() ? builder->Build(return_value) : builder->Build();
1763   if (!computation_or.ok()) {
1764     block->back().emitError(
1765         llvm::Twine(computation_or.status().error_message()));
1766     return failure();
1767   }
1768   *result = std::move(computation_or.ValueOrDie());
1769   return success();
1770 }
1771 
LowerRegionAsComputation(mlir::Region * region,xla::XlaComputation * func)1772 LogicalResult ConvertToHloModule::LowerRegionAsComputation(
1773     mlir::Region* region, xla::XlaComputation* func) {
1774   std::unique_ptr<xla::XlaBuilder> builder =
1775       module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++));
1776   return LowerBasicBlockAsFunction(
1777       &region->front(), builder.get(),
1778       /*is_entry_function=*/false, /*entry_args_same_across_replicas=*/{},
1779       /*arg_shardings=*/{}, /*ret_shardings=*/{}, func);
1780 }
1781 
AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto * binding,int arg_index,int32_t shape_index,int32_t padding_arg_index,bool use_tuple_args)1782 void AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto* binding,
1783                                      int arg_index, int32_t shape_index,
1784                                      int32_t padding_arg_index,
1785                                      bool use_tuple_args) {
1786   auto* entry = binding->add_entries();
1787   entry->set_target_param_dim_num(shape_index);
1788   if (use_tuple_args) {
1789     entry->set_target_param_num(0);
1790     entry->add_target_param_index(arg_index);
1791     entry->set_dynamic_param_num(0);
1792     entry->add_dynamic_param_index(padding_arg_index);
1793   } else {
1794     entry->set_target_param_num(arg_index);
1795     entry->set_dynamic_param_num(padding_arg_index);
1796   }
1797 }
1798 
1799 // Runs the PrepareForExport pass on the ModuleOp.
PrepareForExport(mlir::ModuleOp module)1800 Status PrepareForExport(mlir::ModuleOp module) {
1801   // Prepare for export to XLA HLO.
1802   mlir::PassManager pm(module.getContext());
1803   pm.addNestedPass<mlir::FuncOp>(mhlo::CreatePrepareForExport());
1804   if (failed(pm.run(module)))
1805     return tensorflow::errors::Internal("Unable to optimize for XLA export");
1806   return Status::OK();
1807 }
1808 
1809 }  // namespace
1810 
ConvertRegionToComputation(mlir::Region * region,xla::XlaComputation * func,MlirToHloConversionOptions options)1811 Status ConvertRegionToComputation(mlir::Region* region,
1812                                   xla::XlaComputation* func,
1813                                   MlirToHloConversionOptions options) {
1814   mlir::ModuleOp module;
1815   xla::XlaBuilder module_builder("main");
1816   ConvertToHloModule converter(module, module_builder, true, true, {}, options);
1817   if (failed(converter.LowerRegionAsComputation(region, func)))
1818     return tensorflow::errors::Internal(
1819         "failed to convert region to computation");
1820   return Status::OK();
1821 }
1822 
ConvertMlirHloToHlo(mlir::ModuleOp module,xla::HloProto * hlo_proto,bool use_tuple_args,bool return_tuple,const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,MlirToHloConversionOptions options)1823 Status ConvertMlirHloToHlo(
1824     mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args,
1825     bool return_tuple,
1826     const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
1827     MlirToHloConversionOptions options) {
1828   TF_RETURN_IF_ERROR(PrepareForExport(module));
1829   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
1830   xla::XlaBuilder module_builder("main");
1831   ConvertToHloModule converter(module, module_builder, use_tuple_args,
1832                                return_tuple, shape_representation_fn, options);
1833   if (failed(converter.Run())) return diag_handler.ConsumeStatus();
1834   auto hlo_module = converter.ConsumeMainProto();
1835   hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
1836   return Status::OK();
1837 }
1838 
BuildHloFromMlirHlo(mlir::Block & block,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,MlirToHloConversionOptions options)1839 Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
1840                            llvm::ArrayRef<xla::XlaOp> xla_params,
1841                            std::vector<xla::XlaOp>& returns,
1842                            MlirToHloConversionOptions options) {
1843   auto module = block.getParentOp()->getParentOfType<mlir::ModuleOp>();
1844   TF_RETURN_IF_ERROR(PrepareForExport(module));
1845   ConvertToHloModule converter(module, builder,
1846                                /*use_tuple_args=*/false, /*return_tuple=*/false,
1847                                /*shape_representation_fn=*/nullptr, options);
1848 
1849   ConvertToHloModule::ValueLoweringMap lowering;
1850   if (xla_params.size() != block.getArguments().size())
1851     return tensorflow::errors::Internal(
1852         "xla_params size != block arguments size");
1853   for (BlockArgument& arg : block.getArguments()) {
1854     auto num = arg.getArgNumber();
1855     lowering[arg] = xla_params[num];
1856   }
1857 
1858   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
1859   for (auto& inst : block) {
1860     if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
1861       returns.resize(inst.getNumOperands());
1862       for (OpOperand& ret : inst.getOpOperands()) {
1863         unsigned index = ret.getOperandNumber();
1864         xla::XlaOp operand;
1865         if (failed(GetXlaOp(ret.get(), lowering, &operand, &inst)))
1866           return diag_handler.ConsumeStatus();
1867         returns[index] = operand;
1868       }
1869     } else {
1870       xla::XlaOp return_value;
1871       if (failed(converter.Lower(&inst, /*is_entry_function=*/true,
1872                                  /*ret_shardings=*/{}, &builder, &lowering,
1873                                  &return_value)))
1874         return diag_handler.ConsumeStatus();
1875     }
1876   }
1877 
1878   return Status::OK();
1879 }
1880 
1881 }  // namespace mlir
1882