• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# XLA Custom Calls
2
3This document describes how to write and use XLA "custom calls". Custom calls
4let you invoke code written in a programming language like C++ or CUDA from an
5XLA program.
6
7Warning: Custom calls are a low-level power-user feature. It is easy to break
8your program in difficult-to-debug (and even difficult-to-notice) ways using
9custom-calls. You shouldn't use custom calls unless you're prepared to debug XLA
10yourself when something goes wrong, and you should expect relatively less
11assistance from XLA developers if you run into trouble.
12
13Warning: The custom-call API/ABI is not currently stable. We don't intend to
14change it capriciously, but it may change. Some possible future changes are
15described below.
16
17## Custom-call on CPU
18
19You can create an HLO instruction which represents a custom-call via XLA's
20client API. This is not exposed via TensorFlow as of writing.
21
22For example, the following code uses a custom-call to compute `A[i] = B[i %
23128]+ C[i]` on the CPU. (Of course you could -- and should! -- do this with
24regular HLO.)
25
26```c++
27#include "tensorflow/compiler/xla/client/xla_builder.h"
28#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
29
30void do_it() {
31  xla::XlaBuilder b("do_it");
32  xla::XlaOp param0 =
33      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
34  xla::XlaOp param1 =
35      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
36  xla::XlaOp custom_call =
37      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
38                      /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
39}
40
41void do_custom_call(void* out, const void** in) {
42  float* out_buf = reinterpret_cast<float*>(out);
43  const float* in0 = reinterpret_cast<const float*>(in[0]);
44  const float* in1 = reinterpret_cast<const float*>(in[1]);
45  for (int i = 0; i < 2048; ++i) {
46    out_buf[i] = in0[i % 128] + in1[i];
47  }
48}
49XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");
50```
51
52Notice that the function `do_custom_call` needs to know the dimensions of the
53buffers it operates over. In this example we hardcode the sizes 128 and 2048. If
54you don't want to do this, you can pass the dimensions in as parameters to the
55call.
56
57## Custom-call on GPU
58
59The GPU custom call framework is somewhat different than that on the CPU. Here
60is a CUDA example that does the same `A[i] = B[i % 128] + C[i]` computation as
61the CPU code above.
62
63```c++
64void do_it() { /* same implementation as above */ }
65
66__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
67  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
68  out[idx] = in0[idx % 128] + in1[idx];
69}
70
71void do_custom_call(CUstream stream, void** buffers,
72                    const char* opaque, size_t opaque_len) {
73  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
74  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
75  float* out = reinterpret_cast<float*>(buffers[2]);
76
77  const int64_t block_dim = 64;
78  const int64_t grid_dim = 2048 / block_dim;
79  custom_call_kernel<<<grid_dim, block_dim,
80                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
81}
82XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");
83```
84
85Notice first that the GPU custom call function *is still a function executed on
86the CPU*. Our `do_custom_call` CPU function is responsible for enqueueing work
87on the GPU. Here it launches a CUDA kernel, but it could also do something else,
88like call cublas.
89
90`buffers` is an array of pointers which lives on the host, and each element it
91contains points to device (i.e. GPU) memory. The parameters come first, followed
92by the output value. This is notably different from the CPU calling convention,
93which has two params, `ins` and `out`. The main reason we diverge is to make it
94possible to handle tuple-shaped inputs/outputs efficiently; see the section
95below.
96
97As in the CPU example, we've hardcoded the input and output buffer sizes into
98our custom call. However unlike in the CPU case, passing the buffer sizes in as
99operands to the custom call would not work well. Usually we need the buffer
100sizes available to us on the CPU; e.g. when launching a kernel, we need to know
101the block/grid dimensions to use. But if we were to pass the buffer sizes as
102operands to our custom call, their values would live in GPU memory. We'd then
103have to do an expensive synchronous device-to-host memcpy at the start of our
104operation just to read the sizes.
105
106To let you work around this, we provide the `opaque` parameter. You can set this
107to an arbitrary string of bytes when you create the custom call:
108
109```c++
110std::string opaque = "...";
111xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
112                /*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
113                opaque);
114```
115
116Since `xla::Shape` has a protocol buffer representation, you could store this
117serialized proto inside of `opaque` and deserialize it within your GPU
118custom-call. Note however that although `xla::ShapeProto` does not change
119frequently, it *does* change. Check the git log to see how it has changed in the
120past.
121
122## Signalling an error.
123
124If your custom call encounters an error, you can signal the error to the XLA
125runtime (instead of e.g. crashing or returning nonsense in the output buffers)
126by using the following signature for your function on CPU:
127
128```c++
129#include "tensorflow/compiler/xla/service/custom_call_status.h"
130
131void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);
132```
133
134... and on GPU:
135
136```c++
137#include "tensorflow/compiler/xla/service/custom_call_status.h"
138
139void do_custom_call(CUstream stream, void** buffers, const char* opaque,
140                    size_t opaque_len, xla::XlaCustomCallStatus* status);
141```
142
143You can signal failure by using `XlaCustomCallStatusSetFailure`, e.g.:
144
145```c++
146void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
147  // ... do some work.
148
149  if (bad_condition) {
150    char* error_message = "An error occurred";
151    XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
152    return;
153  }
154
155  // ... continue.
156}
157```
158
159You can also use `XlaCustomCallStatusSetSuccess` to indicate success, but the
160`XlaCustomCallStatus` is in a success state by default, so ignoring it
161completely will also indicate success.
162
163When using custom call functions with this signature, you must create the
164corresponding `custom-call` op with the appropriate API version set, e.g.:
165
166```c++
167xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
168                /*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
169                opaque, /*has_side_effect=*/false,
170                /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
171                /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
172                /*api_version=*/API_VERSION_STATUS_RETURNING);
173```
174
175NOTE: In the future all clients will be required to migrate their custom call
176functions to the new API version and the old one will be deprecated. For custom
177calls that can't fail, you can simply add the new `XlaCustomCallStatus*`
178parameter and then ignore it.
179
180On failure, none of the custom call outputs will be used; the XLA runtime will
181terminate the computation. It is not possible for an HLO computation to recover
182from the error (e.g. by catching and handling it).
183
184## Passing tuples to custom-calls
185
186Consider the following custom-call.
187
188```c++
189using xla::ShapeUtil;
190using xla::F32;
191Shape p0_shape = ShapeUtil::MakeTuple({
192    ShapeUtil::MakeShape(F32, {32}),
193    ShapeUtil::MakeTuple({
194        ShapeUtil::MakeShape(F32, {64}),
195        ShapeUtil::MakeShape(F32, {128}),
196    }),
197    ShapeUtil::MakeShape(F32, {256}),
198});
199xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
200
201Shape out_shape = ShapeUtil::MakeTuple({
202  ShapeUtil::MakeShape(F32, {512}),
203  ShapeUtil::MakeShape(F32, {1024}),
204});
205xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);
206```
207
208On both CPU and GPU, a tuple is represented in memory as an array of pointers.
209In C++-pseudocode, parameter 0 above is laid out as follows.
210
211```c++
212// In-memory layout of parameter 0 from custom-call above.  True on both CPU
213// and GPU.
214float* subbuf0 = new float[32];
215float* subbuf1 = new float[64];
216float* subbuf2 = new float[128]
217float* subbuf3 = new float[256];
218
219void* subtuple = new void*[2];
220(*subtuple)[0] = subbuf1;
221(*subtuple)[1] = subbuf2;
222
223void* p0 = new void*[3];
224(*p0)[0] = subbuf0;
225(*p0)[1] = subtuple;
226(*p0)[2] = subbuf3;
227```
228
229Although the in-memory representation of tuples is the same in CPU and GPU, they
230are handled differently in the CPU and GPU custom-call calling conventions.
231
232### Tuple outputs as temp buffers
233
234Tuple inputs to custom-calls are a convenience, but they aren't strictly
235necessary. If we didn't support tuple inputs to custom calls, you could always
236unpack the tuples using get-tuple-element before passing them to the custom
237call.
238
239On the other hand, tuple *outputs* do let you do things you couldn't otherwise.
240
241The obvious reason to have tuple outputs is, that's how a custom call (or any
242other XLA op) returns multiple independent arrays.
243
244But less obviously, a tuple output is also a way to give your custom call temp
245memory. Yes, an *output* can represent a temp buffer. Consider, an output buffer
246has the property that the op can write to it, and it can read from it after it's
247been written to. That's exactly what you want from a temp buffer.
248
249In the example above, suppose we wanted to use the `F32[1024]` as a temp buffer.
250Then we'd write the HLO just as above, and we'd simply never read tuple index 1
251of the custom call's output.
252
253### Tuples in CPU custom-calls
254
255In CPU code, we have a function `do_custom_call(const void** ins, void* out)`.
256`ins` is an array with just one element, which points to `param0`. The
257subbuffers of `param0` are accessible by dereferencing that pointer, and the
258subbuffers of `output_tuple` are accessible by dereferencing `out`.
259
260### Tuples in GPU custom-calls
261
262In GPU code, we have a function `do_custom_call(..., void** buffers, ...)`. In
263this case `buffers` is a host array of *six* device pointers, one for each leaf
264buffer in the input/output. To generate the flat list, we iterate over the
265parameters and output, and for each we do a preorder traversal of its shape.
266Concretely:
267
268```c++
269// Layout of `buffers` parameter to GPU custom call function for custom-call
270// above.
271buffers[0] == subbuf0
272buffers[1] == subbuf1
273buffers[2] == subbuf2
274buffers[3] == subbuf3
275buffers[4] == output_subbuf0
276buffers[5] == output_subbuf1
277```
278