• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines helpers useful when creating or manipulating lhlo/hlo.
17 
18 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
19 
20 #include "mlir/IR/AffineMap.h"  // from @llvm-project
21 #include "mlir/IR/Attributes.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/core/platform/bfloat16.h"
27 #include "tensorflow/core/platform/logging.h"
28 
29 namespace xla {
30 namespace {
31 
32 using mlir::AffineMap;
33 using mlir::Builder;
34 using mlir::DenseElementsAttr;
35 using mlir::ShapedType;
36 using xla::LiteralBase;
37 using xla::StatusOr;
38 
39 template <typename CppType>
CreateDenseAttrFromLiteral(const ShapedType & type,const LiteralBase & literal)40 ::mlir::DenseElementsAttr CreateDenseAttrFromLiteral(
41     const ShapedType& type, const LiteralBase& literal) {
42   auto data_span = literal.data<CppType>();
43   return ::mlir::DenseElementsAttr::get(
44       type, llvm::makeArrayRef(data_span.data(), data_span.size()));
45 }
46 
GetPermutationIfAvailable(const Shape & shape,mlir::Builder builder)47 StatusOr<llvm::SmallVector<AffineMap, 1>> GetPermutationIfAvailable(
48     const Shape& shape, mlir::Builder builder) {
49   // N.B. IsMonotonicWithDim0Major ignores tiling, and I can't change it because
50   // some XLA code relies on it treating tiled layouts as equivalent to untiled
51   // layouts, so the check to rule out tiling has to come /before/ the
52   // early-return branch, or we'd miss tiled monotonic layouts.
53   if (!shape.layout().tiles().empty()) {
54     return tensorflow::errors::Internal("Tiled layouts are not yet supported");
55   }
56   if (!shape.has_layout() ||
57       LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {
58     return llvm::SmallVector<AffineMap, 1>{};
59   }
60   if (!shape.is_static()) {
61     return tensorflow::errors::Internal(
62         "Permutations for dynamic shapes are not yet supported");
63   }
64   int64_t accumulated_stride = 1;
65   llvm::SmallVector<int64_t, 4> strides(shape.rank(), 1);
66   for (int64_t dim : LayoutUtil::MinorToMajor(shape)) {
67     strides[dim] = accumulated_stride;
68     accumulated_stride *= shape.dimensions(dim);
69   }
70   if (accumulated_stride == 0) {
71     return llvm::SmallVector<AffineMap, 1>{};
72   }
73   return llvm::SmallVector<AffineMap, 1>{
74       makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())};
75 }
76 
77 template <typename T>
CopyDenseElementsBy(mlir::DenseElementsAttr data,std::vector<uint8> * output)78 void CopyDenseElementsBy(mlir::DenseElementsAttr data,
79                          std::vector<uint8>* output) {
80   output->resize(data.getNumElements() * sizeof(T));
81   int i = 0;
82   for (T element : data.getValues<T>()) {
83     std::memcpy(&(*output)[i], &element, sizeof(T));
84     i += sizeof(T);
85   }
86 }
87 
88 }  // namespace
89 
ConvertTensorShapeToMemRefType(const Shape & shape,mlir::Builder builder)90 StatusOr<mlir::MemRefType> ConvertTensorShapeToMemRefType(
91     const Shape& shape, mlir::Builder builder) {
92   auto element_type_or =
93       ConvertPrimitiveTypeToMLIRType(shape.element_type(), builder);
94   if (!element_type_or.ok()) return element_type_or.status();
95 
96   using mlir::MemRefType;
97   auto dimensions = shape.dimensions();
98   llvm::SmallVector<int64_t, 4> array(dimensions.begin(), dimensions.end());
99   auto permutation_or = GetPermutationIfAvailable(shape, builder);
100   if (!permutation_or.ok()) return permutation_or.status();
101   return MemRefType::get(array, element_type_or.ValueOrDie(),
102                          permutation_or.ValueOrDie());
103 }
104 
CreateDenseElementsAttrFromLiteral(const LiteralBase & literal,Builder builder)105 StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
106     const LiteralBase& literal, Builder builder) {
107   TF_ASSIGN_OR_RETURN(auto type,
108                       ConvertTensorShapeToType<mlir::RankedTensorType>(
109                           literal.shape(), builder));
110 
111   // TODO(hinsu): Support remaining XLA primitive types.
112   auto element_type = literal.shape().element_type();
113   switch (element_type) {
114     case PrimitiveType::PRED:
115       return CreateDenseAttrFromLiteral<bool>(type, literal);
116     case PrimitiveType::F16:
117       return CreateDenseAttrFromLiteral<half>(type, literal);
118     case PrimitiveType::BF16:
119       return CreateDenseAttrFromLiteral<bfloat16>(type, literal);
120     case PrimitiveType::F32:
121       return CreateDenseAttrFromLiteral<float>(type, literal);
122     case PrimitiveType::F64:
123       return CreateDenseAttrFromLiteral<double>(type, literal);
124     case PrimitiveType::S8:
125       return CreateDenseAttrFromLiteral<int8>(type, literal);
126     case PrimitiveType::S16:
127       return CreateDenseAttrFromLiteral<int16>(type, literal);
128     case PrimitiveType::S32:
129       return CreateDenseAttrFromLiteral<int32>(type, literal);
130     case PrimitiveType::S64:
131       return CreateDenseAttrFromLiteral<int64>(type, literal);
132     case PrimitiveType::U8:
133       return CreateDenseAttrFromLiteral<uint8>(type, literal);
134     case PrimitiveType::U16:
135       return CreateDenseAttrFromLiteral<uint16>(type, literal);
136     case PrimitiveType::U32:
137       return CreateDenseAttrFromLiteral<uint32>(type, literal);
138     case PrimitiveType::U64:
139       return CreateDenseAttrFromLiteral<uint64>(type, literal);
140     case PrimitiveType::C64:
141       return CreateDenseAttrFromLiteral<complex64>(type, literal);
142     case PrimitiveType::C128:
143       return CreateDenseAttrFromLiteral<complex128>(type, literal);
144     default:
145       return tensorflow::errors::Internal(
146           absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type)));
147   }
148 }
149 
CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,std::vector<uint8> * output)150 Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,
151                                         std::vector<uint8>* output) {
152   mlir::Type element_type = data.getType().getElementType();
153 
154   // TODO(hinsu): Support remaining XLA primitive types.
155   if (element_type.isInteger(1)) {
156     CopyDenseElementsBy<bool>(data, output);
157     return Status::OK();
158   }
159   if (element_type.isInteger(8)) {
160     CopyDenseElementsBy<uint8>(data, output);
161     return Status::OK();
162   }
163   if (element_type.isInteger(16)) {
164     CopyDenseElementsBy<uint16>(data, output);
165     return Status::OK();
166   }
167   if (element_type.isInteger(32)) {
168     CopyDenseElementsBy<uint32>(data, output);
169     return Status::OK();
170   }
171   if (element_type.isInteger(64)) {
172     CopyDenseElementsBy<uint64>(data, output);
173     return Status::OK();
174   }
175   if (element_type.isBF16()) {
176     CopyDenseElementsBy<bfloat16>(data, output);
177     return Status::OK();
178   }
179   if (element_type.isF16()) {
180     CopyDenseElementsBy<half>(data, output);
181     return Status::OK();
182   }
183   if (element_type.isF32()) {
184     CopyDenseElementsBy<float>(data, output);
185     return Status::OK();
186   }
187   if (element_type.isF64()) {
188     CopyDenseElementsBy<double>(data, output);
189     return Status::OK();
190   }
191   if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
192     if (complex_type.getElementType().isF32()) {
193       CopyDenseElementsBy<complex64>(data, output);
194       return Status::OK();
195     }
196     if (complex_type.getElementType().isF64()) {
197       CopyDenseElementsBy<complex128>(data, output);
198       return Status::OK();
199     }
200   }
201   return tensorflow::errors::Internal(
202       "Unsupported type in CopyDenseElementsDataToXlaFormat");
203 }
204 
GetElementTypeBytes(mlir::Type type)205 StatusOr<int> GetElementTypeBytes(mlir::Type type) {
206   if (type.isInteger(1)) {
207     return 1;
208   }
209   if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
210     TF_ASSIGN_OR_RETURN(int bytes,
211                         GetElementTypeBytes(complex_type.getElementType()));
212     return bytes * 2;
213   }
214   int width = type.getIntOrFloatBitWidth();
215   TF_RET_CHECK(width % 8 == 0);
216   return width / 8;
217 }
218 
CreateDenseIntElementsAttrFromVector(const llvm::ArrayRef<int64> vector,mlir::Builder builder,llvm::ArrayRef<int64_t> shape)219 mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
220     const llvm::ArrayRef<int64> vector, mlir::Builder builder,
221     llvm::ArrayRef<int64_t> shape) {
222   return mlir::DenseIntElementsAttr::get(
223       mlir::RankedTensorType::get(shape.empty() ? vector.size() : shape,
224                                   builder.getIntegerType(64)),
225       vector);
226 }
227 
ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,mlir::Builder builder)228 StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,
229                                                     mlir::Builder builder) {
230   switch (element_type) {
231     case PrimitiveType::PRED:
232       return builder.getI1Type();
233     case PrimitiveType::F16:
234       return builder.getF16Type();
235     case PrimitiveType::BF16:
236       return builder.getBF16Type();
237     case PrimitiveType::F32:
238       return builder.getF32Type();
239     case PrimitiveType::F64:
240       return builder.getF64Type();
241     case PrimitiveType::S8:
242       return builder.getIntegerType(8);
243     case PrimitiveType::S16:
244       return builder.getIntegerType(16);
245     case PrimitiveType::S32:
246       return builder.getIntegerType(32);
247     case PrimitiveType::S64:
248       return builder.getIntegerType(64);
249     case PrimitiveType::U8:
250       return builder.getIntegerType(8, /*isSigned=*/false);
251     case PrimitiveType::U16:
252       return builder.getIntegerType(16, /*isSigned=*/false);
253     case PrimitiveType::U32:
254       return builder.getIntegerType(32, /*isSigned=*/false);
255     case PrimitiveType::U64:
256       return builder.getIntegerType(64, /*isSigned=*/false);
257     case PrimitiveType::C64:
258       return mlir::ComplexType::get(builder.getF32Type());
259     case PrimitiveType::C128:
260       return mlir::ComplexType::get(builder.getF64Type());
261     // TODO(b/130356985): Support unsigned primitive types.
262     default:
263       return tensorflow::errors::Internal(
264           absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type)));
265   }
266 }
267 
CreateGatherDimensionNumbers(const GatherDimensionNumbers & input,mlir::Builder builder)268 mlir::mhlo::GatherDimensionNumbers CreateGatherDimensionNumbers(
269     const GatherDimensionNumbers& input, mlir::Builder builder) {
270   auto offset_dims = CreateDenseIntElementsAttrFromVector(
271       llvm::SmallVector<int64, 4>{input.offset_dims().begin(),
272                                   input.offset_dims().end()},
273       builder);
274   auto collapsed_slice_dims = CreateDenseIntElementsAttrFromVector(
275       llvm::SmallVector<int64, 4>{input.collapsed_slice_dims().begin(),
276                                   input.collapsed_slice_dims().end()},
277       builder);
278   auto start_index_map = CreateDenseIntElementsAttrFromVector(
279       llvm::SmallVector<int64, 4>{input.start_index_map().begin(),
280                                   input.start_index_map().end()},
281       builder);
282 
283   mlir::IntegerAttr index_vector_dim =
284       builder.getI64IntegerAttr(input.index_vector_dim());
285 
286   return mlir::mhlo::GatherDimensionNumbers::get(
287       offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim,
288       builder.getContext());
289 }
290 
MhloToHloOpcode(mlir::Operation * op)291 StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) {
292   using mlir::isa;
293 
294   if (isa<mlir::mhlo::ConstOp, mlir::lmhlo::ConstOp>(op)) {
295     return xla::HloOpcode::kConstant;
296   } else if (isa<mlir::mhlo::IotaOp, mlir::lmhlo::IotaOp>(op)) {
297     return xla::HloOpcode::kIota;
298   } else if (isa<mlir::mhlo::ConvertOp, mlir::lmhlo::ConvertOp>(op)) {
299     return xla::HloOpcode::kConvert;
300   } else if (isa<mlir::mhlo::AddOp, mlir::lmhlo::AddOp>(op)) {
301     return xla::HloOpcode::kAdd;
302   } else if (isa<mlir::mhlo::Atan2Op, mlir::lmhlo::Atan2Op>(op)) {
303     return xla::HloOpcode::kAtan2;
304   } else if (isa<mlir::mhlo::DivOp, mlir::lmhlo::DivOp>(op)) {
305     return xla::HloOpcode::kDivide;
306   } else if (isa<mlir::mhlo::MaxOp, mlir::lmhlo::MaxOp>(op)) {
307     return xla::HloOpcode::kMaximum;
308   } else if (isa<mlir::mhlo::MinOp, mlir::lmhlo::MinOp>(op)) {
309     return xla::HloOpcode::kMinimum;
310   } else if (isa<mlir::mhlo::MulOp, mlir::lmhlo::MulOp>(op)) {
311     return xla::HloOpcode::kMultiply;
312   } else if (isa<mlir::mhlo::PowOp, mlir::lmhlo::PowOp>(op)) {
313     return xla::HloOpcode::kPower;
314   } else if (isa<mlir::mhlo::RemOp, mlir::lmhlo::RemOp>(op)) {
315     return xla::HloOpcode::kRemainder;
316   } else if (isa<mlir::mhlo::ShiftLeftOp, mlir::lmhlo::ShiftLeftOp>(op)) {
317     return xla::HloOpcode::kShiftLeft;
318   } else if (isa<mlir::mhlo::ShiftRightArithmeticOp,
319                  mlir::lmhlo::ShiftRightArithmeticOp>(op)) {
320     return xla::HloOpcode::kShiftRightArithmetic;
321   } else if (isa<mlir::mhlo::ShiftRightLogicalOp,
322                  mlir::lmhlo::ShiftRightLogicalOp>(op)) {
323     return xla::HloOpcode::kShiftRightLogical;
324   } else if (isa<mlir::mhlo::SubOp, mlir::lmhlo::SubOp>(op)) {
325     return xla::HloOpcode::kSubtract;
326   } else if (isa<mlir::mhlo::XorOp, mlir::lmhlo::XorOp>(op)) {
327     return xla::HloOpcode::kXor;
328   } else if (isa<mlir::mhlo::InfeedOp, mlir::lmhlo::InfeedOp>(op)) {
329     return xla::HloOpcode::kInfeed;
330   } else if (isa<mlir::mhlo::OutfeedOp, mlir::lmhlo::OutfeedOp>(op)) {
331     return xla::HloOpcode::kOutfeed;
332   } else if (isa<mlir::mhlo::SendOp>(op)) {
333     return xla::HloOpcode::kSend;
334   } else if (isa<mlir::mhlo::RecvOp>(op)) {
335     return xla::HloOpcode::kRecv;
336   } else if (isa<mlir::mhlo::ReplicaIdOp, mlir::lmhlo::ReplicaIdOp>(op)) {
337     return xla::HloOpcode::kReplicaId;
338   } else if (isa<mlir::mhlo::AfterAllOp>(op)) {
339     return xla::HloOpcode::kAfterAll;
340   } else if (isa<mlir::mhlo::AllReduceOp, mlir::lmhlo::AllReduceOp>(op)) {
341     return xla::HloOpcode::kAllReduce;
342   } else if (isa<mlir::mhlo::AllToAllOp>(op)) {
343     return xla::HloOpcode::kAllToAll;
344   } else if (isa<mlir::mhlo::TupleOp>(op)) {
345     return xla::HloOpcode::kTuple;
346   } else if (isa<mlir::mhlo::BatchNormGradOp, mlir::lmhlo::BatchNormGradOp>(
347                  op)) {
348     return xla::HloOpcode::kBatchNormGrad;
349   } else if (isa<mlir::mhlo::BatchNormInferenceOp,
350                  mlir::lmhlo::BatchNormInferenceOp>(op)) {
351     return xla::HloOpcode::kBatchNormInference;
352   } else if (isa<mlir::mhlo::BatchNormTrainingOp,
353                  mlir::lmhlo::BatchNormTrainingOp>(op)) {
354     return xla::HloOpcode::kBatchNormTraining;
355   } else if (isa<mlir::mhlo::BitcastConvertOp, mlir::lmhlo::BitcastConvertOp>(
356                  op)) {
357     return xla::HloOpcode::kBitcastConvert;
358   } else if (isa<mlir::mhlo::BroadcastOp, mlir::lmhlo::BroadcastOp>(op)) {
359     return xla::HloOpcode::kBroadcast;
360   } else if (isa<mlir::mhlo::CholeskyOp, mlir::lmhlo::CholeskyOp>(op)) {
361     return xla::HloOpcode::kCholesky;
362   } else if (isa<mlir::mhlo::ClampOp, mlir::lmhlo::ClampOp>(op)) {
363     return xla::HloOpcode::kClamp;
364   } else if (isa<mlir::mhlo::ConcatenateOp, mlir::lmhlo::ConcatenateOp>(op)) {
365     return xla::HloOpcode::kConcatenate;
366   } else if (isa<mlir::mhlo::ConvOp, mlir::lmhlo::ConvOp>(op)) {
367     return xla::HloOpcode::kConvolution;
368   } else if (isa<mlir::mhlo::SortOp, mlir::lmhlo::SortOp>(op)) {
369     return xla::HloOpcode::kSort;
370   } else if (isa<mlir::mhlo::RngBitGeneratorOp>(op)) {
371     return xla::HloOpcode::kRngBitGenerator;
372   } else if (isa<mlir::mhlo::FusionOp, mlir::lmhlo::FusionOp>(op)) {
373     return xla::HloOpcode::kFusion;
374   } else if (isa<mlir::mhlo::BitcastOp>(op)) {
375     return xla::HloOpcode::kBitcast;
376   } else if (isa<mlir::mhlo::AbsOp, mlir::lmhlo::AbsOp>(op)) {
377     return xla::HloOpcode::kAbs;
378   } else if (isa<mlir::mhlo::CbrtOp, mlir::lmhlo::CbrtOp>(op)) {
379     return xla::HloOpcode::kCbrt;
380   } else if (isa<mlir::mhlo::CeilOp, mlir::lmhlo::CeilOp>(op)) {
381     return xla::HloOpcode::kCeil;
382   } else if (isa<mlir::mhlo::ClzOp, mlir::lmhlo::ClzOp>(op)) {
383     return xla::HloOpcode::kClz;
384   } else if (isa<mlir::mhlo::CosOp, mlir::lmhlo::CosOp>(op)) {
385     return xla::HloOpcode::kCos;
386   } else if (isa<mlir::mhlo::ExpOp, mlir::lmhlo::ExpOp>(op)) {
387     return xla::HloOpcode::kExp;
388   } else if (isa<mlir::mhlo::Expm1Op, mlir::lmhlo::Expm1Op>(op)) {
389     return xla::HloOpcode::kExpm1;
390   } else if (isa<mlir::mhlo::FloorOp, mlir::lmhlo::FloorOp>(op)) {
391     return xla::HloOpcode::kFloor;
392   } else if (isa<mlir::mhlo::ImagOp, mlir::lmhlo::ImagOp>(op)) {
393     return xla::HloOpcode::kImag;
394   } else if (isa<mlir::mhlo::IsFiniteOp, mlir::lmhlo::IsFiniteOp>(op)) {
395     return xla::HloOpcode::kIsFinite;
396   } else if (isa<mlir::mhlo::LogOp, mlir::lmhlo::LogOp>(op)) {
397     return xla::HloOpcode::kLog;
398   } else if (isa<mlir::mhlo::Log1pOp, mlir::lmhlo::Log1pOp>(op)) {
399     return xla::HloOpcode::kLog1p;
400   } else if (isa<mlir::mhlo::LogisticOp>(op)) {
401     return xla::HloOpcode::kLogistic;
402   } else if (isa<mlir::mhlo::NotOp, mlir::lmhlo::NotOp>(op)) {
403     return xla::HloOpcode::kNot;
404   } else if (isa<mlir::mhlo::NegOp, mlir::lmhlo::NegOp>(op)) {
405     return xla::HloOpcode::kNegate;
406   } else if (isa<mlir::mhlo::PopulationCountOp, mlir::lmhlo::PopulationCountOp>(
407                  op)) {
408     return xla::HloOpcode::kPopulationCount;
409   } else if (isa<mlir::mhlo::RealOp, mlir::lmhlo::RealOp>(op)) {
410     return xla::HloOpcode::kReal;
411   } else if (isa<mlir::mhlo::RoundOp, mlir::lmhlo::RoundOp>(op)) {
412     return xla::HloOpcode::kRoundNearestAfz;
413   } else if (isa<mlir::mhlo::RsqrtOp, mlir::lmhlo::RsqrtOp>(op)) {
414     return xla::HloOpcode::kRsqrt;
415   } else if (isa<mlir::mhlo::SignOp, mlir::lmhlo::SignOp>(op)) {
416     return xla::HloOpcode::kSign;
417   } else if (isa<mlir::mhlo::SinOp, mlir::lmhlo::SinOp>(op)) {
418     return xla::HloOpcode::kSin;
419   } else if (isa<mlir::mhlo::SqrtOp, mlir::lmhlo::SqrtOp>(op)) {
420     return xla::HloOpcode::kSqrt;
421   } else if (isa<mlir::mhlo::TanhOp, mlir::lmhlo::TanhOp>(op)) {
422     return xla::HloOpcode::kTanh;
423   } else if (isa<mlir::mhlo::ComplexOp, mlir::lmhlo::ComplexOp>(op)) {
424     return xla::HloOpcode::kComplex;
425   } else if (isa<mlir::mhlo::AndOp, mlir::lmhlo::AndOp>(op)) {
426     return xla::HloOpcode::kAnd;
427   } else if (isa<mlir::mhlo::OrOp, mlir::lmhlo::OrOp>(op)) {
428     return xla::HloOpcode::kOr;
429   } else if (isa<mlir::mhlo::WhileOp, mlir::lmhlo::WhileOp>(op)) {
430     return xla::HloOpcode::kWhile;
431   } else if (isa<mlir::mhlo::ReduceOp, mlir::lmhlo::ReduceOp>(op)) {
432     return xla::HloOpcode::kReduce;
433   } else if (isa<mlir::mhlo::GetTupleElementOp>(op)) {
434     return xla::HloOpcode::kGetTupleElement;
435   } else if (isa<mlir::mhlo::CompareOp, mlir::lmhlo::CompareOp>(op)) {
436     return xla::HloOpcode::kCompare;
437   } else if (isa<mlir::mhlo::SliceOp, mlir::lmhlo::SliceOp>(op)) {
438     return xla::HloOpcode::kSlice;
439   } else if (isa<mlir::mhlo::DynamicSliceOp, mlir::lmhlo::DynamicSliceOp>(op)) {
440     return xla::HloOpcode::kDynamicSlice;
441   } else if (isa<mlir::mhlo::DynamicUpdateSliceOp,
442                  mlir::lmhlo::DynamicUpdateSliceOp>(op)) {
443     return xla::HloOpcode::kDynamicUpdateSlice;
444   } else if (isa<mlir::mhlo::CollectivePermuteOp,
445                  mlir::lmhlo::CollectivePermuteOp>(op)) {
446     return xla::HloOpcode::kCollectivePermute;
447   } else if (isa<mlir::mhlo::CopyOp, mlir::lmhlo::CopyOp>(op)) {
448     return xla::HloOpcode::kCopy;
449   } else if (isa<mlir::mhlo::CustomCallOp, mlir::lmhlo::CustomCallOp>(op)) {
450     return xla::HloOpcode::kCustomCall;
451   } else if (isa<mlir::mhlo::DotOp, mlir::lmhlo::DotOp>(op)) {
452     return xla::HloOpcode::kDot;
453   } else if (isa<mlir::mhlo::FftOp, mlir::lmhlo::FftOp>(op)) {
454     return xla::HloOpcode::kFft;
455   } else if (isa<mlir::mhlo::GatherOp, mlir::lmhlo::GatherOp>(op)) {
456     return xla::HloOpcode::kGather;
457   } else if (isa<mlir::mhlo::GetDimensionSizeOp>(op)) {
458     return xla::HloOpcode::kGetDimensionSize;
459   } else if (isa<mlir::mhlo::MapOp, mlir::lmhlo::MapOp>(op)) {
460     return xla::HloOpcode::kMap;
461   } else if (isa<mlir::mhlo::ReshapeOp, mlir::lmhlo::ReshapeOp>(op)) {
462     return xla::HloOpcode::kReshape;
463   } else if (isa<mlir::mhlo::DynamicReshapeOp>(op)) {
464     return xla::HloOpcode::kDynamicReshape;
465   } else if (isa<mlir::mhlo::ScatterOp, mlir::lmhlo::ScatterOp>(op)) {
466     return xla::HloOpcode::kScatter;
467   } else if (isa<mlir::mhlo::SelectOp, mlir::lmhlo::SelectOp>(op)) {
468     return xla::HloOpcode::kSelect;
469   } else if (isa<mlir::mhlo::SelectAndScatterOp,
470                  mlir::lmhlo::SelectAndScatterOp>(op)) {
471     return xla::HloOpcode::kSelectAndScatter;
472   } else if (isa<mlir::mhlo::SetDimensionSizeOp>(op)) {
473     return xla::HloOpcode::kSetDimensionSize;
474   } else if (isa<mlir::mhlo::ReverseOp, mlir::lmhlo::ReverseOp>(op)) {
475     return xla::HloOpcode::kReverse;
476   } else if (isa<mlir::mhlo::PadOp, mlir::lmhlo::PadOp>(op)) {
477     return xla::HloOpcode::kPad;
478   } else if (isa<mlir::mhlo::TraceOp>(op)) {
479     return xla::HloOpcode::kTrace;
480   } else if (isa<mlir::mhlo::TransposeOp, mlir::lmhlo::TransposeOp>(op)) {
481     return xla::HloOpcode::kTranspose;
482   } else if (isa<mlir::mhlo::TriangularSolveOp, mlir::lmhlo::TriangularSolveOp>(
483                  op)) {
484     return xla::HloOpcode::kTriangularSolve;
485   } else if (isa<mlir::mhlo::ReduceWindowOp, mlir::lmhlo::ReduceWindowOp>(op)) {
486     return xla::HloOpcode::kReduceWindow;
487   } else if (isa<mlir::mhlo::ReducePrecisionOp, mlir::lmhlo::ReducePrecisionOp>(
488                  op)) {
489     return xla::HloOpcode::kReducePrecision;
490   } else if (isa<mlir::mhlo::DotGeneralOp>(op)) {
491     return xla::HloOpcode::kDot;
492   } else if (isa<mlir::mhlo::BroadcastInDimOp, mlir::lmhlo::BroadcastInDimOp>(
493                  op)) {
494     return xla::HloOpcode::kBroadcast;
495   } else {
496     std::string s;
497     {
498       llvm::raw_string_ostream os(s);
499       op->print(os);
500     }
501     return tensorflow::errors::Unimplemented(
502         "Unimplemented MHLO -> HloOpcode: ", s);
503   }
504 }
505 
506 }  // namespace xla
507