• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# cython: language_level=2
2# distutils: language = c++
3
4# Test case for defining a XLA custom call target in Cython, and registering
5# it via the xla_client SWIG API.
6
7from cpython.pycapsule cimport PyCapsule_New
8
9cdef void test_subtract_f32(void* out_ptr, void** data_ptr) nogil:
10  cdef float a = (<float*>(data_ptr[0]))[0]
11  cdef float b = (<float*>(data_ptr[1]))[0]
12  cdef float* out = <float*>(out_ptr)
13  out[0] = a - b
14
15
16cpu_custom_call_targets = {}
17
18cdef register_custom_call_target(fn_name, void* fn):
19  cdef const char* name = "xla._CUSTOM_CALL_TARGET"
20  cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL)
21
22register_custom_call_target(b"test_subtract_f32", <void*>(test_subtract_f32))
23