1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef XLA_MLIR_RUNTIME_TYPE_CONVERTER_H_ 17 #define XLA_MLIR_RUNTIME_TYPE_CONVERTER_H_ 18 19 #include <functional> 20 #include <memory> 21 22 #include "llvm/ADT/STLExtras.h" 23 #include "llvm/Support/Error.h" 24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 25 #include "tensorflow/compiler/xla/runtime/types.h" 26 27 namespace xla { 28 namespace runtime { 29 30 //===----------------------------------------------------------------------===// 31 // Type conversion from the compile time types to the run-time types. 32 //===----------------------------------------------------------------------===// 33 34 // Type converter converts MLIR types known at compile time to the corresponding 35 // types used at run time. It provides default conversions for the canonical 36 // types (memrefs, tensors, etc...) and allows users to register custom 37 // conversions for user-defined types. 38 class TypeConverter { 39 public: 40 // Conversion function must return run time type corresponding to the compile 41 // time type if the conversion is successful, or `nullptr` if failed. 42 using ConversionFn = std::function<std::unique_ptr<Type>(mlir::Type)>; 43 44 // Adds a type conversion function with a type predicate. 45 // 46 // Example: 47 // 48 // AddConversion([](mlir::TensorType) -> std::unique_ptr<Type> { ... }); 49 // 50 // The conversion function will match only the tensor type, and return empty 51 // result for all other types, and the type converter will try the next 52 // conversion function (see `Convert` implementation). 53 template <typename Fn, typename FnTraits = llvm::function_traits<Fn>> AddConversion(Fn fn)54 void AddConversion(Fn fn) { 55 using ArgType = typename FnTraits::template arg_t<0>; 56 conversions_.emplace_back( 57 [fn = std::forward<Fn>(fn)](mlir::Type type) -> std::unique_ptr<Type> { 58 if (auto arg = type.dyn_cast<ArgType>()) return fn(arg); 59 return {}; 60 }); 61 } 62 63 // Converts MLIR element type to the DType. 64 static llvm::Expected<tfrt::DType> ConvertElementType(mlir::Type type); 65 66 // Converts MLIR type to the runtime type. Returns error if conversion was not 67 // successful and the type has no corresponding run time type. 68 llvm::Expected<std::unique_ptr<Type>> Convert(mlir::Type type) const; 69 70 // Converts MLIR function type to the runtime function type. Returns error if 71 // function has unsupported operands or results types. 72 llvm::Expected<FunctionType> Convert(mlir::FunctionType type) const; 73 74 private: 75 llvm::SmallVector<ConversionFn> conversions_; 76 }; 77 78 } // namespace runtime 79 } // namespace xla 80 81 #endif // XLA_MLIR_RUNTIME_TYPE_CONVERTER_H_ 82