• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
17 #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
18 
19 #include "absl/types/optional.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
27 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
28 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
29 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 
34 namespace mlir {
35 
36 // This class will process an HloModule with the supplied BufferAssignment and
37 // populate the MLIR ModuleOp with the computation converted in the LHLO
38 // dialect.
39 class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault {
40  public:
41   // Initializes internal data structures. It must be called before calling any
42   // of the visitors.
43   tensorflow::Status Initialize();
44 
LhloDialectEmitter(const xla::BufferAssignment & assignment,const xla::HloComputation & computation,ModuleOp module)45   LhloDialectEmitter(const xla::BufferAssignment& assignment,
46                      const xla::HloComputation& computation, ModuleOp module)
47       : assignment_(std::move(assignment)),
48         computation_(computation),
49         module_(module),
50         builder_(module.getContext()),
51         i8_type_(builder_.getIntegerType(8)) {}
52 
53   xla::StatusOr<mlir::Operation*> EmitOp(const xla::HloInstruction* instr);
54 
55   static xla::StatusOr<mhlo::ScatterDimensionNumbers>
56   GetScatterDimensionNumbers(const xla::HloInstruction* instr,
57                              mlir::MLIRContext* context);
58 
59  private:
60   xla::StatusOr<lmhlo::SortOp> EmitSortOp(const xla::HloInstruction* instr);
61   xla::StatusOr<lmhlo::FusionOp> EmitFusionOp(const xla::HloInstruction* instr);
62   xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(
63       const xla::HloInstruction* instr);
64   xla::StatusOr<lmhlo::SelectAndScatterOp> EmitSelectAndScatterOp(
65       const xla::HloInstruction* instr);
66 
67   xla::StatusOr<Operation*> EmitCustomCallOp(const xla::HloInstruction* instr);
68   xla::StatusOr<lmhlo_gpu::CholeskyOp> EmitCholesky(
69       const xla::HloCustomCallInstruction* custom_call);
70   xla::StatusOr<Operation*> EmitGemm(
71       const xla::HloCustomCallInstruction* custom_call);
72   xla::StatusOr<Operation*> EmitDnnConvolution(
73       const xla::HloCustomCallInstruction* custom_call);
74   xla::StatusOr<Operation*> EmitDnnBatchNorm(
75       const xla::HloCustomCallInstruction* custom_call);
76 
77   xla::StatusOr<memref::GetGlobalOp> EmitConstant(
78       const xla::HloInstruction* instr);
79 
80   xla::StatusOr<lmhlo::InfeedOp> EmitInfeedOp(const xla::HloInstruction* instr);
81   xla::StatusOr<lmhlo::OutfeedOp> EmitOutfeedOp(
82       const xla::HloInstruction* instr);
83 
84   xla::StatusOr<lmhlo::AllToAllOp> EmitAllToAllOp(
85       const xla::HloInstruction* instr);
86   xla::StatusOr<lmhlo::AllGatherOp> EmitAllGatherOp(
87       const xla::HloInstruction* instr);
88   xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
89       const xla::HloInstruction* instr);
90   xla::StatusOr<lmhlo_gpu::AllReduceStartOp> EmitAllReduceStartOp(
91       const xla::HloInstruction* instr);
92   xla::StatusOr<lmhlo_gpu::AllReduceDoneOp> EmitAllReduceDoneOp(
93       const xla::HloInstruction* instr);
94   xla::StatusOr<lmhlo::ReduceScatterOp> EmitReduceScatterOp(
95       const xla::HloInstruction* instr);
96   xla::StatusOr<lmhlo::CollectivePermuteOp> EmitCollectivePermuteOp(
97       const xla::HloInstruction* instr);
98 
99   xla::StatusOr<lmhlo::RngGetAndUpdateStateOp> EmitRngGetAndUpdateStateOp(
100       const xla::HloInstruction* instr);
101   xla::StatusOr<lmhlo::FftOp> EmitFftOp(const xla::HloInstruction* instr);
102   xla::StatusOr<lmhlo::TriangularSolveOp> EmitTriangularSolveOp(
103       const xla::HloInstruction* instr);
104   xla::StatusOr<Operation*> EmitBitcast(const xla::HloInstruction* instr);
105 
106   xla::StatusOr<lmhlo::CaseOp> EmitCaseOp(const xla::HloInstruction* instr);
107 
108   xla::StatusOr<lmhlo::WhileOp> EmitWhileOp(const xla::HloInstruction* instr);
109 
110   xla::Status ImportAsLmhloRegion(xla::HloComputation* computation,
111                                   mlir::Region* region);
112 
113   // Since LMHLO dialect does not define token types, this enum controls how
114   // token operand/results from XLA:HLO are lowered to MLIR.
115   enum class TokenLoweringMode {
116     kFailToLower,  // Fail lowering if token inputs are encountered.
117     kUseNull,      // Use a null Value in the operand list for each token.
118     // kSkip,        // Skip any token inputs or outputs (not yet needed)
119   };
120 
121   // Create LHLO operation operands given an XLA HLO instruction. By default,
122   // all XLA HLO operands and results are converted to MLIR and appended to
123   // `operands`. If `num_operands` is specified, only the first `num_operand`
124   // operands of the instruction are converted to MLIR. The function returns the
125   // actual number of operands and results generated for MLIR in `num_arguments`
126   // and `num_results`.
127   xla::Status CreateOperands(const xla::HloInstruction* instr,
128                              absl::optional<xla::int64> num_operands,
129                              TokenLoweringMode token_mode,
130                              SmallVectorImpl<Value>& operands,
131                              size_t& num_arguments, size_t& num_results);
132 
133   template <typename OpType>
134   xla::StatusOr<OpType> CreateOpWithoutAttrs(
135       const xla::HloInstruction* instr,
136       absl::optional<xla::int64> num_operands = absl::nullopt) {
137     size_t unused;
138     return CreateOpWithoutAttrs<OpType>(instr, unused, unused, num_operands);
139   }
140 
141   template <typename OpType>
142   xla::StatusOr<OpType> CreateOpWithoutAttrs(
143       const xla::HloInstruction* instr, size_t& num_arguments,
144       size_t& num_results,
145       absl::optional<xla::int64> num_operands = absl::nullopt);
146 
147   template <typename OpType>
148   OpType CreateOpWithoutAttrs(const xla::HloInstruction* instr,
149                               ValueRange operands);
150 
151   xla::StatusOr<mlir::Operation*> CreateOpInFusion(
152       const xla::HloInstruction* instr, ValueRange buffer_operands,
153       size_t num_arguments, size_t num_results);
154 
155   xla::StatusOr<mlir::Operation*> CreateOpInFusion(
156       const xla::HloInstruction* instr);
157 
158   template <typename T>
GetI64DenseElementsAttr(const T & container)159   DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) {
160     return builder_.getI64TensorAttr(
161         {container.data(), static_cast<size_t>(container.size())});
162   }
163 
GetWindowElements(const xla::Window & window,std::function<int64_t (const xla::WindowDimension & dim)> getter)164   DenseIntElementsAttr GetWindowElements(
165       const xla::Window& window,
166       std::function<int64_t(const xla::WindowDimension& dim)> getter) {
167     llvm::SmallVector<int64_t, 4> elements;
168     elements.reserve(window.dimensions_size());
169     for (const xla::WindowDimension& dim : window.dimensions()) {
170       elements.push_back(getter(dim));
171     }
172     return GetI64DenseElementsAttr(elements);
173   }
174 
175   static mlir::DenseIntElementsAttr GetLayoutAttribute(
176       const xla::Layout& layout, Builder* builder);
177 
178   tensorflow::Status DefaultAction(const xla::HloInstruction* instr) final;
179 
180   // Computation parameters don't need any specific handling when they are
181   // visited, they are already processed when we enter a new computation.
HandleParameter(const xla::HloInstruction * instr)182   tensorflow::Status HandleParameter(const xla::HloInstruction* instr) final {
183     return tensorflow::Status::OK();
184   }
185 
186   // Helper function that recursively visits the tuple structure in
187   // `current_shape`, and reconstruct a matching lmhlo::TupleOp.
188   // Each leaf node is converted to an std.view op with corresponding offsets.
189   // If no tuple presents, it simply returns a view of the buffer.
190   tensorflow::Status GetOrCreateViewImpl(const xla::HloInstruction* instr,
191                                          const xla::Shape& current_shape,
192                                          xla::ShapeIndex* current_shape_index,
193                                          SmallVectorImpl<Value>* values,
194                                          TokenLoweringMode token_mode);
195 
196   // Helper function to create view/tuple of views to a buffer for a given
197   // instruction result. `result_subset` can be used to for instructions that
198   // have a tuple result and MLIR conversion needs to convert only one of the
199   // tuple elements. Note that if needed, this can be extended to take a list of
200   // ShapeIndex values in case we need finer control on what elements of the
201   // output tuple to be converted to MLIR.
202   tensorflow::Status GetOrCreateView(
203       const xla::HloInstruction* instr, SmallVectorImpl<Value>* values,
204       const xla::ShapeIndex& result_subset = {},
205       TokenLoweringMode token_mode = TokenLoweringMode::kFailToLower);
206 
207   xla::StatusOr<Value> GetOrCreateArrayView(
208       const xla::HloInstruction* instr, const xla::Shape& current_shape,
209       const xla::ShapeIndex& current_shape_index);
210 
211   xla::StatusOr<Value> RewriteFusionOperand(const xla::HloInstruction* root,
212                                             const xla::Shape& shape,
213                                             xla::ShapeIndex* shape_index,
214                                             OpBuilder* b, Location loc);
215 
216   // Return an MLIR location for an HLO instruction.
getLocation(const xla::HloInstruction * inst)217   Location getLocation(const xla::HloInstruction* inst) {
218     return NameLoc::get(builder_.getIdentifier(inst->name()));
219   }
220 
221   // This map provides access to MLIR buffers for each HLO buffer allocation.
222   // The MLIR buffers are all `memref<{size}xi8>` and correspond to function
223   // parameters. It is populated at the beginning of the processing with all
224   // the buffer allocations and is unchanged afterward. Every HLOInstruction
225   // is using a "slice" of the buffer allocation and providing shape, layout,
226   // and Dtype. An MLIR view is used separately to model slices into the
227   // allocations (see below).
228   llvm::DenseMap<const xla::BufferAllocation*, Value> allocations_;
229 
230   // This map provides access to MLIR buffers for each HLO instruction, keyed
231   // instruction identity. A slice is contained in a BufferAllocation, and has
232   // an offset and a size.
233   //
234   // As for why we don't use HloInstruction*, see GetOrCreateView(), but
235   // mostly we want to leverage better of the aliased buffers.
236   //
237   // If the HloInstruction is a tuple, all leaf nodes are stored flattened.
238   // Otherwise, there will be a single buffer.
239   //
240   // An MLIR buffer is either an input parameter, or a ViewOp in the case
241   // where the slice is only part of its allocation.
242   //
243   // `slices_` is populated lazily in the `GetOrCreateView()` helper as we
244   // process every instruction.
245   absl::flat_hash_map<std::pair<const xla::HloInstruction*, xla::ShapeIndex>,
246                       Value>
247       slices_;
248 
249   // The BufferAssignment computed by XLA ahead of time.
250   const xla::BufferAssignment& assignment_;
251 
252   // The HLO module that will be converted.
253   const xla::HloComputation& computation_;
254 
255   // This is the MLIR module in which a function will be created for every HLO
256   // computation.
257   ModuleOp module_;
258 
259   // The builder keeps track of the current insertion point in the MLIR
260   // module.
261   OpBuilder builder_;
262   // Convenient "cached" access to this widely used MLIR type (i8).
263   Type i8_type_;
264 
265   // Map all-reduce-start ops to their LHLO op, so we can connect the
266   // all-reduce-done op with the correct token.
267   absl::flat_hash_map<const xla::HloInstruction*, lmhlo_gpu::AllReduceStartOp>
268       all_reduce_start_ops_;
269 };
270 
271 // Populate the MLIR `module` with the computation from the `hlo_module` using
272 // the provided buffer `assignment`. The returned `Status` indicates success
273 // or failure in the conversion.
274 tensorflow::Status HloToLhloModule(const xla::BufferAssignment& assignment,
275                                    const xla::HloModule& hlo_module,
276                                    ModuleOp module);
277 
278 tensorflow::Status OptimizeAndConvertHloToLmhlo(
279     std::unique_ptr<xla::HloModule> hlo_module, ModuleOp module,
280     StringRef platform_name);
281 
282 OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input,
283                                                MLIRContext* context);
284 
285 }  // namespace mlir
286 
287 #endif  // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
288