• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/kernels.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/c/tf_status_helper.h"
22 #include "tensorflow/c/tf_tensor_internal.h"
23 #include "tensorflow/core/framework/kernel_def_builder.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 // This file forms the basis of a stable ABI for third-party kernel
30 // implementations. It is crucial that changes to this file are made cautiously
31 // and with a focus on maintaining both source and binary compatibility.
32 
33 struct TF_KernelBuilder {
34   ::tensorflow::KernelDefBuilder* cc_builder;
35 
36   void* (*create_function)(TF_OpKernelConstruction*);
37   void (*compute_function)(void*, TF_OpKernelContext*);
38   void (*delete_function)(void*);
39 };
40 
TF_NewKernelBuilder(const char * op_name,const char * device_name,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))41 TF_KernelBuilder* TF_NewKernelBuilder(
42     const char* op_name, const char* device_name,
43     void* (*create_func)(TF_OpKernelConstruction*),
44     void (*compute_func)(void*, TF_OpKernelContext*),
45     void (*delete_func)(void*)) {
46   TF_KernelBuilder* result = new TF_KernelBuilder;
47   result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name);
48   result->cc_builder->Device(device_name);
49   result->create_function = create_func;
50   result->compute_function = compute_func;
51   result->delete_function = delete_func;
52   return result;
53 }
54 
TF_DeleteKernelBuilder(TF_KernelBuilder * builder)55 void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) {
56   if (builder != nullptr) {
57     delete builder->cc_builder;
58     delete builder;
59   }
60 }
61 
62 namespace tensorflow {
63 namespace {
64 
65 #define CASE(type)                                               \
66   case DataTypeToEnum<type>::value: {                            \
67     kernel_builder->cc_builder->TypeConstraint<type>(attr_name); \
68     break;                                                       \
69   }
70 
AddTypeConstraint(TF_KernelBuilder * kernel_builder,const char * attr_name,const DataType dtype,TF_Status * status)71 void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
72                        const DataType dtype, TF_Status* status) {
73   // This needs to be under tensorflow:: namespace so that
74   // TF_CALL_ALL_TYPES macro can find tensorflow::string as string.
75   switch (dtype) {
76     TF_CALL_ALL_TYPES(CASE);
77     default:
78       status->status = errors::Unimplemented("Unexpected type ", dtype);
79       return;
80   }
81   TF_SetStatus(status, TF_OK, "");
82 }
83 #undef CASE
84 }  // namespace
85 }  // namespace tensorflow
86 
TF_KernelBuilder_TypeConstraint(TF_KernelBuilder * kernel_builder,const char * attr_name,const TF_DataType type,TF_Status * status)87 void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
88                                      const char* attr_name,
89                                      const TF_DataType type,
90                                      TF_Status* status) {
91   tensorflow::DataType dtype = static_cast<tensorflow::DataType>(type);
92   tensorflow::AddTypeConstraint(kernel_builder, attr_name, dtype, status);
93 }
94 
TF_KernelBuilder_HostMemory(TF_KernelBuilder * kernel_builder,const char * arg_name)95 void TF_KernelBuilder_HostMemory(TF_KernelBuilder* kernel_builder,
96                                  const char* arg_name) {
97   kernel_builder->cc_builder->HostMemory(arg_name);
98 }
99 
100 namespace tensorflow {
101 namespace {
102 
103 // An OpKernel whose methods delegate to C function pointers.
104 class COpKernel : public OpKernel {
105  public:
COpKernel(OpKernelConstruction * ctx,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))106   explicit COpKernel(OpKernelConstruction* ctx,
107                      void* (*create_func)(TF_OpKernelConstruction*),
108                      void (*compute_func)(void*, TF_OpKernelContext*),
109                      void (*delete_func)(void*))
110       : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) {
111     if (create_func != nullptr) {
112       c_kernel_ =
113           (*create_func)(reinterpret_cast<TF_OpKernelConstruction*>(ctx));
114     } else {
115       c_kernel_ = nullptr;
116     }
117   }
118 
Compute(OpKernelContext * ctx)119   void Compute(OpKernelContext* ctx) override {
120     (*compute_func_)(c_kernel_, reinterpret_cast<TF_OpKernelContext*>(ctx));
121   }
122 
~COpKernel()123   ~COpKernel() override {
124     if (delete_func_ != nullptr) {
125       (*delete_func_)(c_kernel_);
126     }
127   }
128 
129  private:
130   void (*compute_func_)(void*, TF_OpKernelContext* context);
131   void (*delete_func_)(void*);
132   void* c_kernel_;
133 };
134 
135 // A KernelFactory that returns COpKernel instances.
136 class KernelBuilderFactory
137     : public ::tensorflow::kernel_factory::OpKernelFactory {
138  public:
KernelBuilderFactory(TF_KernelBuilder * builder)139   explicit KernelBuilderFactory(TF_KernelBuilder* builder)
140       : builder_(builder) {}
Create(::tensorflow::OpKernelConstruction * context)141   ::tensorflow::OpKernel* Create(
142       ::tensorflow::OpKernelConstruction* context) override {
143     return new ::tensorflow::COpKernel(context, builder_->create_function,
144                                        builder_->compute_function,
145                                        builder_->delete_function);
146   }
~KernelBuilderFactory()147   ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); }
148 
149  private:
150   TF_KernelBuilder* builder_;
151 };
152 }  // namespace
153 }  // namespace tensorflow
154 
TF_RegisterKernelBuilder(const char * name,TF_KernelBuilder * builder,TF_Status * status)155 void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
156                               TF_Status* status) {
157   using tensorflow::register_kernel::Name;
158 
159   tensorflow::kernel_factory::OpKernelRegistrar(
160       builder->cc_builder->Build(), name,
161       absl::make_unique<tensorflow::KernelBuilderFactory>(builder));
162 
163   TF_SetStatus(status, TF_OK, "");
164 }
165 
TF_NumInputs(TF_OpKernelContext * ctx)166 int TF_NumInputs(TF_OpKernelContext* ctx) {
167   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
168   return cc_ctx->num_inputs();
169 }
170 
TF_NumOutputs(TF_OpKernelContext * ctx)171 int TF_NumOutputs(TF_OpKernelContext* ctx) {
172   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
173   return cc_ctx->num_outputs();
174 }
175 
TF_GetInput(TF_OpKernelContext * ctx,int i,TF_Tensor ** tensor,TF_Status * status)176 void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
177                  TF_Status* status) {
178   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
179   if (i < 0 || i >= cc_ctx->num_inputs()) {
180     TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
181     return;
182   }
183   const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
184   TF_Tensor* result =
185       ::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
186   if (TF_GetCode(status) == TF_OK) {
187     *tensor = result;
188   }
189 }
190 
TF_SetOutput(TF_OpKernelContext * ctx,int i,const TF_Tensor * tensor,TF_Status * status)191 void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor,
192                   TF_Status* status) {
193   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
194   if (i < 0 || i >= cc_ctx->num_outputs()) {
195     TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
196     return;
197   }
198   ::tensorflow::Tensor cc_tensor;
199   ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor);
200   TF_SetStatus(status, TF_OK, "");
201   ::tensorflow::Set_TF_Status_from_Status(status, s);
202   if (s.ok()) {
203     cc_ctx->set_output(i, cc_tensor);
204   }
205 }
206 
TF_OpKernelConstruction_Failure(TF_OpKernelConstruction * ctx,TF_Status * status)207 void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx,
208                                      TF_Status* status) {
209   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
210   ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
211   cc_ctx->CtxFailure(s);
212 }
213 
TF_OpKernelContext_Failure(TF_OpKernelContext * ctx,TF_Status * status)214 void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
215   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
216   ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
217   cc_ctx->CtxFailure(s);
218 }
219 
220 #define DEFINE_TF_GETATTR(func, c_type, cc_type)                               \
221   void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx,     \
222                                              const char* attr_name,            \
223                                              c_type* val, TF_Status* status) { \
224     TF_SetStatus(status, TF_OK, "");                                           \
225     cc_type v;                                                                 \
226     auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \
227     ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);                   \
228     ::tensorflow::Set_TF_Status_from_Status(status, s);                        \
229     if (s.ok()) {                                                              \
230       *val = static_cast<c_type>(v);                                           \
231     }                                                                          \
232   }
233 
DEFINE_TF_GETATTR(Type,TF_DataType,tensorflow::DataType)234 DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
235 DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t)
236 
237 TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) {
238   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
239   return static_cast<TF_DataType>(cc_ctx->expected_output_dtype(i));
240 }
241 
TF_StepId(TF_OpKernelContext * ctx)242 int64_t TF_StepId(TF_OpKernelContext* ctx) {
243   return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id();
244 }
245 
TF_AllocateOutput(TF_OpKernelContext * context,int index,TF_DataType dtype,int64_t * dims,int num_dims,size_t len,TF_Status * status)246 TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
247                              TF_DataType dtype, int64_t* dims, int num_dims,
248                              size_t len, TF_Status* status) {
249   TF_SetStatus(status, TF_OK, "");
250   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
251   tensorflow::AllocatorAttributes attr = cc_ctx->output_alloc_attr(index);
252   auto* allocator = cc_ctx->get_allocator(attr);
253   void* data = tensorflow::allocate_tensor("TF_AllocateOutput", len, allocator);
254   TF_Tensor* result = TF_NewTensor(dtype, dims, num_dims, data, len,
255                                    tensorflow::deallocate_buffer, allocator);
256   TF_SetOutput(context, index, result, status);
257   if (TF_GetCode(status) != TF_OK) {
258     TF_DeleteTensor(result);
259     return nullptr;
260   }
261   return result;
262 }
263