1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/kernels.h"
17
18 #include <memory>
19
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/c/tf_status_helper.h"
22 #include "tensorflow/c/tf_tensor_internal.h"
23 #include "tensorflow/core/framework/kernel_def_builder.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/platform/types.h"
28
29 // This file forms the basis of a stable ABI for third-party kernel
30 // implementations. It is crucial that changes to this file are made cautiously
31 // and with a focus on maintaining both source and binary compatibility.
32
33 struct TF_KernelBuilder {
34 ::tensorflow::KernelDefBuilder* cc_builder;
35
36 void* (*create_function)(TF_OpKernelConstruction*);
37 void (*compute_function)(void*, TF_OpKernelContext*);
38 void (*delete_function)(void*);
39 };
40
TF_NewKernelBuilder(const char * op_name,const char * device_name,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))41 TF_KernelBuilder* TF_NewKernelBuilder(
42 const char* op_name, const char* device_name,
43 void* (*create_func)(TF_OpKernelConstruction*),
44 void (*compute_func)(void*, TF_OpKernelContext*),
45 void (*delete_func)(void*)) {
46 TF_KernelBuilder* result = new TF_KernelBuilder;
47 result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name);
48 result->cc_builder->Device(device_name);
49 result->create_function = create_func;
50 result->compute_function = compute_func;
51 result->delete_function = delete_func;
52 return result;
53 }
54
TF_DeleteKernelBuilder(TF_KernelBuilder * builder)55 void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) {
56 if (builder != nullptr) {
57 delete builder->cc_builder;
58 delete builder;
59 }
60 }
61
62 namespace tensorflow {
63 namespace {
64
65 #define CASE(type) \
66 case DataTypeToEnum<type>::value: { \
67 kernel_builder->cc_builder->TypeConstraint<type>(attr_name); \
68 break; \
69 }
70
AddTypeConstraint(TF_KernelBuilder * kernel_builder,const char * attr_name,const DataType dtype,TF_Status * status)71 void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
72 const DataType dtype, TF_Status* status) {
73 // This needs to be under tensorflow:: namespace so that
74 // TF_CALL_ALL_TYPES macro can find tensorflow::string as string.
75 switch (dtype) {
76 TF_CALL_ALL_TYPES(CASE);
77 default:
78 status->status = errors::Unimplemented("Unexpected type ", dtype);
79 return;
80 }
81 TF_SetStatus(status, TF_OK, "");
82 }
83 #undef CASE
84 } // namespace
85 } // namespace tensorflow
86
TF_KernelBuilder_TypeConstraint(TF_KernelBuilder * kernel_builder,const char * attr_name,const TF_DataType type,TF_Status * status)87 void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
88 const char* attr_name,
89 const TF_DataType type,
90 TF_Status* status) {
91 tensorflow::DataType dtype = static_cast<tensorflow::DataType>(type);
92 tensorflow::AddTypeConstraint(kernel_builder, attr_name, dtype, status);
93 }
94
TF_KernelBuilder_HostMemory(TF_KernelBuilder * kernel_builder,const char * arg_name)95 void TF_KernelBuilder_HostMemory(TF_KernelBuilder* kernel_builder,
96 const char* arg_name) {
97 kernel_builder->cc_builder->HostMemory(arg_name);
98 }
99
100 namespace tensorflow {
101 namespace {
102
103 // An OpKernel whose methods delegate to C function pointers.
104 class COpKernel : public OpKernel {
105 public:
COpKernel(OpKernelConstruction * ctx,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))106 explicit COpKernel(OpKernelConstruction* ctx,
107 void* (*create_func)(TF_OpKernelConstruction*),
108 void (*compute_func)(void*, TF_OpKernelContext*),
109 void (*delete_func)(void*))
110 : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) {
111 if (create_func != nullptr) {
112 c_kernel_ =
113 (*create_func)(reinterpret_cast<TF_OpKernelConstruction*>(ctx));
114 } else {
115 c_kernel_ = nullptr;
116 }
117 }
118
Compute(OpKernelContext * ctx)119 void Compute(OpKernelContext* ctx) override {
120 (*compute_func_)(c_kernel_, reinterpret_cast<TF_OpKernelContext*>(ctx));
121 }
122
~COpKernel()123 ~COpKernel() override {
124 if (delete_func_ != nullptr) {
125 (*delete_func_)(c_kernel_);
126 }
127 }
128
129 private:
130 void (*compute_func_)(void*, TF_OpKernelContext* context);
131 void (*delete_func_)(void*);
132 void* c_kernel_;
133 };
134
135 // A KernelFactory that returns COpKernel instances.
136 class KernelBuilderFactory
137 : public ::tensorflow::kernel_factory::OpKernelFactory {
138 public:
KernelBuilderFactory(TF_KernelBuilder * builder)139 explicit KernelBuilderFactory(TF_KernelBuilder* builder)
140 : builder_(builder) {}
Create(::tensorflow::OpKernelConstruction * context)141 ::tensorflow::OpKernel* Create(
142 ::tensorflow::OpKernelConstruction* context) override {
143 return new ::tensorflow::COpKernel(context, builder_->create_function,
144 builder_->compute_function,
145 builder_->delete_function);
146 }
~KernelBuilderFactory()147 ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); }
148
149 private:
150 TF_KernelBuilder* builder_;
151 };
152 } // namespace
153 } // namespace tensorflow
154
TF_RegisterKernelBuilder(const char * name,TF_KernelBuilder * builder,TF_Status * status)155 void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
156 TF_Status* status) {
157 using tensorflow::register_kernel::Name;
158
159 tensorflow::kernel_factory::OpKernelRegistrar(
160 builder->cc_builder->Build(), name,
161 absl::make_unique<tensorflow::KernelBuilderFactory>(builder));
162
163 TF_SetStatus(status, TF_OK, "");
164 }
165
TF_NumInputs(TF_OpKernelContext * ctx)166 int TF_NumInputs(TF_OpKernelContext* ctx) {
167 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
168 return cc_ctx->num_inputs();
169 }
170
TF_NumOutputs(TF_OpKernelContext * ctx)171 int TF_NumOutputs(TF_OpKernelContext* ctx) {
172 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
173 return cc_ctx->num_outputs();
174 }
175
TF_GetInput(TF_OpKernelContext * ctx,int i,TF_Tensor ** tensor,TF_Status * status)176 void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
177 TF_Status* status) {
178 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
179 if (i < 0 || i >= cc_ctx->num_inputs()) {
180 TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
181 return;
182 }
183 const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
184 TF_Tensor* result =
185 ::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
186 if (TF_GetCode(status) == TF_OK) {
187 *tensor = result;
188 }
189 }
190
TF_SetOutput(TF_OpKernelContext * ctx,int i,const TF_Tensor * tensor,TF_Status * status)191 void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor,
192 TF_Status* status) {
193 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
194 if (i < 0 || i >= cc_ctx->num_outputs()) {
195 TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
196 return;
197 }
198 ::tensorflow::Tensor cc_tensor;
199 ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor);
200 TF_SetStatus(status, TF_OK, "");
201 ::tensorflow::Set_TF_Status_from_Status(status, s);
202 if (s.ok()) {
203 cc_ctx->set_output(i, cc_tensor);
204 }
205 }
206
TF_OpKernelConstruction_Failure(TF_OpKernelConstruction * ctx,TF_Status * status)207 void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx,
208 TF_Status* status) {
209 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
210 ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
211 cc_ctx->CtxFailure(s);
212 }
213
TF_OpKernelContext_Failure(TF_OpKernelContext * ctx,TF_Status * status)214 void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
215 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
216 ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
217 cc_ctx->CtxFailure(s);
218 }
219
220 #define DEFINE_TF_GETATTR(func, c_type, cc_type) \
221 void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \
222 const char* attr_name, \
223 c_type* val, TF_Status* status) { \
224 TF_SetStatus(status, TF_OK, ""); \
225 cc_type v; \
226 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \
227 ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); \
228 ::tensorflow::Set_TF_Status_from_Status(status, s); \
229 if (s.ok()) { \
230 *val = static_cast<c_type>(v); \
231 } \
232 }
233
DEFINE_TF_GETATTR(Type,TF_DataType,tensorflow::DataType)234 DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
235 DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t)
236
237 TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) {
238 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
239 return static_cast<TF_DataType>(cc_ctx->expected_output_dtype(i));
240 }
241
TF_StepId(TF_OpKernelContext * ctx)242 int64_t TF_StepId(TF_OpKernelContext* ctx) {
243 return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id();
244 }
245
TF_AllocateOutput(TF_OpKernelContext * context,int index,TF_DataType dtype,int64_t * dims,int num_dims,size_t len,TF_Status * status)246 TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
247 TF_DataType dtype, int64_t* dims, int num_dims,
248 size_t len, TF_Status* status) {
249 TF_SetStatus(status, TF_OK, "");
250 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
251 tensorflow::AllocatorAttributes attr = cc_ctx->output_alloc_attr(index);
252 auto* allocator = cc_ctx->get_allocator(attr);
253 void* data = tensorflow::allocate_tensor("TF_AllocateOutput", len, allocator);
254 TF_Tensor* result = TF_NewTensor(dtype, dims, num_dims, data, len,
255 tensorflow::deallocate_buffer, allocator);
256 TF_SetOutput(context, index, result, status);
257 if (TF_GetCode(status) != TF_OK) {
258 TF_DeleteTensor(result);
259 return nullptr;
260 }
261 return result;
262 }
263