• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_MLIR_RUNTIME_TYPE_CONVERTER_H_
17 #define XLA_MLIR_RUNTIME_TYPE_CONVERTER_H_
18 
19 #include <functional>
20 #include <memory>
21 
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Error.h"
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "tensorflow/compiler/xla/runtime/types.h"
26 
27 namespace xla {
28 namespace runtime {
29 
30 //===----------------------------------------------------------------------===//
31 // Type conversion from the compile time types to the run-time types.
32 //===----------------------------------------------------------------------===//
33 
34 // Type converter converts MLIR types known at compile time to the corresponding
35 // types used at run time. It provides default conversions for the canonical
36 // types (memrefs, tensors, etc...) and allows users to register custom
37 // conversions for user-defined types.
38 class TypeConverter {
39  public:
40   // Conversion function must return run time type corresponding to the compile
41   // time type if the conversion is successful, or `nullptr` if failed.
42   using ConversionFn = std::function<std::unique_ptr<Type>(mlir::Type)>;
43 
44   // Adds a type conversion function with a type predicate.
45   //
46   // Example:
47   //
48   //   AddConversion([](mlir::TensorType) -> std::unique_ptr<Type> { ... });
49   //
50   // The conversion function will match only the tensor type, and return empty
51   // result for all other types, and the type converter will try the next
52   // conversion function (see `Convert` implementation).
53   template <typename Fn, typename FnTraits = llvm::function_traits<Fn>>
AddConversion(Fn fn)54   void AddConversion(Fn fn) {
55     using ArgType = typename FnTraits::template arg_t<0>;
56     conversions_.emplace_back(
57         [fn = std::forward<Fn>(fn)](mlir::Type type) -> std::unique_ptr<Type> {
58           if (auto arg = type.dyn_cast<ArgType>()) return fn(arg);
59           return {};
60         });
61   }
62 
63   // Converts MLIR element type to the DType.
64   static llvm::Expected<tfrt::DType> ConvertElementType(mlir::Type type);
65 
66   // Converts MLIR type to the runtime type. Returns error if conversion was not
67   // successful and the type has no corresponding run time type.
68   llvm::Expected<std::unique_ptr<Type>> Convert(mlir::Type type) const;
69 
70   // Converts MLIR function type to the runtime function type. Returns error if
71   // function has unsupported operands or results types.
72   llvm::Expected<FunctionType> Convert(mlir::FunctionType type) const;
73 
74  private:
75   llvm::SmallVector<ConversionFn> conversions_;
76 };
77 
78 }  // namespace runtime
79 }  // namespace xla
80 
81 #endif  // XLA_MLIR_RUNTIME_TYPE_CONVERTER_H_
82