1# XLA Custom Calls 2 3This document describes how to write and use XLA "custom calls". Custom calls 4let you invoke code written in a programming language like C++ or CUDA from an 5XLA program. 6 7Warning: Custom calls are a low-level power-user feature. It is easy to break 8your program in difficult-to-debug (and even difficult-to-notice) ways using 9custom-calls. You shouldn't use custom calls unless you're prepared to debug XLA 10yourself when something goes wrong, and you should expect relatively less 11assistance from XLA developers if you run into trouble. 12 13Warning: The custom-call API/ABI is not currently stable. We don't intend to 14change it capriciously, but it may change. Some possible future changes are 15described below. 16 17## Custom-call on CPU 18 19You can create an HLO instruction which represents a custom-call via XLA's 20client API. This is not exposed via TensorFlow as of writing. 21 22For example, the following code uses a custom-call to compute `A[i] = B[i % 23128]+ C[i]` on the CPU. (Of course you could -- and should! -- do this with 24regular HLO.) 25 26```c++ 27#include "tensorflow/compiler/xla/client/xla_builder.h" 28#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" 29 30void do_it() { 31 xla::XlaBuilder b("do_it"); 32 xla::XlaOp param0 = 33 xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0"); 34 xla::XlaOp param1 = 35 xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1"); 36 xla::XlaOp custom_call = 37 xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1}, 38 /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048})); 39} 40 41void do_custom_call(void* out, const void** in) { 42 float* out_buf = reinterpret_cast<float*>(out); 43 const float* in0 = reinterpret_cast<const float*>(in[0]); 44 const float* in1 = reinterpret_cast<const float*>(in[1]); 45 for (int i = 0; i < 2048; ++i) { 46 out_buf[i] = in0[i % 128] + in1[i]; 47 } 48} 49XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host"); 50``` 51 52Notice that the function `do_custom_call` needs to know the dimensions of the 53buffers it operates over. In this example we hardcode the sizes 128 and 2048. If 54you don't want to do this, you can pass the dimensions in as parameters to the 55call. 56 57## Custom-call on GPU 58 59The GPU custom call framework is somewhat different than that on the CPU. Here 60is a CUDA example that does the same `A[i] = B[i % 128] + C[i]` computation as 61the CPU code above. 62 63```c++ 64void do_it() { /* same implementation as above */ } 65 66__global__ custom_call_kernel(const float* in0, const float* in1, float* out) { 67 size_t idx = blockIdx.x * blockDim.x + threadIdx.x; 68 out[idx] = in0[idx % 128] + in1[idx]; 69} 70 71void do_custom_call(CUstream stream, void** buffers, 72 const char* opaque, size_t opaque_len) { 73 const float* in0 = reinterpret_cast<const float*>(buffers[0]); 74 const float* in1 = reinterpret_cast<const float*>(buffers[1]); 75 float* out = reinterpret_cast<float*>(buffers[2]); 76 77 const int64_t block_dim = 64; 78 const int64_t grid_dim = 2048 / block_dim; 79 custom_call_kernel<<<grid_dim, block_dim, 80 /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out); 81} 82XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA"); 83``` 84 85Notice first that the GPU custom call function *is still a function executed on 86the CPU*. Our `do_custom_call` CPU function is responsible for enqueueing work 87on the GPU. Here it launches a CUDA kernel, but it could also do something else, 88like call cublas. 89 90`buffers` is an array of pointers which lives on the host, and each element it 91contains points to device (i.e. GPU) memory. The parameters come first, followed 92by the output value. This is notably different from the CPU calling convention, 93which has two params, `ins` and `out`. The main reason we diverge is to make it 94possible to handle tuple-shaped inputs/outputs efficiently; see the section 95below. 96 97As in the CPU example, we've hardcoded the input and output buffer sizes into 98our custom call. However unlike in the CPU case, passing the buffer sizes in as 99operands to the custom call would not work well. Usually we need the buffer 100sizes available to us on the CPU; e.g. when launching a kernel, we need to know 101the block/grid dimensions to use. But if we were to pass the buffer sizes as 102operands to our custom call, their values would live in GPU memory. We'd then 103have to do an expensive synchronous device-to-host memcpy at the start of our 104operation just to read the sizes. 105 106To let you work around this, we provide the `opaque` parameter. You can set this 107to an arbitrary string of bytes when you create the custom call: 108 109```c++ 110std::string opaque = "..."; 111xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1}, 112 /*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}), 113 opaque); 114``` 115 116Since `xla::Shape` has a protocol buffer representation, you could store this 117serialized proto inside of `opaque` and deserialize it within your GPU 118custom-call. Note however that although `xla::ShapeProto` does not change 119frequently, it *does* change. Check the git log to see how it has changed in the 120past. 121 122## Signalling an error. 123 124If your custom call encounters an error, you can signal the error to the XLA 125runtime (instead of e.g. crashing or returning nonsense in the output buffers) 126by using the following signature for your function on CPU: 127 128```c++ 129#include "tensorflow/compiler/xla/service/custom_call_status.h" 130 131void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status); 132``` 133 134... and on GPU: 135 136```c++ 137#include "tensorflow/compiler/xla/service/custom_call_status.h" 138 139void do_custom_call(CUstream stream, void** buffers, const char* opaque, 140 size_t opaque_len, xla::XlaCustomCallStatus* status); 141``` 142 143You can signal failure by using `XlaCustomCallStatusSetFailure`, e.g.: 144 145```c++ 146void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) { 147 // ... do some work. 148 149 if (bad_condition) { 150 char* error_message = "An error occurred"; 151 XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message)); 152 return; 153 } 154 155 // ... continue. 156} 157``` 158 159You can also use `XlaCustomCallStatusSetSuccess` to indicate success, but the 160`XlaCustomCallStatus` is in a success state by default, so ignoring it 161completely will also indicate success. 162 163When using custom call functions with this signature, you must create the 164corresponding `custom-call` op with the appropriate API version set, e.g.: 165 166```c++ 167xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1}, 168 /*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}), 169 opaque, /*has_side_effect=*/false, 170 /*output_operand_aliasing=*/{}, /*literal=*/nullptr, 171 /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE, 172 /*api_version=*/API_VERSION_STATUS_RETURNING); 173``` 174 175NOTE: In the future all clients will be required to migrate their custom call 176functions to the new API version and the old one will be deprecated. For custom 177calls that can't fail, you can simply add the new `XlaCustomCallStatus*` 178parameter and then ignore it. 179 180On failure, none of the custom call outputs will be used; the XLA runtime will 181terminate the computation. It is not possible for an HLO computation to recover 182from the error (e.g. by catching and handling it). 183 184## Passing tuples to custom-calls 185 186Consider the following custom-call. 187 188```c++ 189using xla::ShapeUtil; 190using xla::F32; 191Shape p0_shape = ShapeUtil::MakeTuple({ 192 ShapeUtil::MakeShape(F32, {32}), 193 ShapeUtil::MakeTuple({ 194 ShapeUtil::MakeShape(F32, {64}), 195 ShapeUtil::MakeShape(F32, {128}), 196 }), 197 ShapeUtil::MakeShape(F32, {256}), 198}); 199xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0"); 200 201Shape out_shape = ShapeUtil::MakeTuple({ 202 ShapeUtil::MakeShape(F32, {512}), 203 ShapeUtil::MakeShape(F32, {1024}), 204}); 205xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape); 206``` 207 208On both CPU and GPU, a tuple is represented in memory as an array of pointers. 209In C++-pseudocode, parameter 0 above is laid out as follows. 210 211```c++ 212// In-memory layout of parameter 0 from custom-call above. True on both CPU 213// and GPU. 214float* subbuf0 = new float[32]; 215float* subbuf1 = new float[64]; 216float* subbuf2 = new float[128] 217float* subbuf3 = new float[256]; 218 219void* subtuple = new void*[2]; 220(*subtuple)[0] = subbuf1; 221(*subtuple)[1] = subbuf2; 222 223void* p0 = new void*[3]; 224(*p0)[0] = subbuf0; 225(*p0)[1] = subtuple; 226(*p0)[2] = subbuf3; 227``` 228 229Although the in-memory representation of tuples is the same in CPU and GPU, they 230are handled differently in the CPU and GPU custom-call calling conventions. 231 232### Tuple outputs as temp buffers 233 234Tuple inputs to custom-calls are a convenience, but they aren't strictly 235necessary. If we didn't support tuple inputs to custom calls, you could always 236unpack the tuples using get-tuple-element before passing them to the custom 237call. 238 239On the other hand, tuple *outputs* do let you do things you couldn't otherwise. 240 241The obvious reason to have tuple outputs is, that's how a custom call (or any 242other XLA op) returns multiple independent arrays. 243 244But less obviously, a tuple output is also a way to give your custom call temp 245memory. Yes, an *output* can represent a temp buffer. Consider, an output buffer 246has the property that the op can write to it, and it can read from it after it's 247been written to. That's exactly what you want from a temp buffer. 248 249In the example above, suppose we wanted to use the `F32[1024]` as a temp buffer. 250Then we'd write the HLO just as above, and we'd simply never read tuple index 1 251of the custom call's output. 252 253### Tuples in CPU custom-calls 254 255In CPU code, we have a function `do_custom_call(const void** ins, void* out)`. 256`ins` is an array with just one element, which points to `param0`. The 257subbuffers of `param0` are accessible by dereferencing that pointer, and the 258subbuffers of `output_tuple` are accessible by dereferencing `out`. 259 260### Tuples in GPU custom-calls 261 262In GPU code, we have a function `do_custom_call(..., void** buffers, ...)`. In 263this case `buffers` is a host array of *six* device pointers, one for each leaf 264buffer in the input/output. To generate the flat list, we iterate over the 265parameters and output, and for each we do a preorder traversal of its shape. 266Concretely: 267 268```c++ 269// Layout of `buffers` parameter to GPU custom call function for custom-call 270// above. 271buffers[0] == subbuf0 272buffers[1] == subbuf1 273buffers[2] == subbuf2 274buffers[3] == subbuf3 275buffers[4] == output_subbuf0 276buffers[5] == output_subbuf1 277``` 278