1# XLA Custom Calls 2 3This document describes how to write and use XLA "custom calls". Custom calls 4let you invoke code written in a programming language like C++ or CUDA from an 5XLA program. 6 7Warning: Custom calls are a low-level power-user feature. It is easy to break 8your program in difficult-to-debug (and even difficult-to-notice) ways using 9custom-calls. You shouldn't use custom calls unless you're prepared to debug XLA 10yourself when something goes wrong, and you should expect relatively less 11assistance from XLA developers if you run into trouble. 12 13Warning: The custom-call API/ABI is not currently stable. We don't intend to 14change it capriciously, but it may change. Some possible future changes are 15described below. 16 17## Custom-call on CPU 18 19You can create an HLO instruction which represents a custom-call via XLA's 20client API. This is not exposed via TensorFlow as of writing. 21 22For example, the following code uses a custom-call to compute 23`A[i] = B[i % 128] + C[i]` on the CPU. (Of course you could -- and should! -- do 24this with regular HLO.) 25 26```c++ 27#include "tensorflow/compiler/xla/client/xla_builder.h" 28#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" 29 30void do_it() { 31 xla::XlaBuilder b("do_it"); 32 xla::XlaOp param0 = 33 xla::Parameter(0, xla::ShapeUtil::CreateShape(F32, {128}), "p0"); 34 xla::XlaOp param1 = 35 xla::Parameter(1, xla::ShapeUtil::CreateShape(F32, {2048}), "p1"); 36 xla::XlaOp custom_call = 37 xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1}, 38 /*output_shape=*/ShapeUtil::CreateShape(F32, {2048})); 39} 40 41void do_custom_call(void* out, const void** in) { 42 float* out_buf = reinterpret_cast<float*>(out); 43 const float* in0 = reinterpret_cast<const float*>(in[0]); 44 const float* in1 = reinterpret_cast<const float*>(in[1]); 45 for (int i = 0; i < 2048; ++i) { 46 out_buf[i] = in0[i % 128] + in1[i]; 47 } 48} 49XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host"); 50``` 51 52Notice that the function `do_custom_call` needs to know the dimensions of the 53buffers it operates over. In this example we hardcode the sizes 128 and 2048. If 54you don't want to do this, you can pass the dimensions in as parameters to the 55call. 56 57## Custom-call on GPU 58 59The GPU custom call framework is somewhat different than that on the CPU. Here 60is a CUDA example that does the same `A[i] = B[i % 128] + C[i]` computation as 61the CPU code above. 62 63```c++ 64void do_it() { /* same implementation as above */ } 65 66__global__ custom_call_kernel(const float* in0, const float* in1, float* out) { 67 size_t idx = threadIdx.x * blockSize.x + gridIdx.x; 68 out[idx] = in0[idx % 128] + in1[idx]; 69} 70 71void do_custom_call(CUstream stream, void** buffers, 72 const char* opaque, size_t opaque_len) { 73 const float* in0 = reinterpret_cast<const float*>(buffers[0]); 74 const float* in1 = reinterpret_cast<const float*>(buffers[1]); 75 float* out = reinterpret_cast<float*>(buffers[2]); 76 77 const int64 block_dim = 64; 78 const int64 grid_dim = 2048 / block_dim; 79 custom_call_kernel<<<grid_dim, block_dim, 80 /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out); 81} 82XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA"); 83``` 84 85Notice first that the GPU custom call function *is still a function executed on 86the CPU*. Our `do_custom_call` CPU function is responsible for enqueueing work 87on the GPU. Here it launches a CUDA kernel, but it could also do something else, 88like call cublas. 89 90`buffers` is an array of pointers which lives on the host, and each element it 91contains points to device (i.e. GPU) memory. The parameters come first, followed 92by the output value. This is notably different from the CPU calling convention, 93which has two params, `ins` and `out`. The main reason we diverge is to make it 94possible to handle tuple-shaped inputs/outputs efficiently; see the section 95below. 96 97As in the CPU example, we've hardcoded the input and output buffer sizes into 98our custom call. However unlike in the CPU case, passing the buffer sizes in as 99operands to the custom call would not work well. Usually we need the buffer 100sizes available to us on the CPU; e.g. when launching a kernel, we need to know 101the block/grid dimensions to use. But if we were to pass the buffer sizes as 102operands to our custom call, their values would live in GPU memory. We'd then 103have to do an expensive synchronous device-to-host memcpy at the start of our 104operation just to read the sizes. 105 106To let you work around this, we provide the `opaque` parameter. You can set this 107to an arbitrary string of bytes when you create the custom call: 108 109```c++ 110std::string opaque = "..."; 111xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1}, 112 /*output_shape=*/ShapeUtil::CreateShape(F32, {2048}), 113 opaque); 114``` 115 116Since `xla::Shape` has a protocol buffer representation, you could store this 117serialized proto inside of `opaque` and deserialize it within your GPU 118custom-call. Note however that although `xla::ShapeProto` does not change 119frequently, it *does* change. Check the git log to see how it has changed in the 120past. 121 122## Passing tuples to custom-calls 123 124Consider the following custom-call. 125 126```c++ 127using xla::ShapeUtil; 128Shape p0_shape = ShapeUtil::MakeTuple({ 129 ShapeUtil::MakeShape(F32, {32}), 130 ShapeUtil::MakeTuple({ 131 ShapeUtil::MakeShape(F32, {64}), 132 ShapeUtil::MakeShape(F32, {128}), 133 }), 134 ShapeUtil::MakeShape(F32, {256}), 135}); 136xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0"); 137 138Shape out_shape = ShapeUtil::MakeTuple({ 139 ShapeUtil::MakeShape(F32, {512}), 140 ShapeUtil::MakeShape(F32, {1024}), 141}); 142xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape); 143``` 144 145On both CPU and GPU, a tuple is represented in memory as an array of pointers. 146In C++-pseudocode, parameter 0 above is laid out as follows. 147 148```c++ 149// In-memory layout of parameter 0 from custom-call above. True on both CPU 150// and GPU. 151float* subbuf0 = new float[32]; 152float* subbuf1 = new float[64]; 153float* subbuf2 = new float[128] 154float* subbuf3 = new float[256]; 155 156void* subtuple = new void*[2]; 157(*subtuple)[0] = subbuf1; 158(*subtuple)[1] = subbuf2; 159 160void* p0 = new void*[3]; 161(*p0)[0] = subbuf0; 162(*p0)[1] = subtuple; 163(*p0)[2] = subbuf3; 164``` 165 166Although the in-memory representation of tuples is the same in CPU and GPU, they 167are handled differently in the CPU and GPU custom-call calling conventions. 168 169### Tuple outputs as temp buffers 170 171Tuple inputs to custom-calls are a convenience, but they aren't strictly 172necessary. If we didn't support tuple inputs to custom calls, you could always 173unpack the tuples using get-tuple-element before passing them to the custom 174call. 175 176On the other hand, tuple *outputs* do let you do things you couldn't otherwise. 177 178The obvious reason to have tuple outputs is, that's how a custom call (or any 179other XLA op) returns multiple independent arrays. 180 181But less obviously, a tuple output is also a way to give your custom call temp 182memory. Yes, an *output* can represent a temp buffer. Consider, an output buffer 183has the property that the op can write to it, and it can read from it after it's 184been written to. That's exactly what you want from a temp buffer. 185 186In the example above, suppose we wanted to use the `F32[1024]` as a temp buffer. 187Then we'd write the HLO just as above, and we'd simply never read tuple index 1 188of the custom call's output. 189 190### Tuples in CPU custom-calls 191 192In CPU code, we have a function `do_custom_call(const void** ins, void* out)`. 193`ins` is an array with just one element, which points to `param0`. The 194subbuffers of `param0` are accessible by dereferencing that pointer, and the 195subbuffers of `output_tuple` are accessible by dereferencing `out`. 196 197### Tuples in GPU custom-calls 198 199In GPU code, we have a function `do_custom_call(..., void** buffers, ...)`. In 200this case `buffers` is a host array of *six* device pointers, one for each leaf 201buffer in the input/output. To generate the flat list, we iterate over the 202parameters and output, and for each we do a preorder traversal of its shape. 203Concretely: 204 205```c++ 206// Layout of `buffers` parameter to GPU custom call function for custom-call 207// above. 208buffers[0] == subbuf0 209buffers[1] == subbuf1 210buffers[2] == subbuf2 211buffers[3] == subbuf3 212buffers[4] == output_subbuf0 213buffers[5] == output_subbuf1 214``` 215