1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file defines helpers useful when creating or manipulating lhlo/hlo.
17
18 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
19
20 #include "mlir/IR/AffineMap.h" // from @llvm-project
21 #include "mlir/IR/Attributes.h" // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/core/platform/bfloat16.h"
27 #include "tensorflow/core/platform/logging.h"
28
29 namespace xla {
30 namespace {
31
32 using mlir::AffineMap;
33 using mlir::Builder;
34 using mlir::DenseElementsAttr;
35 using mlir::ShapedType;
36 using xla::LiteralBase;
37 using xla::StatusOr;
38
39 template <typename CppType>
CreateDenseAttrFromLiteral(const ShapedType & type,const LiteralBase & literal)40 ::mlir::DenseElementsAttr CreateDenseAttrFromLiteral(
41 const ShapedType& type, const LiteralBase& literal) {
42 auto data_span = literal.data<CppType>();
43 return ::mlir::DenseElementsAttr::get(
44 type, llvm::makeArrayRef(data_span.data(), data_span.size()));
45 }
46
GetPermutationIfAvailable(const Shape & shape,mlir::Builder builder)47 StatusOr<llvm::SmallVector<AffineMap, 1>> GetPermutationIfAvailable(
48 const Shape& shape, mlir::Builder builder) {
49 // N.B. IsMonotonicWithDim0Major ignores tiling, and I can't change it because
50 // some XLA code relies on it treating tiled layouts as equivalent to untiled
51 // layouts, so the check to rule out tiling has to come /before/ the
52 // early-return branch, or we'd miss tiled monotonic layouts.
53 if (!shape.layout().tiles().empty()) {
54 return tensorflow::errors::Internal("Tiled layouts are not yet supported");
55 }
56 if (!shape.has_layout() ||
57 LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {
58 return llvm::SmallVector<AffineMap, 1>{};
59 }
60 if (!shape.is_static()) {
61 return tensorflow::errors::Internal(
62 "Permutations for dynamic shapes are not yet supported");
63 }
64 int64_t accumulated_stride = 1;
65 llvm::SmallVector<int64_t, 4> strides(shape.rank(), 1);
66 for (int64_t dim : LayoutUtil::MinorToMajor(shape)) {
67 strides[dim] = accumulated_stride;
68 accumulated_stride *= shape.dimensions(dim);
69 }
70 if (accumulated_stride == 0) {
71 return llvm::SmallVector<AffineMap, 1>{};
72 }
73 return llvm::SmallVector<AffineMap, 1>{
74 makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())};
75 }
76
77 template <typename T>
CopyDenseElementsBy(mlir::DenseElementsAttr data,std::vector<uint8> * output)78 void CopyDenseElementsBy(mlir::DenseElementsAttr data,
79 std::vector<uint8>* output) {
80 output->resize(data.getNumElements() * sizeof(T));
81 int i = 0;
82 for (T element : data.getValues<T>()) {
83 std::memcpy(&(*output)[i], &element, sizeof(T));
84 i += sizeof(T);
85 }
86 }
87
88 } // namespace
89
ConvertTensorShapeToMemRefType(const Shape & shape,mlir::Builder builder)90 StatusOr<mlir::MemRefType> ConvertTensorShapeToMemRefType(
91 const Shape& shape, mlir::Builder builder) {
92 auto element_type_or =
93 ConvertPrimitiveTypeToMLIRType(shape.element_type(), builder);
94 if (!element_type_or.ok()) return element_type_or.status();
95
96 using mlir::MemRefType;
97 auto dimensions = shape.dimensions();
98 llvm::SmallVector<int64_t, 4> array(dimensions.begin(), dimensions.end());
99 auto permutation_or = GetPermutationIfAvailable(shape, builder);
100 if (!permutation_or.ok()) return permutation_or.status();
101 return MemRefType::get(array, element_type_or.ValueOrDie(),
102 permutation_or.ValueOrDie());
103 }
104
CreateDenseElementsAttrFromLiteral(const LiteralBase & literal,Builder builder)105 StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
106 const LiteralBase& literal, Builder builder) {
107 TF_ASSIGN_OR_RETURN(auto type,
108 ConvertTensorShapeToType<mlir::RankedTensorType>(
109 literal.shape(), builder));
110
111 // TODO(hinsu): Support remaining XLA primitive types.
112 auto element_type = literal.shape().element_type();
113 switch (element_type) {
114 case PrimitiveType::PRED:
115 return CreateDenseAttrFromLiteral<bool>(type, literal);
116 case PrimitiveType::F16:
117 return CreateDenseAttrFromLiteral<half>(type, literal);
118 case PrimitiveType::BF16:
119 return CreateDenseAttrFromLiteral<bfloat16>(type, literal);
120 case PrimitiveType::F32:
121 return CreateDenseAttrFromLiteral<float>(type, literal);
122 case PrimitiveType::F64:
123 return CreateDenseAttrFromLiteral<double>(type, literal);
124 case PrimitiveType::S8:
125 return CreateDenseAttrFromLiteral<int8>(type, literal);
126 case PrimitiveType::S16:
127 return CreateDenseAttrFromLiteral<int16>(type, literal);
128 case PrimitiveType::S32:
129 return CreateDenseAttrFromLiteral<int32>(type, literal);
130 case PrimitiveType::S64:
131 return CreateDenseAttrFromLiteral<int64>(type, literal);
132 case PrimitiveType::U8:
133 return CreateDenseAttrFromLiteral<uint8>(type, literal);
134 case PrimitiveType::U16:
135 return CreateDenseAttrFromLiteral<uint16>(type, literal);
136 case PrimitiveType::U32:
137 return CreateDenseAttrFromLiteral<uint32>(type, literal);
138 case PrimitiveType::U64:
139 return CreateDenseAttrFromLiteral<uint64>(type, literal);
140 case PrimitiveType::C64:
141 return CreateDenseAttrFromLiteral<complex64>(type, literal);
142 case PrimitiveType::C128:
143 return CreateDenseAttrFromLiteral<complex128>(type, literal);
144 default:
145 return tensorflow::errors::Internal(
146 absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type)));
147 }
148 }
149
CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,std::vector<uint8> * output)150 Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data,
151 std::vector<uint8>* output) {
152 mlir::Type element_type = data.getType().getElementType();
153
154 // TODO(hinsu): Support remaining XLA primitive types.
155 if (element_type.isInteger(1)) {
156 CopyDenseElementsBy<bool>(data, output);
157 return Status::OK();
158 }
159 if (element_type.isInteger(8)) {
160 CopyDenseElementsBy<uint8>(data, output);
161 return Status::OK();
162 }
163 if (element_type.isInteger(16)) {
164 CopyDenseElementsBy<uint16>(data, output);
165 return Status::OK();
166 }
167 if (element_type.isInteger(32)) {
168 CopyDenseElementsBy<uint32>(data, output);
169 return Status::OK();
170 }
171 if (element_type.isInteger(64)) {
172 CopyDenseElementsBy<uint64>(data, output);
173 return Status::OK();
174 }
175 if (element_type.isBF16()) {
176 CopyDenseElementsBy<bfloat16>(data, output);
177 return Status::OK();
178 }
179 if (element_type.isF16()) {
180 CopyDenseElementsBy<half>(data, output);
181 return Status::OK();
182 }
183 if (element_type.isF32()) {
184 CopyDenseElementsBy<float>(data, output);
185 return Status::OK();
186 }
187 if (element_type.isF64()) {
188 CopyDenseElementsBy<double>(data, output);
189 return Status::OK();
190 }
191 if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
192 if (complex_type.getElementType().isF32()) {
193 CopyDenseElementsBy<complex64>(data, output);
194 return Status::OK();
195 }
196 if (complex_type.getElementType().isF64()) {
197 CopyDenseElementsBy<complex128>(data, output);
198 return Status::OK();
199 }
200 }
201 return tensorflow::errors::Internal(
202 "Unsupported type in CopyDenseElementsDataToXlaFormat");
203 }
204
GetElementTypeBytes(mlir::Type type)205 StatusOr<int> GetElementTypeBytes(mlir::Type type) {
206 if (type.isInteger(1)) {
207 return 1;
208 }
209 if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
210 TF_ASSIGN_OR_RETURN(int bytes,
211 GetElementTypeBytes(complex_type.getElementType()));
212 return bytes * 2;
213 }
214 int width = type.getIntOrFloatBitWidth();
215 TF_RET_CHECK(width % 8 == 0);
216 return width / 8;
217 }
218
CreateDenseIntElementsAttrFromVector(const llvm::ArrayRef<int64> vector,mlir::Builder builder,llvm::ArrayRef<int64_t> shape)219 mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
220 const llvm::ArrayRef<int64> vector, mlir::Builder builder,
221 llvm::ArrayRef<int64_t> shape) {
222 return mlir::DenseIntElementsAttr::get(
223 mlir::RankedTensorType::get(shape.empty() ? vector.size() : shape,
224 builder.getIntegerType(64)),
225 vector);
226 }
227
ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,mlir::Builder builder)228 StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,
229 mlir::Builder builder) {
230 switch (element_type) {
231 case PrimitiveType::PRED:
232 return builder.getI1Type();
233 case PrimitiveType::F16:
234 return builder.getF16Type();
235 case PrimitiveType::BF16:
236 return builder.getBF16Type();
237 case PrimitiveType::F32:
238 return builder.getF32Type();
239 case PrimitiveType::F64:
240 return builder.getF64Type();
241 case PrimitiveType::S8:
242 return builder.getIntegerType(8);
243 case PrimitiveType::S16:
244 return builder.getIntegerType(16);
245 case PrimitiveType::S32:
246 return builder.getIntegerType(32);
247 case PrimitiveType::S64:
248 return builder.getIntegerType(64);
249 case PrimitiveType::U8:
250 return builder.getIntegerType(8, /*isSigned=*/false);
251 case PrimitiveType::U16:
252 return builder.getIntegerType(16, /*isSigned=*/false);
253 case PrimitiveType::U32:
254 return builder.getIntegerType(32, /*isSigned=*/false);
255 case PrimitiveType::U64:
256 return builder.getIntegerType(64, /*isSigned=*/false);
257 case PrimitiveType::C64:
258 return mlir::ComplexType::get(builder.getF32Type());
259 case PrimitiveType::C128:
260 return mlir::ComplexType::get(builder.getF64Type());
261 // TODO(b/130356985): Support unsigned primitive types.
262 default:
263 return tensorflow::errors::Internal(
264 absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type)));
265 }
266 }
267
CreateGatherDimensionNumbers(const GatherDimensionNumbers & input,mlir::Builder builder)268 mlir::mhlo::GatherDimensionNumbers CreateGatherDimensionNumbers(
269 const GatherDimensionNumbers& input, mlir::Builder builder) {
270 auto offset_dims = CreateDenseIntElementsAttrFromVector(
271 llvm::SmallVector<int64, 4>{input.offset_dims().begin(),
272 input.offset_dims().end()},
273 builder);
274 auto collapsed_slice_dims = CreateDenseIntElementsAttrFromVector(
275 llvm::SmallVector<int64, 4>{input.collapsed_slice_dims().begin(),
276 input.collapsed_slice_dims().end()},
277 builder);
278 auto start_index_map = CreateDenseIntElementsAttrFromVector(
279 llvm::SmallVector<int64, 4>{input.start_index_map().begin(),
280 input.start_index_map().end()},
281 builder);
282
283 mlir::IntegerAttr index_vector_dim =
284 builder.getI64IntegerAttr(input.index_vector_dim());
285
286 return mlir::mhlo::GatherDimensionNumbers::get(
287 offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim,
288 builder.getContext());
289 }
290
MhloToHloOpcode(mlir::Operation * op)291 StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) {
292 using mlir::isa;
293
294 if (isa<mlir::mhlo::ConstOp, mlir::lmhlo::ConstOp>(op)) {
295 return xla::HloOpcode::kConstant;
296 } else if (isa<mlir::mhlo::IotaOp, mlir::lmhlo::IotaOp>(op)) {
297 return xla::HloOpcode::kIota;
298 } else if (isa<mlir::mhlo::ConvertOp, mlir::lmhlo::ConvertOp>(op)) {
299 return xla::HloOpcode::kConvert;
300 } else if (isa<mlir::mhlo::AddOp, mlir::lmhlo::AddOp>(op)) {
301 return xla::HloOpcode::kAdd;
302 } else if (isa<mlir::mhlo::Atan2Op, mlir::lmhlo::Atan2Op>(op)) {
303 return xla::HloOpcode::kAtan2;
304 } else if (isa<mlir::mhlo::DivOp, mlir::lmhlo::DivOp>(op)) {
305 return xla::HloOpcode::kDivide;
306 } else if (isa<mlir::mhlo::MaxOp, mlir::lmhlo::MaxOp>(op)) {
307 return xla::HloOpcode::kMaximum;
308 } else if (isa<mlir::mhlo::MinOp, mlir::lmhlo::MinOp>(op)) {
309 return xla::HloOpcode::kMinimum;
310 } else if (isa<mlir::mhlo::MulOp, mlir::lmhlo::MulOp>(op)) {
311 return xla::HloOpcode::kMultiply;
312 } else if (isa<mlir::mhlo::PowOp, mlir::lmhlo::PowOp>(op)) {
313 return xla::HloOpcode::kPower;
314 } else if (isa<mlir::mhlo::RemOp, mlir::lmhlo::RemOp>(op)) {
315 return xla::HloOpcode::kRemainder;
316 } else if (isa<mlir::mhlo::ShiftLeftOp, mlir::lmhlo::ShiftLeftOp>(op)) {
317 return xla::HloOpcode::kShiftLeft;
318 } else if (isa<mlir::mhlo::ShiftRightArithmeticOp,
319 mlir::lmhlo::ShiftRightArithmeticOp>(op)) {
320 return xla::HloOpcode::kShiftRightArithmetic;
321 } else if (isa<mlir::mhlo::ShiftRightLogicalOp,
322 mlir::lmhlo::ShiftRightLogicalOp>(op)) {
323 return xla::HloOpcode::kShiftRightLogical;
324 } else if (isa<mlir::mhlo::SubOp, mlir::lmhlo::SubOp>(op)) {
325 return xla::HloOpcode::kSubtract;
326 } else if (isa<mlir::mhlo::XorOp, mlir::lmhlo::XorOp>(op)) {
327 return xla::HloOpcode::kXor;
328 } else if (isa<mlir::mhlo::InfeedOp, mlir::lmhlo::InfeedOp>(op)) {
329 return xla::HloOpcode::kInfeed;
330 } else if (isa<mlir::mhlo::OutfeedOp, mlir::lmhlo::OutfeedOp>(op)) {
331 return xla::HloOpcode::kOutfeed;
332 } else if (isa<mlir::mhlo::SendOp>(op)) {
333 return xla::HloOpcode::kSend;
334 } else if (isa<mlir::mhlo::RecvOp>(op)) {
335 return xla::HloOpcode::kRecv;
336 } else if (isa<mlir::mhlo::ReplicaIdOp, mlir::lmhlo::ReplicaIdOp>(op)) {
337 return xla::HloOpcode::kReplicaId;
338 } else if (isa<mlir::mhlo::AfterAllOp>(op)) {
339 return xla::HloOpcode::kAfterAll;
340 } else if (isa<mlir::mhlo::AllReduceOp, mlir::lmhlo::AllReduceOp>(op)) {
341 return xla::HloOpcode::kAllReduce;
342 } else if (isa<mlir::mhlo::AllToAllOp>(op)) {
343 return xla::HloOpcode::kAllToAll;
344 } else if (isa<mlir::mhlo::TupleOp>(op)) {
345 return xla::HloOpcode::kTuple;
346 } else if (isa<mlir::mhlo::BatchNormGradOp, mlir::lmhlo::BatchNormGradOp>(
347 op)) {
348 return xla::HloOpcode::kBatchNormGrad;
349 } else if (isa<mlir::mhlo::BatchNormInferenceOp,
350 mlir::lmhlo::BatchNormInferenceOp>(op)) {
351 return xla::HloOpcode::kBatchNormInference;
352 } else if (isa<mlir::mhlo::BatchNormTrainingOp,
353 mlir::lmhlo::BatchNormTrainingOp>(op)) {
354 return xla::HloOpcode::kBatchNormTraining;
355 } else if (isa<mlir::mhlo::BitcastConvertOp, mlir::lmhlo::BitcastConvertOp>(
356 op)) {
357 return xla::HloOpcode::kBitcastConvert;
358 } else if (isa<mlir::mhlo::BroadcastOp, mlir::lmhlo::BroadcastOp>(op)) {
359 return xla::HloOpcode::kBroadcast;
360 } else if (isa<mlir::mhlo::CholeskyOp, mlir::lmhlo::CholeskyOp>(op)) {
361 return xla::HloOpcode::kCholesky;
362 } else if (isa<mlir::mhlo::ClampOp, mlir::lmhlo::ClampOp>(op)) {
363 return xla::HloOpcode::kClamp;
364 } else if (isa<mlir::mhlo::ConcatenateOp, mlir::lmhlo::ConcatenateOp>(op)) {
365 return xla::HloOpcode::kConcatenate;
366 } else if (isa<mlir::mhlo::ConvOp, mlir::lmhlo::ConvOp>(op)) {
367 return xla::HloOpcode::kConvolution;
368 } else if (isa<mlir::mhlo::SortOp, mlir::lmhlo::SortOp>(op)) {
369 return xla::HloOpcode::kSort;
370 } else if (isa<mlir::mhlo::RngBitGeneratorOp>(op)) {
371 return xla::HloOpcode::kRngBitGenerator;
372 } else if (isa<mlir::mhlo::FusionOp, mlir::lmhlo::FusionOp>(op)) {
373 return xla::HloOpcode::kFusion;
374 } else if (isa<mlir::mhlo::BitcastOp>(op)) {
375 return xla::HloOpcode::kBitcast;
376 } else if (isa<mlir::mhlo::AbsOp, mlir::lmhlo::AbsOp>(op)) {
377 return xla::HloOpcode::kAbs;
378 } else if (isa<mlir::mhlo::CbrtOp, mlir::lmhlo::CbrtOp>(op)) {
379 return xla::HloOpcode::kCbrt;
380 } else if (isa<mlir::mhlo::CeilOp, mlir::lmhlo::CeilOp>(op)) {
381 return xla::HloOpcode::kCeil;
382 } else if (isa<mlir::mhlo::ClzOp, mlir::lmhlo::ClzOp>(op)) {
383 return xla::HloOpcode::kClz;
384 } else if (isa<mlir::mhlo::CosOp, mlir::lmhlo::CosOp>(op)) {
385 return xla::HloOpcode::kCos;
386 } else if (isa<mlir::mhlo::ExpOp, mlir::lmhlo::ExpOp>(op)) {
387 return xla::HloOpcode::kExp;
388 } else if (isa<mlir::mhlo::Expm1Op, mlir::lmhlo::Expm1Op>(op)) {
389 return xla::HloOpcode::kExpm1;
390 } else if (isa<mlir::mhlo::FloorOp, mlir::lmhlo::FloorOp>(op)) {
391 return xla::HloOpcode::kFloor;
392 } else if (isa<mlir::mhlo::ImagOp, mlir::lmhlo::ImagOp>(op)) {
393 return xla::HloOpcode::kImag;
394 } else if (isa<mlir::mhlo::IsFiniteOp, mlir::lmhlo::IsFiniteOp>(op)) {
395 return xla::HloOpcode::kIsFinite;
396 } else if (isa<mlir::mhlo::LogOp, mlir::lmhlo::LogOp>(op)) {
397 return xla::HloOpcode::kLog;
398 } else if (isa<mlir::mhlo::Log1pOp, mlir::lmhlo::Log1pOp>(op)) {
399 return xla::HloOpcode::kLog1p;
400 } else if (isa<mlir::mhlo::LogisticOp>(op)) {
401 return xla::HloOpcode::kLogistic;
402 } else if (isa<mlir::mhlo::NotOp, mlir::lmhlo::NotOp>(op)) {
403 return xla::HloOpcode::kNot;
404 } else if (isa<mlir::mhlo::NegOp, mlir::lmhlo::NegOp>(op)) {
405 return xla::HloOpcode::kNegate;
406 } else if (isa<mlir::mhlo::PopulationCountOp, mlir::lmhlo::PopulationCountOp>(
407 op)) {
408 return xla::HloOpcode::kPopulationCount;
409 } else if (isa<mlir::mhlo::RealOp, mlir::lmhlo::RealOp>(op)) {
410 return xla::HloOpcode::kReal;
411 } else if (isa<mlir::mhlo::RoundOp, mlir::lmhlo::RoundOp>(op)) {
412 return xla::HloOpcode::kRoundNearestAfz;
413 } else if (isa<mlir::mhlo::RsqrtOp, mlir::lmhlo::RsqrtOp>(op)) {
414 return xla::HloOpcode::kRsqrt;
415 } else if (isa<mlir::mhlo::SignOp, mlir::lmhlo::SignOp>(op)) {
416 return xla::HloOpcode::kSign;
417 } else if (isa<mlir::mhlo::SinOp, mlir::lmhlo::SinOp>(op)) {
418 return xla::HloOpcode::kSin;
419 } else if (isa<mlir::mhlo::SqrtOp, mlir::lmhlo::SqrtOp>(op)) {
420 return xla::HloOpcode::kSqrt;
421 } else if (isa<mlir::mhlo::TanhOp, mlir::lmhlo::TanhOp>(op)) {
422 return xla::HloOpcode::kTanh;
423 } else if (isa<mlir::mhlo::ComplexOp, mlir::lmhlo::ComplexOp>(op)) {
424 return xla::HloOpcode::kComplex;
425 } else if (isa<mlir::mhlo::AndOp, mlir::lmhlo::AndOp>(op)) {
426 return xla::HloOpcode::kAnd;
427 } else if (isa<mlir::mhlo::OrOp, mlir::lmhlo::OrOp>(op)) {
428 return xla::HloOpcode::kOr;
429 } else if (isa<mlir::mhlo::WhileOp, mlir::lmhlo::WhileOp>(op)) {
430 return xla::HloOpcode::kWhile;
431 } else if (isa<mlir::mhlo::ReduceOp, mlir::lmhlo::ReduceOp>(op)) {
432 return xla::HloOpcode::kReduce;
433 } else if (isa<mlir::mhlo::GetTupleElementOp>(op)) {
434 return xla::HloOpcode::kGetTupleElement;
435 } else if (isa<mlir::mhlo::CompareOp, mlir::lmhlo::CompareOp>(op)) {
436 return xla::HloOpcode::kCompare;
437 } else if (isa<mlir::mhlo::SliceOp, mlir::lmhlo::SliceOp>(op)) {
438 return xla::HloOpcode::kSlice;
439 } else if (isa<mlir::mhlo::DynamicSliceOp, mlir::lmhlo::DynamicSliceOp>(op)) {
440 return xla::HloOpcode::kDynamicSlice;
441 } else if (isa<mlir::mhlo::DynamicUpdateSliceOp,
442 mlir::lmhlo::DynamicUpdateSliceOp>(op)) {
443 return xla::HloOpcode::kDynamicUpdateSlice;
444 } else if (isa<mlir::mhlo::CollectivePermuteOp,
445 mlir::lmhlo::CollectivePermuteOp>(op)) {
446 return xla::HloOpcode::kCollectivePermute;
447 } else if (isa<mlir::mhlo::CopyOp, mlir::lmhlo::CopyOp>(op)) {
448 return xla::HloOpcode::kCopy;
449 } else if (isa<mlir::mhlo::CustomCallOp, mlir::lmhlo::CustomCallOp>(op)) {
450 return xla::HloOpcode::kCustomCall;
451 } else if (isa<mlir::mhlo::DotOp, mlir::lmhlo::DotOp>(op)) {
452 return xla::HloOpcode::kDot;
453 } else if (isa<mlir::mhlo::FftOp, mlir::lmhlo::FftOp>(op)) {
454 return xla::HloOpcode::kFft;
455 } else if (isa<mlir::mhlo::GatherOp, mlir::lmhlo::GatherOp>(op)) {
456 return xla::HloOpcode::kGather;
457 } else if (isa<mlir::mhlo::GetDimensionSizeOp>(op)) {
458 return xla::HloOpcode::kGetDimensionSize;
459 } else if (isa<mlir::mhlo::MapOp, mlir::lmhlo::MapOp>(op)) {
460 return xla::HloOpcode::kMap;
461 } else if (isa<mlir::mhlo::ReshapeOp, mlir::lmhlo::ReshapeOp>(op)) {
462 return xla::HloOpcode::kReshape;
463 } else if (isa<mlir::mhlo::DynamicReshapeOp>(op)) {
464 return xla::HloOpcode::kDynamicReshape;
465 } else if (isa<mlir::mhlo::ScatterOp, mlir::lmhlo::ScatterOp>(op)) {
466 return xla::HloOpcode::kScatter;
467 } else if (isa<mlir::mhlo::SelectOp, mlir::lmhlo::SelectOp>(op)) {
468 return xla::HloOpcode::kSelect;
469 } else if (isa<mlir::mhlo::SelectAndScatterOp,
470 mlir::lmhlo::SelectAndScatterOp>(op)) {
471 return xla::HloOpcode::kSelectAndScatter;
472 } else if (isa<mlir::mhlo::SetDimensionSizeOp>(op)) {
473 return xla::HloOpcode::kSetDimensionSize;
474 } else if (isa<mlir::mhlo::ReverseOp, mlir::lmhlo::ReverseOp>(op)) {
475 return xla::HloOpcode::kReverse;
476 } else if (isa<mlir::mhlo::PadOp, mlir::lmhlo::PadOp>(op)) {
477 return xla::HloOpcode::kPad;
478 } else if (isa<mlir::mhlo::TraceOp>(op)) {
479 return xla::HloOpcode::kTrace;
480 } else if (isa<mlir::mhlo::TransposeOp, mlir::lmhlo::TransposeOp>(op)) {
481 return xla::HloOpcode::kTranspose;
482 } else if (isa<mlir::mhlo::TriangularSolveOp, mlir::lmhlo::TriangularSolveOp>(
483 op)) {
484 return xla::HloOpcode::kTriangularSolve;
485 } else if (isa<mlir::mhlo::ReduceWindowOp, mlir::lmhlo::ReduceWindowOp>(op)) {
486 return xla::HloOpcode::kReduceWindow;
487 } else if (isa<mlir::mhlo::ReducePrecisionOp, mlir::lmhlo::ReducePrecisionOp>(
488 op)) {
489 return xla::HloOpcode::kReducePrecision;
490 } else if (isa<mlir::mhlo::DotGeneralOp>(op)) {
491 return xla::HloOpcode::kDot;
492 } else if (isa<mlir::mhlo::BroadcastInDimOp, mlir::lmhlo::BroadcastInDimOp>(
493 op)) {
494 return xla::HloOpcode::kBroadcast;
495 } else {
496 std::string s;
497 {
498 llvm::raw_string_ostream os(s);
499 op->print(os);
500 }
501 return tensorflow::errors::Unimplemented(
502 "Unimplemented MHLO -> HloOpcode: ", s);
503 }
504 }
505
506 } // namespace xla
507