• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
17 #define TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
18 
19 #include <unordered_map>
20 
21 #include "absl/types/optional.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
30 #include "tensorflow/compiler/xla/comparison_util.h"
31 #include "tensorflow/compiler/xla/status.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace xla {
37 
38 class HloModule;
39 class HloComputation;
40 class HloInstruction;
41 class Shape;
42 
43 // Helper class for importing HloComputations.
44 class HloFunctionImporter {
45  public:
46   // Imports the given computation as a function in the given module. This also
47   // imports any computations referred by instructions in this computation.
48   static Status ImportAsFunc(const xla::HloComputation& computation,
49                              mlir::ModuleOp module,
50                              std::unordered_map<const xla::HloComputation*,
51                                                 mlir::FuncOp>* function_map,
52                              mlir::Builder* builder);
53 
54   // Imports the given hlo computation to the specified region.
55   static Status ImportAsRegion(const xla::HloComputation& computation,
56                                mlir::Region* region, mlir::Builder* builder);
57 
58   // Imports the given computation to the given place specified by `builder`.
59   // `arguments` contains values for all parameters.
60   static StatusOr<mlir::Value> ImportInstructions(
61       const xla::HloComputation& computation,
62       const llvm::SmallVectorImpl<mlir::Value>& arguments,
63       mlir::OpBuilder* builder);
64 
65   static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape);
66 
67   // TODO(b/179166199): move this to attribute_importer.h.
68   // Converts XLA instruction source target pairs to MLIR attribute.
69   static mlir::NamedAttribute ConvertSourceTargetPairs(
70       const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
71           source_target_pairs,
72       mlir::Builder* builder);
73 
74   // TODO(b/179166199): move this to attribute_importer.h.
75   // Converts replica groups to attribute
76   static mlir::NamedAttribute ConvertReplicaGroups(
77       const std::vector<ReplicaGroup>& replica_groups, mlir::Builder* builder);
78 
79  private:
HloFunctionImporter(mlir::ModuleOp module,std::unordered_map<const xla::HloComputation *,mlir::FuncOp> * function_map,mlir::Builder * builder)80   HloFunctionImporter(mlir::ModuleOp module,
81                       std::unordered_map<const xla::HloComputation*,
82                                          mlir::FuncOp>* function_map,
83                       mlir::Builder* builder)
84       : context_(module.getContext()),
85         module_(module),
86         builder_(builder),
87         function_map_(function_map) {
88     context_->loadDialect<mlir::StandardOpsDialect>();
89     context_->loadDialect<mlir::mhlo::MhloDialect>();
90   }
91 
92   // Imports the given computation as a new function, if it hasn't been already
93   // imported.
94   StatusOr<mlir::FuncOp> ImportAsFunc(const xla::HloComputation& computation);
95 
96   // Imports the given computation in the specified region.
97   tensorflow::Status ImportAsRegion(const HloComputation& computation,
98                                     mlir::Region* region);
99 
100   // Imports instructions from the given computation in the specified block.
101   // Assumes that the block already has correct arguments populated.
102   tensorflow::Status ImportInstructions(const HloComputation& computation,
103                                         mlir::Block* block);
104   StatusOr<mlir::Value> ImportInstructionsImpl(
105       const xla::HloComputation& computation,
106       const llvm::SmallVectorImpl<mlir::Value>& arguments,
107       mlir::OpBuilder* builder);
108 
109   // Imports an instruction.
110   StatusOr<mlir::Operation*> ImportInstruction(xla::HloInstruction* instruction,
111                                                mlir::OpBuilder* func_builder);
112   StatusOr<mlir::Operation*> ImportInstructionImpl(
113       HloInstruction* instruction, mlir::OpBuilder* func_builder);
114 
115   // Gets the MLIR operand values from an HLO Instruction.
116   StatusOr<llvm::SmallVector<mlir::Value, 4>> GetOperands(
117       xla::HloInstruction* instruction);
118 
119   // Converts xla Tensor type to the corresponding MLIR type.
120   StatusOr<mlir::RankedTensorType> ConvertTensorType(const xla::Shape& shape);
121 
122   // Returns the output type of an HloInstruction.
123   StatusOr<mlir::Type> GetReturnType(xla::HloInstruction* instruction);
124 
125   // Takes a list of HloInstructions and generates the list of types used for
126   // input, bypassing tuples to subsets.
127   Status GetMlirTypes(const std::vector<xla::HloInstruction*>& instructions,
128                       llvm::SmallVectorImpl<mlir::Type>* types);
129 
130   // Returns the Mlir Value for the corresponding HloInstruction.
131   StatusOr<mlir::Value> GetMlirValue(xla::HloInstruction* instruction);
132 
133   // Converts an XLA ComparisonDirection to the corresponding MLIR attribute.
134   mlir::NamedAttribute ConvertComparisonDirection(
135       ComparisonDirection direction);
136 
137   // Converts an XLA Comparison::Type to the corresponding MLIR attribute.
138   mlir::NamedAttribute ConvertComparisonType(Comparison::Type type);
139 
140   // Converts the dimensions of an HLO instruction into an MLIR attribute.
141   mlir::DenseIntElementsAttr ConvertDimensions(
142       llvm::ArrayRef<tensorflow::int64> op_dimensions);
143 
144   // Converts Array ref to an DenseIntElementsAttr.
145   mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> elements);
146 
147   // Converts Array ref to padding attribute. Input is a flattened list of
148   // padding low and padding high for each of the spatial dimensions.
149   mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding);
150 
151   // Converts channel id to attribute
152   mlir::NamedAttribute ConvertChannelHandle(
153       absl::optional<tensorflow::int64> channel_id);
154 
155   // Converts channel handle to attribute
156   mlir::NamedAttribute ConvertChannelHandle(const xla::ChannelHandle& channel);
157 
158   mlir::MLIRContext* context_;
159   mlir::ModuleOp module_;
160   mlir::Builder* builder_;
161 
162   // Mapping from HloComputation to the created MLIR function.
163   std::unordered_map<const xla::HloComputation*, mlir::FuncOp>* function_map_;
164 
165   // Mapping from HloInstructions to the associative MLIR values.
166   std::unordered_map<xla::HloInstruction*, mlir::Value> instruction_value_map_;
167 };
168 
169 }  // namespace xla
170 
171 #endif  // TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
172