• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# XLA Custom Calls
2
3This document describes how to write and use XLA "custom calls". Custom calls
4let you invoke code written in a programming language like C++ or CUDA from an
5XLA program.
6
7Warning: Custom calls are a low-level power-user feature. It is easy to break
8your program in difficult-to-debug (and even difficult-to-notice) ways using
9custom-calls. You shouldn't use custom calls unless you're prepared to debug XLA
10yourself when something goes wrong, and you should expect relatively less
11assistance from XLA developers if you run into trouble.
12
13Warning: The custom-call API/ABI is not currently stable. We don't intend to
14change it capriciously, but it may change. Some possible future changes are
15described below.
16
17## Custom-call on CPU
18
19You can create an HLO instruction which represents a custom-call via XLA's
20client API. This is not exposed via TensorFlow as of writing.
21
22For example, the following code uses a custom-call to compute
23`A[i] = B[i % 128] + C[i]` on the CPU. (Of course you could -- and should! -- do
24this with regular HLO.)
25
26```c++
27#include "tensorflow/compiler/xla/client/xla_builder.h"
28#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
29
30void do_it() {
31  xla::XlaBuilder b("do_it");
32  xla::XlaOp param0 =
33      xla::Parameter(0, xla::ShapeUtil::CreateShape(F32, {128}), "p0");
34  xla::XlaOp param1 =
35      xla::Parameter(1, xla::ShapeUtil::CreateShape(F32, {2048}), "p1");
36  xla::XlaOp custom_call =
37      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
38                      /*output_shape=*/ShapeUtil::CreateShape(F32, {2048}));
39}
40
41void do_custom_call(void* out, const void** in) {
42  float* out_buf = reinterpret_cast<float*>(out);
43  const float* in0 = reinterpret_cast<const float*>(in[0]);
44  const float* in1 = reinterpret_cast<const float*>(in[1]);
45  for (int i = 0; i < 2048; ++i) {
46    out_buf[i] = in0[i % 128] + in1[i];
47  }
48}
49XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");
50```
51
52Notice that the function `do_custom_call` needs to know the dimensions of the
53buffers it operates over. In this example we hardcode the sizes 128 and 2048. If
54you don't want to do this, you can pass the dimensions in as parameters to the
55call.
56
57## Custom-call on GPU
58
59The GPU custom call framework is somewhat different than that on the CPU. Here
60is a CUDA example that does the same `A[i] = B[i % 128] + C[i]` computation as
61the CPU code above.
62
63```c++
64void do_it() { /* same implementation as above */ }
65
66__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
67  size_t idx = threadIdx.x * blockSize.x + gridIdx.x;
68  out[idx] = in0[idx % 128] + in1[idx];
69}
70
71void do_custom_call(CUstream stream, void** buffers,
72                    const char* opaque, size_t opaque_len) {
73  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
74  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
75  float* out = reinterpret_cast<float*>(buffers[2]);
76
77  const int64 block_dim = 64;
78  const int64 grid_dim = 2048 / block_dim;
79  custom_call_kernel<<<grid_dim, block_dim,
80                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
81}
82XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");
83```
84
85Notice first that the GPU custom call function *is still a function executed on
86the CPU*. Our `do_custom_call` CPU function is responsible for enqueueing work
87on the GPU. Here it launches a CUDA kernel, but it could also do something else,
88like call cublas.
89
90`buffers` is an array of pointers which lives on the host, and each element it
91contains points to device (i.e. GPU) memory. The parameters come first, followed
92by the output value. This is notably different from the CPU calling convention,
93which has two params, `ins` and `out`. The main reason we diverge is to make it
94possible to handle tuple-shaped inputs/outputs efficiently; see the section
95below.
96
97As in the CPU example, we've hardcoded the input and output buffer sizes into
98our custom call. However unlike in the CPU case, passing the buffer sizes in as
99operands to the custom call would not work well. Usually we need the buffer
100sizes available to us on the CPU; e.g. when launching a kernel, we need to know
101the block/grid dimensions to use. But if we were to pass the buffer sizes as
102operands to our custom call, their values would live in GPU memory. We'd then
103have to do an expensive synchronous device-to-host memcpy at the start of our
104operation just to read the sizes.
105
106To let you work around this, we provide the `opaque` parameter. You can set this
107to an arbitrary string of bytes when you create the custom call:
108
109```c++
110std::string opaque = "...";
111xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
112                /*output_shape=*/ShapeUtil::CreateShape(F32, {2048}),
113                opaque);
114```
115
116Since `xla::Shape` has a protocol buffer representation, you could store this
117serialized proto inside of `opaque` and deserialize it within your GPU
118custom-call. Note however that although `xla::ShapeProto` does not change
119frequently, it *does* change. Check the git log to see how it has changed in the
120past.
121
122## Passing tuples to custom-calls
123
124Consider the following custom-call.
125
126```c++
127using xla::ShapeUtil;
128Shape p0_shape = ShapeUtil::MakeTuple({
129    ShapeUtil::MakeShape(F32, {32}),
130    ShapeUtil::MakeTuple({
131        ShapeUtil::MakeShape(F32, {64}),
132        ShapeUtil::MakeShape(F32, {128}),
133    }),
134    ShapeUtil::MakeShape(F32, {256}),
135});
136xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
137
138Shape out_shape = ShapeUtil::MakeTuple({
139  ShapeUtil::MakeShape(F32, {512}),
140  ShapeUtil::MakeShape(F32, {1024}),
141});
142xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);
143```
144
145On both CPU and GPU, a tuple is represented in memory as an array of pointers.
146In C++-pseudocode, parameter 0 above is laid out as follows.
147
148```c++
149// In-memory layout of parameter 0 from custom-call above.  True on both CPU
150// and GPU.
151float* subbuf0 = new float[32];
152float* subbuf1 = new float[64];
153float* subbuf2 = new float[128]
154float* subbuf3 = new float[256];
155
156void* subtuple = new void*[2];
157(*subtuple)[0] = subbuf1;
158(*subtuple)[1] = subbuf2;
159
160void* p0 = new void*[3];
161(*p0)[0] = subbuf0;
162(*p0)[1] = subtuple;
163(*p0)[2] = subbuf3;
164```
165
166Although the in-memory representation of tuples is the same in CPU and GPU, they
167are handled differently in the CPU and GPU custom-call calling conventions.
168
169### Tuple outputs as temp buffers
170
171Tuple inputs to custom-calls are a convenience, but they aren't strictly
172necessary. If we didn't support tuple inputs to custom calls, you could always
173unpack the tuples using get-tuple-element before passing them to the custom
174call.
175
176On the other hand, tuple *outputs* do let you do things you couldn't otherwise.
177
178The obvious reason to have tuple outputs is, that's how a custom call (or any
179other XLA op) returns multiple independent arrays.
180
181But less obviously, a tuple output is also a way to give your custom call temp
182memory. Yes, an *output* can represent a temp buffer. Consider, an output buffer
183has the property that the op can write to it, and it can read from it after it's
184been written to. That's exactly what you want from a temp buffer.
185
186In the example above, suppose we wanted to use the `F32[1024]` as a temp buffer.
187Then we'd write the HLO just as above, and we'd simply never read tuple index 1
188of the custom call's output.
189
190### Tuples in CPU custom-calls
191
192In CPU code, we have a function `do_custom_call(const void** ins, void* out)`.
193`ins` is an array with just one element, which points to `param0`. The
194subbuffers of `param0` are accessible by dereferencing that pointer, and the
195subbuffers of `output_tuple` are accessible by dereferencing `out`.
196
197### Tuples in GPU custom-calls
198
199In GPU code, we have a function `do_custom_call(..., void** buffers, ...)`. In
200this case `buffers` is a host array of *six* device pointers, one for each leaf
201buffer in the input/output. To generate the flat list, we iterate over the
202parameters and output, and for each we do a preorder traversal of its shape.
203Concretely:
204
205```c++
206// Layout of `buffers` parameter to GPU custom call function for custom-call
207// above.
208buffers[0] == subbuf0
209buffers[1] == subbuf1
210buffers[2] == subbuf2
211buffers[3] == subbuf3
212buffers[4] == output_subbuf0
213buffers[5] == output_subbuf1
214```
215