1# MLIR-HLO: A Standalone "HLO" MLIR-based Compiler 2 3The code here exists in two places: 4 5* https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/mlir_hlo; 6 this is the canonical location and where contributions should be made using 7 GitHub pull-requests. 8* https://github.com/tensorflow/mlir-hlo; this is a standalone repository with 9 a view to the same code to allow other projects to use this without 10 depending on the entire TF monorepo. 11 12This implements a self-contained compiler for a linear algebra set of operations 13inspired by XLA 14[HLO IR](https://www.tensorflow.org/xla/architecture#how_does_xla_work) using 15MLIR components. It is designed to provide an end-to-end flow independent of 16TensorFlow and XLA, but usable inside of these projects. 17 18Coding practice and conventions in this repository follow the 19[MLIR Developer Guide](https://mlir.llvm.org/getting_started/DeveloperGuide/) in 20this repo as part of the intent to act as an incubator for technology to 21upstream. 22 23## QuickStart: building and testing 24 25These instructions work on Linux, you may have to adjust for your platform. 26 27To build the code in this repository, you need a clone of the LLVM/MLIR git 28repository: 29 30 $ git clone https://github.com/llvm/llvm-project.git 31 32 33You need to make sure you have the right commit checked out in the LLVM 34repository (you need to do this every time you pull from this repo): 35 36 $ (cd llvm-project && git checkout $(cat ../build_tools/llvm_version.txt)) 37 38We provide a script to configure and build LLVM/MLIR: 39 40 $ build_tools/build_mlir.sh ${PWD}/llvm-project/ ${PWD}/llvm-build 41 42Again this is something to do every time you pull from this repository and the 43LLVM revision changes. 44 45Finally you can build and test this repository: 46 47 $ mkdir build && cd build 48 $ cmake .. -GNinja \ 49 -DLLVM_ENABLE_LLD=ON \ 50 -DCMAKE_BUILD_TYPE=Release \ 51 -DLLVM_ENABLE_ASSERTIONS=On \ 52 -DMLIR_DIR=${PWD}/../llvm-build/lib/cmake/mlir 53 $ ninja check-mlir-hlo 54 55 56## Overview 57 58MLIR-HLO aims to provide an end-to-end compiler for CPU and GPU, as well as 59building reusable blocks for other accelerators. This is heavily inspired by the 60success of XLA. 61 62[XLA](https://www.tensorflow.org/xla/) (Accelerated Linear Algebra) is a 63domain-specific compiler framework and execution environment for linear algebra, 64which powers code-generation for ML frameworks like TensorFlow, JAX, and others. 65 66A cornerstone of XLA is the HLO (High Level Optimizer) IR, which offers a 67carefully fixed selected list of operations, mostly orthogonal to each other. It 68provides an efficient optimizer for computations expressed with this set of 69operations and generate codes for hardware platforms like CPU, GPU, and TPUs. 70Its goal is to provide a uniform interface to compile and execute these 71optimized HLO programs independently of the targeted device. It is not a 72front-end ML system like TensorFlow or JAX, rather it is a backend framework 73that optimizes HLO and lowers to machine code. 74 75The HLO set of operations is closed and has well defined semantics. HLO 76operations operate on immutable Tensors with static shapes (actually bounded 77shapes to be exact) and explicit broadcasts. 78 79[MLIR](https://mlir.llvm.org/) is a compiler infrastructure which intends to 80come with "battery included", as such it intends to provide all the blocks 81required to assemble graph optimization and codegen pipelines. The longer term 82roadmap for MLIR is to provide a 83[Tensor Compute Primitive](https://llvm.discourse.group/c/mlir/MLIR-TCP-WG/36) 84(TCP) dialect, which should hopefully be general enough to model what HLO 85represents today (see 86[slides](https://drive.google.com/open?id=1iljcpTQ5NPaMfGpoPDFml1XkYxjK_6A4) and 87[recording](https://drive.google.com/open?id=1jSPa8TwPKUt0WuLquGc8OgSUVYJHMvWZ) 88for a technical discussion on this topic). 89 90The work on MLIR-HLO can be seen as a stepping stone towards building TCP, while 91integrating intermediate components into XLA itself by relying on the 92well-proven HLO IR and introducing more pieces from upstream MLIR 93([Linalg](https://mlir.llvm.org/docs/Dialects/Linalg/), 94[Vector](https://mlir.llvm.org/docs/Dialects/Vector/), 95[GPU](https://mlir.llvm.org/docs/Dialects/GPU/) dialect, ...). 96[This document](https://www.tensorflow.org/mlir/xla_gpu_codegen) provides more 97information on the current migration of the XLA GPU codegen. 98 99## MLIR Dialects for XLA-style compilation 100 101This repository defines three dialects to support a HLO-like compilation 102pipeline using MLIR: 103 104* `chlo`: the "client" HLO dialect, intended to be closer to the frontend 105 (including implicit broadcast semantics). 106* `mhlo`: "meta"-HLO dialect ; similar to `xla_hlo`, but with extensions for 107 dynamic shape support. 108* `lmhlo`: "late"-"meta"-HLO, it is the IR after buffer allocation is 109 performed. In XLA the buffer allocation is a side-data structure which keeps 110 track of these informations, while this separate dialect materializes it in 111 the IR. 112 113We describe these in more details below. 114 115### HLO Client Dialect: `chlo`. 116 117* It was originally designed to map the 118 [XLA client APIs](https://www.tensorflow.org/xla/operation_semantics) (e.g., 119 ops supports implicit broadcast and roughly modeled on XlaBuilder API) 120 modulo support for dynamic shapes and additional ops required to support 121 dynamic client side HLOs. 122* Ops can be from either the XlaBuilder or XLA helper functions can be 123 converted into ops (e.g., given ambiguity in what constitutes these ops, 124 there is some freedom to decide), the goal of this dialect is to correspond 125 close to client level and enable a thin layer between client use and op 126 construction (making it cheap to construct and optimizations on the dialect 127 close to optimizations on the client ops). 128 129Entry: 130 131* The vast majority of old "client" interactions are via the XlaBuilder APIs. 132 These APIs are used by TF2XLA kernels, JAX, PyTorch bridge and directly. The 133 legalization path (described below) can also reuse the XlaBuilder's APIs to 134 construct XLA Client HLO ops directly (this uses MlirXlaBuilder which is a 135 subclass of XlaBuilder). 136* The other entry point is during legalization from TensorFlow ops in the TF 137 Graph Compiler and other tools (e.g., SavedModel lowering and TFCompile). 138 139Exit: 140 141* MHLO 142* May be exported to xla::HloInstructionProto by invoking the XlaBuilder APIs 143 (with regular XlaBuilder) 144 145The `chlo` dialect started originally as mapping to the XLA client Builder APIs. 146It enables it to both be constructed and converted back to existing XLA 147interfaces using the XlaBuilder API. Due to the way that translation into and 148out of the dialect works, there is no expectation that this dialect roundtrips 149to XLA (e.g., it is only intended to be translated to MLIR and then legalized to 150another dialect or translated to HloInstructionProto). 151 152The export approach of reusing the XlaBuilders enables reusing a lot of logic 153that was already implemented in terms of computing shapes, inserting broadcasts 154etc. 155 156An important topic here is that XLA Client HLO ops are not a well defined set. 157And in particular what some would consider helper functions, others would 158consider ops. It should be easy to move between these and so define a new op 159along with the helper function or autogenerate the helper functions from the 160descriptions of the ops. For the former, a simple approach would be to simply 161consider the context in which the op is being constructed and if an MLIR one, 162construct a op in the client dialect instead of further calls into XlaBuilder. 163The latter could be implemented by adding the op and a legalization of the op to 164other known ops, from which a helper function can get generated that could be 165used as regular. 166 167Status: Exists but need to be cleaned up. 168 169### Meta HLO Dialect `mhlo` 170 171* Dialect is closer to current HLO server ops (e.g., no implicit broadcast) 172* MHLO dialect where we can deviate from the requirements of the client or 173 server dialect, in particular: 174 * Control flow ops with implicit capture to enable simpler optimizations 175 (e.g., generic LICM, unroll & jam, etc.) 176 * Multiple results ops (e.g., no tuples) 177 * More ops (for example, unique op or assert op), and ops that don't need 178 to be added to either client or server dialect. 179 * Op set not constrained by implementation (e.g., hlo.add operating on say 180 i79 or !mydialect.weird_type is allowed even though no XLA backend 181 supports it). Verification on types happening at the boundaries. 182 * It does not need to preserve some deprecated XLA constructs (e.g. 183 stateful RNG HLO). 184 * More dynamic shape support ops without need for updating all 185 users/backends. 186* This dialect enables evolving HLO independently from XLA in order to 187 experiment with features we'd like to upstream in MLIR TCP. In particular it 188 intends to be user-extensible through 189 [interfaces](https://mlir.llvm.org/docs/Interfaces/). 190* It should have no TensorFlow, or proto, or other Google internal 191 dependencies. 192* It need not be a complete superset of ops compared to XLA HLO dialect. 193 194Entry: 195 196* Legalization from `chlo` dialect or conversion from XLA HLO. 197* Directly emitted from TF Graph Compiler; 198* Builder call (e.g., EDSL); 199 200Exit: 201 202* LMHLO, Linalg IREE, directly used in codegen. 203* XLA HLO. 204 205The MHLO dialect has no direct export format, it is only meant as an 206intermediate optimization dialect/format. It is also where we can experiment 207cheaply with new ops. This format will be where the representation would differ 208from existing endpoints. 209 210Status: Exists but need to be cleaned up and evolved, in particular with respect 211to supporting dynamic shapes. 212 213MHLO differs from XLA HLO op set in multiple ways, including: 2141. MHLO While accepts multiple operands and may produce multiple results 215 instead; 216 217### LMHLO 218 219LMHLO corresponds to late `mhlo` and operates on buffer domain (e.g., memref) 220with side-effecting operations. The lowering from `mhlo` dialect proceeds by way 221of scheduling, memory and buffer allocation. The current mapping is directly on 222XLA Client HLOs but without implicit broadcast and with operation on memrefs. 223This dialect will instead be rebased on `mhlo` dialect but operating on buffers 224still. 225 226Entry: 227 228* Post buffer assignment on `mhlo` dialect, or from XLA after buffer 229 assignment. 230 231Exit: 232 233* Codegen (LLVM IR in the common cases at the moment) 234 235## End-to-End pipeline 236 237TODO 238 239## Alternative build setups 240 241### Building Python API 242 243Building the MHLO Python API requires building as an LLVM external project. 244The below instructions presume that you have this `mlir-hlo` repo and an 245`llvm-project` repo checked out side by side. 246 247Note that the python package produced by this procedure includes the `mlir` 248package and is not suitable for deployment as-is (but it can be included into 249a larger aggregate). 250 251``` 252mkdir build && cd build 253cmake -GNinja -B. ${LLVM_SRC_DIR}/llvm \ 254 -DCMAKE_BUILD_TYPE=Release \ 255 -DLLVM_ENABLE_PROJECTS=mlir \ 256 -DLLVM_EXTERNAL_PROJECTS=mlir_hlo \ 257 -DLLVM_EXTERNAL_MLIR_HLO_SOURCE_DIR=${MLIR_HLO_SRC_DIR} \ 258 -DLLVM_TARGETS_TO_BUILD=host \ 259 -DPython3_EXECUTABLE=$(which python) \ 260 -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ 261 -DMHLO_ENABLE_BINDINGS_PYTHON=ON 262 263ninja MLIRHLOPythonModules 264export PYTHONPATH=$PWD/tools/mlir_hlo/python_packages/mlir_hlo 265python -c "import mlir.dialects.mhlo" 266``` 267 268## External projects that depend on mlir-hlo 269 270External projects that need to depend on `mlir-hlo` (for example via a git 271submodule) can use the following setting in their cmake configuration in order 272for `find_package(MHLO)` to import all mlir-hlo cmake targets into their build 273setup and have access to the required include and lib variables (see generated 274`MHLOConfig.cmake`). 275 276``` 277... 278 -DMHLO_DIR=<path to mlir-hlo build dir>/lib/cmake/mlir-hlo 279 ... 280``` 281