1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/kernels.h"
17
18 #include <memory>
19
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/c/tf_status_helper.h"
22 #include "tensorflow/c/tf_tensor_internal.h"
23 #include "tensorflow/core/framework/kernel_def_builder.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/types.h"
27 // Required for IS_MOBILE_PLATFORM definition
28 #include "tensorflow/core/platform/platform.h"
29 #include "tensorflow/core/platform/types.h"
30 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
31 #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
32 #include "tensorflow/stream_executor/stream.h"
33 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
34
35 using tensorflow::errors::InvalidArgument;
36 // This file forms the basis of a stable ABI for third-party kernel
37 // implementations. It is crucial that changes to this file are made cautiously
38 // and with a focus on maintaining both source and binary compatibility.
39
40 struct TF_KernelBuilder {
41 ::tensorflow::KernelDefBuilder* cc_builder;
42
43 void* (*create_function)(TF_OpKernelConstruction*);
44 void (*compute_function)(void*, TF_OpKernelContext*);
45 void (*delete_function)(void*);
46 };
47
TF_NewKernelBuilder(const char * op_name,const char * device_name,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))48 TF_KernelBuilder* TF_NewKernelBuilder(
49 const char* op_name, const char* device_name,
50 void* (*create_func)(TF_OpKernelConstruction*),
51 void (*compute_func)(void*, TF_OpKernelContext*),
52 void (*delete_func)(void*)) {
53 TF_KernelBuilder* result = new TF_KernelBuilder;
54 result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name);
55 result->cc_builder->Device(device_name);
56 result->create_function = create_func;
57 result->compute_function = compute_func;
58 result->delete_function = delete_func;
59 return result;
60 }
61
TF_DeleteKernelBuilder(TF_KernelBuilder * builder)62 void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) {
63 if (builder != nullptr) {
64 delete builder->cc_builder;
65 delete builder;
66 }
67 }
68
69 namespace tensorflow {
70 namespace {
71
72 #define CASE(type) \
73 case DataTypeToEnum<type>::value: { \
74 kernel_builder->cc_builder->TypeConstraint<type>(attr_name); \
75 break; \
76 }
77
AddTypeConstraint(TF_KernelBuilder * kernel_builder,const char * attr_name,const DataType dtype,TF_Status * status)78 void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
79 const DataType dtype, TF_Status* status) {
80 // This needs to be under tensorflow:: namespace so that
81 // TF_CALL_ALL_TYPES macro can find tensorflow::string as string.
82 switch (dtype) {
83 TF_CALL_ALL_TYPES(CASE);
84 default:
85 status->status = errors::Unimplemented("Unexpected type ", dtype);
86 return;
87 }
88 TF_SetStatus(status, TF_OK, "");
89 }
90 #undef CASE
91
92 } // namespace
93 } // namespace tensorflow
94
95 namespace {
GetAttrValue(TF_OpKernelConstruction * ctx,const char * attr_name,TF_Status * status)96 const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx,
97 const char* attr_name,
98 TF_Status* status) {
99 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
100 const tensorflow::AttrValue* attr =
101 ::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name);
102 if (attr == nullptr) {
103 status->status = InvalidArgument("Operation '", cc_ctx->def().name(),
104 "' has no attr named '", attr_name, "'.");
105 }
106 return attr;
107 }
108 } // namespace
109
TF_KernelBuilder_TypeConstraint(TF_KernelBuilder * kernel_builder,const char * attr_name,const TF_DataType type,TF_Status * status)110 void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
111 const char* attr_name,
112 const TF_DataType type,
113 TF_Status* status) {
114 tensorflow::DataType dtype = static_cast<tensorflow::DataType>(type);
115 tensorflow::AddTypeConstraint(kernel_builder, attr_name, dtype, status);
116 }
117
TF_KernelBuilder_HostMemory(TF_KernelBuilder * kernel_builder,const char * arg_name)118 void TF_KernelBuilder_HostMemory(TF_KernelBuilder* kernel_builder,
119 const char* arg_name) {
120 kernel_builder->cc_builder->HostMemory(arg_name);
121 }
122
TF_KernelBuilder_Priority(TF_KernelBuilder * kernel_builder,int32_t priority_number)123 void TF_KernelBuilder_Priority(TF_KernelBuilder* kernel_builder,
124 int32_t priority_number) {
125 kernel_builder->cc_builder->Priority(priority_number);
126 }
127
128 namespace tensorflow {
129 namespace {
130
131 // An OpKernel whose methods delegate to C function pointers.
132 class COpKernel : public OpKernel {
133 public:
COpKernel(OpKernelConstruction * ctx,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))134 explicit COpKernel(OpKernelConstruction* ctx,
135 void* (*create_func)(TF_OpKernelConstruction*),
136 void (*compute_func)(void*, TF_OpKernelContext*),
137 void (*delete_func)(void*))
138 : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) {
139 if (create_func != nullptr) {
140 c_kernel_ =
141 (*create_func)(reinterpret_cast<TF_OpKernelConstruction*>(ctx));
142 } else {
143 c_kernel_ = nullptr;
144 }
145 }
146
Compute(OpKernelContext * ctx)147 void Compute(OpKernelContext* ctx) override {
148 (*compute_func_)(c_kernel_, reinterpret_cast<TF_OpKernelContext*>(ctx));
149 }
150
~COpKernel()151 ~COpKernel() override {
152 if (delete_func_ != nullptr) {
153 (*delete_func_)(c_kernel_);
154 }
155 }
156
157 private:
158 void (*compute_func_)(void*, TF_OpKernelContext* context);
159 void (*delete_func_)(void*);
160 void* c_kernel_;
161 };
162
163 // A KernelFactory that returns COpKernel instances.
164 class KernelBuilderFactory
165 : public ::tensorflow::kernel_factory::OpKernelFactory {
166 public:
KernelBuilderFactory(TF_KernelBuilder * builder)167 explicit KernelBuilderFactory(TF_KernelBuilder* builder)
168 : builder_(builder) {}
Create(::tensorflow::OpKernelConstruction * context)169 ::tensorflow::OpKernel* Create(
170 ::tensorflow::OpKernelConstruction* context) override {
171 return new ::tensorflow::COpKernel(context, builder_->create_function,
172 builder_->compute_function,
173 builder_->delete_function);
174 }
~KernelBuilderFactory()175 ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); }
176
177 private:
178 TF_KernelBuilder* builder_;
179 };
180 } // namespace
181 } // namespace tensorflow
182
TF_RegisterKernelBuilder(const char * name,TF_KernelBuilder * builder,TF_Status * status)183 void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
184 TF_Status* status) {
185 using tensorflow::register_kernel::Name;
186
187 tensorflow::kernel_factory::OpKernelRegistrar(
188 builder->cc_builder->Build(), name,
189 absl::make_unique<tensorflow::KernelBuilderFactory>(builder));
190
191 TF_SetStatus(status, TF_OK, "");
192 }
193
194 // This function is only for pluggable device.
195 // It will return nullptr in all other cases.
196 // This function is experimental and subject to change.
TF_GetStream(TF_OpKernelContext * ctx,TF_Status * status)197 SP_Stream TF_GetStream(TF_OpKernelContext* ctx, TF_Status* status) {
198 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
199 status->status = tensorflow::errors::Unimplemented(
200 "Accessing device stream is not supported on mobile. File a bug at "
201 "https://github.com/tensorflow/tensorflow/issues if this feature is "
202 "important to you");
203 return nullptr;
204 #else
205 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
206 if (cc_ctx->op_device_context() == nullptr) { // CPU Device
207 status->status = tensorflow::errors::FailedPrecondition(
208 "Accessing device stream is not supported for a CPU device.");
209 return nullptr;
210 } else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
211 status->status = tensorflow::errors::FailedPrecondition(
212 "Accessing device stream is only supported for pluggable devices.");
213 return nullptr;
214 } else { // Is a PluggableDevice
215 TF_SetStatus(status, TF_OK, "");
216 auto c_stream = static_cast<stream_executor::CStream*>(
217 cc_ctx->op_device_context()->stream()->implementation());
218 return c_stream->Handle();
219 }
220 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
221 }
222
TF_NumInputs(TF_OpKernelContext * ctx)223 int TF_NumInputs(TF_OpKernelContext* ctx) {
224 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
225 return cc_ctx->num_inputs();
226 }
227
TF_NumOutputs(TF_OpKernelContext * ctx)228 int TF_NumOutputs(TF_OpKernelContext* ctx) {
229 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
230 return cc_ctx->num_outputs();
231 }
232
TF_GetInput(TF_OpKernelContext * ctx,int i,TF_Tensor ** tensor,TF_Status * status)233 void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
234 TF_Status* status) {
235 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
236 if (i < 0 || i >= cc_ctx->num_inputs()) {
237 TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
238 return;
239 }
240 const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
241 TF_Tensor* result =
242 ::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
243 if (TF_GetCode(status) == TF_OK) {
244 *tensor = result;
245 }
246 }
247
TF_SetOutput(TF_OpKernelContext * ctx,int i,const TF_Tensor * tensor,TF_Status * status)248 void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor,
249 TF_Status* status) {
250 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
251 if (i < 0 || i >= cc_ctx->num_outputs()) {
252 TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
253 return;
254 }
255 ::tensorflow::Tensor cc_tensor;
256 ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor);
257 TF_SetStatus(status, TF_OK, "");
258 ::tensorflow::Set_TF_Status_from_Status(status, s);
259 if (s.ok()) {
260 cc_ctx->set_output(i, cc_tensor);
261 }
262 }
263
TF_OpKernelConstruction_Failure(TF_OpKernelConstruction * ctx,TF_Status * status)264 void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx,
265 TF_Status* status) {
266 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
267 ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
268 cc_ctx->CtxFailure(s);
269 }
270
TF_OpKernelContext_Failure(TF_OpKernelContext * ctx,TF_Status * status)271 void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
272 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
273 ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
274 cc_ctx->CtxFailure(s);
275 }
276
TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction * ctx,const char * attr_name,int32_t * list_size,int32_t * total_size,TF_Status * status)277 void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
278 const char* attr_name,
279 int32_t* list_size,
280 int32_t* total_size,
281 TF_Status* status) {
282 const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);
283 if (!status->status.ok()) {
284 *list_size = -1;
285 *total_size = -1;
286 return;
287 }
288 switch (attr->value_case()) {
289 #define SINGLE_CASE(kK, attr_type, size_expr) \
290 case tensorflow::AttrValue::kK: \
291 *list_size = -1; \
292 *total_size = size_expr; \
293 break;
294
295 SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
296 SINGLE_CASE(kI, TF_ATTR_INT, -1);
297 SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
298 SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
299 SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
300 SINGLE_CASE(kShape, TF_ATTR_SHAPE,
301 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
302 SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
303 #undef SINGLE_CASE
304
305 case tensorflow::AttrValue::kList:
306 *list_size = 0;
307 *total_size = -1;
308 #define LIST_CASE(field, attr_type, ...) \
309 if (attr->list().field##_size() > 0) { \
310 *list_size = attr->list().field##_size(); \
311 __VA_ARGS__; \
312 break; \
313 }
314
315 LIST_CASE(
316 s, TF_ATTR_STRING, *total_size = 0;
317 for (int i = 0; i < attr->list().s_size();
318 ++i) { *total_size += attr->list().s(i).size(); });
319 LIST_CASE(i, TF_ATTR_INT);
320 LIST_CASE(f, TF_ATTR_FLOAT);
321 LIST_CASE(b, TF_ATTR_BOOL);
322 LIST_CASE(type, TF_ATTR_TYPE);
323 LIST_CASE(
324 shape, TF_ATTR_SHAPE, *total_size = 0;
325 for (int i = 0; i < attr->list().shape_size(); ++i) {
326 const auto& s = attr->list().shape(i);
327 *total_size += s.unknown_rank() ? 0 : s.dim_size();
328 });
329 LIST_CASE(tensor, TF_ATTR_TENSOR);
330 LIST_CASE(tensor, TF_ATTR_FUNC);
331 #undef LIST_CASE
332 break;
333
334 case tensorflow::AttrValue::kPlaceholder:
335 *list_size = -1;
336 *total_size = -1;
337 break;
338
339 case tensorflow::AttrValue::kFunc:
340 *list_size = -1;
341 *total_size = -1;
342 break;
343
344 case tensorflow::AttrValue::VALUE_NOT_SET:
345 status->status =
346 InvalidArgument("Attribute '", attr_name, "' has no value set");
347 break;
348 }
349 }
350
351 #define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field) \
352 void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \
353 const char* attr_name, \
354 c_type* val, TF_Status* status) { \
355 TF_SetStatus(status, TF_OK, ""); \
356 cc_type v; \
357 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \
358 ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); \
359 ::tensorflow::Set_TF_Status_from_Status(status, s); \
360 if (s.ok()) { \
361 *val = static_cast<c_type>(v); \
362 } \
363 } \
364 void TF_OpKernelConstruction_GetAttr##func##List( \
365 TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals, \
366 int max_vals, TF_Status* status) { \
367 TF_SetStatus(status, TF_OK, ""); \
368 const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); \
369 if (!status->status.ok()) return; \
370 if (attr->value_case() != tensorflow::AttrValue::kList) { \
371 status->status = \
372 InvalidArgument("Value for '", attr_name, "' is not a list."); \
373 return; \
374 } \
375 status->status = \
376 tensorflow::AttrValueHasType(*attr, "list(" attr_type ")"); \
377 if (!status->status.ok()) return; \
378 const auto len = std::min(max_vals, attr->list().list_field##_size()); \
379 for (int i = 0; i < len; ++i) { \
380 vals[i] = static_cast<c_type>(attr->list().list_field(i)); \
381 } \
382 }
383
384 DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type)
385 DEFINE_TF_GETATTR(Int32, int32_t, tensorflow::int32, "int", i)
386 DEFINE_TF_GETATTR(Int64, int64_t, tensorflow::int64, "int", i)
387 DEFINE_TF_GETATTR(Float, float, float, "float", f)
388 DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b)
389
TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction * ctx,const char * attr_name,char * value,size_t max_length,TF_Status * status)390 void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx,
391 const char* attr_name, char* value,
392 size_t max_length,
393 TF_Status* status) {
394 std::string v;
395 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
396 ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
397 ::tensorflow::Set_TF_Status_from_Status(status, s);
398
399 if (!status->status.ok()) return;
400
401 if (max_length <= 0) {
402 return;
403 }
404 std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length));
405 }
406
TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction * ctx,const char * attr_name,char ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)407 void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx,
408 const char* attr_name,
409 char** values, size_t* lengths,
410 int max_values, void* storage,
411 size_t storage_size,
412 TF_Status* status) {
413 std::vector<std::string> v;
414 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
415 ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
416 ::tensorflow::Set_TF_Status_from_Status(status, s);
417
418 if (!status->status.ok()) return;
419
420 const auto len = std::min(max_values, static_cast<int>(v.size()));
421 char* p = static_cast<char*>(storage);
422 for (int i = 0; i < len; ++i) {
423 const std::string& s = v[i];
424 values[i] = p;
425 lengths[i] = s.size();
426 if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
427 status->status = InvalidArgument(
428 "Not enough storage to hold the requested list of strings");
429 return;
430 }
431 memcpy(values[i], s.data(), s.size());
432 p += s.size();
433 }
434 }
435
TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction * ctx,const char * attr_name,TF_Status * status)436 bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx,
437 const char* attr_name, TF_Status* status) {
438 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
439 return cc_ctx->HasAttr(attr_name);
440 }
441
TF_OpKernelConstruction_GetName(TF_OpKernelConstruction * ctx)442 TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
443 auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);
444 TF_StringView string_view_of_name;
445 string_view_of_name.data = cc_ctx->def().name().data();
446 string_view_of_name.len = cc_ctx->def().name().length();
447 return string_view_of_name;
448 }
449
TF_ExpectedOutputDataType(TF_OpKernelContext * ctx,int i)450 TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) {
451 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
452 return static_cast<TF_DataType>(cc_ctx->expected_output_dtype(i));
453 }
454
TF_StepId(TF_OpKernelContext * ctx)455 int64_t TF_StepId(TF_OpKernelContext* ctx) {
456 return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id();
457 }
458
TF_AllocateOutput(TF_OpKernelContext * context,int index,TF_DataType dtype,int64_t * dims,int num_dims,size_t len,TF_Status * status)459 TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
460 TF_DataType dtype, int64_t* dims, int num_dims,
461 size_t len, TF_Status* status) {
462 TF_SetStatus(status, TF_OK, "");
463 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
464 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
465 "64-bit int types should match in size");
466 tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
467 reinterpret_cast<tensorflow::int64*>(dims), num_dims);
468 tensorflow::Tensor* tensor;
469 tensorflow::Status s = cc_ctx->allocate_output(
470 index, tensorflow::TensorShape(dimarray), &tensor);
471 if (!s.ok()) {
472 ::tensorflow::Set_TF_Status_from_Status(status, s);
473 return nullptr;
474 }
475 TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s);
476 if (!s.ok()) {
477 ::tensorflow::Set_TF_Status_from_Status(status, s);
478 return nullptr;
479 }
480 return tf_tensor;
481 }
482
TF_ForwardInputOrAllocateOutput(TF_OpKernelContext * context,int * candidate_input_indices,int num_candidate_input_indices,int output_index,int64_t * output_dims,int output_num_dims,int * forwarded_input,TF_Status * status)483 TF_Tensor* TF_ForwardInputOrAllocateOutput(
484 TF_OpKernelContext* context, int* candidate_input_indices,
485 int num_candidate_input_indices, int output_index, int64_t* output_dims,
486 int output_num_dims, int* forwarded_input, TF_Status* status) {
487 TF_SetStatus(status, TF_OK, "");
488 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
489
490 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
491 "64-bit int types should match in size");
492 tensorflow::gtl::ArraySlice<int> input_indices_array(
493 candidate_input_indices, num_candidate_input_indices);
494 tensorflow::gtl::ArraySlice<tensorflow::int64> output_dimarray(
495 reinterpret_cast<tensorflow::int64*>(output_dims), output_num_dims);
496 tensorflow::Tensor* output_tensor_pointer;
497 tensorflow::Status s = cc_ctx->forward_input_or_allocate_output(
498 input_indices_array, output_index,
499 tensorflow::TensorShape(output_dimarray), &output_tensor_pointer,
500 forwarded_input);
501 if (!s.ok()) {
502 ::tensorflow::Set_TF_Status_from_Status(status, s);
503 return nullptr;
504 }
505 TF_Tensor* tf_tensor_output = TF_TensorFromTensor(*output_tensor_pointer, &s);
506 if (!s.ok()) {
507 ::tensorflow::Set_TF_Status_from_Status(status, s);
508 return nullptr;
509 }
510 return tf_tensor_output;
511 }
512
TF_AllocateTemp(TF_OpKernelContext * context,TF_DataType dtype,int64_t * dims,int num_dims,TF_AllocatorAttributes * attributes,TF_Status * status)513 TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype,
514 int64_t* dims, int num_dims,
515 TF_AllocatorAttributes* attributes,
516 TF_Status* status) {
517 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
518 TF_SetStatus(status, TF_OK, "");
519 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
520 "64-bit int types should match in size");
521 tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
522 reinterpret_cast<tensorflow::int64*>(dims), num_dims);
523 if (attributes && !attributes->struct_size) {
524 TF_SetStatus(
525 status, TF_INVALID_ARGUMENT,
526 "TF_AllocatorAttributes struct "
527 "size member must be set to TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE");
528 return nullptr;
529 }
530 tensorflow::AllocatorAttributes allocator_attr;
531 if (attributes && attributes->on_host) {
532 allocator_attr.set_on_host(true);
533 }
534 tensorflow::Status s;
535 tensorflow::Tensor tensor;
536 s = cc_ctx->allocate_temp(static_cast<tensorflow::DataType>(dtype),
537 tensorflow::TensorShape(dimarray), &tensor,
538 allocator_attr);
539 if (!s.ok()) {
540 ::tensorflow::Set_TF_Status_from_Status(status, s);
541 return nullptr;
542 }
543 TF_Tensor* tf_tensor;
544 tf_tensor = TF_TensorFromTensor(tensor, &s);
545 if (!s.ok()) {
546 ::tensorflow::Set_TF_Status_from_Status(status, s);
547 return nullptr;
548 }
549 return tf_tensor;
550 }
551