• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
17 #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include "llvm/ADT/iterator_range.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
26 #include "mlir/Dialect/Complex/IR/Complex.h"
27 #include "mlir/Dialect/Math/IR/Math.h"
28 #include "mlir/Dialect/SCF/SCF.h"
29 #include "mlir/Dialect/StandardOps/IR/Ops.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/ImplicitLocOpBuilder.h"
32 #include "mlir/IR/TypeUtilities.h"
33 
34 namespace mlir {
35 namespace lmhlo {
36 namespace impl {
37 
38 // A struct to map LhloBinaryOpTy type to the corresponding floating-point and
39 // integer scalar operation types.
40 template <typename LhloBinaryOpTy>
41 struct LhloToScalarOp {
42   using FOp = void;
43   using IOp = void;
44   using UOp = void;
45   using COp = void;
46 };
47 
48 template <>
49 struct LhloToScalarOp<lmhlo::AddOp> {
50   using FOp = ::mlir::AddFOp;
51   using IOp = ::mlir::AddIOp;
52   using UOp = ::mlir::AddIOp;
53   using COp = ::mlir::complex::AddOp;
54 };
55 template <>
56 struct LhloToScalarOp<lmhlo::CompareOp> {
57   using FOp = ::mlir::CmpFOp;
58   using IOp = ::mlir::CmpIOp;
59   using UOp = ::mlir::CmpIOp;
60 };
61 template <>
62 struct LhloToScalarOp<lmhlo::DivOp> {
63   using FOp = ::mlir::DivFOp;
64   using IOp = ::mlir::SignedDivIOp;
65   using UOp = ::mlir::UnsignedDivIOp;
66   using COp = ::mlir::complex::DivOp;
67 };
68 template <>
69 struct LhloToScalarOp<lmhlo::MulOp> {
70   using FOp = ::mlir::MulFOp;
71   using IOp = ::mlir::MulIOp;
72   using UOp = ::mlir::MulIOp;
73   using COp = ::mlir::complex::MulOp;
74 };
75 template <>
76 struct LhloToScalarOp<lmhlo::RemOp> {
77   using FOp = ::mlir::RemFOp;
78   using IOp = ::mlir::SignedRemIOp;
79   using UOp = ::mlir::UnsignedRemIOp;
80 };
81 template <>
82 struct LhloToScalarOp<lmhlo::SubOp> {
83   using FOp = ::mlir::SubFOp;
84   using IOp = ::mlir::SubIOp;
85   using UOp = ::mlir::SubIOp;
86   using COp = ::mlir::complex::SubOp;
87 };
88 
89 // Alias for the map from LHLO binary op type to STD floating-point op type.
90 template <typename LhloOp>
91 using ScalarFOp = typename LhloToScalarOp<LhloOp>::FOp;
92 // Alias for the map from LHLO binary op type to STD signed integer op type.
93 template <typename LhloOp>
94 using ScalarIOp = typename LhloToScalarOp<LhloOp>::IOp;
95 // Alias for the map from LHLO binary op type to STD unsigned integer op type.
96 template <typename LhloOp>
97 using ScalarUOp = typename LhloToScalarOp<LhloOp>::UOp;
98 // Alias for the map from LHLO binary op type to STD complex op type.
99 template <typename LhloOp>
100 using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp;
101 
102 template <typename... Args>
103 struct MapLhloOpToScalarOpImpl {
104   Value operator()(Location loc, ArrayRef<Type> result_types,
105                    ArrayRef<Type> arg_types, ArrayRef<Value> args,
106                    OpBuilder* b) {
107     return nullptr;
108   }
109 };
110 
111 template <typename StdScalarOp>
112 struct MapLhloOpToScalarOpImpl<StdScalarOp> {
113   Value operator()(Location loc, ArrayRef<Type> result_types,
114                    ArrayRef<Type> arg_types, ArrayRef<Value> args,
115                    OpBuilder* b) {
116     return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
117   }
118 };
119 
120 template <typename SupportedType, typename StdScalarOp, typename... Args>
121 struct MapLhloOpToScalarOpImpl<SupportedType, StdScalarOp, Args...> {
122   Value operator()(Location loc, ArrayRef<Type> result_types,
123                    ArrayRef<Type> arg_types, ArrayRef<Value> args,
124                    OpBuilder* b) {
125     Type element_type = getElementTypeOrSelf(arg_types.front());
126     if (SupportedType{}(element_type)) {
127       return b->template create<StdScalarOp>(loc, result_types, args,
128                                              mlir::None);
129     }
130     return MapLhloOpToScalarOpImpl<Args...>{}(loc, result_types, arg_types,
131                                               args, b);
132   }
133 };
134 
135 template <typename SupportedType, typename... Args>
136 struct MapLhloOpToScalarOpImpl<SupportedType, void, Args...> {
137   Value operator()(Location loc, ArrayRef<Type> result_types,
138                    ArrayRef<Type> arg_types, ArrayRef<Value> args,
139                    OpBuilder* b) {
140     return MapLhloOpToScalarOpImpl<Args...>{}(loc, result_types, arg_types,
141                                               args, b);
142   }
143 };
144 
145 struct isAnyIntegerType {
146   bool operator()(Type t) { return t.isa<IntegerType>(); }
147 };
148 
149 struct isSignedIntegerType {
150   bool operator()(Type t) {
151     // Pretend that signless is signed. This will change eventually.
152     return t.isa<IntegerType>() && !t.isUnsignedInteger();
153   }
154 };
155 
156 struct isUnsignedIntegerType {
157   bool operator()(Type t) { return t.isUnsignedInteger(); }
158 };
159 
160 struct isFloatType {
161   bool operator()(Type t) { return t.isa<FloatType>(); }
162 };
163 
164 struct isComplexType {
165   bool operator()(Type t) { return t.isa<ComplexType>(); }
166 };
167 
168 // Inserts the computation that corresponds to the body of the loop for lowered
169 // LHLO unary/binary op. Returns the value for the result.
170 template <typename LhloOpTy>
171 inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
172                                     ArrayRef<Type> arg_types,
173                                     ArrayRef<Value> args, OpBuilder* b) {
174   return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<LhloOpTy>,
175                                  isUnsignedIntegerType, ScalarUOp<LhloOpTy>,
176                                  isFloatType, ScalarFOp<LhloOpTy>>{}(
177       loc, result_types, arg_types, args, b);
178 }
179 
180 template <>
181 inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
182                                                   ArrayRef<Type> result_types,
183                                                   ArrayRef<Type> arg_types,
184                                                   ArrayRef<Value> args,
185                                                   OpBuilder* b) {
186   Type element_type = getElementTypeOrSelf(arg_types.front());
187   if (element_type.isa<FloatType>()) {
188     return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::AbsFOp>{}(
189         loc, result_types, arg_types, args, b);
190   }
191   if (element_type.isa<ComplexType>()) {
192     return MapLhloOpToScalarOpImpl<isComplexType, ::mlir::complex::AbsOp>{}(
193         loc, result_types, arg_types, args, b);
194   }
195   if (element_type.isSignlessInteger() || element_type.isSignedInteger()) {
196     // lmhlo.abs(x, result) ->  result = select((x > 0), x, sub(0, x))
197     Value lhs = args[0];
198     auto integer_type = element_type.dyn_cast<IntegerType>();
199 
200     Value zero_intval =
201         b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
202     if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
203       zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
204     }
205     auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge,
206                                                        lhs, zero_intval);
207     auto neg_val = b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
208     return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val);
209   }
210   return nullptr;
211 }
212 
213 template <>
214 inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc,
215                                                   ArrayRef<Type> result_types,
216                                                   ArrayRef<Type> arg_types,
217                                                   ArrayRef<Value> args,
218                                                   OpBuilder* b) {
219   return MapLhloOpToScalarOpImpl<isAnyIntegerType, ScalarIOp<lmhlo::AddOp>,
220                                  isFloatType, ScalarFOp<lmhlo::AddOp>,
221                                  isComplexType, ScalarCOp<lmhlo::AddOp>>{}(
222       loc, result_types, arg_types, args, b);
223 }
224 
225 template <>
226 inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
227                                                   ArrayRef<Type> result_types,
228                                                   ArrayRef<Type> arg_types,
229                                                   ArrayRef<Value> args,
230                                                   OpBuilder* b) {
231   return MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::AndOp>{}(
232       loc, result_types, arg_types, args, b);
233 }
234 
235 template <>
236 inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
237                                                     ArrayRef<Type> result_types,
238                                                     ArrayRef<Type> arg_types,
239                                                     ArrayRef<Value> args,
240                                                     OpBuilder* b) {
241   return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::Atan2Op>{}(
242       loc, result_types, arg_types, args, b);
243 }
244 
245 template <typename PredicateType>
246 inline Optional<PredicateType> getCmpPredicate(StringRef, bool) {
247   return llvm::None;
248 }
249 
250 template <>
251 inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
252     StringRef comparison_direction, bool is_signed) {
253   assert(is_signed && "cannot have an unsigned float!");
254   return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
255       .Case("EQ", CmpFPredicate::OEQ)
256       .Case("NE", CmpFPredicate::UNE)
257       .Case("GE", CmpFPredicate::OGE)
258       .Case("GT", CmpFPredicate::OGT)
259       .Case("LE", CmpFPredicate::OLE)
260       .Case("LT", CmpFPredicate::OLT)
261       .Default(llvm::None);
262 }
263 
264 template <>
265 inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
266     StringRef comparison_direction, bool is_signed) {
267   return llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction)
268       .Case("EQ", CmpIPredicate::eq)
269       .Case("NE", CmpIPredicate::ne)
270       .Case("GE", is_signed ? CmpIPredicate::sge : CmpIPredicate::uge)
271       .Case("GT", is_signed ? CmpIPredicate::sgt : CmpIPredicate::ugt)
272       .Case("LE", is_signed ? CmpIPredicate::sle : CmpIPredicate::ule)
273       .Case("LT", is_signed ? CmpIPredicate::slt : CmpIPredicate::ult)
274       .Default(llvm::None);
275 }
276 
277 template <typename CompareOpTy>
278 inline Value MapCompareOpToStdScalarOp(Location loc,
279                                        StringRef comparison_direction,
280                                        ArrayRef<Type> result_types,
281                                        ArrayRef<Type> arg_types,
282                                        ArrayRef<Value> args, OpBuilder* b) {
283   const auto& lhs = args[0];
284   const auto& rhs = args[1];
285   Type element_type = getElementTypeOrSelf(arg_types.front());
286   if (element_type.isa<IntegerType>()) {
287     Optional<CmpIPredicate> predicate = getCmpPredicate<CmpIPredicate>(
288         comparison_direction, !element_type.isUnsignedInteger());
289     assert(predicate.hasValue() && "expected valid comparison direction");
290     return b->create<ScalarIOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
291                                              rhs);
292   }
293   if (element_type.isa<FloatType>()) {
294     Optional<CmpFPredicate> predicate = getCmpPredicate<CmpFPredicate>(
295         comparison_direction, /*is_signed=*/true);
296     assert(predicate.hasValue() && "expected valid comparison direction");
297     return b->create<ScalarFOp<CompareOpTy>>(loc, predicate.getValue(), lhs,
298                                              rhs);
299   }
300   if (auto complex_type = element_type.dyn_cast<ComplexType>()) {
301     if (complex_type.getElementType().isa<FloatType>()) {
302       if (comparison_direction == "EQ") {
303         return b->create<complex::EqualOp>(loc, lhs, rhs);
304       }
305       if (comparison_direction == "NE") {
306         return b->create<complex::NotEqualOp>(loc, lhs, rhs);
307       }
308     }
309   }
310   return nullptr;
311 }
312 
313 template <>
314 inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc,
315                                                    ArrayRef<Type> result_types,
316                                                    ArrayRef<Type> arg_types,
317                                                    ArrayRef<Value> args,
318                                                    OpBuilder* b) {
319   return args.front();
320 }
321 
322 template <>
323 inline Value MapLhloOpToStdScalarOp<lmhlo::DivOp>(Location loc,
324                                                   ArrayRef<Type> result_types,
325                                                   ArrayRef<Type> arg_types,
326                                                   ArrayRef<Value> args,
327                                                   OpBuilder* b) {
328   return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<lmhlo::DivOp>,
329                                  isUnsignedIntegerType, ScalarUOp<lmhlo::DivOp>,
330                                  isFloatType, ScalarFOp<lmhlo::DivOp>,
331                                  isComplexType, ScalarCOp<lmhlo::DivOp>>{}(
332       loc, result_types, arg_types, args, b);
333 }
334 
335 template <>
336 inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
337                                                   ArrayRef<Type> result_types,
338                                                   ArrayRef<Type> arg_types,
339                                                   ArrayRef<Value> args,
340                                                   OpBuilder* b) {
341   return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::ExpOp,
342                                  isComplexType, ::mlir::complex::ExpOp>{}(
343       loc, result_types, arg_types, args, b);
344 }
345 
346 template <>
347 inline Value MapLhloOpToStdScalarOp<lmhlo::Expm1Op>(Location loc,
348                                                     ArrayRef<Type> result_types,
349                                                     ArrayRef<Type> arg_types,
350                                                     ArrayRef<Value> args,
351                                                     OpBuilder* b) {
352   return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::ExpM1Op>{}(
353       loc, result_types, arg_types, args, b);
354 }
355 
356 template <>
357 inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc,
358                                                    ArrayRef<Type> result_types,
359                                                    ArrayRef<Type> arg_types,
360                                                    ArrayRef<Value> args,
361                                                    OpBuilder* b) {
362   return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::CeilFOp>{}(
363       loc, result_types, arg_types, args, b);
364 }
365 
366 template <>
367 inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>(
368     Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
369     ArrayRef<Value> args, OpBuilder* b) {
370   return MapLhloOpToScalarOpImpl<complex::CreateOp>{}(loc, result_types,
371                                                       arg_types, args, b);
372 }
373 
374 template <>
375 inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc,
376                                                    ArrayRef<Type> result_types,
377                                                    ArrayRef<Type> arg_types,
378                                                    ArrayRef<Value> args,
379                                                    OpBuilder* b) {
380   return MapLhloOpToScalarOpImpl<complex::ReOp>{}(loc, result_types, arg_types,
381                                                   args, b);
382 }
383 
384 template <>
385 inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc,
386                                                    ArrayRef<Type> result_types,
387                                                    ArrayRef<Type> arg_types,
388                                                    ArrayRef<Value> args,
389                                                    OpBuilder* b) {
390   return MapLhloOpToScalarOpImpl<complex::ImOp>{}(loc, result_types, arg_types,
391                                                   args, b);
392 }
393 
394 template <>
395 inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
396     Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
397     ArrayRef<Value> args, OpBuilder* b) {
398   Type sourceType = getElementTypeOrSelf(arg_types.front());
399   Type targetType = getElementTypeOrSelf(result_types.front());
400   Type convertedSourceType = getElementTypeOrSelf(args.front());
401 
402   // A boolean value is considered to be unsigned when converting to
403   // floating-point. Otherwise, it will become `-1`.
404   if ((sourceType.isInteger(/*width=*/1) || sourceType.isUnsignedInteger()) &&
405       mlir::UIToFPOp::areCastCompatible(convertedSourceType, targetType)) {
406     return b->create<mlir::UIToFPOp>(loc, result_types, args, mlir::None);
407   } else if (mlir::SIToFPOp::areCastCompatible(convertedSourceType,
408                                                targetType)) {
409     return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None);
410   } else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) {
411     FloatType src = sourceType.cast<FloatType>();
412     FloatType res = targetType.cast<FloatType>();
413     if (src.getWidth() > res.getWidth()) {
414       return b->create<mlir::FPTruncOp>(loc, result_types, args, mlir::None);
415     } else if (src.getWidth() < res.getWidth()) {
416       return b->create<mlir::FPExtOp>(loc, result_types, args, mlir::None);
417     }
418     // No conversion is needed for the same width floats
419     return args.front();
420   }
421   if (targetType.isInteger(/*width=*/1)) {
422     // When casting to bool, we need to compare whether the value is equal to
423     // zero.
424     if (sourceType.isSignlessInteger() || sourceType.isUnsignedInteger()) {
425       Value zero_intval = b->create<::mlir::ConstantIntOp>(
426           loc, 0, sourceType.cast<IntegerType>().getWidth());
427       if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
428         zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
429       }
430       return b->create<mlir::CmpIOp>(loc, CmpIPredicate::ne, args.front(),
431                                      zero_intval);
432     } else if (sourceType.isa<FloatType>()) {
433       Value zero = b->create<ConstantOp>(loc, b->getFloatAttr(sourceType, 0.0));
434       if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
435         zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
436       }
437       return b->create<mlir::CmpFOp>(loc, CmpFPredicate::UNE, args.front(),
438                                      zero);
439     }
440   }
441   if (sourceType.isa<IntegerType>() && targetType.isa<IntegerType>()) {
442     IntegerType src = sourceType.cast<IntegerType>();
443     IntegerType res = targetType.cast<IntegerType>();
444     if (src.getWidth() > res.getWidth()) {
445       return b->create<mlir::TruncateIOp>(loc, result_types, args, mlir::None);
446     } else if (src.getWidth() < res.getWidth()) {
447       // Special case boolean values, so they get casted to `1` instead of `-1`.
448       if (src.isUnsignedInteger() || src.getWidth() == 1) {
449         return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
450                                               mlir::None);
451       }
452       return b->create<mlir::SignExtendIOp>(loc, result_types, args,
453                                             mlir::None);
454     }
455     // No conversion is needed for the same width integers
456     return args.front();
457   }
458   if (mlir::FPToSIOp::areCastCompatible(convertedSourceType, targetType)) {
459     return b->create<mlir::FPToSIOp>(loc, result_types, args, mlir::None);
460   }
461   return nullptr;
462 }
463 
464 template <>
465 inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc,
466                                                   ArrayRef<Type> result_types,
467                                                   ArrayRef<Type> arg_types,
468                                                   ArrayRef<Value> args,
469                                                   OpBuilder* b) {
470   // Dot Op converter from lhlo to affine only accepts float and integer types.
471   const auto& lhs = args[0];
472   const auto& rhs = args[1];
473   const auto& result = args[2];
474   Type element_type = lhs.getType();
475   if (element_type.isa<FloatType>()) {
476     Value float_mul = MapLhloOpToScalarOpImpl<isFloatType, ::mlir::MulFOp>{}(
477         loc, result_types, arg_types, {lhs, rhs}, b);
478     return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::AddFOp>{}(
479         loc, result_types, arg_types, {float_mul, result}, b);
480   }
481   if (element_type.isa<IntegerType>()) {
482     Value int_mul = MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::MulIOp>{}(
483         loc, result_types, arg_types, {lhs, rhs}, b);
484     return MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::AddIOp>{}(
485         loc, result_types, arg_types, {int_mul, result}, b);
486   }
487   return nullptr;
488 }
489 
490 template <>
491 inline Value MapLhloOpToStdScalarOp<lmhlo::CosOp>(Location loc,
492                                                   ArrayRef<Type> result_types,
493                                                   ArrayRef<Type> arg_types,
494                                                   ArrayRef<Value> args,
495                                                   OpBuilder* b) {
496   return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::CosOp>{}(
497       loc, result_types, arg_types, args, b);
498 }
499 
500 template <>
501 inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
502                                                   ArrayRef<Type> result_types,
503                                                   ArrayRef<Type> arg_types,
504                                                   ArrayRef<Value> args,
505                                                   OpBuilder* b) {
506   return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::SinOp>{}(
507       loc, result_types, arg_types, args, b);
508 }
509 
510 template <>
511 inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc,
512                                                     ArrayRef<Type> result_types,
513                                                     ArrayRef<Type> arg_types,
514                                                     ArrayRef<Value> args,
515                                                     OpBuilder* b) {
516   return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::FloorFOp>{}(
517       loc, result_types, arg_types, args, b);
518 }
519 
520 template <>
521 inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>(
522     Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
523     ArrayRef<Value> args, OpBuilder* b) {
524   if (args[0].getType().isa<FloatType>()) {
525     auto pos_inf = APFloat::getInf(
526         args[0].getType().cast<FloatType>().getFloatSemantics());
527     auto const_pos_inf =
528         b->create<ConstantOp>(loc, b->getFloatAttr(args[0].getType(), pos_inf));
529     Value abs_x = b->create<::mlir::AbsFOp>(loc, args[0]);
530     return b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, abs_x,
531                                      const_pos_inf);
532   }
533   return nullptr;
534 }
535 
536 /// Implements the conversion of HLO op to scalar op (to use within region of a
537 /// linalg.generic op) for compare-select style operations like min/max.
538 template <typename... Args>
539 struct CompareSelectOpToStdScalarOp {
540   static Value map(Location loc, StringRef comparison_direction,
541                    ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
542                    ArrayRef<Value> args, OpBuilder* b) {
543     return nullptr;
544   }
545 };
546 
547 /// Specialization which allows converting to a comparison operation in standard
548 /// dialect with a given predicate based on the element type of the operand.
549 template <typename SupportedType, typename StdCompareOp, typename Predicate,
550           typename... Args>
551 struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
552                                     Args...> {
553   static Value map(Location loc, StringRef comparison_direction,
554                    ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
555                    ArrayRef<Value> args, OpBuilder* b) {
556     Type element_type = getElementTypeOrSelf(arg_types.front());
557     if (element_type.isa<SupportedType>()) {
558       auto predicate = getCmpPredicate<Predicate>(
559           comparison_direction, !element_type.isUnsignedInteger());
560       assert(predicate.hasValue() && "expected valid comparison direction");
561       auto cmp = b->template create<StdCompareOp>(loc, predicate.getValue(),
562                                                   args[0], args[1]);
563       return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
564     }
565     return CompareSelectOpToStdScalarOp<Args...>::map(
566         loc, comparison_direction, result_types, arg_types, args, b);
567   }
568 };
569 
570 template <>
571 inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc,
572                                                   ArrayRef<Type> result_types,
573                                                   ArrayRef<Type> arg_types,
574                                                   ArrayRef<Value> args,
575                                                   OpBuilder* b) {
576   return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::LogOp,
577                                  isComplexType, ::mlir::complex::LogOp>{}(
578       loc, result_types, arg_types, args, b);
579 }
580 
581 inline Value LhloAlwaysPropagateNaN(Value v, ArrayRef<Value> args, Location loc,
582                                     OpBuilder* b) {
583   Type element_type = getElementTypeOrSelf(args.front().getType());
584   if (auto float_type = element_type.dyn_cast<FloatType>()) {
585     Value isnan =
586         b->create<mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[1]);
587 
588     auto nan_apfloat = APFloat::getQNaN(float_type.getFloatSemantics());
589     Value nan = b->create<mlir::ConstantFloatOp>(loc, nan_apfloat, float_type);
590     if (VectorType vec_type = args[0].getType().dyn_cast<VectorType>()) {
591       nan = b->create<::mlir::SplatOp>(loc, vec_type, nan);
592     }
593     v = b->create<mlir::SelectOp>(loc, isnan, nan, v);
594   }
595   return v;
596 }
597 
598 template <>
599 inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>(
600     Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
601     ArrayRef<Value> args, OpBuilder* b) {
602   auto ty = result_types.front().cast<FloatType>();
603   Value one = b->create<ConstantOp>(loc, b->getFloatAttr(ty, 1.0));
604   Value x = args.front();
605   Value neg_x = b->create<NegFOp>(loc, x);
606   Value exp_neg_x = b->create<::mlir::math::ExpOp>(loc, neg_x);
607   Value one_add_exp_neg_x = b->create<AddFOp>(loc, one, exp_neg_x);
608   return b->create<DivFOp>(loc, one, one_add_exp_neg_x);
609 }
610 
611 template <>
612 inline Value MapLhloOpToStdScalarOp<lmhlo::Log1pOp>(Location loc,
613                                                     ArrayRef<Type> result_types,
614                                                     ArrayRef<Type> arg_types,
615                                                     ArrayRef<Value> args,
616                                                     OpBuilder* b) {
617   return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::Log1pOp,
618                                  isComplexType, ::mlir::complex::Log1pOp>{}(
619       loc, result_types, arg_types, args, b);
620 }
621 
622 template <>
623 inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc,
624                                                   ArrayRef<Type> result_types,
625                                                   ArrayRef<Type> arg_types,
626                                                   ArrayRef<Value> args,
627                                                   OpBuilder* b) {
628   return LhloAlwaysPropagateNaN(
629       CompareSelectOpToStdScalarOp<
630           IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
631           ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
632                                                            result_types,
633                                                            arg_types, args, b),
634       args, loc, b);
635 }
636 
637 template <>
638 inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
639                                                   ArrayRef<Type> result_types,
640                                                   ArrayRef<Type> arg_types,
641                                                   ArrayRef<Value> args,
642                                                   OpBuilder* b) {
643   return LhloAlwaysPropagateNaN(
644       CompareSelectOpToStdScalarOp<
645           IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
646           ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
647                                                            result_types,
648                                                            arg_types, args, b),
649       args, loc, b);
650 }
651 
652 template <>
653 inline Value MapLhloOpToStdScalarOp<lmhlo::MulOp>(Location loc,
654                                                   ArrayRef<Type> result_types,
655                                                   ArrayRef<Type> arg_types,
656                                                   ArrayRef<Value> args,
657                                                   OpBuilder* b) {
658   return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<lmhlo::MulOp>,
659                                  isUnsignedIntegerType, ScalarUOp<lmhlo::MulOp>,
660                                  isFloatType, ScalarFOp<lmhlo::MulOp>,
661                                  isComplexType, ScalarCOp<lmhlo::MulOp>>{}(
662       loc, result_types, arg_types, args, b);
663 }
664 
665 template <>
666 inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
667                                                     ArrayRef<Type> result_types,
668                                                     ArrayRef<Type> arg_types,
669                                                     ArrayRef<Value> args,
670                                                     OpBuilder* b) {
671   assert(args.size() == 3 && "expected 3 arguments");
672   Value lb = args[0];
673   Value x = args[1];
674   Value ub = args[2];
675 
676   // clamp(lb, x, ub) = max(min(x, ub), lb)
677   Value min_x_ub = MapLhloOpToStdScalarOp<lmhlo::MinOp>(loc, result_types,
678                                                         arg_types, {x, ub}, b);
679   return MapLhloOpToStdScalarOp<lmhlo::MaxOp>(loc, result_types, arg_types,
680                                               {min_x_ub, lb}, b);
681 }
682 
683 template <>
684 inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
685                                                   ArrayRef<Type> result_types,
686                                                   ArrayRef<Type> arg_types,
687                                                   ArrayRef<Value> args,
688                                                   OpBuilder* b) {
689   Type element_type = getElementTypeOrSelf(args.front().getType());
690   if (element_type.isa<ComplexType, FloatType>()) {
691     return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::NegFOp, isComplexType,
692                                    ::mlir::complex::NegOp>{}(
693         loc, result_types, arg_types, args, b);
694   }
695   if (element_type.isa<IntegerType>()) {
696     // lmhlo.neg(x, result) -> result = sub(0, x)
697     Value lhs = args[0];
698     auto integer_type = element_type.dyn_cast<IntegerType>();
699 
700     Value zero_intval =
701         b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
702     if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
703       zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
704     }
705     return b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
706   }
707   return nullptr;
708 }
709 
710 template <>
711 inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
712                                                   ArrayRef<Type> result_types,
713                                                   ArrayRef<Type> arg_types,
714                                                   ArrayRef<Value> args,
715                                                   OpBuilder* b) {
716   Type element_type = getElementTypeOrSelf(args.front().getType());
717   if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
718     // lmhlo.not(x) -> x ^ -1
719     Value all_ones =
720         b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth());
721     if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
722       all_ones = b->create<::mlir::SplatOp>(loc, vec_type, all_ones);
723     }
724     return b->create<::mlir::XOrOp>(loc, all_ones, args[0]);
725   }
726   return nullptr;
727 }
728 
729 template <>
730 inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc,
731                                                  ArrayRef<Type> result_types,
732                                                  ArrayRef<Type> arg_types,
733                                                  ArrayRef<Value> args,
734                                                  OpBuilder* b) {
735   return MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::OrOp>{}(
736       loc, result_types, arg_types, args, b);
737 }
738 
739 template <>
740 inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
741                                                     ArrayRef<Type> result_types,
742                                                     ArrayRef<Type> arg_types,
743                                                     ArrayRef<Value> args,
744                                                     OpBuilder* b) {
745   return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::RsqrtOp>{}(
746       loc, result_types, arg_types, args, b);
747 }
748 
749 template <>
750 inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
751                                                   ArrayRef<Type> result_types,
752                                                   ArrayRef<Type> arg_types,
753                                                   ArrayRef<Value> args,
754                                                   OpBuilder* b) {
755   lmhlo::PowOp::Adaptor adaptor(args);
756   auto lb = ImplicitLocOpBuilder(loc, *b);
757   // Floating point can use std::powf
758   auto result_type = result_types.front();
759   if (result_type.isa<::mlir::FloatType>())
760     return MapLhloOpToScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types,
761                                                            arg_types, args, b);
762 
763   assert(result_type.isa<::mlir::IntegerType>() &&
764          "only float and integer `pow` is supported right now");
765 
766   // Exponentiation by squaring:
767   // https://en.wikipedia.org/wiki/Exponentiation_by_squaring;
768   Value neg_one = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, -1));
769   Value zero = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 0));
770   Value one = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 1));
771   Value two = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 2));
772   Value step = lb.create<ConstantIndexOp>(1);
773   Value lowerBound = lb.create<ConstantIndexOp>(0);
774   // Everything else would overflow for any exponent > 1, as 2^64
775   // is the larget possible exponent for a 64-bit integer, and
776   // that's 1 << 6.
777   Value upperBound = lb.create<ConstantIndexOp>(6);
778   auto original_base = adaptor.lhs();
779   auto original_exponent = adaptor.rhs();
780 
781   Value accum =
782       lb.create<scf::ForOp>(
783             lowerBound, upperBound, step,
784             SmallVector<Value>({one, original_base, original_exponent}),
785             [&](OpBuilder& b, Location, Value v, ValueRange iters) {
786               Value accum = iters[0];
787               Value base = iters[1];
788               Value exponent = iters[2];
789 
790               Value condition = b.create<CmpIOp>(
791                   loc, CmpIPredicate::eq,
792                   b.create<::mlir::AndOp>(loc, exponent, one), one);
793               Value multiplied = b.create<::mlir::MulIOp>(loc, accum, base);
794               accum =
795                   b.create<::mlir::SelectOp>(loc, condition, multiplied, accum);
796               base = b.create<::mlir::MulIOp>(loc, base, base);
797               exponent =
798                   b.create<::mlir::UnsignedShiftRightOp>(loc, exponent, one);
799               b.create<scf::YieldOp>(
800                   loc, SmallVector<Value>({accum, base, exponent}));
801             })
802           .getResult(0);
803 
804   Value rhs_is_even = lb.create<CmpIOp>(
805       CmpIPredicate::eq, lb.create<SignedRemIOp>(adaptor.rhs(), two), zero);
806   Value rhs_is_negative =
807       lb.create<CmpIOp>(CmpIPredicate::slt, adaptor.rhs(), zero);
808   Value lhs_is_one = lb.create<CmpIOp>(CmpIPredicate::eq, adaptor.lhs(), one);
809   Value lhs_is_neg_one =
810       lb.create<CmpIOp>(CmpIPredicate::eq, adaptor.lhs(), neg_one);
811 
812   // The accum is correct when the rhs is non-negative. When rhs is
813   // negative, we return 0 for integer, with the exception of lhs values of 1
814   // and -1 which have integer results for negative exponents. Specifically, the
815   // calulation is the following:
816   //
817   // - Return accum if the rhs is not negative.
818   // - Return 1 or -1 depending on the parity of rhs when the lhs is -1.
819   // - Return 1 if lhs is 1.
820   // - Else return 0.
821   Value if_lhs_is_one = lb.create<::mlir::SelectOp>(lhs_is_one, one, zero);
822   Value if_lhs_is_neg_one = lb.create<::mlir::SelectOp>(
823       lhs_is_neg_one, lb.create<::mlir::SelectOp>(rhs_is_even, one, neg_one),
824       if_lhs_is_one);
825   return lb.create<::mlir::SelectOp>(rhs_is_negative, if_lhs_is_neg_one, accum);
826 }
827 
828 template <>
829 inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
830     Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
831     ArrayRef<Value> args, OpBuilder* b) {
832   return MapLhloOpToScalarOpImpl<::mlir::SelectOp>{}(loc, result_types,
833                                                      arg_types, args, b);
834 }
835 
836 template <>
837 inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftLeftOp>(
838     Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
839     ArrayRef<Value> args, OpBuilder* b) {
840   return MapLhloOpToScalarOpImpl<isAnyIntegerType, mlir::ShiftLeftOp>{}(
841       loc, result_types, arg_types, args, b);
842 }
843 
844 template <>
845 inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightArithmeticOp>(
846     Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
847     ArrayRef<Value> args, OpBuilder* b) {
848   return MapLhloOpToScalarOpImpl<isAnyIntegerType, mlir::SignedShiftRightOp>{}(
849       loc, result_types, arg_types, args, b);
850 }
851 
852 template <>
853 inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightLogicalOp>(
854     Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
855     ArrayRef<Value> args, OpBuilder* b) {
856   return MapLhloOpToScalarOpImpl<isAnyIntegerType,
857                                  mlir::UnsignedShiftRightOp>{}(
858       loc, result_types, arg_types, args, b);
859 }
860 
861 template <>
862 inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
863                                                    ArrayRef<Type> result_types,
864                                                    ArrayRef<Type> arg_types,
865                                                    ArrayRef<Value> args,
866                                                    OpBuilder* b) {
867   Type element_type = getElementTypeOrSelf(args.front().getType());
868   if (auto float_type = element_type.dyn_cast<FloatType>()) {
869     bool ignored;
870     APFloat zero_apfloat(0.0f);
871     zero_apfloat.convert(float_type.getFloatSemantics(),
872                          APFloat::rmNearestTiesToEven, &ignored);
873     Value zero =
874         b->create<mlir::ConstantFloatOp>(loc, zero_apfloat, float_type);
875     if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
876       zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
877     }
878     Value ne0_i1 =
879         b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, args[0], zero);
880     Value ne0_float = b->create<::mlir::UIToFPOp>(loc, ne0_i1, zero.getType());
881     Value copy_sign =
882         b->create<::mlir::CopySignOp>(loc, result_types, ne0_float, args[0]);
883     auto is_nan =
884         b->create<::mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[0]);
885     return b->create<::mlir::SelectOp>(loc, is_nan, args[0], copy_sign);
886   } else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
887     // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
888     Value zero =
889         b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
890     Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>(
891         loc, integer_type.getWidth() - 1, integer_type.getWidth());
892     Value one =
893         b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth());
894     if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
895       zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
896       bitwidth_minus_one =
897           b->create<::mlir::SplatOp>(loc, vec_type, bitwidth_minus_one);
898       one = b->create<::mlir::SplatOp>(loc, vec_type, one);
899     }
900     Value cmp =
901         b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
902     Value ashr =
903         b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
904     Value or_op = b->create<::mlir::OrOp>(loc, ashr, one);
905     return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op);
906   } else if (element_type.isa<ComplexType>()) {
907     return b->create<::mlir::complex::SignOp>(loc, element_type, args.front());
908   }
909   return nullptr;
910 }
911 
912 template <>
913 inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
914                                                    ArrayRef<Type> result_types,
915                                                    ArrayRef<Type> arg_types,
916                                                    ArrayRef<Value> args,
917                                                    OpBuilder* b) {
918   return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::SqrtOp>{}(
919       loc, result_types, arg_types, args, b);
920 }
921 
922 template <>
923 inline Value MapLhloOpToStdScalarOp<lmhlo::SubOp>(Location loc,
924                                                   ArrayRef<Type> result_types,
925                                                   ArrayRef<Type> arg_types,
926                                                   ArrayRef<Value> args,
927                                                   OpBuilder* b) {
928   return MapLhloOpToScalarOpImpl<isAnyIntegerType, ScalarIOp<lmhlo::SubOp>,
929                                  isFloatType, ScalarFOp<lmhlo::SubOp>,
930                                  isComplexType, ScalarCOp<lmhlo::SubOp>>{}(
931       loc, result_types, arg_types, args, b);
932 }
933 
934 template <>
935 inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
936                                                    ArrayRef<Type> result_types,
937                                                    ArrayRef<Type> arg_types,
938                                                    ArrayRef<Value> args,
939                                                    OpBuilder* b) {
940   return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::math::TanhOp>{}(
941       loc, result_types, arg_types, args, b);
942 }
943 
944 template <>
945 inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc,
946                                                   ArrayRef<Type> result_types,
947                                                   ArrayRef<Type> arg_types,
948                                                   ArrayRef<Value> args,
949                                                   OpBuilder* b) {
950   return MapLhloOpToScalarOpImpl<isAnyIntegerType, ::mlir::XOrOp>{}(
951       loc, result_types, arg_types, args, b);
952 }
953 
954 }  // namespace impl
955 
956 struct HloOpToStdScalarOp {
957   // Implementation for LHLO ops except lmhlo::CompareOp.
958   template <typename HloOpTy, typename LhloOpTy = HloOpTy,
959             typename = std::enable_if_t<
960                 !std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
961                 std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
962                              std::false_type>::value>>
963   static Value map(HloOpTy op, ArrayRef<Type> result_types,
964                    ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
965     return impl::MapLhloOpToStdScalarOp<LhloOpTy>(
966         op.getLoc(), result_types, llvm::to_vector<4>(op->getOperandTypes()),
967         args, b);
968   }
969 
970   // Implementation for HLO ops except mhlo::CompareOp.
971   template <typename HloOpTy, typename LhloOpTy = mhlo::HloToLhloOp<HloOpTy>,
972             typename = std::enable_if_t<
973                 !std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
974                 !std::is_same<LhloOpTy, std::false_type>::value>>
975   static Value map(HloOpTy op, ArrayRef<Type> result_types,
976                    ArrayRef<Value> args, OpBuilder* b, int i = 0) {
977     return impl::MapLhloOpToStdScalarOp<LhloOpTy>(
978         op.getLoc(), result_types, llvm::to_vector<4>(op->getOperandTypes()),
979         args, b);
980   }
981 
982   // Implementation for lmhlo::CompareOp.
983   template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
984                                    LhloOpTy, lmhlo::CompareOp>::value>>
985   static Value map(lmhlo::CompareOp op, ArrayRef<Type> result_types,
986                    ArrayRef<Value> args, OpBuilder* b) {
987     auto comparison_direction = op.comparison_direction();
988     return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
989         op.getLoc(), comparison_direction, result_types,
990         llvm::to_vector<4>(op->getOperandTypes()), args, b);
991   }
992 
993   // Implementation for mhlo::CompareOp.
994   template <typename HloOpTy,
995             typename =
996                 std::enable_if_t<std::is_same<HloOpTy, mhlo::CompareOp>::value>>
997   static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types,
998                    ArrayRef<Value> args, OpBuilder* b) {
999     auto comparison_direction = op.comparison_direction();
1000     return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
1001         op.getLoc(), comparison_direction, result_types,
1002         llvm::to_vector<4>(op->getOperandTypes()), args, b);
1003   }
1004 
1005   // Implementation for LHLO ops except lmhlo::CompareOp.
1006   template <typename LhloOpTy,
1007             typename = std::enable_if_t<
1008                 !std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
1009                 std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
1010                              std::false_type>::value>>
1011   static Value map(Location loc, ArrayRef<Type> result_types,
1012                    ArrayRef<Type> arg_types, ArrayRef<Value> args, OpBuilder* b,
1013                    unsigned i = 0) {
1014     return impl::MapLhloOpToStdScalarOp<LhloOpTy>(loc, result_types, arg_types,
1015                                                   args, b);
1016   }
1017 
1018   // Implementation for lmhlo::CompareOp.
1019   template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
1020                                    LhloOpTy, lmhlo::CompareOp>::value>>
1021   static Value map(Location loc, StringRef comparison_direction,
1022                    ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
1023                    ArrayRef<Value> args, OpBuilder* b) {
1024     return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>(
1025         loc, comparison_direction, result_types, arg_types, args, b);
1026   }
1027 };
1028 
1029 }  // namespace lmhlo
1030 }  // namespace mlir
1031 
1032 #endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
1033