• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/kernels.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/c/tf_status_helper.h"
22 #include "tensorflow/c/tf_tensor_internal.h"
23 #include "tensorflow/core/framework/kernel_def_builder.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/types.h"
27 // Required for IS_MOBILE_PLATFORM definition
28 #include "tensorflow/core/platform/platform.h"
29 #include "tensorflow/core/platform/types.h"
30 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
31 #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
32 #include "tensorflow/stream_executor/stream.h"
33 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
34 
35 using tensorflow::errors::InvalidArgument;
36 // This file forms the basis of a stable ABI for third-party kernel
37 // implementations. It is crucial that changes to this file are made cautiously
38 // and with a focus on maintaining both source and binary compatibility.
39 
40 struct TF_KernelBuilder {
41   ::tensorflow::KernelDefBuilder* cc_builder;
42 
43   void* (*create_function)(TF_OpKernelConstruction*);
44   void (*compute_function)(void*, TF_OpKernelContext*);
45   void (*delete_function)(void*);
46 };
47 
TF_NewKernelBuilder(const char * op_name,const char * device_name,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))48 TF_KernelBuilder* TF_NewKernelBuilder(
49     const char* op_name, const char* device_name,
50     void* (*create_func)(TF_OpKernelConstruction*),
51     void (*compute_func)(void*, TF_OpKernelContext*),
52     void (*delete_func)(void*)) {
53   TF_KernelBuilder* result = new TF_KernelBuilder;
54   result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name);
55   result->cc_builder->Device(device_name);
56   result->create_function = create_func;
57   result->compute_function = compute_func;
58   result->delete_function = delete_func;
59   return result;
60 }
61 
TF_DeleteKernelBuilder(TF_KernelBuilder * builder)62 void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) {
63   if (builder != nullptr) {
64     delete builder->cc_builder;
65     delete builder;
66   }
67 }
68 
69 namespace tensorflow {
70 namespace {
71 
72 #define CASE(type)                                               \
73   case DataTypeToEnum<type>::value: {                            \
74     kernel_builder->cc_builder->TypeConstraint<type>(attr_name); \
75     break;                                                       \
76   }
77 
AddTypeConstraint(TF_KernelBuilder * kernel_builder,const char * attr_name,const DataType dtype,TF_Status * status)78 void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
79                        const DataType dtype, TF_Status* status) {
80   // This needs to be under tensorflow:: namespace so that
81   // TF_CALL_ALL_TYPES macro can find tensorflow::string as string.
82   switch (dtype) {
83     TF_CALL_ALL_TYPES(CASE);
84     TF_CALL_QUANTIZED_TYPES(CASE);
85     TF_CALL_quint16(CASE);
86     TF_CALL_qint16(CASE);
87     default:
88       status->status = errors::Unimplemented("Unexpected type ", dtype);
89       return;
90   }
91   TF_SetStatus(status, TF_OK, "");
92 }
93 #undef CASE
94 
95 }  // namespace
96 }  // namespace tensorflow
97 
98 namespace {
GetAttrValue(TF_OpKernelConstruction * ctx,const char * attr_name,TF_Status * status)99 const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx,
100                                           const char* attr_name,
101                                           TF_Status* status) {
102   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
103   const tensorflow::AttrValue* attr =
104       ::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name);
105   if (attr == nullptr) {
106     status->status = InvalidArgument("Operation '", cc_ctx->def().name(),
107                                      "' has no attr named '", attr_name, "'.");
108   }
109   return attr;
110 }
111 }  // namespace
112 
TF_KernelBuilder_TypeConstraint(TF_KernelBuilder * kernel_builder,const char * attr_name,const TF_DataType type,TF_Status * status)113 void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
114                                      const char* attr_name,
115                                      const TF_DataType type,
116                                      TF_Status* status) {
117   tensorflow::DataType dtype = static_cast<tensorflow::DataType>(type);
118   tensorflow::AddTypeConstraint(kernel_builder, attr_name, dtype, status);
119 }
120 
TF_KernelBuilder_HostMemory(TF_KernelBuilder * kernel_builder,const char * arg_name)121 void TF_KernelBuilder_HostMemory(TF_KernelBuilder* kernel_builder,
122                                  const char* arg_name) {
123   kernel_builder->cc_builder->HostMemory(arg_name);
124 }
125 
TF_KernelBuilder_Priority(TF_KernelBuilder * kernel_builder,int32_t priority_number)126 void TF_KernelBuilder_Priority(TF_KernelBuilder* kernel_builder,
127                                int32_t priority_number) {
128   kernel_builder->cc_builder->Priority(priority_number);
129 }
130 
131 namespace tensorflow {
132 namespace {
133 
134 // An OpKernel whose methods delegate to C function pointers.
135 class COpKernel : public OpKernel {
136  public:
COpKernel(OpKernelConstruction * ctx,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))137   explicit COpKernel(OpKernelConstruction* ctx,
138                      void* (*create_func)(TF_OpKernelConstruction*),
139                      void (*compute_func)(void*, TF_OpKernelContext*),
140                      void (*delete_func)(void*))
141       : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) {
142     if (create_func != nullptr) {
143       c_kernel_ =
144           (*create_func)(reinterpret_cast<TF_OpKernelConstruction*>(ctx));
145     } else {
146       c_kernel_ = nullptr;
147     }
148   }
149 
Compute(OpKernelContext * ctx)150   void Compute(OpKernelContext* ctx) override {
151     (*compute_func_)(c_kernel_, reinterpret_cast<TF_OpKernelContext*>(ctx));
152   }
153 
~COpKernel()154   ~COpKernel() override {
155     if (delete_func_ != nullptr) {
156       (*delete_func_)(c_kernel_);
157     }
158   }
159 
160  private:
161   void (*compute_func_)(void*, TF_OpKernelContext* context);
162   void (*delete_func_)(void*);
163   void* c_kernel_;
164 };
165 
166 // A KernelFactory that returns COpKernel instances.
167 class KernelBuilderFactory
168     : public ::tensorflow::kernel_factory::OpKernelFactory {
169  public:
KernelBuilderFactory(TF_KernelBuilder * builder)170   explicit KernelBuilderFactory(TF_KernelBuilder* builder)
171       : builder_(builder) {}
Create(::tensorflow::OpKernelConstruction * context)172   ::tensorflow::OpKernel* Create(
173       ::tensorflow::OpKernelConstruction* context) override {
174     return new ::tensorflow::COpKernel(context, builder_->create_function,
175                                        builder_->compute_function,
176                                        builder_->delete_function);
177   }
~KernelBuilderFactory()178   ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); }
179 
180  private:
181   TF_KernelBuilder* builder_;
182 };
183 }  // namespace
184 }  // namespace tensorflow
185 
TF_RegisterKernelBuilder(const char * name,TF_KernelBuilder * builder,TF_Status * status)186 void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
187                               TF_Status* status) {
188   using tensorflow::register_kernel::Name;
189 
190   tensorflow::kernel_factory::OpKernelRegistrar(
191       builder->cc_builder->Build(), name,
192       absl::make_unique<tensorflow::KernelBuilderFactory>(builder));
193 
194   TF_SetStatus(status, TF_OK, "");
195 }
196 
197 // This function is only for pluggable device.
198 // It will return nullptr in all other cases.
199 // This function is experimental and subject to change.
TF_GetStream(TF_OpKernelContext * ctx,TF_Status * status)200 SP_Stream TF_GetStream(TF_OpKernelContext* ctx, TF_Status* status) {
201 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
202   status->status = tensorflow::errors::Unimplemented(
203       "Accessing device stream is not supported on mobile. File a bug at "
204       "https://github.com/tensorflow/tensorflow/issues if this feature is "
205       "important to you");
206   return nullptr;
207 #else
208   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
209   if (cc_ctx->op_device_context() == nullptr) {  // CPU Device
210     status->status = tensorflow::errors::FailedPrecondition(
211         "Accessing device stream is not supported for a CPU device.");
212     return nullptr;
213   } else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
214     status->status = tensorflow::errors::FailedPrecondition(
215         "Accessing device stream is only supported for pluggable devices.");
216     return nullptr;
217   } else {  // Is a PluggableDevice
218     TF_SetStatus(status, TF_OK, "");
219     auto c_stream = static_cast<stream_executor::CStream*>(
220         cc_ctx->op_device_context()->stream()->implementation());
221     return c_stream->Handle();
222   }
223 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
224 }
225 
TF_NumInputs(TF_OpKernelContext * ctx)226 int TF_NumInputs(TF_OpKernelContext* ctx) {
227   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
228   return cc_ctx->num_inputs();
229 }
230 
TF_NumOutputs(TF_OpKernelContext * ctx)231 int TF_NumOutputs(TF_OpKernelContext* ctx) {
232   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
233   return cc_ctx->num_outputs();
234 }
235 
TF_GetInput(TF_OpKernelContext * ctx,int i,TF_Tensor ** tensor,TF_Status * status)236 void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
237                  TF_Status* status) {
238   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
239   if (i < 0 || i >= cc_ctx->num_inputs()) {
240     TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
241     return;
242   }
243   const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
244   TF_Tensor* result =
245       ::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
246   if (TF_GetCode(status) == TF_OK) {
247     *tensor = result;
248   }
249 }
250 
TF_SetOutput(TF_OpKernelContext * ctx,int i,const TF_Tensor * tensor,TF_Status * status)251 void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor,
252                   TF_Status* status) {
253   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
254   if (i < 0 || i >= cc_ctx->num_outputs()) {
255     TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
256     return;
257   }
258   ::tensorflow::Tensor cc_tensor;
259   ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor);
260   TF_SetStatus(status, TF_OK, "");
261   ::tensorflow::Set_TF_Status_from_Status(status, s);
262   if (s.ok()) {
263     cc_ctx->set_output(i, cc_tensor);
264   }
265 }
266 
TF_OpKernelConstruction_Failure(TF_OpKernelConstruction * ctx,TF_Status * status)267 void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx,
268                                      TF_Status* status) {
269   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
270   ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
271   cc_ctx->CtxFailure(s);
272 }
273 
TF_OpKernelContext_Failure(TF_OpKernelContext * ctx,TF_Status * status)274 void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
275   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
276   ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
277   cc_ctx->CtxFailure(s);
278 }
279 
TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction * ctx,const char * attr_name,int32_t * list_size,int32_t * total_size,TF_Status * status)280 void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
281                                          const char* attr_name,
282                                          int32_t* list_size,
283                                          int32_t* total_size,
284                                          TF_Status* status) {
285   const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);
286   if (!status->status.ok()) {
287     *list_size = -1;
288     *total_size = -1;
289     return;
290   }
291   switch (attr->value_case()) {
292 #define SINGLE_CASE(kK, attr_type, size_expr) \
293   case tensorflow::AttrValue::kK:             \
294     *list_size = -1;                          \
295     *total_size = size_expr;                  \
296     break;
297 
298     SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
299     SINGLE_CASE(kI, TF_ATTR_INT, -1);
300     SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
301     SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
302     SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
303     SINGLE_CASE(kShape, TF_ATTR_SHAPE,
304                 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
305     SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
306 #undef SINGLE_CASE
307 
308     case tensorflow::AttrValue::kList:
309       *list_size = 0;
310       *total_size = -1;
311 #define LIST_CASE(field, attr_type, ...)      \
312   if (attr->list().field##_size() > 0) {      \
313     *list_size = attr->list().field##_size(); \
314     __VA_ARGS__;                              \
315     break;                                    \
316   }
317 
318       LIST_CASE(
319           s, TF_ATTR_STRING, *total_size = 0;
320           for (int i = 0; i < attr->list().s_size();
321                ++i) { *total_size += attr->list().s(i).size(); });
322       LIST_CASE(i, TF_ATTR_INT);
323       LIST_CASE(f, TF_ATTR_FLOAT);
324       LIST_CASE(b, TF_ATTR_BOOL);
325       LIST_CASE(type, TF_ATTR_TYPE);
326       LIST_CASE(
327           shape, TF_ATTR_SHAPE, *total_size = 0;
328           for (int i = 0; i < attr->list().shape_size(); ++i) {
329             const auto& s = attr->list().shape(i);
330             *total_size += s.unknown_rank() ? 0 : s.dim_size();
331           });
332       LIST_CASE(tensor, TF_ATTR_TENSOR);
333       LIST_CASE(tensor, TF_ATTR_FUNC);
334 #undef LIST_CASE
335       break;
336 
337     case tensorflow::AttrValue::kPlaceholder:
338       *list_size = -1;
339       *total_size = -1;
340       break;
341 
342     case tensorflow::AttrValue::kFunc:
343       *list_size = -1;
344       *total_size = -1;
345       break;
346 
347     case tensorflow::AttrValue::VALUE_NOT_SET:
348       status->status =
349           InvalidArgument("Attribute '", attr_name, "' has no value set");
350       break;
351   }
352 }
353 
354 #define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field)        \
355   void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx,     \
356                                              const char* attr_name,            \
357                                              c_type* val, TF_Status* status) { \
358     TF_SetStatus(status, TF_OK, "");                                           \
359     cc_type v;                                                                 \
360     auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \
361     ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);                   \
362     ::tensorflow::Set_TF_Status_from_Status(status, s);                        \
363     if (s.ok()) {                                                              \
364       *val = static_cast<c_type>(v);                                           \
365     }                                                                          \
366   }                                                                            \
367   void TF_OpKernelConstruction_GetAttr##func##List(                            \
368       TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals,       \
369       int max_vals, TF_Status* status) {                                       \
370     TF_SetStatus(status, TF_OK, "");                                           \
371     const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);  \
372     if (!status->status.ok()) return;                                          \
373     if (attr->value_case() != tensorflow::AttrValue::kList) {                  \
374       status->status =                                                         \
375           InvalidArgument("Value for '", attr_name, "' is not a list.");       \
376       return;                                                                  \
377     }                                                                          \
378     status->status =                                                           \
379         tensorflow::AttrValueHasType(*attr, "list(" attr_type ")");            \
380     if (!status->status.ok()) return;                                          \
381     const auto len = std::min(max_vals, attr->list().list_field##_size());     \
382     for (int i = 0; i < len; ++i) {                                            \
383       vals[i] = static_cast<c_type>(attr->list().list_field(i));               \
384     }                                                                          \
385   }
386 
387 DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type)
388 DEFINE_TF_GETATTR(Int32, int32_t, int32_t, "int", i)
389 DEFINE_TF_GETATTR(Int64, int64_t, int64_t, "int", i)
390 DEFINE_TF_GETATTR(Float, float, float, "float", f)
391 DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b)
392 
TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction * ctx,const char * attr_name,char * value,size_t max_length,TF_Status * status)393 void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx,
394                                            const char* attr_name, char* value,
395                                            size_t max_length,
396                                            TF_Status* status) {
397   std::string v;
398   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
399   ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
400   ::tensorflow::Set_TF_Status_from_Status(status, s);
401 
402   if (!status->status.ok()) return;
403 
404   if (max_length <= 0) {
405     return;
406   }
407   std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length));
408 }
409 
TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction * ctx,const char * attr_name,char ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)410 void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx,
411                                                const char* attr_name,
412                                                char** values, size_t* lengths,
413                                                int max_values, void* storage,
414                                                size_t storage_size,
415                                                TF_Status* status) {
416   std::vector<std::string> v;
417   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
418   ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
419   ::tensorflow::Set_TF_Status_from_Status(status, s);
420 
421   if (!status->status.ok()) return;
422 
423   const auto len = std::min(max_values, static_cast<int>(v.size()));
424   char* p = static_cast<char*>(storage);
425   for (int i = 0; i < len; ++i) {
426     const std::string& s = v[i];
427     values[i] = p;
428     lengths[i] = s.size();
429     if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
430       status->status = InvalidArgument(
431           "Not enough storage to hold the requested list of strings");
432       return;
433     }
434     memcpy(values[i], s.data(), s.size());
435     p += s.size();
436   }
437 }
438 
TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction * ctx,const char * attr_name,TF_Status * status)439 bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx,
440                                      const char* attr_name, TF_Status* status) {
441   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
442   return cc_ctx->HasAttr(attr_name);
443 }
444 
TF_OpKernelConstruction_GetName(TF_OpKernelConstruction * ctx)445 TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
446   auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);
447   TF_StringView string_view_of_name;
448   string_view_of_name.data = cc_ctx->def().name().data();
449   string_view_of_name.len = cc_ctx->def().name().length();
450   return string_view_of_name;
451 }
452 
TF_ExpectedOutputDataType(TF_OpKernelContext * ctx,int i)453 TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) {
454   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
455   return static_cast<TF_DataType>(cc_ctx->expected_output_dtype(i));
456 }
457 
TF_StepId(TF_OpKernelContext * ctx)458 int64_t TF_StepId(TF_OpKernelContext* ctx) {
459   return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id();
460 }
461 
TF_AllocateOutput(TF_OpKernelContext * context,int index,TF_DataType dtype,int64_t * dims,int num_dims,size_t len,TF_Status * status)462 TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
463                              TF_DataType dtype, int64_t* dims, int num_dims,
464                              size_t len, TF_Status* status) {
465   TF_SetStatus(status, TF_OK, "");
466   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
467   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
468                 "64-bit int types should match in size");
469   tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
470       reinterpret_cast<tensorflow::int64*>(dims), num_dims);
471   tensorflow::Tensor* tensor;
472   tensorflow::Status s = cc_ctx->allocate_output(
473       index, tensorflow::TensorShape(dimarray), &tensor);
474   if (!s.ok()) {
475     ::tensorflow::Set_TF_Status_from_Status(status, s);
476     return nullptr;
477   }
478   TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s);
479   if (!s.ok()) {
480     ::tensorflow::Set_TF_Status_from_Status(status, s);
481     return nullptr;
482   }
483   return tf_tensor;
484 }
485 
TF_ForwardInputOrAllocateOutput(TF_OpKernelContext * context,int * candidate_input_indices,int num_candidate_input_indices,int output_index,int64_t * output_dims,int output_num_dims,int * forwarded_input,TF_Status * status)486 TF_Tensor* TF_ForwardInputOrAllocateOutput(
487     TF_OpKernelContext* context, int* candidate_input_indices,
488     int num_candidate_input_indices, int output_index, int64_t* output_dims,
489     int output_num_dims, int* forwarded_input, TF_Status* status) {
490   TF_SetStatus(status, TF_OK, "");
491   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
492 
493   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
494                 "64-bit int types should match in size");
495   tensorflow::gtl::ArraySlice<int> input_indices_array(
496       candidate_input_indices, num_candidate_input_indices);
497   tensorflow::gtl::ArraySlice<tensorflow::int64> output_dimarray(
498       reinterpret_cast<tensorflow::int64*>(output_dims), output_num_dims);
499   tensorflow::Tensor* output_tensor_pointer;
500   tensorflow::Status s = cc_ctx->forward_input_or_allocate_output(
501       input_indices_array, output_index,
502       tensorflow::TensorShape(output_dimarray), &output_tensor_pointer,
503       forwarded_input);
504   if (!s.ok()) {
505     ::tensorflow::Set_TF_Status_from_Status(status, s);
506     return nullptr;
507   }
508   TF_Tensor* tf_tensor_output = TF_TensorFromTensor(*output_tensor_pointer, &s);
509   if (!s.ok()) {
510     ::tensorflow::Set_TF_Status_from_Status(status, s);
511     return nullptr;
512   }
513   return tf_tensor_output;
514 }
515 
TF_AllocateTemp(TF_OpKernelContext * context,TF_DataType dtype,int64_t * dims,int num_dims,TF_AllocatorAttributes * attributes,TF_Status * status)516 TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype,
517                            int64_t* dims, int num_dims,
518                            TF_AllocatorAttributes* attributes,
519                            TF_Status* status) {
520   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
521   TF_SetStatus(status, TF_OK, "");
522   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
523                 "64-bit int types should match in size");
524   tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
525       reinterpret_cast<tensorflow::int64*>(dims), num_dims);
526   if (attributes && !attributes->struct_size) {
527     TF_SetStatus(
528         status, TF_INVALID_ARGUMENT,
529         "TF_AllocatorAttributes struct "
530         "size member must be set to TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE");
531     return nullptr;
532   }
533   tensorflow::AllocatorAttributes allocator_attr;
534   if (attributes && attributes->on_host) {
535     allocator_attr.set_on_host(true);
536   }
537   tensorflow::Status s;
538   tensorflow::Tensor tensor;
539   s = cc_ctx->allocate_temp(static_cast<tensorflow::DataType>(dtype),
540                             tensorflow::TensorShape(dimarray), &tensor,
541                             allocator_attr);
542   if (!s.ok()) {
543     ::tensorflow::Set_TF_Status_from_Status(status, s);
544     return nullptr;
545   }
546   TF_Tensor* tf_tensor;
547   tf_tensor = TF_TensorFromTensor(tensor, &s);
548   if (!s.ok()) {
549     ::tensorflow::Set_TF_Status_from_Status(status, s);
550     return nullptr;
551   }
552   return tf_tensor;
553 }
554