1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file defines helpers useful when creating or manipulating lhlo/hlo.
17
18 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
19
20 #include "mlir/IR/AffineMap.h" // from @llvm-project
21 #include "mlir/IR/Attributes.h" // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/core/platform/bfloat16.h"
27 #include "tensorflow/core/platform/logging.h"
28
29 namespace xla {
30 namespace {
31
32 using mlir::AffineMap;
33 using mlir::Builder;
34 using mlir::DenseElementsAttr;
35 using mlir::ShapedType;
36 using xla::LiteralBase;
37 using xla::StatusOr;
38
39 template <typename CppType>
CreateDenseAttrFromLiteral(const ShapedType & type,const LiteralBase & literal)40 ::mlir::DenseElementsAttr CreateDenseAttrFromLiteral(
41 const ShapedType& type, const LiteralBase& literal) {
42 auto data_span = literal.data<CppType>();
43 return ::mlir::DenseElementsAttr::get(
44 type, llvm::makeArrayRef(data_span.data(), data_span.size()));
45 }
46
GetPermutationIfAvailable(const Shape & shape,mlir::Builder builder)47 StatusOr<llvm::SmallVector<AffineMap, 1>> GetPermutationIfAvailable(
48 const Shape& shape, mlir::Builder builder) {
49 if (!shape.has_layout() ||
50 LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {
51 return llvm::SmallVector<AffineMap, 1>{};
52 }
53 if (!shape.is_static()) {
54 return tensorflow::errors::Internal(
55 "Permutations for dynamic shapes are not yet supported");
56 }
57 int64_t accumulated_stride = 1;
58 llvm::SmallVector<int64_t, 4> strides(shape.rank(), 1);
59 for (int64 dim : LayoutUtil::MinorToMajor(shape)) {
60 strides[dim] = accumulated_stride;
61 accumulated_stride *= shape.dimensions(dim);
62 }
63 if (accumulated_stride == 0) {
64 return llvm::SmallVector<AffineMap, 1>{};
65 }
66 return llvm::SmallVector<AffineMap, 1>{
67 makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())};
68 }
69
70 template <typename T>
CopyDenseElementsBy(mlir::DenseElementsAttr data,std::vector<uint8> * output)71 void CopyDenseElementsBy(mlir::DenseElementsAttr data,
72 std::vector<uint8>* output) {
73 output->resize(data.getNumElements() * sizeof(T));
74 int i = 0;
75 for (T element : data.getValues<T>()) {
76 std::memcpy(&(*output)[i], &element, sizeof(T));
77 i += sizeof(T);
78 }
79 }
80
81 } // namespace
82
ConvertTensorShapeToMemRefType(const Shape & shape,mlir::Builder builder)83 StatusOr<mlir::MemRefType> ConvertTensorShapeToMemRefType(
84 const Shape& shape, mlir::Builder builder) {
85 auto element_type_or =
86 ConvertPrimitiveTypeToMLIRType(shape.element_type(), builder);
87 if (!element_type_or.ok()) return element_type_or.status();
88
89 using mlir::MemRefType;
90 auto dimensions = shape.dimensions();
91 llvm::SmallVector<int64_t, 4> array(dimensions.begin(), dimensions.end());
92 auto permutation_or = GetPermutationIfAvailable(shape, builder);
93 if (!permutation_or.ok()) return permutation_or.status();
94 return MemRefType::get(array, element_type_or.ValueOrDie(),
95 permutation_or.ValueOrDie());
96 }
97
CreateDenseElementsAttrFromLiteral(const LiteralBase & literal,Builder builder)98 StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
99 const LiteralBase& literal, Builder builder) {
100 TF_ASSIGN_OR_RETURN(auto type,
101 ConvertTensorShapeToType<mlir::RankedTensorType>(
102 literal.shape(), builder));
103
104 // TODO(hinsu): Support remaining XLA primitive types.
105 auto element_type = literal.shape().element_type();
106 switch (element_type) {
107 case PrimitiveType::PRED:
108 return CreateDenseAttrFromLiteral<bool>(type, literal);
109 case PrimitiveType::F16:
110 return CreateDenseAttrFromLiteral<half>(type, literal);
111 case PrimitiveType::BF16:
112 return CreateDenseAttrFromLiteral<bfloat16>(type, literal);
113 case PrimitiveType::F32:
114 return CreateDenseAttrFromLiteral<float>(type, literal);
115 case PrimitiveType::F64:
116 return CreateDenseAttrFromLiteral<double>(type, literal);
117 case PrimitiveType::S8:
118 return CreateDenseAttrFromLiteral<int8>(type, literal);
119 case PrimitiveType::S16:
120 return CreateDenseAttrFromLiteral<int16>(type, literal);
121 case PrimitiveType::S32:
122 return CreateDenseAttrFromLiteral<int32>(type, literal);
123 case PrimitiveType::S64:
124 return CreateDenseAttrFromLiteral<int64>(type, literal);
125 case PrimitiveType::U8:
126 return CreateDenseAttrFromLiteral<uint8>(type, literal);
127 case PrimitiveType::U16:
128 return CreateDenseAttrFromLiteral<uint16>(type, literal);
129 case PrimitiveType::U32:
130 return CreateDenseAttrFromLiteral<uint32>(type, literal);
131 case PrimitiveType::U64:
132 return CreateDenseAttrFromLiteral<uint64>(type, literal);
133 case PrimitiveType::C64:
134 return CreateDenseAttrFromLiteral<complex64>(type, literal);
135 case PrimitiveType::C128:
136 return CreateDenseAttrFromLiteral<complex128>(type, literal);
137 default:
138 return tensorflow::errors::Internal(
139 absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type)));
140 }
141 }
142
CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,std::vector<uint8> * output)143 Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,
144 std::vector<uint8>* output) {
145 mlir::Type element_type = data.getType().getElementType();
146
147 // TODO(hinsu): Support remaining XLA primitive types.
148 if (element_type.isInteger(1)) {
149 CopyDenseElementsBy<bool>(data, output);
150 return Status::OK();
151 }
152 if (element_type.isInteger(8)) {
153 CopyDenseElementsBy<uint8>(data, output);
154 return Status::OK();
155 }
156 if (element_type.isInteger(16)) {
157 CopyDenseElementsBy<uint16>(data, output);
158 return Status::OK();
159 }
160 if (element_type.isInteger(32)) {
161 CopyDenseElementsBy<uint32>(data, output);
162 return Status::OK();
163 }
164 if (element_type.isInteger(64)) {
165 CopyDenseElementsBy<uint64>(data, output);
166 return Status::OK();
167 }
168 if (element_type.isBF16()) {
169 CopyDenseElementsBy<bfloat16>(data, output);
170 return Status::OK();
171 }
172 if (element_type.isF16()) {
173 CopyDenseElementsBy<half>(data, output);
174 return Status::OK();
175 }
176 if (element_type.isF32()) {
177 CopyDenseElementsBy<float>(data, output);
178 return Status::OK();
179 }
180 if (element_type.isF64()) {
181 CopyDenseElementsBy<double>(data, output);
182 return Status::OK();
183 }
184 if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
185 if (complex_type.getElementType().isF32()) {
186 CopyDenseElementsBy<complex64>(data, output);
187 return Status::OK();
188 }
189 if (complex_type.getElementType().isF64()) {
190 CopyDenseElementsBy<complex128>(data, output);
191 return Status::OK();
192 }
193 }
194 return tensorflow::errors::Internal(
195 "Unsupported type in CopyDenseElementsDataToXlaFormat");
196 }
197
GetElementTypeBytes(mlir::Type type)198 StatusOr<int> GetElementTypeBytes(mlir::Type type) {
199 if (type.isInteger(1)) {
200 return 1;
201 }
202 if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
203 TF_ASSIGN_OR_RETURN(int bytes,
204 GetElementTypeBytes(complex_type.getElementType()));
205 return bytes * 2;
206 }
207 int width = type.getIntOrFloatBitWidth();
208 TF_RET_CHECK(width % 8 == 0);
209 return width / 8;
210 }
211
CreateDenseIntElementsAttrFromVector(const llvm::ArrayRef<int64> vector,mlir::Builder builder,llvm::ArrayRef<int64_t> shape)212 mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
213 const llvm::ArrayRef<int64> vector, mlir::Builder builder,
214 llvm::ArrayRef<int64_t> shape) {
215 return mlir::DenseIntElementsAttr::get(
216 mlir::RankedTensorType::get(shape.empty() ? vector.size() : shape,
217 builder.getIntegerType(64)),
218 vector);
219 }
220
ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,mlir::Builder builder)221 StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,
222 mlir::Builder builder) {
223 switch (element_type) {
224 case PrimitiveType::PRED:
225 return builder.getI1Type();
226 case PrimitiveType::F16:
227 return builder.getF16Type();
228 case PrimitiveType::BF16:
229 return builder.getBF16Type();
230 case PrimitiveType::F32:
231 return builder.getF32Type();
232 case PrimitiveType::F64:
233 return builder.getF64Type();
234 case PrimitiveType::S8:
235 return builder.getIntegerType(8);
236 case PrimitiveType::S16:
237 return builder.getIntegerType(16);
238 case PrimitiveType::S32:
239 return builder.getIntegerType(32);
240 case PrimitiveType::S64:
241 return builder.getIntegerType(64);
242 case PrimitiveType::U8:
243 return builder.getIntegerType(8, /*isSigned=*/false);
244 case PrimitiveType::U16:
245 return builder.getIntegerType(16, /*isSigned=*/false);
246 case PrimitiveType::U32:
247 return builder.getIntegerType(32, /*isSigned=*/false);
248 case PrimitiveType::U64:
249 return builder.getIntegerType(64, /*isSigned=*/false);
250 case PrimitiveType::C64:
251 return mlir::ComplexType::get(builder.getF32Type());
252 case PrimitiveType::C128:
253 return mlir::ComplexType::get(builder.getF64Type());
254 // TODO(b/130356985): Support unsigned primitive types.
255 default:
256 return tensorflow::errors::Internal(
257 absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type)));
258 }
259 }
260
CreateGatherDimensionNumbers(const GatherDimensionNumbers & input,mlir::Builder builder)261 mlir::mhlo::GatherDimensionNumbers CreateGatherDimensionNumbers(
262 const GatherDimensionNumbers& input, mlir::Builder builder) {
263 auto offset_dims = CreateDenseIntElementsAttrFromVector(
264 llvm::SmallVector<int64, 4>{input.offset_dims().begin(),
265 input.offset_dims().end()},
266 builder);
267 auto collapsed_slice_dims = CreateDenseIntElementsAttrFromVector(
268 llvm::SmallVector<int64, 4>{input.collapsed_slice_dims().begin(),
269 input.collapsed_slice_dims().end()},
270 builder);
271 auto start_index_map = CreateDenseIntElementsAttrFromVector(
272 llvm::SmallVector<int64, 4>{input.start_index_map().begin(),
273 input.start_index_map().end()},
274 builder);
275
276 mlir::IntegerAttr index_vector_dim =
277 builder.getI64IntegerAttr(input.index_vector_dim());
278
279 return mlir::mhlo::GatherDimensionNumbers::get(
280 offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim,
281 builder.getContext());
282 }
283
MhloToHloOpcode(mlir::Operation * op)284 StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) {
285 using mlir::isa;
286
287 if (isa<mlir::mhlo::ConstOp, mlir::lmhlo::ConstOp>(op)) {
288 return xla::HloOpcode::kConstant;
289 } else if (isa<mlir::mhlo::IotaOp, mlir::lmhlo::IotaOp>(op)) {
290 return xla::HloOpcode::kIota;
291 } else if (isa<mlir::mhlo::ConvertOp, mlir::lmhlo::ConvertOp>(op)) {
292 return xla::HloOpcode::kConvert;
293 } else if (isa<mlir::mhlo::AddOp, mlir::lmhlo::AddOp>(op)) {
294 return xla::HloOpcode::kAdd;
295 } else if (isa<mlir::mhlo::Atan2Op, mlir::lmhlo::Atan2Op>(op)) {
296 return xla::HloOpcode::kAtan2;
297 } else if (isa<mlir::mhlo::DivOp, mlir::lmhlo::DivOp>(op)) {
298 return xla::HloOpcode::kDivide;
299 } else if (isa<mlir::mhlo::MaxOp, mlir::lmhlo::MaxOp>(op)) {
300 return xla::HloOpcode::kMaximum;
301 } else if (isa<mlir::mhlo::MinOp, mlir::lmhlo::MinOp>(op)) {
302 return xla::HloOpcode::kMinimum;
303 } else if (isa<mlir::mhlo::MulOp, mlir::lmhlo::MulOp>(op)) {
304 return xla::HloOpcode::kMultiply;
305 } else if (isa<mlir::mhlo::PowOp, mlir::lmhlo::PowOp>(op)) {
306 return xla::HloOpcode::kPower;
307 } else if (isa<mlir::mhlo::RemOp, mlir::lmhlo::RemOp>(op)) {
308 return xla::HloOpcode::kRemainder;
309 } else if (isa<mlir::mhlo::ShiftLeftOp, mlir::lmhlo::ShiftLeftOp>(op)) {
310 return xla::HloOpcode::kShiftLeft;
311 } else if (isa<mlir::mhlo::ShiftRightArithmeticOp,
312 mlir::lmhlo::ShiftRightArithmeticOp>(op)) {
313 return xla::HloOpcode::kShiftRightArithmetic;
314 } else if (isa<mlir::mhlo::ShiftRightLogicalOp,
315 mlir::lmhlo::ShiftRightLogicalOp>(op)) {
316 return xla::HloOpcode::kShiftRightLogical;
317 } else if (isa<mlir::mhlo::SubOp, mlir::lmhlo::SubOp>(op)) {
318 return xla::HloOpcode::kSubtract;
319 } else if (isa<mlir::mhlo::XorOp, mlir::lmhlo::XorOp>(op)) {
320 return xla::HloOpcode::kXor;
321 } else if (isa<mlir::mhlo::InfeedOp, mlir::lmhlo::InfeedOp>(op)) {
322 return xla::HloOpcode::kInfeed;
323 } else if (isa<mlir::mhlo::OutfeedOp, mlir::lmhlo::OutfeedOp>(op)) {
324 return xla::HloOpcode::kOutfeed;
325 } else if (isa<mlir::mhlo::SendOp>(op)) {
326 return xla::HloOpcode::kSend;
327 } else if (isa<mlir::mhlo::RecvOp>(op)) {
328 return xla::HloOpcode::kRecv;
329 } else if (isa<mlir::mhlo::ReplicaIdOp, mlir::lmhlo::ReplicaIdOp>(op)) {
330 return xla::HloOpcode::kReplicaId;
331 } else if (isa<mlir::mhlo::AfterAllOp>(op)) {
332 return xla::HloOpcode::kAfterAll;
333 } else if (isa<mlir::mhlo::AllReduceOp, mlir::lmhlo::AllReduceOp>(op)) {
334 return xla::HloOpcode::kAllReduce;
335 } else if (isa<mlir::mhlo::AllToAllOp>(op)) {
336 return xla::HloOpcode::kAllToAll;
337 } else if (isa<mlir::mhlo::TupleOp>(op)) {
338 return xla::HloOpcode::kTuple;
339 } else if (isa<mlir::mhlo::BatchNormGradOp, mlir::lmhlo::BatchNormGradOp>(
340 op)) {
341 return xla::HloOpcode::kBatchNormGrad;
342 } else if (isa<mlir::mhlo::BatchNormInferenceOp,
343 mlir::lmhlo::BatchNormInferenceOp>(op)) {
344 return xla::HloOpcode::kBatchNormInference;
345 } else if (isa<mlir::mhlo::BatchNormTrainingOp,
346 mlir::lmhlo::BatchNormTrainingOp>(op)) {
347 return xla::HloOpcode::kBatchNormTraining;
348 } else if (isa<mlir::mhlo::BitcastConvertOp, mlir::lmhlo::BitcastConvertOp>(
349 op)) {
350 return xla::HloOpcode::kBitcastConvert;
351 } else if (isa<mlir::mhlo::BroadcastOp, mlir::lmhlo::BroadcastOp>(op)) {
352 return xla::HloOpcode::kBroadcast;
353 } else if (isa<mlir::mhlo::CholeskyOp, mlir::lmhlo::CholeskyOp>(op)) {
354 return xla::HloOpcode::kCholesky;
355 } else if (isa<mlir::mhlo::ClampOp, mlir::lmhlo::ClampOp>(op)) {
356 return xla::HloOpcode::kClamp;
357 } else if (isa<mlir::mhlo::ConcatenateOp, mlir::lmhlo::ConcatenateOp>(op)) {
358 return xla::HloOpcode::kConcatenate;
359 } else if (isa<mlir::mhlo::ConvOp, mlir::lmhlo::ConvOp>(op)) {
360 return xla::HloOpcode::kConvolution;
361 } else if (isa<mlir::mhlo::SortOp, mlir::lmhlo::SortOp>(op)) {
362 return xla::HloOpcode::kSort;
363 } else if (isa<mlir::mhlo::RngBitGeneratorOp>(op)) {
364 return xla::HloOpcode::kRngBitGenerator;
365 } else if (isa<mlir::mhlo::FusionOp, mlir::lmhlo::FusionOp>(op)) {
366 return xla::HloOpcode::kFusion;
367 } else if (isa<mlir::mhlo::BitcastOp>(op)) {
368 return xla::HloOpcode::kBitcast;
369 } else if (isa<mlir::mhlo::AbsOp, mlir::lmhlo::AbsOp>(op)) {
370 return xla::HloOpcode::kAbs;
371 } else if (isa<mlir::mhlo::CbrtOp, mlir::lmhlo::CbrtOp>(op)) {
372 return xla::HloOpcode::kCbrt;
373 } else if (isa<mlir::mhlo::CeilOp, mlir::lmhlo::CeilOp>(op)) {
374 return xla::HloOpcode::kCeil;
375 } else if (isa<mlir::mhlo::ClzOp, mlir::lmhlo::ClzOp>(op)) {
376 return xla::HloOpcode::kClz;
377 } else if (isa<mlir::mhlo::CosOp, mlir::lmhlo::CosOp>(op)) {
378 return xla::HloOpcode::kCos;
379 } else if (isa<mlir::mhlo::ExpOp, mlir::lmhlo::ExpOp>(op)) {
380 return xla::HloOpcode::kExp;
381 } else if (isa<mlir::mhlo::Expm1Op, mlir::lmhlo::Expm1Op>(op)) {
382 return xla::HloOpcode::kExpm1;
383 } else if (isa<mlir::mhlo::FloorOp, mlir::lmhlo::FloorOp>(op)) {
384 return xla::HloOpcode::kFloor;
385 } else if (isa<mlir::mhlo::ImagOp, mlir::lmhlo::ImagOp>(op)) {
386 return xla::HloOpcode::kImag;
387 } else if (isa<mlir::mhlo::IsFiniteOp, mlir::lmhlo::IsFiniteOp>(op)) {
388 return xla::HloOpcode::kIsFinite;
389 } else if (isa<mlir::mhlo::LogOp, mlir::lmhlo::LogOp>(op)) {
390 return xla::HloOpcode::kLog;
391 } else if (isa<mlir::mhlo::Log1pOp, mlir::lmhlo::Log1pOp>(op)) {
392 return xla::HloOpcode::kLog1p;
393 } else if (isa<mlir::mhlo::LogisticOp>(op)) {
394 return xla::HloOpcode::kLogistic;
395 } else if (isa<mlir::mhlo::NotOp, mlir::lmhlo::NotOp>(op)) {
396 return xla::HloOpcode::kNot;
397 } else if (isa<mlir::mhlo::NegOp, mlir::lmhlo::NegOp>(op)) {
398 return xla::HloOpcode::kNegate;
399 } else if (isa<mlir::mhlo::PopulationCountOp, mlir::lmhlo::PopulationCountOp>(
400 op)) {
401 return xla::HloOpcode::kPopulationCount;
402 } else if (isa<mlir::mhlo::RealOp, mlir::lmhlo::RealOp>(op)) {
403 return xla::HloOpcode::kReal;
404 } else if (isa<mlir::mhlo::RoundOp, mlir::lmhlo::RoundOp>(op)) {
405 return xla::HloOpcode::kRoundNearestAfz;
406 } else if (isa<mlir::mhlo::RsqrtOp, mlir::lmhlo::RsqrtOp>(op)) {
407 return xla::HloOpcode::kRsqrt;
408 } else if (isa<mlir::mhlo::SignOp, mlir::lmhlo::SignOp>(op)) {
409 return xla::HloOpcode::kSign;
410 } else if (isa<mlir::mhlo::SinOp, mlir::lmhlo::SinOp>(op)) {
411 return xla::HloOpcode::kSin;
412 } else if (isa<mlir::mhlo::SqrtOp, mlir::lmhlo::SqrtOp>(op)) {
413 return xla::HloOpcode::kSqrt;
414 } else if (isa<mlir::mhlo::TanhOp, mlir::lmhlo::TanhOp>(op)) {
415 return xla::HloOpcode::kTanh;
416 } else if (isa<mlir::mhlo::ComplexOp, mlir::lmhlo::ComplexOp>(op)) {
417 return xla::HloOpcode::kComplex;
418 } else if (isa<mlir::mhlo::AndOp, mlir::lmhlo::AndOp>(op)) {
419 return xla::HloOpcode::kAnd;
420 } else if (isa<mlir::mhlo::OrOp, mlir::lmhlo::OrOp>(op)) {
421 return xla::HloOpcode::kOr;
422 } else if (isa<mlir::mhlo::WhileOp, mlir::lmhlo::WhileOp>(op)) {
423 return xla::HloOpcode::kWhile;
424 } else if (isa<mlir::mhlo::ReduceOp, mlir::lmhlo::ReduceOp>(op)) {
425 return xla::HloOpcode::kReduce;
426 } else if (isa<mlir::mhlo::GetTupleElementOp>(op)) {
427 return xla::HloOpcode::kGetTupleElement;
428 } else if (isa<mlir::mhlo::CompareOp, mlir::lmhlo::CompareOp>(op)) {
429 return xla::HloOpcode::kCompare;
430 } else if (isa<mlir::mhlo::SliceOp, mlir::lmhlo::SliceOp>(op)) {
431 return xla::HloOpcode::kSlice;
432 } else if (isa<mlir::mhlo::DynamicSliceOp, mlir::lmhlo::DynamicSliceOp>(op)) {
433 return xla::HloOpcode::kDynamicSlice;
434 } else if (isa<mlir::mhlo::DynamicUpdateSliceOp,
435 mlir::lmhlo::DynamicUpdateSliceOp>(op)) {
436 return xla::HloOpcode::kDynamicUpdateSlice;
437 } else if (isa<mlir::mhlo::CollectivePermuteOp,
438 mlir::lmhlo::CollectivePermuteOp>(op)) {
439 return xla::HloOpcode::kCollectivePermute;
440 } else if (isa<mlir::mhlo::CopyOp, mlir::lmhlo::CopyOp>(op)) {
441 return xla::HloOpcode::kCopy;
442 } else if (isa<mlir::mhlo::CustomCallOp, mlir::lmhlo::CustomCallOp>(op)) {
443 return xla::HloOpcode::kCustomCall;
444 } else if (isa<mlir::mhlo::DotOp, mlir::lmhlo::DotOp>(op)) {
445 return xla::HloOpcode::kDot;
446 } else if (isa<mlir::mhlo::FftOp, mlir::lmhlo::FftOp>(op)) {
447 return xla::HloOpcode::kFft;
448 } else if (isa<mlir::mhlo::GatherOp, mlir::lmhlo::GatherOp>(op)) {
449 return xla::HloOpcode::kGather;
450 } else if (isa<mlir::mhlo::GetDimensionSizeOp>(op)) {
451 return xla::HloOpcode::kGetDimensionSize;
452 } else if (isa<mlir::mhlo::MapOp, mlir::lmhlo::MapOp>(op)) {
453 return xla::HloOpcode::kMap;
454 } else if (isa<mlir::mhlo::ReshapeOp, mlir::lmhlo::ReshapeOp>(op)) {
455 return xla::HloOpcode::kReshape;
456 } else if (isa<mlir::mhlo::DynamicReshapeOp>(op)) {
457 return xla::HloOpcode::kDynamicReshape;
458 } else if (isa<mlir::mhlo::ScatterOp, mlir::lmhlo::ScatterOp>(op)) {
459 return xla::HloOpcode::kScatter;
460 } else if (isa<mlir::mhlo::SelectOp, mlir::lmhlo::SelectOp>(op)) {
461 return xla::HloOpcode::kSelect;
462 } else if (isa<mlir::mhlo::SelectAndScatterOp,
463 mlir::lmhlo::SelectAndScatterOp>(op)) {
464 return xla::HloOpcode::kSelectAndScatter;
465 } else if (isa<mlir::mhlo::SetDimensionSizeOp>(op)) {
466 return xla::HloOpcode::kSetDimensionSize;
467 } else if (isa<mlir::mhlo::ReverseOp, mlir::lmhlo::ReverseOp>(op)) {
468 return xla::HloOpcode::kReverse;
469 } else if (isa<mlir::mhlo::PadOp, mlir::lmhlo::PadOp>(op)) {
470 return xla::HloOpcode::kPad;
471 } else if (isa<mlir::mhlo::TraceOp>(op)) {
472 return xla::HloOpcode::kTrace;
473 } else if (isa<mlir::mhlo::TransposeOp, mlir::lmhlo::TransposeOp>(op)) {
474 return xla::HloOpcode::kTranspose;
475 } else if (isa<mlir::mhlo::TriangularSolveOp, mlir::lmhlo::TriangularSolveOp>(
476 op)) {
477 return xla::HloOpcode::kTriangularSolve;
478 } else if (isa<mlir::mhlo::ReduceWindowOp, mlir::lmhlo::ReduceWindowOp>(op)) {
479 return xla::HloOpcode::kReduceWindow;
480 } else if (isa<mlir::mhlo::ReducePrecisionOp, mlir::lmhlo::ReducePrecisionOp>(
481 op)) {
482 return xla::HloOpcode::kReducePrecision;
483 } else if (isa<mlir::mhlo::DotGeneralOp>(op)) {
484 return xla::HloOpcode::kDot;
485 } else if (isa<mlir::mhlo::BroadcastInDimOp, mlir::lmhlo::BroadcastInDimOp>(
486 op)) {
487 return xla::HloOpcode::kBroadcast;
488 } else {
489 std::string s;
490 {
491 llvm::raw_string_ostream os(s);
492 op->print(os);
493 }
494 return tensorflow::errors::Unimplemented(
495 "Unimplemented MHLO -> HloOpcode: ", s);
496 }
497 }
498
499 } // namespace xla
500