• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
17 #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
18 
19 #include "absl/types/optional.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
21 #include "mlir/IR/Attributes.h"  // from @llvm-project
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
26 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
27 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 
33 namespace mlir {
34 
35 // This class will process an HloModule with the supplied BufferAssignment and
36 // populate the MLIR ModuleOp with the computation converted in the LHLO
37 // dialect.
38 class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault {
39  public:
40   // Initializes internal data structures. It must be called before calling any
41   // of the visitors.
42   tensorflow::Status Initialize();
43 
LhloDialectEmitter(const xla::BufferAssignment & assignment,const xla::HloComputation & computation,ModuleOp module)44   LhloDialectEmitter(const xla::BufferAssignment& assignment,
45                      const xla::HloComputation& computation, ModuleOp module)
46       : assignment_(std::move(assignment)),
47         computation_(computation),
48         module_(module),
49         builder_(module.getContext()),
50         i8_type_(builder_.getIntegerType(8)) {}
51 
52   xla::StatusOr<mlir::Operation*> EmitOp(const xla::HloInstruction* instr);
53 
54   xla::StatusOr<mhlo::ScatterDimensionNumbers> GetScatterDimensionNumbers(
55       const xla::HloInstruction* instr);
56 
57  private:
58   xla::StatusOr<lmhlo::SortOp> EmitSortOp(const xla::HloInstruction* instr);
59   xla::StatusOr<lmhlo::FusionOp> EmitFusionOp(const xla::HloInstruction* instr);
60   xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(
61       const xla::HloInstruction* instr);
62   xla::StatusOr<lmhlo::SelectAndScatterOp> EmitSelectAndScatterOp(
63       const xla::HloInstruction* instr);
64 
65   xla::StatusOr<Operation*> EmitCustomCallOp(const xla::HloInstruction* instr);
66   xla::StatusOr<lmhlo_gpu::CholeskyOp> EmitCholesky(
67       const xla::HloCustomCallInstruction* custom_call);
68   xla::StatusOr<Operation*> EmitGemm(
69       const xla::HloCustomCallInstruction* custom_call);
70   xla::StatusOr<Operation*> EmitDnnConvolution(
71       const xla::HloCustomCallInstruction* custom_call);
72   xla::StatusOr<Operation*> EmitDnnBatchNorm(
73       const xla::HloCustomCallInstruction* custom_call);
74 
75   xla::StatusOr<lmhlo::ReduceOp> EmitReduceOp(const xla::HloInstruction* instr);
76   xla::StatusOr<GetGlobalMemrefOp> EmitConstant(
77       const xla::HloInstruction* instr);
78 
79   xla::StatusOr<lmhlo::CompareOp> EmitCompareOp(
80       const xla::HloInstruction* instr);
81 
82   xla::StatusOr<lmhlo::InfeedOp> EmitInfeedOp(const xla::HloInstruction* instr);
83   xla::StatusOr<lmhlo::OutfeedOp> EmitOutfeedOp(
84       const xla::HloInstruction* instr);
85   xla::StatusOr<lmhlo::MapOp> EmitMapOp(const xla::HloInstruction* instr);
86 
87   xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp(
88       const xla::HloInstruction* instr);
89 
90   xla::StatusOr<lmhlo::AllToAllOp> EmitAllToAllOp(
91       const xla::HloInstruction* instr);
92   xla::StatusOr<lmhlo::AllGatherOp> EmitAllGatherOp(
93       const xla::HloInstruction* instr);
94   xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
95       const xla::HloInstruction* instr);
96   xla::StatusOr<lmhlo::CollectivePermuteOp> EmitCollectivePermuteOp(
97       const xla::HloInstruction* instr);
98 
99   xla::StatusOr<lmhlo::BroadcastInDimOp> EmitBroadcastOp(
100       const xla::HloInstruction* instr);
101 
102   xla::StatusOr<lmhlo::ConcatenateOp> EmitConcatenateOp(
103       const xla::HloInstruction* instr);
104 
105   xla::StatusOr<lmhlo::IotaOp> EmitIotaOp(const xla::HloInstruction* instr);
106 
107   xla::StatusOr<lmhlo::ReverseOp> EmitReverseOp(
108       const xla::HloInstruction* instr);
109 
110   xla::StatusOr<lmhlo::TransposeOp> EmitTransposeOp(
111       const xla::HloInstruction* instr);
112 
113   xla::StatusOr<lmhlo::PadOp> EmitPadOp(const xla::HloInstruction* instr);
114 
115   xla::StatusOr<lmhlo::ReduceWindowOp> EmitReduceWindowOp(
116       const xla::HloInstruction* instr);
117 
118   xla::StatusOr<lmhlo::SliceOp> EmitSliceOp(const xla::HloInstruction* instr);
119 
120   xla::StatusOr<lmhlo::GatherOp> EmitGatherOp(const xla::HloInstruction* instr);
121 
122   xla::StatusOr<lmhlo::DynamicSliceOp> EmitDynamicSliceOp(
123       const xla::HloInstruction* instr);
124 
125   xla::StatusOr<lmhlo::DotOp> EmitDotOp(const xla::HloInstruction* instr);
126   xla::StatusOr<lmhlo::RngGetAndUpdateStateOp> EmitRngGetAndUpdateStateOp(
127       const xla::HloInstruction* instr);
128   xla::StatusOr<lmhlo::FftOp> EmitFftOp(const xla::HloInstruction* instr);
129   xla::StatusOr<lmhlo::TriangularSolveOp> EmitTriangularSolveOp(
130       const xla::HloInstruction* instr);
131 
132   // Create LHLO operation operands given an XLA HLO instruction. By default,
133   // all XLA HLO operands and results are converted to MLIR and appended to
134   // `operands`. If `num_operands` is specified, only the first `num_operand`
135   // operands of the instruction are converted to MLIR. The function returns the
136   // actual number of operands and results generated for MLIR in `num_arguments`
137   // and `num_results`.
138   xla::Status CreateOperands(const xla::HloInstruction* instr,
139                              absl::optional<xla::int64> num_operands,
140                              SmallVectorImpl<Value>& operands,
141                              size_t& num_arguments, size_t& num_results);
142 
143   template <typename OpType>
144   xla::StatusOr<OpType> CreateOpWithoutAttrs(
145       const xla::HloInstruction* instr,
146       absl::optional<xla::int64> num_operands = absl::nullopt) {
147     size_t unused;
148     return CreateOpWithoutAttrs<OpType>(instr, unused, unused, num_operands);
149   }
150 
151   template <typename OpType>
152   xla::StatusOr<OpType> CreateOpWithoutAttrs(
153       const xla::HloInstruction* instr, size_t& num_arguments,
154       size_t& num_results,
155       absl::optional<xla::int64> num_operands = absl::nullopt);
156 
157   template <typename OpType>
158   OpType CreateOpWithoutAttrs(const xla::HloInstruction* instr,
159                               ValueRange operands);
160 
161   template <typename T>
GetI64DenseElementsAttr(const T & container)162   DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) {
163     return builder_.getI64TensorAttr(
164         {container.data(), static_cast<size_t>(container.size())});
165   }
166 
GetWindowElements(const xla::Window & window,std::function<int64_t (const xla::WindowDimension & dim)> getter)167   DenseIntElementsAttr GetWindowElements(
168       const xla::Window& window,
169       std::function<int64_t(const xla::WindowDimension& dim)> getter) {
170     llvm::SmallVector<int64_t, 4> elements;
171     elements.reserve(window.dimensions_size());
172     for (const xla::WindowDimension& dim : window.dimensions()) {
173       elements.push_back(getter(dim));
174     }
175     return GetI64DenseElementsAttr(elements);
176   }
177 
178   static mlir::DenseIntElementsAttr GetLayoutAttribute(
179       const xla::Layout& layout, Builder* builder);
180 
181   tensorflow::Status DefaultAction(const xla::HloInstruction* instr) final;
182 
183   // Computation parameters don't need any specific handling when they are
184   // visited, they are already processed when we enter a new computation.
HandleParameter(const xla::HloInstruction * instr)185   tensorflow::Status HandleParameter(const xla::HloInstruction* instr) final {
186     return tensorflow::Status::OK();
187   }
188 
189   // Helper function that recursively visits the tuple structure in
190   // `current_shape`, and reconstruct a matching lmhlo::TupleOp.
191   // Each leaf node is converted to an std.view op with corresponding offsets.
192   // If no tuple presents, it simply returns a view of the buffer.
193   tensorflow::Status GetOrCreateViewImpl(const xla::HloInstruction* instr,
194                                          const xla::Shape& current_shape,
195                                          xla::ShapeIndex* current_shape_index,
196                                          SmallVectorImpl<Value>* values);
197 
198   // Helper function to create view/tuple of views to a buffer for a given
199   // instruction result. `result_subset` can be used to for instructions that
200   // have a tuple result and MLIR conversion needs to convert only one of the
201   // tuple elements. Note that if needed, this can be extended to take a list of
202   // ShapeIndex values in case we need finer control on what elements of the
203   // output tuple to be converted to MLIR.
204   tensorflow::Status GetOrCreateView(const xla::HloInstruction* instr,
205                                      SmallVectorImpl<Value>* values,
206                                      const xla::ShapeIndex& result_subset = {});
207 
208   xla::StatusOr<Value> GetOrCreateArrayView(
209       const xla::HloInstruction* instr, const xla::Shape& current_shape,
210       const xla::ShapeIndex& current_shape_index);
211 
212   xla::StatusOr<Value> RewriteFusionOperand(const xla::HloInstruction* root,
213                                             const xla::Shape& shape,
214                                             xla::ShapeIndex* shape_index,
215                                             OpBuilder* b, Location loc);
216 
217   // Return an MLIR location for an HLO instruction.
getLocation(const xla::HloInstruction * inst)218   Location getLocation(const xla::HloInstruction* inst) {
219     return NameLoc::get(builder_.getIdentifier(inst->name()),
220                         builder_.getContext());
221   }
222 
223   // This map provides access to MLIR buffers for each HLO buffer allocation.
224   // The MLIR buffers are all `memref<{size}xi8>` and correspond to function
225   // parameters. It is populated at the beginning of the processing with all
226   // the buffer allocations and is unchanged afterward. Every HLOInstruction
227   // is using a "slice" of the buffer allocation and providing shape, layout,
228   // and Dtype. An MLIR view is used separately to model slices into the
229   // allocations (see below).
230   llvm::DenseMap<const xla::BufferAllocation*, Value> allocations_;
231 
232   // This map provides access to MLIR buffers for each HLO instruction, keyed
233   // instruction identity. A slice is contained in a BufferAllocation, and has
234   // an offset and a size.
235   //
236   // As for why we don't use HloInstruction*, see GetOrCreateView(), but
237   // mostly we want to leverage better of the aliased buffers.
238   //
239   // If the HloInstruction is a tuple, all leaf nodes are stored flattened.
240   // Otherwise, there will be a single buffer.
241   //
242   // An MLIR buffer is either an input parameter, or a ViewOp in the case
243   // where the slice is only part of its allocation.
244   //
245   // `slices_` is populated lazily in the `GetOrCreateView()` helper as we
246   // process every instruction.
247   absl::flat_hash_map<std::pair<const xla::HloInstruction*, xla::ShapeIndex>,
248                       Value>
249       slices_;
250 
251   // The BufferAssignment computed by XLA ahead of time.
252   const xla::BufferAssignment& assignment_;
253 
254   // The HLO module that will be converted.
255   const xla::HloComputation& computation_;
256 
257   // This is the MLIR module in which a function will be created for every HLO
258   // computation.
259   ModuleOp module_;
260 
261   // The builder keeps track of the current insertion point in the MLIR
262   // module.
263   OpBuilder builder_;
264   // Convenient "cached" access to this widely used MLIR type (i8).
265   Type i8_type_;
266 };
267 
268 // Populate the MLIR `module` with the computation from the `hlo_module` using
269 // the provided buffer `assignment`. The returned `Status` indicates success
270 // or failure in the conversion.
271 tensorflow::Status HloToLhloModule(const xla::BufferAssignment& assignment,
272                                    const xla::HloModule& hlo_module,
273                                    ModuleOp module);
274 
275 OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input,
276                                                MLIRContext* context);
277 
278 }  // namespace mlir
279 
280 #endif  // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
281