1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ 17 #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ 18 19 #include "llvm/ADT/ArrayRef.h" 20 #include "llvm/ADT/StringRef.h" 21 #include "llvm/ADT/StringSwitch.h" 22 #include "llvm/ADT/iterator_range.h" 23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" 24 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" 25 #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" 26 #include "mlir/Dialect/Complex/IR/Complex.h" 27 #include "mlir/Dialect/Math/IR/Math.h" 28 #include "mlir/Dialect/SCF/SCF.h" 29 #include "mlir/Dialect/StandardOps/IR/Ops.h" 30 #include "mlir/IR/BuiltinTypes.h" 31 #include "mlir/IR/ImplicitLocOpBuilder.h" 32 #include "mlir/IR/TypeUtilities.h" 33 34 namespace mlir { 35 namespace lmhlo { 36 namespace impl { 37 38 // A struct to map LhloBinaryOpTy type to the corresponding floating-point and 39 // integer scalar operation types. 40 template <typename LhloBinaryOpTy> 41 struct LhloToScalarOp { 42 using FOp = void; 43 using IOp = void; 44 using UOp = void; 45 using COp = void; 46 }; 47 48 template <> 49 struct LhloToScalarOp<lmhlo::AddOp> { 50 using FOp = ::mlir::AddFOp; 51 using IOp = ::mlir::AddIOp; 52 using UOp = ::mlir::AddIOp; 53 using COp = ::mlir::complex::AddOp; 54 }; 55 template <> 56 struct LhloToScalarOp<lmhlo::CompareOp> { 57 using FOp = ::mlir::CmpFOp; 58 using IOp = ::mlir::CmpIOp; 59 using UOp = ::mlir::CmpIOp; 60 }; 61 template <> 62 struct LhloToScalarOp<lmhlo::DivOp> { 63 using FOp = ::mlir::DivFOp; 64 using IOp = ::mlir::SignedDivIOp; 65 using UOp = ::mlir::UnsignedDivIOp; 66 using COp = ::mlir::complex::DivOp; 67 }; 68 template <> 69 struct LhloToScalarOp<lmhlo::MulOp> { 70 using FOp = ::mlir::MulFOp; 71 using IOp = ::mlir::MulIOp; 72 using UOp = ::mlir::MulIOp; 73 using COp = ::mlir::complex::MulOp; 74 }; 75 template <> 76 struct LhloToScalarOp<lmhlo::RemOp> { 77 using FOp = ::mlir::RemFOp; 78 using IOp = ::mlir::SignedRemIOp; 79 using UOp = ::mlir::UnsignedRemIOp; 80 }; 81 template <> 82 struct LhloToScalarOp<lmhlo::SubOp> { 83 using FOp = ::mlir::SubFOp; 84 using IOp = ::mlir::SubIOp; 85 using UOp = ::mlir::SubIOp; 86 using COp = ::mlir::complex::SubOp; 87 }; 88 89 // Alias for the map from LHLO binary op type to STD floating-point op type. 90 template <typename LhloOp> 91 using ScalarFOp = typename LhloToScalarOp<LhloOp>::FOp; 92 // Alias for the map from LHLO binary op type to STD signed integer op type. 93 template <typename LhloOp> 94 using ScalarIOp = typename LhloToScalarOp<LhloOp>::IOp; 95 // Alias for the map from LHLO binary op type to STD unsigned integer op type. 96 template <typename LhloOp> 97 using ScalarUOp = typename LhloToScalarOp<LhloOp>::UOp; 98 // Alias for the map from LHLO binary op type to STD complex op type. 99 template <typename LhloOp> 100 using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp; 101 102 template <typename... Args> 103 struct MapLhloOpToScalarOpImpl { 104 Value operator()(Location loc, ArrayRef<Type> result_types, 105 ArrayRef<Type> arg_types, ArrayRef<Value> args, 106 OpBuilder* b) { 107 return nullptr; 108 } 109 }; 110 111 template <typename StdScalarOp> 112 struct MapLhloOpToScalarOpImpl<StdScalarOp> { 113 Value operator()(Location loc, ArrayRef<Type> result_types, 114 ArrayRef<Type> arg_types, ArrayRef<Value> args, 115 OpBuilder* b) { 116 return b->template create<StdScalarOp>(loc, result_types, args, mlir::None); 117 } 118 }; 119 120 template <typename SupportedType, typename StdScalarOp, typename... Args> 121 struct MapLhloOpToScalarOpImpl<SupportedType, StdScalarOp, Args...> { 122 Value operator()(Location loc, ArrayRef<Type> result_types, 123 ArrayRef<Type> arg_types, ArrayRef<Value> args, 124 OpBuilder* b) { 125 Type element_type = getElementTypeOrSelf(arg_types.front()); 126 if (SupportedType{}(element_type)) { 127 return b->template create<StdScalarOp>(loc, result_types, args, 128 mlir::None); 129 } 130 return MapLhloOpToScalarOpImpl<Args...>{}(loc, result_types, arg_types, 131 args, b); 132 } 133 }; 134 135 template <typename SupportedType, typename... Args> 136 struct MapLhloOpToScalarOpImpl<SupportedType, void, Args...> { 137 Value operator()(Location loc, ArrayRef<Type> result_types, 138 ArrayRef<Type> arg_types, ArrayRef<Value> args, 139 OpBuilder* b) { 140 return MapLhloOpToScalarOpImpl<Args...>{}(loc, result_types, arg_types, 141 args, b); 142 } 143 }; 144 145 struct isAnyIntegerType { 146 bool operator()(Type t) { return t.isa<IntegerType>(); } 147 }; 148 149 struct isSignedIntegerType { 150 bool operator()(Type t) { 151 // Pretend that signless is signed. This will change eventually. 152 return t.isa<IntegerType>() && !t.isUnsignedInteger(); 153 } 154 }; 155 156 struct isUnsignedIntegerType { 157 bool operator()(Type t) { return t.isUnsignedInteger(); } 158 }; 159 160 struct isFloatType { 161 bool operator()(Type t) { return t.isa<FloatType>(); } 162 }; 163 164 struct isComplexType { 165 bool operator()(Type t) { return t.isa<ComplexType>(); } 166 }; 167 168 // Inserts the computation that corresponds to the body of the loop for lowered 169 // LHLO unary/binary op. Returns the value for the result. 170 template <typename LhloOpTy> 171 inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types, 172 ArrayRef<Type> arg_types, 173 ArrayRef<Value> args, OpBuilder* b) { 174 return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<LhloOpTy>, 175 isUnsignedIntegerType, ScalarUOp<LhloOpTy>, 176 isFloatType, ScalarFOp<LhloOpTy>>{}( 177 loc, result_types, arg_types, args, b); 178 } 179 180 template <> 181 inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc, 182 ArrayRef<Type> result_types, 183 ArrayRef<Type> arg_types, 184 ArrayRef<Value> args, 185 OpBuilder* b) { 186 Type element_type = getElementTypeOrSelf(arg_types.front()); 187 if (element_type.isa<FloatType>()) { 188 return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::AbsFOp>{}( 189 loc, result_types, arg_types, args, b); 190 } 191 if (element_type.isa<ComplexType>()) { 192 return MapLhloOpToScalarOpImpl<isComplexType, ::mlir::complex::AbsOp>{}( 193 loc, result_types, arg_types, args, b); 194 } 195 if (element_type.isSignlessInteger() || element_type.isSignedInteger()) { 196 // lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x)) 197 Value lhs = args[0]; 198 auto integer_type = element_type.dyn_cast<IntegerType>(); 199 200 Value zero_intval = 201 b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); 202 if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { 203 zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval); 204 } 205 auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge, 206 lhs, zero_intval); 207 auto neg_val = b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs); 208 return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val); 209 } 210 return nullptr; 211 } 212 213 template <> 214 inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc, 215 ArrayRef<Type> result_types, 216 ArrayRef<Type> arg_types, 217 ArrayRef<Value> args, 218 OpBuilder* b) { 219 return MapLhloOpToScalarOpImpl<isAnyIntegerType, ScalarIOp<lmhlo::AddOp>, 220 isFloatType, ScalarFOp<lmhlo::AddOp>, 221 isComplexType, ScalarCOp<lmhlo::AddOp>>{}( 222 loc, result_types, arg_types, args, b); 223 } 224 225 template <> 226 inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc, 227 ArrayRef<Type> result_types, 228 ArrayRef<Type> arg_types, 229 ArrayRef<Value> args, 230 OpBuilder* b) { 231 return MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::AndOp>{}( 232 loc, result_types, arg_types, args, b); 233 } 234 235 template <> 236 inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc, 237 ArrayRef<Type> result_types, 238 ArrayRef<Type> arg_types, 239 ArrayRef<Value> args, 240 OpBuilder* b) { 241 return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::Atan2Op>{}( 242 loc, result_types, arg_types, args, b); 243 } 244 245 template <typename PredicateType> 246 inline Optional<PredicateType> getCmpPredicate(StringRef, bool) { 247 return llvm::None; 248 } 249 250 template <> 251 inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>( 252 StringRef comparison_direction, bool is_signed) { 253 assert(is_signed && "cannot have an unsigned float!"); 254 return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction) 255 .Case("EQ", CmpFPredicate::OEQ) 256 .Case("NE", CmpFPredicate::UNE) 257 .Case("GE", CmpFPredicate::OGE) 258 .Case("GT", CmpFPredicate::OGT) 259 .Case("LE", CmpFPredicate::OLE) 260 .Case("LT", CmpFPredicate::OLT) 261 .Default(llvm::None); 262 } 263 264 template <> 265 inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>( 266 StringRef comparison_direction, bool is_signed) { 267 return llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction) 268 .Case("EQ", CmpIPredicate::eq) 269 .Case("NE", CmpIPredicate::ne) 270 .Case("GE", is_signed ? CmpIPredicate::sge : CmpIPredicate::uge) 271 .Case("GT", is_signed ? CmpIPredicate::sgt : CmpIPredicate::ugt) 272 .Case("LE", is_signed ? CmpIPredicate::sle : CmpIPredicate::ule) 273 .Case("LT", is_signed ? CmpIPredicate::slt : CmpIPredicate::ult) 274 .Default(llvm::None); 275 } 276 277 template <typename CompareOpTy> 278 inline Value MapCompareOpToStdScalarOp(Location loc, 279 StringRef comparison_direction, 280 ArrayRef<Type> result_types, 281 ArrayRef<Type> arg_types, 282 ArrayRef<Value> args, OpBuilder* b) { 283 const auto& lhs = args[0]; 284 const auto& rhs = args[1]; 285 Type element_type = getElementTypeOrSelf(arg_types.front()); 286 if (element_type.isa<IntegerType>()) { 287 Optional<CmpIPredicate> predicate = getCmpPredicate<CmpIPredicate>( 288 comparison_direction, !element_type.isUnsignedInteger()); 289 assert(predicate.hasValue() && "expected valid comparison direction"); 290 return b->create<ScalarIOp<CompareOpTy>>(loc, predicate.getValue(), lhs, 291 rhs); 292 } 293 if (element_type.isa<FloatType>()) { 294 Optional<CmpFPredicate> predicate = getCmpPredicate<CmpFPredicate>( 295 comparison_direction, /*is_signed=*/true); 296 assert(predicate.hasValue() && "expected valid comparison direction"); 297 return b->create<ScalarFOp<CompareOpTy>>(loc, predicate.getValue(), lhs, 298 rhs); 299 } 300 if (auto complex_type = element_type.dyn_cast<ComplexType>()) { 301 if (complex_type.getElementType().isa<FloatType>()) { 302 if (comparison_direction == "EQ") { 303 return b->create<complex::EqualOp>(loc, lhs, rhs); 304 } 305 if (comparison_direction == "NE") { 306 return b->create<complex::NotEqualOp>(loc, lhs, rhs); 307 } 308 } 309 } 310 return nullptr; 311 } 312 313 template <> 314 inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc, 315 ArrayRef<Type> result_types, 316 ArrayRef<Type> arg_types, 317 ArrayRef<Value> args, 318 OpBuilder* b) { 319 return args.front(); 320 } 321 322 template <> 323 inline Value MapLhloOpToStdScalarOp<lmhlo::DivOp>(Location loc, 324 ArrayRef<Type> result_types, 325 ArrayRef<Type> arg_types, 326 ArrayRef<Value> args, 327 OpBuilder* b) { 328 return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<lmhlo::DivOp>, 329 isUnsignedIntegerType, ScalarUOp<lmhlo::DivOp>, 330 isFloatType, ScalarFOp<lmhlo::DivOp>, 331 isComplexType, ScalarCOp<lmhlo::DivOp>>{}( 332 loc, result_types, arg_types, args, b); 333 } 334 335 template <> 336 inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc, 337 ArrayRef<Type> result_types, 338 ArrayRef<Type> arg_types, 339 ArrayRef<Value> args, 340 OpBuilder* b) { 341 return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::ExpOp, 342 isComplexType, ::mlir::complex::ExpOp>{}( 343 loc, result_types, arg_types, args, b); 344 } 345 346 template <> 347 inline Value MapLhloOpToStdScalarOp<lmhlo::Expm1Op>(Location loc, 348 ArrayRef<Type> result_types, 349 ArrayRef<Type> arg_types, 350 ArrayRef<Value> args, 351 OpBuilder* b) { 352 return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::ExpM1Op>{}( 353 loc, result_types, arg_types, args, b); 354 } 355 356 template <> 357 inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc, 358 ArrayRef<Type> result_types, 359 ArrayRef<Type> arg_types, 360 ArrayRef<Value> args, 361 OpBuilder* b) { 362 return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::CeilFOp>{}( 363 loc, result_types, arg_types, args, b); 364 } 365 366 template <> 367 inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>( 368 Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types, 369 ArrayRef<Value> args, OpBuilder* b) { 370 return MapLhloOpToScalarOpImpl<complex::CreateOp>{}(loc, result_types, 371 arg_types, args, b); 372 } 373 374 template <> 375 inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc, 376 ArrayRef<Type> result_types, 377 ArrayRef<Type> arg_types, 378 ArrayRef<Value> args, 379 OpBuilder* b) { 380 return MapLhloOpToScalarOpImpl<complex::ReOp>{}(loc, result_types, arg_types, 381 args, b); 382 } 383 384 template <> 385 inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc, 386 ArrayRef<Type> result_types, 387 ArrayRef<Type> arg_types, 388 ArrayRef<Value> args, 389 OpBuilder* b) { 390 return MapLhloOpToScalarOpImpl<complex::ImOp>{}(loc, result_types, arg_types, 391 args, b); 392 } 393 394 template <> 395 inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>( 396 Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types, 397 ArrayRef<Value> args, OpBuilder* b) { 398 Type sourceType = getElementTypeOrSelf(arg_types.front()); 399 Type targetType = getElementTypeOrSelf(result_types.front()); 400 Type convertedSourceType = getElementTypeOrSelf(args.front()); 401 402 // A boolean value is considered to be unsigned when converting to 403 // floating-point. Otherwise, it will become `-1`. 404 if ((sourceType.isInteger(/*width=*/1) || sourceType.isUnsignedInteger()) && 405 mlir::UIToFPOp::areCastCompatible(convertedSourceType, targetType)) { 406 return b->create<mlir::UIToFPOp>(loc, result_types, args, mlir::None); 407 } else if (mlir::SIToFPOp::areCastCompatible(convertedSourceType, 408 targetType)) { 409 return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None); 410 } else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) { 411 FloatType src = sourceType.cast<FloatType>(); 412 FloatType res = targetType.cast<FloatType>(); 413 if (src.getWidth() > res.getWidth()) { 414 return b->create<mlir::FPTruncOp>(loc, result_types, args, mlir::None); 415 } else if (src.getWidth() < res.getWidth()) { 416 return b->create<mlir::FPExtOp>(loc, result_types, args, mlir::None); 417 } 418 // No conversion is needed for the same width floats 419 return args.front(); 420 } 421 if (targetType.isInteger(/*width=*/1)) { 422 // When casting to bool, we need to compare whether the value is equal to 423 // zero. 424 if (sourceType.isSignlessInteger() || sourceType.isUnsignedInteger()) { 425 Value zero_intval = b->create<::mlir::ConstantIntOp>( 426 loc, 0, sourceType.cast<IntegerType>().getWidth()); 427 if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { 428 zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval); 429 } 430 return b->create<mlir::CmpIOp>(loc, CmpIPredicate::ne, args.front(), 431 zero_intval); 432 } else if (sourceType.isa<FloatType>()) { 433 Value zero = b->create<ConstantOp>(loc, b->getFloatAttr(sourceType, 0.0)); 434 if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { 435 zero = b->create<::mlir::SplatOp>(loc, vec_type, zero); 436 } 437 return b->create<mlir::CmpFOp>(loc, CmpFPredicate::UNE, args.front(), 438 zero); 439 } 440 } 441 if (sourceType.isa<IntegerType>() && targetType.isa<IntegerType>()) { 442 IntegerType src = sourceType.cast<IntegerType>(); 443 IntegerType res = targetType.cast<IntegerType>(); 444 if (src.getWidth() > res.getWidth()) { 445 return b->create<mlir::TruncateIOp>(loc, result_types, args, mlir::None); 446 } else if (src.getWidth() < res.getWidth()) { 447 // Special case boolean values, so they get casted to `1` instead of `-1`. 448 if (src.isUnsignedInteger() || src.getWidth() == 1) { 449 return b->create<mlir::ZeroExtendIOp>(loc, result_types, args, 450 mlir::None); 451 } 452 return b->create<mlir::SignExtendIOp>(loc, result_types, args, 453 mlir::None); 454 } 455 // No conversion is needed for the same width integers 456 return args.front(); 457 } 458 if (mlir::FPToSIOp::areCastCompatible(convertedSourceType, targetType)) { 459 return b->create<mlir::FPToSIOp>(loc, result_types, args, mlir::None); 460 } 461 return nullptr; 462 } 463 464 template <> 465 inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc, 466 ArrayRef<Type> result_types, 467 ArrayRef<Type> arg_types, 468 ArrayRef<Value> args, 469 OpBuilder* b) { 470 // Dot Op converter from lhlo to affine only accepts float and integer types. 471 const auto& lhs = args[0]; 472 const auto& rhs = args[1]; 473 const auto& result = args[2]; 474 Type element_type = lhs.getType(); 475 if (element_type.isa<FloatType>()) { 476 Value float_mul = MapLhloOpToScalarOpImpl<isFloatType, ::mlir::MulFOp>{}( 477 loc, result_types, arg_types, {lhs, rhs}, b); 478 return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::AddFOp>{}( 479 loc, result_types, arg_types, {float_mul, result}, b); 480 } 481 if (element_type.isa<IntegerType>()) { 482 Value int_mul = MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::MulIOp>{}( 483 loc, result_types, arg_types, {lhs, rhs}, b); 484 return MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::AddIOp>{}( 485 loc, result_types, arg_types, {int_mul, result}, b); 486 } 487 return nullptr; 488 } 489 490 template <> 491 inline Value MapLhloOpToStdScalarOp<lmhlo::CosOp>(Location loc, 492 ArrayRef<Type> result_types, 493 ArrayRef<Type> arg_types, 494 ArrayRef<Value> args, 495 OpBuilder* b) { 496 return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::CosOp>{}( 497 loc, result_types, arg_types, args, b); 498 } 499 500 template <> 501 inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc, 502 ArrayRef<Type> result_types, 503 ArrayRef<Type> arg_types, 504 ArrayRef<Value> args, 505 OpBuilder* b) { 506 return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::SinOp>{}( 507 loc, result_types, arg_types, args, b); 508 } 509 510 template <> 511 inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc, 512 ArrayRef<Type> result_types, 513 ArrayRef<Type> arg_types, 514 ArrayRef<Value> args, 515 OpBuilder* b) { 516 return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::FloorFOp>{}( 517 loc, result_types, arg_types, args, b); 518 } 519 520 template <> 521 inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>( 522 Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types, 523 ArrayRef<Value> args, OpBuilder* b) { 524 if (args[0].getType().isa<FloatType>()) { 525 auto pos_inf = APFloat::getInf( 526 args[0].getType().cast<FloatType>().getFloatSemantics()); 527 auto const_pos_inf = 528 b->create<ConstantOp>(loc, b->getFloatAttr(args[0].getType(), pos_inf)); 529 Value abs_x = b->create<::mlir::AbsFOp>(loc, args[0]); 530 return b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, abs_x, 531 const_pos_inf); 532 } 533 return nullptr; 534 } 535 536 /// Implements the conversion of HLO op to scalar op (to use within region of a 537 /// linalg.generic op) for compare-select style operations like min/max. 538 template <typename... Args> 539 struct CompareSelectOpToStdScalarOp { 540 static Value map(Location loc, StringRef comparison_direction, 541 ArrayRef<Type> result_types, ArrayRef<Type> arg_types, 542 ArrayRef<Value> args, OpBuilder* b) { 543 return nullptr; 544 } 545 }; 546 547 /// Specialization which allows converting to a comparison operation in standard 548 /// dialect with a given predicate based on the element type of the operand. 549 template <typename SupportedType, typename StdCompareOp, typename Predicate, 550 typename... Args> 551 struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate, 552 Args...> { 553 static Value map(Location loc, StringRef comparison_direction, 554 ArrayRef<Type> result_types, ArrayRef<Type> arg_types, 555 ArrayRef<Value> args, OpBuilder* b) { 556 Type element_type = getElementTypeOrSelf(arg_types.front()); 557 if (element_type.isa<SupportedType>()) { 558 auto predicate = getCmpPredicate<Predicate>( 559 comparison_direction, !element_type.isUnsignedInteger()); 560 assert(predicate.hasValue() && "expected valid comparison direction"); 561 auto cmp = b->template create<StdCompareOp>(loc, predicate.getValue(), 562 args[0], args[1]); 563 return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]); 564 } 565 return CompareSelectOpToStdScalarOp<Args...>::map( 566 loc, comparison_direction, result_types, arg_types, args, b); 567 } 568 }; 569 570 template <> 571 inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc, 572 ArrayRef<Type> result_types, 573 ArrayRef<Type> arg_types, 574 ArrayRef<Value> args, 575 OpBuilder* b) { 576 return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::LogOp, 577 isComplexType, ::mlir::complex::LogOp>{}( 578 loc, result_types, arg_types, args, b); 579 } 580 581 inline Value LhloAlwaysPropagateNaN(Value v, ArrayRef<Value> args, Location loc, 582 OpBuilder* b) { 583 Type element_type = getElementTypeOrSelf(args.front().getType()); 584 if (auto float_type = element_type.dyn_cast<FloatType>()) { 585 Value isnan = 586 b->create<mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[1]); 587 588 auto nan_apfloat = APFloat::getQNaN(float_type.getFloatSemantics()); 589 Value nan = b->create<mlir::ConstantFloatOp>(loc, nan_apfloat, float_type); 590 if (VectorType vec_type = args[0].getType().dyn_cast<VectorType>()) { 591 nan = b->create<::mlir::SplatOp>(loc, vec_type, nan); 592 } 593 v = b->create<mlir::SelectOp>(loc, isnan, nan, v); 594 } 595 return v; 596 } 597 598 template <> 599 inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>( 600 Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types, 601 ArrayRef<Value> args, OpBuilder* b) { 602 auto ty = result_types.front().cast<FloatType>(); 603 Value one = b->create<ConstantOp>(loc, b->getFloatAttr(ty, 1.0)); 604 Value x = args.front(); 605 Value neg_x = b->create<NegFOp>(loc, x); 606 Value exp_neg_x = b->create<::mlir::math::ExpOp>(loc, neg_x); 607 Value one_add_exp_neg_x = b->create<AddFOp>(loc, one, exp_neg_x); 608 return b->create<DivFOp>(loc, one, one_add_exp_neg_x); 609 } 610 611 template <> 612 inline Value MapLhloOpToStdScalarOp<lmhlo::Log1pOp>(Location loc, 613 ArrayRef<Type> result_types, 614 ArrayRef<Type> arg_types, 615 ArrayRef<Value> args, 616 OpBuilder* b) { 617 return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::Log1pOp, 618 isComplexType, ::mlir::complex::Log1pOp>{}( 619 loc, result_types, arg_types, args, b); 620 } 621 622 template <> 623 inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc, 624 ArrayRef<Type> result_types, 625 ArrayRef<Type> arg_types, 626 ArrayRef<Value> args, 627 OpBuilder* b) { 628 return LhloAlwaysPropagateNaN( 629 CompareSelectOpToStdScalarOp< 630 IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType, 631 ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", 632 result_types, 633 arg_types, args, b), 634 args, loc, b); 635 } 636 637 template <> 638 inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc, 639 ArrayRef<Type> result_types, 640 ArrayRef<Type> arg_types, 641 ArrayRef<Value> args, 642 OpBuilder* b) { 643 return LhloAlwaysPropagateNaN( 644 CompareSelectOpToStdScalarOp< 645 IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType, 646 ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", 647 result_types, 648 arg_types, args, b), 649 args, loc, b); 650 } 651 652 template <> 653 inline Value MapLhloOpToStdScalarOp<lmhlo::MulOp>(Location loc, 654 ArrayRef<Type> result_types, 655 ArrayRef<Type> arg_types, 656 ArrayRef<Value> args, 657 OpBuilder* b) { 658 return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<lmhlo::MulOp>, 659 isUnsignedIntegerType, ScalarUOp<lmhlo::MulOp>, 660 isFloatType, ScalarFOp<lmhlo::MulOp>, 661 isComplexType, ScalarCOp<lmhlo::MulOp>>{}( 662 loc, result_types, arg_types, args, b); 663 } 664 665 template <> 666 inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc, 667 ArrayRef<Type> result_types, 668 ArrayRef<Type> arg_types, 669 ArrayRef<Value> args, 670 OpBuilder* b) { 671 assert(args.size() == 3 && "expected 3 arguments"); 672 Value lb = args[0]; 673 Value x = args[1]; 674 Value ub = args[2]; 675 676 // clamp(lb, x, ub) = max(min(x, ub), lb) 677 Value min_x_ub = MapLhloOpToStdScalarOp<lmhlo::MinOp>(loc, result_types, 678 arg_types, {x, ub}, b); 679 return MapLhloOpToStdScalarOp<lmhlo::MaxOp>(loc, result_types, arg_types, 680 {min_x_ub, lb}, b); 681 } 682 683 template <> 684 inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc, 685 ArrayRef<Type> result_types, 686 ArrayRef<Type> arg_types, 687 ArrayRef<Value> args, 688 OpBuilder* b) { 689 Type element_type = getElementTypeOrSelf(args.front().getType()); 690 if (element_type.isa<ComplexType, FloatType>()) { 691 return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::NegFOp, isComplexType, 692 ::mlir::complex::NegOp>{}( 693 loc, result_types, arg_types, args, b); 694 } 695 if (element_type.isa<IntegerType>()) { 696 // lmhlo.neg(x, result) -> result = sub(0, x) 697 Value lhs = args[0]; 698 auto integer_type = element_type.dyn_cast<IntegerType>(); 699 700 Value zero_intval = 701 b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); 702 if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { 703 zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval); 704 } 705 return b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs); 706 } 707 return nullptr; 708 } 709 710 template <> 711 inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc, 712 ArrayRef<Type> result_types, 713 ArrayRef<Type> arg_types, 714 ArrayRef<Value> args, 715 OpBuilder* b) { 716 Type element_type = getElementTypeOrSelf(args.front().getType()); 717 if (auto integer_type = element_type.dyn_cast<IntegerType>()) { 718 // lmhlo.not(x) -> x ^ -1 719 Value all_ones = 720 b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth()); 721 if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { 722 all_ones = b->create<::mlir::SplatOp>(loc, vec_type, all_ones); 723 } 724 return b->create<::mlir::XOrOp>(loc, all_ones, args[0]); 725 } 726 return nullptr; 727 } 728 729 template <> 730 inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc, 731 ArrayRef<Type> result_types, 732 ArrayRef<Type> arg_types, 733 ArrayRef<Value> args, 734 OpBuilder* b) { 735 return MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::OrOp>{}( 736 loc, result_types, arg_types, args, b); 737 } 738 739 template <> 740 inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc, 741 ArrayRef<Type> result_types, 742 ArrayRef<Type> arg_types, 743 ArrayRef<Value> args, 744 OpBuilder* b) { 745 return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::RsqrtOp>{}( 746 loc, result_types, arg_types, args, b); 747 } 748 749 template <> 750 inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc, 751 ArrayRef<Type> result_types, 752 ArrayRef<Type> arg_types, 753 ArrayRef<Value> args, 754 OpBuilder* b) { 755 lmhlo::PowOp::Adaptor adaptor(args); 756 auto lb = ImplicitLocOpBuilder(loc, *b); 757 // Floating point can use std::powf 758 auto result_type = result_types.front(); 759 if (result_type.isa<::mlir::FloatType>()) 760 return MapLhloOpToScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types, 761 arg_types, args, b); 762 763 assert(result_type.isa<::mlir::IntegerType>() && 764 "only float and integer `pow` is supported right now"); 765 766 // Exponentiation by squaring: 767 // https://en.wikipedia.org/wiki/Exponentiation_by_squaring; 768 Value neg_one = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, -1)); 769 Value zero = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 0)); 770 Value one = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 1)); 771 Value two = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 2)); 772 Value step = lb.create<ConstantIndexOp>(1); 773 Value lowerBound = lb.create<ConstantIndexOp>(0); 774 // Everything else would overflow for any exponent > 1, as 2^64 775 // is the larget possible exponent for a 64-bit integer, and 776 // that's 1 << 6. 777 Value upperBound = lb.create<ConstantIndexOp>(6); 778 auto original_base = adaptor.lhs(); 779 auto original_exponent = adaptor.rhs(); 780 781 Value accum = 782 lb.create<scf::ForOp>( 783 lowerBound, upperBound, step, 784 SmallVector<Value>({one, original_base, original_exponent}), 785 [&](OpBuilder& b, Location, Value v, ValueRange iters) { 786 Value accum = iters[0]; 787 Value base = iters[1]; 788 Value exponent = iters[2]; 789 790 Value condition = b.create<CmpIOp>( 791 loc, CmpIPredicate::eq, 792 b.create<::mlir::AndOp>(loc, exponent, one), one); 793 Value multiplied = b.create<::mlir::MulIOp>(loc, accum, base); 794 accum = 795 b.create<::mlir::SelectOp>(loc, condition, multiplied, accum); 796 base = b.create<::mlir::MulIOp>(loc, base, base); 797 exponent = 798 b.create<::mlir::UnsignedShiftRightOp>(loc, exponent, one); 799 b.create<scf::YieldOp>( 800 loc, SmallVector<Value>({accum, base, exponent})); 801 }) 802 .getResult(0); 803 804 Value rhs_is_even = lb.create<CmpIOp>( 805 CmpIPredicate::eq, lb.create<SignedRemIOp>(adaptor.rhs(), two), zero); 806 Value rhs_is_negative = 807 lb.create<CmpIOp>(CmpIPredicate::slt, adaptor.rhs(), zero); 808 Value lhs_is_one = lb.create<CmpIOp>(CmpIPredicate::eq, adaptor.lhs(), one); 809 Value lhs_is_neg_one = 810 lb.create<CmpIOp>(CmpIPredicate::eq, adaptor.lhs(), neg_one); 811 812 // The accum is correct when the rhs is non-negative. When rhs is 813 // negative, we return 0 for integer, with the exception of lhs values of 1 814 // and -1 which have integer results for negative exponents. Specifically, the 815 // calulation is the following: 816 // 817 // - Return accum if the rhs is not negative. 818 // - Return 1 or -1 depending on the parity of rhs when the lhs is -1. 819 // - Return 1 if lhs is 1. 820 // - Else return 0. 821 Value if_lhs_is_one = lb.create<::mlir::SelectOp>(lhs_is_one, one, zero); 822 Value if_lhs_is_neg_one = lb.create<::mlir::SelectOp>( 823 lhs_is_neg_one, lb.create<::mlir::SelectOp>(rhs_is_even, one, neg_one), 824 if_lhs_is_one); 825 return lb.create<::mlir::SelectOp>(rhs_is_negative, if_lhs_is_neg_one, accum); 826 } 827 828 template <> 829 inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>( 830 Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types, 831 ArrayRef<Value> args, OpBuilder* b) { 832 return MapLhloOpToScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, 833 arg_types, args, b); 834 } 835 836 template <> 837 inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftLeftOp>( 838 Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types, 839 ArrayRef<Value> args, OpBuilder* b) { 840 return MapLhloOpToScalarOpImpl<isAnyIntegerType, mlir::ShiftLeftOp>{}( 841 loc, result_types, arg_types, args, b); 842 } 843 844 template <> 845 inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightArithmeticOp>( 846 Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types, 847 ArrayRef<Value> args, OpBuilder* b) { 848 return MapLhloOpToScalarOpImpl<isAnyIntegerType, mlir::SignedShiftRightOp>{}( 849 loc, result_types, arg_types, args, b); 850 } 851 852 template <> 853 inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightLogicalOp>( 854 Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types, 855 ArrayRef<Value> args, OpBuilder* b) { 856 return MapLhloOpToScalarOpImpl<isAnyIntegerType, 857 mlir::UnsignedShiftRightOp>{}( 858 loc, result_types, arg_types, args, b); 859 } 860 861 template <> 862 inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc, 863 ArrayRef<Type> result_types, 864 ArrayRef<Type> arg_types, 865 ArrayRef<Value> args, 866 OpBuilder* b) { 867 Type element_type = getElementTypeOrSelf(args.front().getType()); 868 if (auto float_type = element_type.dyn_cast<FloatType>()) { 869 bool ignored; 870 APFloat zero_apfloat(0.0f); 871 zero_apfloat.convert(float_type.getFloatSemantics(), 872 APFloat::rmNearestTiesToEven, &ignored); 873 Value zero = 874 b->create<mlir::ConstantFloatOp>(loc, zero_apfloat, float_type); 875 if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { 876 zero = b->create<::mlir::SplatOp>(loc, vec_type, zero); 877 } 878 Value ne0_i1 = 879 b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, args[0], zero); 880 Value ne0_float = b->create<::mlir::UIToFPOp>(loc, ne0_i1, zero.getType()); 881 Value copy_sign = 882 b->create<::mlir::CopySignOp>(loc, result_types, ne0_float, args[0]); 883 auto is_nan = 884 b->create<::mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[0]); 885 return b->create<::mlir::SelectOp>(loc, is_nan, args[0], copy_sign); 886 } else if (auto integer_type = element_type.dyn_cast<IntegerType>()) { 887 // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1) 888 Value zero = 889 b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); 890 Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>( 891 loc, integer_type.getWidth() - 1, integer_type.getWidth()); 892 Value one = 893 b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth()); 894 if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { 895 zero = b->create<::mlir::SplatOp>(loc, vec_type, zero); 896 bitwidth_minus_one = 897 b->create<::mlir::SplatOp>(loc, vec_type, bitwidth_minus_one); 898 one = b->create<::mlir::SplatOp>(loc, vec_type, one); 899 } 900 Value cmp = 901 b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero); 902 Value ashr = 903 b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one); 904 Value or_op = b->create<::mlir::OrOp>(loc, ashr, one); 905 return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op); 906 } else if (element_type.isa<ComplexType>()) { 907 return b->create<::mlir::complex::SignOp>(loc, element_type, args.front()); 908 } 909 return nullptr; 910 } 911 912 template <> 913 inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc, 914 ArrayRef<Type> result_types, 915 ArrayRef<Type> arg_types, 916 ArrayRef<Value> args, 917 OpBuilder* b) { 918 return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::SqrtOp>{}( 919 loc, result_types, arg_types, args, b); 920 } 921 922 template <> 923 inline Value MapLhloOpToStdScalarOp<lmhlo::SubOp>(Location loc, 924 ArrayRef<Type> result_types, 925 ArrayRef<Type> arg_types, 926 ArrayRef<Value> args, 927 OpBuilder* b) { 928 return MapLhloOpToScalarOpImpl<isAnyIntegerType, ScalarIOp<lmhlo::SubOp>, 929 isFloatType, ScalarFOp<lmhlo::SubOp>, 930 isComplexType, ScalarCOp<lmhlo::SubOp>>{}( 931 loc, result_types, arg_types, args, b); 932 } 933 934 template <> 935 inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc, 936 ArrayRef<Type> result_types, 937 ArrayRef<Type> arg_types, 938 ArrayRef<Value> args, 939 OpBuilder* b) { 940 return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::TanhOp>{}( 941 loc, result_types, arg_types, args, b); 942 } 943 944 template <> 945 inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc, 946 ArrayRef<Type> result_types, 947 ArrayRef<Type> arg_types, 948 ArrayRef<Value> args, 949 OpBuilder* b) { 950 return MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::XOrOp>{}( 951 loc, result_types, arg_types, args, b); 952 } 953 954 } // namespace impl 955 956 struct HloOpToStdScalarOp { 957 // Implementation for LHLO ops except lmhlo::CompareOp. 958 template <typename HloOpTy, typename LhloOpTy = HloOpTy, 959 typename = std::enable_if_t< 960 !std::is_same<LhloOpTy, lmhlo::CompareOp>::value && 961 std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>, 962 std::false_type>::value>> 963 static Value map(HloOpTy op, ArrayRef<Type> result_types, 964 ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) { 965 return impl::MapLhloOpToStdScalarOp<LhloOpTy>( 966 op.getLoc(), result_types, llvm::to_vector<4>(op->getOperandTypes()), 967 args, b); 968 } 969 970 // Implementation for HLO ops except mhlo::CompareOp. 971 template <typename HloOpTy, typename LhloOpTy = mhlo::HloToLhloOp<HloOpTy>, 972 typename = std::enable_if_t< 973 !std::is_same<LhloOpTy, lmhlo::CompareOp>::value && 974 !std::is_same<LhloOpTy, std::false_type>::value>> 975 static Value map(HloOpTy op, ArrayRef<Type> result_types, 976 ArrayRef<Value> args, OpBuilder* b, int i = 0) { 977 return impl::MapLhloOpToStdScalarOp<LhloOpTy>( 978 op.getLoc(), result_types, llvm::to_vector<4>(op->getOperandTypes()), 979 args, b); 980 } 981 982 // Implementation for lmhlo::CompareOp. 983 template <typename LhloOpTy, typename = std::enable_if_t<std::is_same< 984 LhloOpTy, lmhlo::CompareOp>::value>> 985 static Value map(lmhlo::CompareOp op, ArrayRef<Type> result_types, 986 ArrayRef<Value> args, OpBuilder* b) { 987 auto comparison_direction = op.comparison_direction(); 988 return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>( 989 op.getLoc(), comparison_direction, result_types, 990 llvm::to_vector<4>(op->getOperandTypes()), args, b); 991 } 992 993 // Implementation for mhlo::CompareOp. 994 template <typename HloOpTy, 995 typename = 996 std::enable_if_t<std::is_same<HloOpTy, mhlo::CompareOp>::value>> 997 static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types, 998 ArrayRef<Value> args, OpBuilder* b) { 999 auto comparison_direction = op.comparison_direction(); 1000 return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>( 1001 op.getLoc(), comparison_direction, result_types, 1002 llvm::to_vector<4>(op->getOperandTypes()), args, b); 1003 } 1004 1005 // Implementation for LHLO ops except lmhlo::CompareOp. 1006 template <typename LhloOpTy, 1007 typename = std::enable_if_t< 1008 !std::is_same<LhloOpTy, lmhlo::CompareOp>::value && 1009 std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>, 1010 std::false_type>::value>> 1011 static Value map(Location loc, ArrayRef<Type> result_types, 1012 ArrayRef<Type> arg_types, ArrayRef<Value> args, OpBuilder* b, 1013 unsigned i = 0) { 1014 return impl::MapLhloOpToStdScalarOp<LhloOpTy>(loc, result_types, arg_types, 1015 args, b); 1016 } 1017 1018 // Implementation for lmhlo::CompareOp. 1019 template <typename LhloOpTy, typename = std::enable_if_t<std::is_same< 1020 LhloOpTy, lmhlo::CompareOp>::value>> 1021 static Value map(Location loc, StringRef comparison_direction, 1022 ArrayRef<Type> result_types, ArrayRef<Type> arg_types, 1023 ArrayRef<Value> args, OpBuilder* b) { 1024 return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>( 1025 loc, comparison_direction, result_types, arg_types, args, b); 1026 } 1027 }; 1028 1029 } // namespace lmhlo 1030 } // namespace mlir 1031 1032 #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ 1033