• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/kernels.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/c/tf_status_helper.h"
22 #include "tensorflow/c/tf_tensor_internal.h"
23 #include "tensorflow/core/framework/kernel_def_builder.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/types.h"
27 // Required for IS_MOBILE_PLATFORM definition
28 #include "tensorflow/core/platform/platform.h"
29 #include "tensorflow/core/platform/types.h"
30 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
31 #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
32 #include "tensorflow/stream_executor/stream.h"
33 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
34 
35 using tensorflow::errors::InvalidArgument;
36 // This file forms the basis of a stable ABI for third-party kernel
37 // implementations. It is crucial that changes to this file are made cautiously
38 // and with a focus on maintaining both source and binary compatibility.
39 
40 struct TF_KernelBuilder {
41   ::tensorflow::KernelDefBuilder* cc_builder;
42 
43   void* (*create_function)(TF_OpKernelConstruction*);
44   void (*compute_function)(void*, TF_OpKernelContext*);
45   void (*delete_function)(void*);
46 };
47 
TF_NewKernelBuilder(const char * op_name,const char * device_name,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))48 TF_KernelBuilder* TF_NewKernelBuilder(
49     const char* op_name, const char* device_name,
50     void* (*create_func)(TF_OpKernelConstruction*),
51     void (*compute_func)(void*, TF_OpKernelContext*),
52     void (*delete_func)(void*)) {
53   TF_KernelBuilder* result = new TF_KernelBuilder;
54   result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name);
55   result->cc_builder->Device(device_name);
56   result->create_function = create_func;
57   result->compute_function = compute_func;
58   result->delete_function = delete_func;
59   return result;
60 }
61 
TF_DeleteKernelBuilder(TF_KernelBuilder * builder)62 void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) {
63   if (builder != nullptr) {
64     delete builder->cc_builder;
65     delete builder;
66   }
67 }
68 
69 namespace tensorflow {
70 namespace {
71 
72 #define CASE(type)                                               \
73   case DataTypeToEnum<type>::value: {                            \
74     kernel_builder->cc_builder->TypeConstraint<type>(attr_name); \
75     break;                                                       \
76   }
77 
AddTypeConstraint(TF_KernelBuilder * kernel_builder,const char * attr_name,const DataType dtype,TF_Status * status)78 void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
79                        const DataType dtype, TF_Status* status) {
80   // This needs to be under tensorflow:: namespace so that
81   // TF_CALL_ALL_TYPES macro can find tensorflow::string as string.
82   switch (dtype) {
83     TF_CALL_ALL_TYPES(CASE);
84     default:
85       status->status = errors::Unimplemented("Unexpected type ", dtype);
86       return;
87   }
88   TF_SetStatus(status, TF_OK, "");
89 }
90 #undef CASE
91 
92 }  // namespace
93 }  // namespace tensorflow
94 
95 namespace {
GetAttrValue(TF_OpKernelConstruction * ctx,const char * attr_name,TF_Status * status)96 const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx,
97                                           const char* attr_name,
98                                           TF_Status* status) {
99   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
100   const tensorflow::AttrValue* attr =
101       ::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name);
102   if (attr == nullptr) {
103     status->status = InvalidArgument("Operation '", cc_ctx->def().name(),
104                                      "' has no attr named '", attr_name, "'.");
105   }
106   return attr;
107 }
108 }  // namespace
109 
TF_KernelBuilder_TypeConstraint(TF_KernelBuilder * kernel_builder,const char * attr_name,const TF_DataType type,TF_Status * status)110 void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
111                                      const char* attr_name,
112                                      const TF_DataType type,
113                                      TF_Status* status) {
114   tensorflow::DataType dtype = static_cast<tensorflow::DataType>(type);
115   tensorflow::AddTypeConstraint(kernel_builder, attr_name, dtype, status);
116 }
117 
TF_KernelBuilder_HostMemory(TF_KernelBuilder * kernel_builder,const char * arg_name)118 void TF_KernelBuilder_HostMemory(TF_KernelBuilder* kernel_builder,
119                                  const char* arg_name) {
120   kernel_builder->cc_builder->HostMemory(arg_name);
121 }
122 
TF_KernelBuilder_Priority(TF_KernelBuilder * kernel_builder,int32_t priority_number)123 void TF_KernelBuilder_Priority(TF_KernelBuilder* kernel_builder,
124                                int32_t priority_number) {
125   kernel_builder->cc_builder->Priority(priority_number);
126 }
127 
128 namespace tensorflow {
129 namespace {
130 
131 // An OpKernel whose methods delegate to C function pointers.
132 class COpKernel : public OpKernel {
133  public:
COpKernel(OpKernelConstruction * ctx,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))134   explicit COpKernel(OpKernelConstruction* ctx,
135                      void* (*create_func)(TF_OpKernelConstruction*),
136                      void (*compute_func)(void*, TF_OpKernelContext*),
137                      void (*delete_func)(void*))
138       : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) {
139     if (create_func != nullptr) {
140       c_kernel_ =
141           (*create_func)(reinterpret_cast<TF_OpKernelConstruction*>(ctx));
142     } else {
143       c_kernel_ = nullptr;
144     }
145   }
146 
Compute(OpKernelContext * ctx)147   void Compute(OpKernelContext* ctx) override {
148     (*compute_func_)(c_kernel_, reinterpret_cast<TF_OpKernelContext*>(ctx));
149   }
150 
~COpKernel()151   ~COpKernel() override {
152     if (delete_func_ != nullptr) {
153       (*delete_func_)(c_kernel_);
154     }
155   }
156 
157  private:
158   void (*compute_func_)(void*, TF_OpKernelContext* context);
159   void (*delete_func_)(void*);
160   void* c_kernel_;
161 };
162 
163 // A KernelFactory that returns COpKernel instances.
164 class KernelBuilderFactory
165     : public ::tensorflow::kernel_factory::OpKernelFactory {
166  public:
KernelBuilderFactory(TF_KernelBuilder * builder)167   explicit KernelBuilderFactory(TF_KernelBuilder* builder)
168       : builder_(builder) {}
Create(::tensorflow::OpKernelConstruction * context)169   ::tensorflow::OpKernel* Create(
170       ::tensorflow::OpKernelConstruction* context) override {
171     return new ::tensorflow::COpKernel(context, builder_->create_function,
172                                        builder_->compute_function,
173                                        builder_->delete_function);
174   }
~KernelBuilderFactory()175   ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); }
176 
177  private:
178   TF_KernelBuilder* builder_;
179 };
180 }  // namespace
181 }  // namespace tensorflow
182 
TF_RegisterKernelBuilder(const char * name,TF_KernelBuilder * builder,TF_Status * status)183 void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
184                               TF_Status* status) {
185   using tensorflow::register_kernel::Name;
186 
187   tensorflow::kernel_factory::OpKernelRegistrar(
188       builder->cc_builder->Build(), name,
189       absl::make_unique<tensorflow::KernelBuilderFactory>(builder));
190 
191   TF_SetStatus(status, TF_OK, "");
192 }
193 
194 // This function is only for pluggable device.
195 // It will return nullptr in all other cases.
196 // This function is experimental and subject to change.
TF_GetStream(TF_OpKernelContext * ctx,TF_Status * status)197 SP_Stream TF_GetStream(TF_OpKernelContext* ctx, TF_Status* status) {
198 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
199   status->status = tensorflow::errors::Unimplemented(
200       "Accessing device stream is not supported on mobile. File a bug at "
201       "https://github.com/tensorflow/tensorflow/issues if this feature is "
202       "important to you");
203   return nullptr;
204 #else
205   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
206   if (cc_ctx->op_device_context() == nullptr) {  // CPU Device
207     status->status = tensorflow::errors::FailedPrecondition(
208         "Accessing device stream is not supported for a CPU device.");
209     return nullptr;
210   } else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
211     status->status = tensorflow::errors::FailedPrecondition(
212         "Accessing device stream is only supported for pluggable devices.");
213     return nullptr;
214   } else {  // Is a PluggableDevice
215     TF_SetStatus(status, TF_OK, "");
216     auto c_stream = static_cast<stream_executor::CStream*>(
217         cc_ctx->op_device_context()->stream()->implementation());
218     return c_stream->Handle();
219   }
220 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
221 }
222 
TF_NumInputs(TF_OpKernelContext * ctx)223 int TF_NumInputs(TF_OpKernelContext* ctx) {
224   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
225   return cc_ctx->num_inputs();
226 }
227 
TF_NumOutputs(TF_OpKernelContext * ctx)228 int TF_NumOutputs(TF_OpKernelContext* ctx) {
229   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
230   return cc_ctx->num_outputs();
231 }
232 
TF_GetInput(TF_OpKernelContext * ctx,int i,TF_Tensor ** tensor,TF_Status * status)233 void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
234                  TF_Status* status) {
235   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
236   if (i < 0 || i >= cc_ctx->num_inputs()) {
237     TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
238     return;
239   }
240   const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
241   TF_Tensor* result =
242       ::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
243   if (TF_GetCode(status) == TF_OK) {
244     *tensor = result;
245   }
246 }
247 
TF_SetOutput(TF_OpKernelContext * ctx,int i,const TF_Tensor * tensor,TF_Status * status)248 void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor,
249                   TF_Status* status) {
250   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
251   if (i < 0 || i >= cc_ctx->num_outputs()) {
252     TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
253     return;
254   }
255   ::tensorflow::Tensor cc_tensor;
256   ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor);
257   TF_SetStatus(status, TF_OK, "");
258   ::tensorflow::Set_TF_Status_from_Status(status, s);
259   if (s.ok()) {
260     cc_ctx->set_output(i, cc_tensor);
261   }
262 }
263 
TF_OpKernelConstruction_Failure(TF_OpKernelConstruction * ctx,TF_Status * status)264 void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx,
265                                      TF_Status* status) {
266   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
267   ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
268   cc_ctx->CtxFailure(s);
269 }
270 
TF_OpKernelContext_Failure(TF_OpKernelContext * ctx,TF_Status * status)271 void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
272   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
273   ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
274   cc_ctx->CtxFailure(s);
275 }
276 
TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction * ctx,const char * attr_name,int32_t * list_size,int32_t * total_size,TF_Status * status)277 void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
278                                          const char* attr_name,
279                                          int32_t* list_size,
280                                          int32_t* total_size,
281                                          TF_Status* status) {
282   const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);
283   if (!status->status.ok()) {
284     *list_size = -1;
285     *total_size = -1;
286     return;
287   }
288   switch (attr->value_case()) {
289 #define SINGLE_CASE(kK, attr_type, size_expr) \
290   case tensorflow::AttrValue::kK:             \
291     *list_size = -1;                          \
292     *total_size = size_expr;                  \
293     break;
294 
295     SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
296     SINGLE_CASE(kI, TF_ATTR_INT, -1);
297     SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
298     SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
299     SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
300     SINGLE_CASE(kShape, TF_ATTR_SHAPE,
301                 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
302     SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
303 #undef SINGLE_CASE
304 
305     case tensorflow::AttrValue::kList:
306       *list_size = 0;
307       *total_size = -1;
308 #define LIST_CASE(field, attr_type, ...)      \
309   if (attr->list().field##_size() > 0) {      \
310     *list_size = attr->list().field##_size(); \
311     __VA_ARGS__;                              \
312     break;                                    \
313   }
314 
315       LIST_CASE(
316           s, TF_ATTR_STRING, *total_size = 0;
317           for (int i = 0; i < attr->list().s_size();
318                ++i) { *total_size += attr->list().s(i).size(); });
319       LIST_CASE(i, TF_ATTR_INT);
320       LIST_CASE(f, TF_ATTR_FLOAT);
321       LIST_CASE(b, TF_ATTR_BOOL);
322       LIST_CASE(type, TF_ATTR_TYPE);
323       LIST_CASE(
324           shape, TF_ATTR_SHAPE, *total_size = 0;
325           for (int i = 0; i < attr->list().shape_size(); ++i) {
326             const auto& s = attr->list().shape(i);
327             *total_size += s.unknown_rank() ? 0 : s.dim_size();
328           });
329       LIST_CASE(tensor, TF_ATTR_TENSOR);
330       LIST_CASE(tensor, TF_ATTR_FUNC);
331 #undef LIST_CASE
332       break;
333 
334     case tensorflow::AttrValue::kPlaceholder:
335       *list_size = -1;
336       *total_size = -1;
337       break;
338 
339     case tensorflow::AttrValue::kFunc:
340       *list_size = -1;
341       *total_size = -1;
342       break;
343 
344     case tensorflow::AttrValue::VALUE_NOT_SET:
345       status->status =
346           InvalidArgument("Attribute '", attr_name, "' has no value set");
347       break;
348   }
349 }
350 
351 #define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field)        \
352   void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx,     \
353                                              const char* attr_name,            \
354                                              c_type* val, TF_Status* status) { \
355     TF_SetStatus(status, TF_OK, "");                                           \
356     cc_type v;                                                                 \
357     auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \
358     ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);                   \
359     ::tensorflow::Set_TF_Status_from_Status(status, s);                        \
360     if (s.ok()) {                                                              \
361       *val = static_cast<c_type>(v);                                           \
362     }                                                                          \
363   }                                                                            \
364   void TF_OpKernelConstruction_GetAttr##func##List(                            \
365       TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals,       \
366       int max_vals, TF_Status* status) {                                       \
367     TF_SetStatus(status, TF_OK, "");                                           \
368     const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);  \
369     if (!status->status.ok()) return;                                          \
370     if (attr->value_case() != tensorflow::AttrValue::kList) {                  \
371       status->status =                                                         \
372           InvalidArgument("Value for '", attr_name, "' is not a list.");       \
373       return;                                                                  \
374     }                                                                          \
375     status->status =                                                           \
376         tensorflow::AttrValueHasType(*attr, "list(" attr_type ")");            \
377     if (!status->status.ok()) return;                                          \
378     const auto len = std::min(max_vals, attr->list().list_field##_size());     \
379     for (int i = 0; i < len; ++i) {                                            \
380       vals[i] = static_cast<c_type>(attr->list().list_field(i));               \
381     }                                                                          \
382   }
383 
384 DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type)
385 DEFINE_TF_GETATTR(Int32, int32_t, tensorflow::int32, "int", i)
386 DEFINE_TF_GETATTR(Int64, int64_t, tensorflow::int64, "int", i)
387 DEFINE_TF_GETATTR(Float, float, float, "float", f)
388 DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b)
389 
TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction * ctx,const char * attr_name,char * value,size_t max_length,TF_Status * status)390 void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx,
391                                            const char* attr_name, char* value,
392                                            size_t max_length,
393                                            TF_Status* status) {
394   std::string v;
395   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
396   ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
397   ::tensorflow::Set_TF_Status_from_Status(status, s);
398 
399   if (!status->status.ok()) return;
400 
401   if (max_length <= 0) {
402     return;
403   }
404   std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length));
405 }
406 
TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction * ctx,const char * attr_name,char ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)407 void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx,
408                                                const char* attr_name,
409                                                char** values, size_t* lengths,
410                                                int max_values, void* storage,
411                                                size_t storage_size,
412                                                TF_Status* status) {
413   std::vector<std::string> v;
414   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
415   ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
416   ::tensorflow::Set_TF_Status_from_Status(status, s);
417 
418   if (!status->status.ok()) return;
419 
420   const auto len = std::min(max_values, static_cast<int>(v.size()));
421   char* p = static_cast<char*>(storage);
422   for (int i = 0; i < len; ++i) {
423     const std::string& s = v[i];
424     values[i] = p;
425     lengths[i] = s.size();
426     if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
427       status->status = InvalidArgument(
428           "Not enough storage to hold the requested list of strings");
429       return;
430     }
431     memcpy(values[i], s.data(), s.size());
432     p += s.size();
433   }
434 }
435 
TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction * ctx,const char * attr_name,TF_Status * status)436 bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx,
437                                      const char* attr_name, TF_Status* status) {
438   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
439   return cc_ctx->HasAttr(attr_name);
440 }
441 
TF_OpKernelConstruction_GetName(TF_OpKernelConstruction * ctx)442 TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
443   auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);
444   TF_StringView string_view_of_name;
445   string_view_of_name.data = cc_ctx->def().name().data();
446   string_view_of_name.len = cc_ctx->def().name().length();
447   return string_view_of_name;
448 }
449 
TF_ExpectedOutputDataType(TF_OpKernelContext * ctx,int i)450 TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) {
451   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
452   return static_cast<TF_DataType>(cc_ctx->expected_output_dtype(i));
453 }
454 
TF_StepId(TF_OpKernelContext * ctx)455 int64_t TF_StepId(TF_OpKernelContext* ctx) {
456   return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id();
457 }
458 
TF_AllocateOutput(TF_OpKernelContext * context,int index,TF_DataType dtype,int64_t * dims,int num_dims,size_t len,TF_Status * status)459 TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
460                              TF_DataType dtype, int64_t* dims, int num_dims,
461                              size_t len, TF_Status* status) {
462   TF_SetStatus(status, TF_OK, "");
463   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
464   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
465                 "64-bit int types should match in size");
466   tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
467       reinterpret_cast<tensorflow::int64*>(dims), num_dims);
468   tensorflow::Tensor* tensor;
469   tensorflow::Status s = cc_ctx->allocate_output(
470       index, tensorflow::TensorShape(dimarray), &tensor);
471   if (!s.ok()) {
472     ::tensorflow::Set_TF_Status_from_Status(status, s);
473     return nullptr;
474   }
475   TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s);
476   if (!s.ok()) {
477     ::tensorflow::Set_TF_Status_from_Status(status, s);
478     return nullptr;
479   }
480   return tf_tensor;
481 }
482 
TF_ForwardInputOrAllocateOutput(TF_OpKernelContext * context,int * candidate_input_indices,int num_candidate_input_indices,int output_index,int64_t * output_dims,int output_num_dims,int * forwarded_input,TF_Status * status)483 TF_Tensor* TF_ForwardInputOrAllocateOutput(
484     TF_OpKernelContext* context, int* candidate_input_indices,
485     int num_candidate_input_indices, int output_index, int64_t* output_dims,
486     int output_num_dims, int* forwarded_input, TF_Status* status) {
487   TF_SetStatus(status, TF_OK, "");
488   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
489 
490   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
491                 "64-bit int types should match in size");
492   tensorflow::gtl::ArraySlice<int> input_indices_array(
493       candidate_input_indices, num_candidate_input_indices);
494   tensorflow::gtl::ArraySlice<tensorflow::int64> output_dimarray(
495       reinterpret_cast<tensorflow::int64*>(output_dims), output_num_dims);
496   tensorflow::Tensor* output_tensor_pointer;
497   tensorflow::Status s = cc_ctx->forward_input_or_allocate_output(
498       input_indices_array, output_index,
499       tensorflow::TensorShape(output_dimarray), &output_tensor_pointer,
500       forwarded_input);
501   if (!s.ok()) {
502     ::tensorflow::Set_TF_Status_from_Status(status, s);
503     return nullptr;
504   }
505   TF_Tensor* tf_tensor_output = TF_TensorFromTensor(*output_tensor_pointer, &s);
506   if (!s.ok()) {
507     ::tensorflow::Set_TF_Status_from_Status(status, s);
508     return nullptr;
509   }
510   return tf_tensor_output;
511 }
512 
TF_AllocateTemp(TF_OpKernelContext * context,TF_DataType dtype,int64_t * dims,int num_dims,TF_AllocatorAttributes * attributes,TF_Status * status)513 TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype,
514                            int64_t* dims, int num_dims,
515                            TF_AllocatorAttributes* attributes,
516                            TF_Status* status) {
517   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
518   TF_SetStatus(status, TF_OK, "");
519   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
520                 "64-bit int types should match in size");
521   tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
522       reinterpret_cast<tensorflow::int64*>(dims), num_dims);
523   if (attributes && !attributes->struct_size) {
524     TF_SetStatus(
525         status, TF_INVALID_ARGUMENT,
526         "TF_AllocatorAttributes struct "
527         "size member must be set to TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE");
528     return nullptr;
529   }
530   tensorflow::AllocatorAttributes allocator_attr;
531   if (attributes && attributes->on_host) {
532     allocator_attr.set_on_host(true);
533   }
534   tensorflow::Status s;
535   tensorflow::Tensor tensor;
536   s = cc_ctx->allocate_temp(static_cast<tensorflow::DataType>(dtype),
537                             tensorflow::TensorShape(dimarray), &tensor,
538                             allocator_attr);
539   if (!s.ok()) {
540     ::tensorflow::Set_TF_Status_from_Status(status, s);
541     return nullptr;
542   }
543   TF_Tensor* tf_tensor;
544   tf_tensor = TF_TensorFromTensor(tensor, &s);
545   if (!s.ok()) {
546     ::tensorflow::Set_TF_Status_from_Status(status, s);
547     return nullptr;
548   }
549   return tf_tensor;
550 }
551