1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/kernels.h"
17
18 #include <memory>
19
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/c/tf_status_helper.h"
22 #include "tensorflow/c/tf_tensor_internal.h"
23 #include "tensorflow/core/framework/kernel_def_builder.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/types.h"
27 // Required for IS_MOBILE_PLATFORM definition
28 #include "tensorflow/core/platform/platform.h"
29 #include "tensorflow/core/platform/types.h"
30 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
31 #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
32 #include "tensorflow/stream_executor/stream.h"
33 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
34
35 using tensorflow::errors::InvalidArgument;
36 // This file forms the basis of a stable ABI for third-party kernel
37 // implementations. It is crucial that changes to this file are made cautiously
38 // and with a focus on maintaining both source and binary compatibility.
39
40 struct TF_KernelBuilder {
41 ::tensorflow::KernelDefBuilder* cc_builder;
42
43 void* (*create_function)(TF_OpKernelConstruction*);
44 void (*compute_function)(void*, TF_OpKernelContext*);
45 void (*delete_function)(void*);
46 };
47
TF_NewKernelBuilder(const char * op_name,const char * device_name,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))48 TF_KernelBuilder* TF_NewKernelBuilder(
49 const char* op_name, const char* device_name,
50 void* (*create_func)(TF_OpKernelConstruction*),
51 void (*compute_func)(void*, TF_OpKernelContext*),
52 void (*delete_func)(void*)) {
53 TF_KernelBuilder* result = new TF_KernelBuilder;
54 result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name);
55 result->cc_builder->Device(device_name);
56 result->create_function = create_func;
57 result->compute_function = compute_func;
58 result->delete_function = delete_func;
59 return result;
60 }
61
TF_DeleteKernelBuilder(TF_KernelBuilder * builder)62 void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) {
63 if (builder != nullptr) {
64 delete builder->cc_builder;
65 delete builder;
66 }
67 }
68
69 namespace tensorflow {
70 namespace {
71
72 #define CASE(type) \
73 case DataTypeToEnum<type>::value: { \
74 kernel_builder->cc_builder->TypeConstraint<type>(attr_name); \
75 break; \
76 }
77
AddTypeConstraint(TF_KernelBuilder * kernel_builder,const char * attr_name,const DataType dtype,TF_Status * status)78 void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
79 const DataType dtype, TF_Status* status) {
80 // This needs to be under tensorflow:: namespace so that
81 // TF_CALL_ALL_TYPES macro can find tensorflow::string as string.
82 switch (dtype) {
83 TF_CALL_ALL_TYPES(CASE);
84 TF_CALL_QUANTIZED_TYPES(CASE);
85 TF_CALL_quint16(CASE);
86 TF_CALL_qint16(CASE);
87 default:
88 status->status = errors::Unimplemented("Unexpected type ", dtype);
89 return;
90 }
91 TF_SetStatus(status, TF_OK, "");
92 }
93 #undef CASE
94
95 } // namespace
96 } // namespace tensorflow
97
98 namespace {
GetAttrValue(TF_OpKernelConstruction * ctx,const char * attr_name,TF_Status * status)99 const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx,
100 const char* attr_name,
101 TF_Status* status) {
102 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
103 const tensorflow::AttrValue* attr =
104 ::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name);
105 if (attr == nullptr) {
106 status->status = InvalidArgument("Operation '", cc_ctx->def().name(),
107 "' has no attr named '", attr_name, "'.");
108 }
109 return attr;
110 }
111 } // namespace
112
TF_KernelBuilder_TypeConstraint(TF_KernelBuilder * kernel_builder,const char * attr_name,const TF_DataType type,TF_Status * status)113 void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
114 const char* attr_name,
115 const TF_DataType type,
116 TF_Status* status) {
117 tensorflow::DataType dtype = static_cast<tensorflow::DataType>(type);
118 tensorflow::AddTypeConstraint(kernel_builder, attr_name, dtype, status);
119 }
120
TF_KernelBuilder_HostMemory(TF_KernelBuilder * kernel_builder,const char * arg_name)121 void TF_KernelBuilder_HostMemory(TF_KernelBuilder* kernel_builder,
122 const char* arg_name) {
123 kernel_builder->cc_builder->HostMemory(arg_name);
124 }
125
TF_KernelBuilder_Priority(TF_KernelBuilder * kernel_builder,int32_t priority_number)126 void TF_KernelBuilder_Priority(TF_KernelBuilder* kernel_builder,
127 int32_t priority_number) {
128 kernel_builder->cc_builder->Priority(priority_number);
129 }
130
131 namespace tensorflow {
132 namespace {
133
134 // An OpKernel whose methods delegate to C function pointers.
135 class COpKernel : public OpKernel {
136 public:
COpKernel(OpKernelConstruction * ctx,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))137 explicit COpKernel(OpKernelConstruction* ctx,
138 void* (*create_func)(TF_OpKernelConstruction*),
139 void (*compute_func)(void*, TF_OpKernelContext*),
140 void (*delete_func)(void*))
141 : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) {
142 if (create_func != nullptr) {
143 c_kernel_ =
144 (*create_func)(reinterpret_cast<TF_OpKernelConstruction*>(ctx));
145 } else {
146 c_kernel_ = nullptr;
147 }
148 }
149
Compute(OpKernelContext * ctx)150 void Compute(OpKernelContext* ctx) override {
151 (*compute_func_)(c_kernel_, reinterpret_cast<TF_OpKernelContext*>(ctx));
152 }
153
~COpKernel()154 ~COpKernel() override {
155 if (delete_func_ != nullptr) {
156 (*delete_func_)(c_kernel_);
157 }
158 }
159
160 private:
161 void (*compute_func_)(void*, TF_OpKernelContext* context);
162 void (*delete_func_)(void*);
163 void* c_kernel_;
164 };
165
166 // A KernelFactory that returns COpKernel instances.
167 class KernelBuilderFactory
168 : public ::tensorflow::kernel_factory::OpKernelFactory {
169 public:
KernelBuilderFactory(TF_KernelBuilder * builder)170 explicit KernelBuilderFactory(TF_KernelBuilder* builder)
171 : builder_(builder) {}
Create(::tensorflow::OpKernelConstruction * context)172 ::tensorflow::OpKernel* Create(
173 ::tensorflow::OpKernelConstruction* context) override {
174 return new ::tensorflow::COpKernel(context, builder_->create_function,
175 builder_->compute_function,
176 builder_->delete_function);
177 }
~KernelBuilderFactory()178 ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); }
179
180 private:
181 TF_KernelBuilder* builder_;
182 };
183 } // namespace
184 } // namespace tensorflow
185
TF_RegisterKernelBuilder(const char * name,TF_KernelBuilder * builder,TF_Status * status)186 void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
187 TF_Status* status) {
188 using tensorflow::register_kernel::Name;
189
190 tensorflow::kernel_factory::OpKernelRegistrar(
191 builder->cc_builder->Build(), name,
192 absl::make_unique<tensorflow::KernelBuilderFactory>(builder));
193
194 TF_SetStatus(status, TF_OK, "");
195 }
196
197 // This function is only for pluggable device.
198 // It will return nullptr in all other cases.
199 // This function is experimental and subject to change.
TF_GetStream(TF_OpKernelContext * ctx,TF_Status * status)200 SP_Stream TF_GetStream(TF_OpKernelContext* ctx, TF_Status* status) {
201 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
202 status->status = tensorflow::errors::Unimplemented(
203 "Accessing device stream is not supported on mobile. File a bug at "
204 "https://github.com/tensorflow/tensorflow/issues if this feature is "
205 "important to you");
206 return nullptr;
207 #else
208 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
209 if (cc_ctx->op_device_context() == nullptr) { // CPU Device
210 status->status = tensorflow::errors::FailedPrecondition(
211 "Accessing device stream is not supported for a CPU device.");
212 return nullptr;
213 } else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
214 status->status = tensorflow::errors::FailedPrecondition(
215 "Accessing device stream is only supported for pluggable devices.");
216 return nullptr;
217 } else { // Is a PluggableDevice
218 TF_SetStatus(status, TF_OK, "");
219 auto c_stream = static_cast<stream_executor::CStream*>(
220 cc_ctx->op_device_context()->stream()->implementation());
221 return c_stream->Handle();
222 }
223 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
224 }
225
TF_NumInputs(TF_OpKernelContext * ctx)226 int TF_NumInputs(TF_OpKernelContext* ctx) {
227 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
228 return cc_ctx->num_inputs();
229 }
230
TF_NumOutputs(TF_OpKernelContext * ctx)231 int TF_NumOutputs(TF_OpKernelContext* ctx) {
232 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
233 return cc_ctx->num_outputs();
234 }
235
TF_GetInput(TF_OpKernelContext * ctx,int i,TF_Tensor ** tensor,TF_Status * status)236 void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
237 TF_Status* status) {
238 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
239 if (i < 0 || i >= cc_ctx->num_inputs()) {
240 TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
241 return;
242 }
243 const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
244 TF_Tensor* result =
245 ::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
246 if (TF_GetCode(status) == TF_OK) {
247 *tensor = result;
248 }
249 }
250
TF_SetOutput(TF_OpKernelContext * ctx,int i,const TF_Tensor * tensor,TF_Status * status)251 void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor,
252 TF_Status* status) {
253 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
254 if (i < 0 || i >= cc_ctx->num_outputs()) {
255 TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
256 return;
257 }
258 ::tensorflow::Tensor cc_tensor;
259 ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor);
260 TF_SetStatus(status, TF_OK, "");
261 ::tensorflow::Set_TF_Status_from_Status(status, s);
262 if (s.ok()) {
263 cc_ctx->set_output(i, cc_tensor);
264 }
265 }
266
TF_OpKernelConstruction_Failure(TF_OpKernelConstruction * ctx,TF_Status * status)267 void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx,
268 TF_Status* status) {
269 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
270 ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
271 cc_ctx->CtxFailure(s);
272 }
273
TF_OpKernelContext_Failure(TF_OpKernelContext * ctx,TF_Status * status)274 void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
275 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
276 ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
277 cc_ctx->CtxFailure(s);
278 }
279
TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction * ctx,const char * attr_name,int32_t * list_size,int32_t * total_size,TF_Status * status)280 void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
281 const char* attr_name,
282 int32_t* list_size,
283 int32_t* total_size,
284 TF_Status* status) {
285 const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);
286 if (!status->status.ok()) {
287 *list_size = -1;
288 *total_size = -1;
289 return;
290 }
291 switch (attr->value_case()) {
292 #define SINGLE_CASE(kK, attr_type, size_expr) \
293 case tensorflow::AttrValue::kK: \
294 *list_size = -1; \
295 *total_size = size_expr; \
296 break;
297
298 SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
299 SINGLE_CASE(kI, TF_ATTR_INT, -1);
300 SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
301 SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
302 SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
303 SINGLE_CASE(kShape, TF_ATTR_SHAPE,
304 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
305 SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
306 #undef SINGLE_CASE
307
308 case tensorflow::AttrValue::kList:
309 *list_size = 0;
310 *total_size = -1;
311 #define LIST_CASE(field, attr_type, ...) \
312 if (attr->list().field##_size() > 0) { \
313 *list_size = attr->list().field##_size(); \
314 __VA_ARGS__; \
315 break; \
316 }
317
318 LIST_CASE(
319 s, TF_ATTR_STRING, *total_size = 0;
320 for (int i = 0; i < attr->list().s_size();
321 ++i) { *total_size += attr->list().s(i).size(); });
322 LIST_CASE(i, TF_ATTR_INT);
323 LIST_CASE(f, TF_ATTR_FLOAT);
324 LIST_CASE(b, TF_ATTR_BOOL);
325 LIST_CASE(type, TF_ATTR_TYPE);
326 LIST_CASE(
327 shape, TF_ATTR_SHAPE, *total_size = 0;
328 for (int i = 0; i < attr->list().shape_size(); ++i) {
329 const auto& s = attr->list().shape(i);
330 *total_size += s.unknown_rank() ? 0 : s.dim_size();
331 });
332 LIST_CASE(tensor, TF_ATTR_TENSOR);
333 LIST_CASE(tensor, TF_ATTR_FUNC);
334 #undef LIST_CASE
335 break;
336
337 case tensorflow::AttrValue::kPlaceholder:
338 *list_size = -1;
339 *total_size = -1;
340 break;
341
342 case tensorflow::AttrValue::kFunc:
343 *list_size = -1;
344 *total_size = -1;
345 break;
346
347 case tensorflow::AttrValue::VALUE_NOT_SET:
348 status->status =
349 InvalidArgument("Attribute '", attr_name, "' has no value set");
350 break;
351 }
352 }
353
354 #define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field) \
355 void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \
356 const char* attr_name, \
357 c_type* val, TF_Status* status) { \
358 TF_SetStatus(status, TF_OK, ""); \
359 cc_type v; \
360 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \
361 ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); \
362 ::tensorflow::Set_TF_Status_from_Status(status, s); \
363 if (s.ok()) { \
364 *val = static_cast<c_type>(v); \
365 } \
366 } \
367 void TF_OpKernelConstruction_GetAttr##func##List( \
368 TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals, \
369 int max_vals, TF_Status* status) { \
370 TF_SetStatus(status, TF_OK, ""); \
371 const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); \
372 if (!status->status.ok()) return; \
373 if (attr->value_case() != tensorflow::AttrValue::kList) { \
374 status->status = \
375 InvalidArgument("Value for '", attr_name, "' is not a list."); \
376 return; \
377 } \
378 status->status = \
379 tensorflow::AttrValueHasType(*attr, "list(" attr_type ")"); \
380 if (!status->status.ok()) return; \
381 const auto len = std::min(max_vals, attr->list().list_field##_size()); \
382 for (int i = 0; i < len; ++i) { \
383 vals[i] = static_cast<c_type>(attr->list().list_field(i)); \
384 } \
385 }
386
387 DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type)
388 DEFINE_TF_GETATTR(Int32, int32_t, int32_t, "int", i)
389 DEFINE_TF_GETATTR(Int64, int64_t, int64_t, "int", i)
390 DEFINE_TF_GETATTR(Float, float, float, "float", f)
391 DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b)
392
TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction * ctx,const char * attr_name,char * value,size_t max_length,TF_Status * status)393 void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx,
394 const char* attr_name, char* value,
395 size_t max_length,
396 TF_Status* status) {
397 std::string v;
398 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
399 ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
400 ::tensorflow::Set_TF_Status_from_Status(status, s);
401
402 if (!status->status.ok()) return;
403
404 if (max_length <= 0) {
405 return;
406 }
407 std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length));
408 }
409
TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction * ctx,const char * attr_name,char ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)410 void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx,
411 const char* attr_name,
412 char** values, size_t* lengths,
413 int max_values, void* storage,
414 size_t storage_size,
415 TF_Status* status) {
416 std::vector<std::string> v;
417 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
418 ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
419 ::tensorflow::Set_TF_Status_from_Status(status, s);
420
421 if (!status->status.ok()) return;
422
423 const auto len = std::min(max_values, static_cast<int>(v.size()));
424 char* p = static_cast<char*>(storage);
425 for (int i = 0; i < len; ++i) {
426 const std::string& s = v[i];
427 values[i] = p;
428 lengths[i] = s.size();
429 if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
430 status->status = InvalidArgument(
431 "Not enough storage to hold the requested list of strings");
432 return;
433 }
434 memcpy(values[i], s.data(), s.size());
435 p += s.size();
436 }
437 }
438
TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction * ctx,const char * attr_name,TF_Status * status)439 bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx,
440 const char* attr_name, TF_Status* status) {
441 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
442 return cc_ctx->HasAttr(attr_name);
443 }
444
TF_OpKernelConstruction_GetName(TF_OpKernelConstruction * ctx)445 TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
446 auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);
447 TF_StringView string_view_of_name;
448 string_view_of_name.data = cc_ctx->def().name().data();
449 string_view_of_name.len = cc_ctx->def().name().length();
450 return string_view_of_name;
451 }
452
TF_ExpectedOutputDataType(TF_OpKernelContext * ctx,int i)453 TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) {
454 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
455 return static_cast<TF_DataType>(cc_ctx->expected_output_dtype(i));
456 }
457
TF_StepId(TF_OpKernelContext * ctx)458 int64_t TF_StepId(TF_OpKernelContext* ctx) {
459 return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id();
460 }
461
TF_AllocateOutput(TF_OpKernelContext * context,int index,TF_DataType dtype,int64_t * dims,int num_dims,size_t len,TF_Status * status)462 TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
463 TF_DataType dtype, int64_t* dims, int num_dims,
464 size_t len, TF_Status* status) {
465 TF_SetStatus(status, TF_OK, "");
466 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
467 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
468 "64-bit int types should match in size");
469 tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
470 reinterpret_cast<tensorflow::int64*>(dims), num_dims);
471 tensorflow::Tensor* tensor;
472 tensorflow::Status s = cc_ctx->allocate_output(
473 index, tensorflow::TensorShape(dimarray), &tensor);
474 if (!s.ok()) {
475 ::tensorflow::Set_TF_Status_from_Status(status, s);
476 return nullptr;
477 }
478 TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s);
479 if (!s.ok()) {
480 ::tensorflow::Set_TF_Status_from_Status(status, s);
481 return nullptr;
482 }
483 return tf_tensor;
484 }
485
TF_ForwardInputOrAllocateOutput(TF_OpKernelContext * context,int * candidate_input_indices,int num_candidate_input_indices,int output_index,int64_t * output_dims,int output_num_dims,int * forwarded_input,TF_Status * status)486 TF_Tensor* TF_ForwardInputOrAllocateOutput(
487 TF_OpKernelContext* context, int* candidate_input_indices,
488 int num_candidate_input_indices, int output_index, int64_t* output_dims,
489 int output_num_dims, int* forwarded_input, TF_Status* status) {
490 TF_SetStatus(status, TF_OK, "");
491 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
492
493 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
494 "64-bit int types should match in size");
495 tensorflow::gtl::ArraySlice<int> input_indices_array(
496 candidate_input_indices, num_candidate_input_indices);
497 tensorflow::gtl::ArraySlice<tensorflow::int64> output_dimarray(
498 reinterpret_cast<tensorflow::int64*>(output_dims), output_num_dims);
499 tensorflow::Tensor* output_tensor_pointer;
500 tensorflow::Status s = cc_ctx->forward_input_or_allocate_output(
501 input_indices_array, output_index,
502 tensorflow::TensorShape(output_dimarray), &output_tensor_pointer,
503 forwarded_input);
504 if (!s.ok()) {
505 ::tensorflow::Set_TF_Status_from_Status(status, s);
506 return nullptr;
507 }
508 TF_Tensor* tf_tensor_output = TF_TensorFromTensor(*output_tensor_pointer, &s);
509 if (!s.ok()) {
510 ::tensorflow::Set_TF_Status_from_Status(status, s);
511 return nullptr;
512 }
513 return tf_tensor_output;
514 }
515
TF_AllocateTemp(TF_OpKernelContext * context,TF_DataType dtype,int64_t * dims,int num_dims,TF_AllocatorAttributes * attributes,TF_Status * status)516 TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype,
517 int64_t* dims, int num_dims,
518 TF_AllocatorAttributes* attributes,
519 TF_Status* status) {
520 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
521 TF_SetStatus(status, TF_OK, "");
522 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
523 "64-bit int types should match in size");
524 tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
525 reinterpret_cast<tensorflow::int64*>(dims), num_dims);
526 if (attributes && !attributes->struct_size) {
527 TF_SetStatus(
528 status, TF_INVALID_ARGUMENT,
529 "TF_AllocatorAttributes struct "
530 "size member must be set to TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE");
531 return nullptr;
532 }
533 tensorflow::AllocatorAttributes allocator_attr;
534 if (attributes && attributes->on_host) {
535 allocator_attr.set_on_host(true);
536 }
537 tensorflow::Status s;
538 tensorflow::Tensor tensor;
539 s = cc_ctx->allocate_temp(static_cast<tensorflow::DataType>(dtype),
540 tensorflow::TensorShape(dimarray), &tensor,
541 allocator_attr);
542 if (!s.ok()) {
543 ::tensorflow::Set_TF_Status_from_Status(status, s);
544 return nullptr;
545 }
546 TF_Tensor* tf_tensor;
547 tf_tensor = TF_TensorFromTensor(tensor, &s);
548 if (!s.ok()) {
549 ::tensorflow::Set_TF_Status_from_Status(status, s);
550 return nullptr;
551 }
552 return tf_tensor;
553 }
554