• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
17 #define TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
18 
19 #include <unordered_map>
20 
21 #include "absl/types/optional.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
30 #include "tensorflow/compiler/xla/comparison_util.h"
31 #include "tensorflow/compiler/xla/status.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace xla {
37 
38 class HloModule;
39 class HloComputation;
40 class HloInstruction;
41 class Shape;
42 
43 // Helper class for importing HloComputations.
44 class HloFunctionImporter {
45  public:
46   // Imports the given computation as a function in the given module. This also
47   // imports any computations referred by instructions in this computation.
48   static Status ImportAsFunc(const xla::HloComputation& computation,
49                              mlir::ModuleOp module,
50                              std::unordered_map<const xla::HloComputation*,
51                                                 mlir::FuncOp>* function_map,
52                              mlir::Builder* builder);
53 
54   // Imports the given hlo computation to the specified region.
55   static Status ImportAsRegion(const xla::HloComputation& computation,
56                                mlir::Region* region, mlir::Builder* builder);
57 
58   // Imports the given computation to the given place specified by `builder`.
59   // `arguments` contains values for all parameters.
60   static StatusOr<mlir::Value> ImportInstructions(
61       const xla::HloComputation& computation,
62       const llvm::SmallVectorImpl<mlir::Value>& arguments,
63       mlir::OpBuilder* builder);
64 
65   static StatusOr<mlir::Operation*> ImportInstruction(
66       const xla::HloInstruction* instr,
67       const llvm::SmallVectorImpl<mlir::Value>& operands,
68       mlir::OpBuilder* builder);
69 
70   static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape,
71                                llvm::StringRef attr_name = "minor_to_major");
72 
73   // TODO(b/179166199): move this to attribute_importer.h.
74   // Converts XLA instruction source target pairs to MLIR attribute.
75   static mlir::NamedAttribute ConvertSourceTargetPairs(
76       const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
77           source_target_pairs,
78       mlir::Builder* builder);
79 
80   // TODO(b/179166199): move this to attribute_importer.h.
81   // Converts replica groups to attribute
82   static mlir::NamedAttribute ConvertReplicaGroups(
83       absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder);
84 
85  private:
HloFunctionImporter(mlir::ModuleOp module,std::unordered_map<const xla::HloComputation *,mlir::FuncOp> * function_map,mlir::Builder * builder)86   HloFunctionImporter(mlir::ModuleOp module,
87                       std::unordered_map<const xla::HloComputation*,
88                                          mlir::FuncOp>* function_map,
89                       mlir::Builder* builder)
90       : context_(module.getContext()),
91         module_(module),
92         builder_(builder),
93         function_map_(function_map) {
94     context_->loadDialect<mlir::StandardOpsDialect>();
95     context_->loadDialect<mlir::mhlo::MhloDialect>();
96   }
97 
98   // Imports the given computation as a new function, if it hasn't been already
99   // imported.
100   StatusOr<mlir::FuncOp> ImportAsFunc(const xla::HloComputation& computation);
101 
102   // Imports the given computation in the specified region.
103   tensorflow::Status ImportAsRegion(const HloComputation& computation,
104                                     mlir::Region* region);
105 
106   // Imports instructions from the given computation in the specified block.
107   // Assumes that the block already has correct arguments populated.
108   tensorflow::Status ImportInstructions(const HloComputation& computation,
109                                         mlir::Block* block);
110   StatusOr<mlir::Value> ImportInstructionsImpl(
111       const xla::HloComputation& computation,
112       const llvm::SmallVectorImpl<mlir::Value>& arguments,
113       mlir::OpBuilder* builder);
114 
115   // Imports an instruction.
116   StatusOr<mlir::Operation*> ImportInstructionWithLayout(
117       const xla::HloInstruction* instruction,
118       const llvm::SmallVectorImpl<mlir::Value>& operands,
119       mlir::OpBuilder* func_builder);
120   StatusOr<mlir::Operation*> ImportInstructionImpl(
121       const HloInstruction* instruction,
122       const llvm::SmallVectorImpl<mlir::Value>& operands,
123       mlir::OpBuilder* func_builder);
124 
125   // Gets the MLIR operand values from an HLO Instruction.
126   StatusOr<llvm::SmallVector<mlir::Value, 4>> GetOperands(
127       const xla::HloInstruction* instruction);
128 
129   // Converts xla Tensor type to the corresponding MLIR type.
130   StatusOr<mlir::RankedTensorType> ConvertTensorType(const xla::Shape& shape);
131 
132   // Converts an XLA shape/layout to the corresponding MLIR layout
133   StatusOr<mlir::Attribute> ConvertShapeToMlirLayout(const xla::Shape& shape);
134 
135   // Returns the output type of an HloInstruction.
136   StatusOr<mlir::Type> GetReturnType(const xla::HloInstruction* instruction);
137 
138   // Takes a list of HloInstructions and generates the list of types used for
139   // input, bypassing tuples to subsets.
140   Status GetMlirTypes(const std::vector<xla::HloInstruction*>& instructions,
141                       llvm::SmallVectorImpl<mlir::Type>* types);
142 
143   // Returns the Mlir Value for the corresponding HloInstruction.
144   StatusOr<mlir::Value> GetMlirValue(const xla::HloInstruction* instruction);
145 
146   // Converts an XLA ComparisonDirection to the corresponding MLIR attribute.
147   mlir::NamedAttribute ConvertComparisonDirection(
148       ComparisonDirection direction);
149 
150   // Converts an XLA Comparison::Type to the corresponding MLIR attribute.
151   mlir::NamedAttribute ConvertComparisonType(Comparison::Type type);
152 
153   // Converts the dimensions of an HLO instruction into an MLIR attribute.
154   mlir::DenseIntElementsAttr ConvertDimensions(
155       llvm::ArrayRef<tensorflow::int64> op_dimensions);
156 
157   // Converts Array ref to an DenseIntElementsAttr.
158   mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> elements);
159 
160   // Converts Array ref to padding attribute. Input is a flattened list of
161   // padding low and padding high for each of the spatial dimensions.
162   mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding);
163 
164   // Converts channel id to attribute
165   mlir::NamedAttribute ConvertChannelHandle(
166       absl::optional<tensorflow::int64> channel_id);
167 
168   // Converts channel handle to attribute
169   mlir::NamedAttribute ConvertChannelHandle(const xla::ChannelHandle& channel);
170 
171   mlir::MLIRContext* context_;
172   mlir::ModuleOp module_;
173   mlir::Builder* builder_;
174 
175   // Mapping from HloComputation to the created MLIR function.
176   std::unordered_map<const xla::HloComputation*, mlir::FuncOp>* function_map_;
177 
178   // Mapping from HloInstructions to the associative MLIR values.
179   std::unordered_map<const xla::HloInstruction*, mlir::Value>
180       instruction_value_map_;
181 };
182 
183 }  // namespace xla
184 
185 #endif  // TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
186