1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api_experimental.h"
17
18 #include "absl/strings/substitute.h"
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/c/checkpoint_reader.h"
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/c_api_internal.h"
24 #include "tensorflow/c/eager/tfe_context_internal.h"
25 #include "tensorflow/c/eager/tfe_op_internal.h"
26 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
27 #include "tensorflow/compiler/jit/flags.h"
28 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
29 #include "tensorflow/core/common_runtime/eager/context.h"
30 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
31 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.h"
32 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
33 #include "tensorflow/core/framework/collective.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/shape_inference.h"
36 #include "tensorflow/core/framework/tensor.pb.h"
37 #include "tensorflow/core/graph/graph.h"
38 #include "tensorflow/core/graph/node_builder.h"
39 #include "tensorflow/core/platform/blocking_counter.h"
40 #include "tensorflow/core/platform/casts.h"
41 #include "tensorflow/core/platform/env.h"
42 #include "tensorflow/core/platform/init_main.h"
43 #include "tensorflow/core/platform/mutex.h"
44 #include "tensorflow/core/platform/net.h"
45 #include "tensorflow/core/platform/platform.h"
46 #include "tensorflow/core/platform/strcat.h"
47 #include "tensorflow/core/protobuf/config.pb.h"
48 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
49
50 using tensorflow::FunctionDef;
51 using tensorflow::Node;
52 using tensorflow::NodeBuilder;
53 using tensorflow::Status;
54 using tensorflow::errors::InvalidArgument;
55
56 namespace {
57 typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)>
58 UniqueFuncPtr;
59 }
60
61 // struct TF_Operation { tensorflow::Node node; };
ToTF_Operation(Node * node)62 static TF_Operation* ToTF_Operation(Node* node) {
63 return static_cast<TF_Operation*>(static_cast<void*>(node));
64 }
65
TF_EnableXLACompilation(TF_SessionOptions * options,unsigned char enable)66 void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
67 tensorflow::ConfigProto& config = options->options.config;
68 auto* optimizer_options =
69 config.mutable_graph_options()->mutable_optimizer_options();
70 if (enable) {
71 optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
72
73 // These XLA flags are needed to trigger XLA properly from C (more generally
74 // non-Python) clients. If this API is called again with `enable` set to
75 // false, it is safe to keep these flag values as is.
76 tensorflow::MarkForCompilationPassFlags* flags =
77 tensorflow::GetMarkForCompilationPassFlags();
78 flags->tf_xla_cpu_global_jit = true;
79 flags->tf_xla_min_cluster_size = 1;
80 } else {
81 optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
82 }
83 }
84
TF_SetXlaEnableLazyCompilation(unsigned char enable)85 unsigned char TF_SetXlaEnableLazyCompilation(unsigned char enable) {
86 tensorflow::BuildXlaOpsPassFlags* flags =
87 tensorflow::GetBuildXlaOpsPassFlags();
88 bool original = flags->tf_xla_enable_lazy_compilation;
89 flags->tf_xla_enable_lazy_compilation = enable;
90 return original;
91 }
92
TF_SetTfXlaCpuGlobalJit(unsigned char enable)93 unsigned char TF_SetTfXlaCpuGlobalJit(unsigned char enable) {
94 tensorflow::MarkForCompilationPassFlags* flags =
95 tensorflow::GetMarkForCompilationPassFlags();
96 bool original = flags->tf_xla_cpu_global_jit;
97 flags->tf_xla_cpu_global_jit = static_cast<bool>(enable);
98 return static_cast<unsigned char>(original);
99 }
100
TF_SetXlaAutoJitMode(const char * mode)101 void TF_SetXlaAutoJitMode(const char* mode) {
102 tensorflow::SetXlaAutoJitFlagFromFlagString(mode);
103 }
104
TF_GetXlaConstantFoldingDisabled()105 unsigned char TF_GetXlaConstantFoldingDisabled() {
106 return static_cast<unsigned char>(
107 tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding);
108 }
109
TF_SetXlaConstantFoldingDisabled(unsigned char should_enable)110 void TF_SetXlaConstantFoldingDisabled(unsigned char should_enable) {
111 tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding =
112 static_cast<bool>(should_enable);
113 }
114
TF_SetXlaMinClusterSize(int size)115 void TF_SetXlaMinClusterSize(int size) {
116 tensorflow::MarkForCompilationPassFlags* flags =
117 tensorflow::GetMarkForCompilationPassFlags();
118 flags->tf_xla_min_cluster_size = size;
119 }
120
TF_CreateConfig(unsigned char enable_xla_compilation,unsigned char gpu_memory_allow_growth,unsigned int num_cpu_devices)121 TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
122 unsigned char gpu_memory_allow_growth,
123 unsigned int num_cpu_devices) {
124 tensorflow::ConfigProto config;
125 auto* optimizer_options =
126 config.mutable_graph_options()->mutable_optimizer_options();
127 if (enable_xla_compilation) {
128 optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
129
130 // These XLA flags are needed to trigger XLA properly from C (more generally
131 // non-Python) clients. If this API is called again with `enable` set to
132 // false, it is safe to keep these flag values as is.
133 tensorflow::MarkForCompilationPassFlags* flags =
134 tensorflow::GetMarkForCompilationPassFlags();
135 flags->tf_xla_cpu_global_jit = true;
136 flags->tf_xla_min_cluster_size = 1;
137 } else {
138 optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
139 }
140
141 auto* gpu_options = config.mutable_gpu_options();
142 gpu_options->set_allow_growth(gpu_memory_allow_growth);
143
144 (*config.mutable_device_count())["CPU"] = num_cpu_devices;
145
146 // TODO(b/113217601): This is needed for EagerContext::runner_ to use a
147 // threadpool, so that we avoid the possibility of running the runner_ in the
148 // threadpool of GPU event mgr, as that can trigger more callbacks to be
149 // scheduled on that same threadpool, causing a deadlock in cases where the
150 // caller of event_mgr->ThenExecute() blocks on the completion of the callback
151 // (as in the case of ConstOp kernel creation on GPU, which involves copying a
152 // CPU tensor to GPU).
153 // Setting a larger thread pool does not help with the Swift caller, as we use
154 // a different TFE context for each thread of execution (for running graph
155 // functions, and their send/recvs corountines).
156 config.set_inter_op_parallelism_threads(1);
157
158 TF_Buffer* ret = TF_NewBuffer();
159 TF_CHECK_OK(MessageToBuffer(config, ret));
160 return ret;
161 }
162
TF_CreateRunOptions(unsigned char enable_full_trace)163 TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) {
164 tensorflow::RunOptions options;
165 if (enable_full_trace) {
166 options.set_trace_level(tensorflow::RunOptions::FULL_TRACE);
167 } else {
168 options.set_trace_level(tensorflow::RunOptions::NO_TRACE);
169 }
170 TF_Buffer* ret = TF_NewBuffer();
171 TF_CHECK_OK(MessageToBuffer(options, ret));
172 return ret;
173 }
174
TF_GraphDebugString(TF_Graph * graph,size_t * len)175 const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
176 tensorflow::mutex_lock c(graph->mu);
177 const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString();
178 *len = debug_str.size();
179 char* ret = static_cast<char*>(malloc(*len + 1));
180 memcpy(ret, debug_str.c_str(), *len + 1);
181 return ret;
182 }
183
TF_FunctionDebugString(TF_Function * func,size_t * len)184 char* TF_FunctionDebugString(TF_Function* func, size_t* len) {
185 const auto& debug_str = DebugString(func->fdef);
186 *len = debug_str.size();
187 char* ret = static_cast<char*>(malloc(*len + 1));
188 memcpy(ret, debug_str.c_str(), *len + 1);
189 return ret;
190 }
191
192 // On success, returns a set of TF_Function instances from `text_proto` of
193 // GraphDef type. These functions must be deleted by calling TF_DeleteFunction.
194 //
195 // If `mutate_proto_func` is non-NULL, run it over each FunctionDef proto,
196 // before creating a TF_Function out of the possibly mutated proto.
CreateFunctionsFromTextProto(const char * text_proto,std::function<void (FunctionDef *)> * mutate_proto_func,TF_Status * status)197 static std::vector<UniqueFuncPtr> CreateFunctionsFromTextProto(
198 const char* text_proto,
199 std::function<void(FunctionDef*)>* mutate_proto_func, TF_Status* status) {
200 tensorflow::GraphDef gdef;
201 if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &gdef)) {
202 status->status = tensorflow::errors::Internal(
203 "Invalid text proto for GraphDef: ", text_proto);
204 return {};
205 }
206 const auto& fdef_lib = gdef.library();
207 if (fdef_lib.gradient_size() > 0) {
208 status->status = tensorflow::errors::Internal(
209 "GradientDef is not supported in reading Dataset related functions: ",
210 text_proto);
211 return {};
212 }
213 std::vector<UniqueFuncPtr> ret;
214 for (const FunctionDef& fdef : fdef_lib.function()) {
215 // Make a copy so that we can mutate it.
216 FunctionDef fdef_to_load = fdef;
217 if (mutate_proto_func) {
218 (*mutate_proto_func)(&fdef_to_load);
219 }
220 VLOG(1) << "Adding func to graph: " << fdef_to_load.DebugString();
221 std::vector<char> binary_proto_buf(fdef_to_load.ByteSizeLong());
222 fdef_to_load.SerializeToArray(binary_proto_buf.data(),
223 binary_proto_buf.size());
224 TF_Function* func = TF_FunctionImportFunctionDef(
225 binary_proto_buf.data(), binary_proto_buf.size(), status);
226 if (!status->status.ok()) return {};
227 ret.push_back(UniqueFuncPtr(func, TF_DeleteFunction));
228 }
229 return ret;
230 }
231
TF_DequeueNamedTensor(TF_Session * session,int tensor_id,TF_Status * status)232 TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
233 TF_Status* status) {
234 assert(session);
235 {
236 tensorflow::mutex_lock c(session->graph->mu);
237 VLOG(1) << "Dequeuing named tensor with id " << tensor_id
238 << ", with input graph: "
239 << session->graph->graph.ToGraphDefDebug().DebugString();
240 }
241
242 TF_Operation* dequeue_op = TF_GraphOperationByName(
243 session->graph,
244 tensorflow::strings::StrCat("fifo_queue_dequeue_", tensor_id).c_str());
245 if (dequeue_op == nullptr) {
246 status->status = tensorflow::errors::Internal(
247 "Unable to find the dequeue node in the TF graph.");
248 return nullptr;
249 }
250
251 VLOG(1) << "Running the dequeue op";
252 TF_Output output{dequeue_op, 0};
253 TF_Tensor* ret;
254 TF_SessionRun(session, /*run_options*/ nullptr,
255 // input related parameters
256 /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0,
257 // output related parameters
258 /*outputs*/ &output, /*output_values*/ &ret,
259 /*noutputs*/ 1,
260 /*targets*/ nullptr, /*ntargets*/ 0,
261 /*run_metadata*/ nullptr, status);
262 if (VLOG_IS_ON(1) && status->status.ok()) {
263 tensorflow::Tensor tensor;
264 if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) {
265 VLOG(1) << "Dequeued tensor content: " << tensor.DebugString();
266 }
267 }
268 return ret;
269 }
270
TF_EnqueueNamedTensor(TF_Session * session,int tensor_id,TF_Tensor * tensor,TF_Status * status)271 void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
272 TF_Tensor* tensor, TF_Status* status) {
273 assert(session);
274 {
275 tensorflow::mutex_lock c(session->graph->mu);
276 if (VLOG_IS_ON(1)) {
277 VLOG(1) << "Enqueuing named tensor with id " << tensor_id
278 << ", with input graph: "
279 << session->graph->graph.ToGraphDefDebug().DebugString();
280 tensorflow::Tensor internal_tensor;
281 if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) {
282 VLOG(1) << "Enqueu'ing tensor content: "
283 << internal_tensor.DebugString();
284 }
285 }
286 }
287
288 TF_Operation* enqueue_op = TF_GraphOperationByName(
289 session->graph,
290 tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str());
291 if (enqueue_op == nullptr) {
292 status->status = tensorflow::errors::Internal(
293 "Unable to find the enqueue node in the TF graph.");
294 return;
295 }
296
297 TF_Operation* placeholder_op = TF_GraphOperationByName(
298 session->graph,
299 tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str());
300 if (placeholder_op == nullptr) {
301 status->status = tensorflow::errors::Internal(
302 "Unable to find the placeholder node as input to enqueue in the TF "
303 "graph.");
304 return;
305 }
306
307 VLOG(1) << "Running the enqueue op";
308 TF_Output input{placeholder_op, 0};
309 TF_SessionRun(session, /*run_options*/ nullptr,
310 // input related parameters
311 /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1,
312 // output related parameters
313 /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0,
314 /*targets*/ &enqueue_op, /*ntargets*/ 1,
315 /*run_metadata*/ nullptr, status);
316 VLOG(1) << "Enqueuing is done.";
317 }
318
TFE_GetServerDef(const char * text_proto,TF_Status * status)319 TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
320 tensorflow::ServerDef server_def;
321 if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto,
322 &server_def)) {
323 status->status = tensorflow::errors::Internal(
324 "Invalid text proto for ServerDef: ", text_proto);
325 return nullptr;
326 }
327 status->status = tensorflow::Status();
328 TF_Buffer* ret = TF_NewBuffer();
329 TF_CHECK_OK(MessageToBuffer(server_def, ret));
330 return ret;
331 }
332
TF_MakeInternalErrorStatus(TF_Status * status,const char * errMsg)333 void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
334 status->status = tensorflow::errors::Internal(errMsg);
335 }
336
337 struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader {
338 using tensorflow::checkpoint::CheckpointReader::CheckpointReader;
339 std::vector<std::string> variable_list;
340 };
341
TF_NewCheckpointReader(const char * filename,TF_Status * status)342 TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
343 TF_Status* status) {
344 TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
345 if (!status->status.ok()) {
346 TF_DeleteCheckpointReader(reader);
347 return nullptr;
348 }
349 const auto& m = reader->GetVariableToDataTypeMap();
350 for (auto it = m.begin(); it != m.end(); ++it)
351 reader->variable_list.push_back(it->first);
352 std::sort(reader->variable_list.begin(), reader->variable_list.end());
353 return reader;
354 }
355
TF_DeleteCheckpointReader(TF_CheckpointReader * reader)356 void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; }
357
TF_CheckpointReaderHasTensor(TF_CheckpointReader * reader,const char * name)358 int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader,
359 const char* name) {
360 return reader->HasTensor(name);
361 }
362
TF_CheckpointReaderGetVariable(TF_CheckpointReader * reader,int index)363 const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader,
364 int index) {
365 return reader->variable_list[index].c_str();
366 }
367
TF_CheckpointReaderSize(TF_CheckpointReader * reader)368 int TF_CheckpointReaderSize(TF_CheckpointReader* reader) {
369 return reader->variable_list.size();
370 }
371
TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader * reader,const char * name)372 TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader,
373 const char* name) {
374 const auto& m = reader->GetVariableToDataTypeMap();
375 return static_cast<TF_DataType>(m.at(name));
376 }
377
TF_CheckpointReaderGetTensor(TF_CheckpointReader * reader,const char * name,TF_Status * status)378 TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
379 const char* name, TF_Status* status) {
380 std::unique_ptr<tensorflow::Tensor> tensor;
381 reader->GetTensor(name, &tensor, status);
382 if (!status->status.ok()) return nullptr;
383 return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
384 }
385
TF_CheckpointReaderGetVariableShape(TF_CheckpointReader * reader,const char * name,int64_t * dims,int num_dims,TF_Status * status)386 void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
387 const char* name, int64_t* dims,
388 int num_dims, TF_Status* status) {
389 const auto& shape = reader->GetVariableToShapeMap().at(name);
390 int rank = shape.dims();
391 if (num_dims != rank) {
392 status->status = InvalidArgument("Expected rank is ", num_dims,
393 " but actual rank is ", rank);
394 return;
395 }
396 for (int i = 0; i < num_dims; i++) {
397 dims[i] = shape.dim_size(i);
398 }
399 }
400
TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader * reader,const char * name)401 int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader,
402 const char* name) {
403 const auto& m = reader->GetVariableToShapeMap();
404 return m.at(name).dims();
405 }
406
407 // This builder is used in the eager API to build a NodeDef.
408 struct TF_AttrBuilder : public tensorflow::AttrBuilder {
409 using tensorflow::AttrBuilder::AttrBuilder;
410 // The string buffers to make sure that any `attr_name` we pass into
411 // `builder->Set()` will outlive the subsequent
412 // `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`.
413 std::set<std::string> attr_names;
414 };
415
TF_NewAttrBuilder(const char * op_name)416 TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) {
417 return new TF_AttrBuilder(op_name);
418 }
419
TF_DeleteAttrBuilder(TF_AttrBuilder * builder)420 void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; }
421
TF_AttrBuilderSetType(TF_AttrBuilder * builder,const char * attr_name,TF_DataType value)422 void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
423 TF_DataType value) {
424 auto iter = builder->attr_names.insert(attr_name).first;
425 builder->Set(*iter, static_cast<tensorflow::DataType>(value));
426 }
427
TF_AttrBuilderSetTypeList(TF_AttrBuilder * builder,const char * attr_name,const TF_DataType * values,int num_values)428 void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
429 const TF_DataType* values, int num_values) {
430 auto iter = builder->attr_names.insert(attr_name).first;
431 builder->Set(*iter, tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
432 reinterpret_cast<const tensorflow::DataType*>(values),
433 num_values));
434 }
435
TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder * builder,const char * device_type,TF_Status * status)436 void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
437 const char* device_type,
438 TF_Status* status) {
439 status->status = tensorflow::FindKernelDef(
440 tensorflow::DeviceType(device_type), builder->BuildNodeDef(),
441 /* def = */ nullptr, /* kernel_class_name = */ nullptr);
442 }
443
TF_GetNumberAttrForOpListInput(const char * op_name,int input_index,TF_Status * status)444 const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index,
445 TF_Status* status) {
446 const tensorflow::OpDef* op_def = nullptr;
447 status->status =
448 tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def);
449 if (!status->status.ok()) return nullptr;
450
451 if (input_index >= op_def->input_arg_size() || input_index < 0) {
452 status->status = tensorflow::errors::InvalidArgument(
453 input_index, " out of range for ", op_name);
454 return nullptr;
455 }
456
457 const tensorflow::OpDef_ArgDef& input_arg = op_def->input_arg()[input_index];
458
459 if (input_arg.number_attr().empty()) {
460 status->status = tensorflow::errors::NotFound(
461 op_name, " does not have number_attr() defined.");
462 return nullptr;
463 }
464
465 // The returned string is owned by OpRegistry, so liveness is not a concern.
466 return input_arg.number_attr().c_str();
467 }
468
TF_OpIsStateful(const char * op_type,TF_Status * status)469 int TF_OpIsStateful(const char* op_type, TF_Status* status) {
470 const tensorflow::OpRegistrationData* op_reg_data;
471 status->status =
472 tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data);
473 if (!status->status.ok()) {
474 return 0;
475 }
476 return op_reg_data->op_def.is_stateful();
477 }
478
TF_InitMain(const char * usage,int * argc,char *** argv)479 void TF_InitMain(const char* usage, int* argc, char*** argv) {
480 tensorflow::port::InitMain(usage, argc, argv);
481 }
482
TF_PickUnusedPortOrDie()483 int TF_PickUnusedPortOrDie() {
484 return tensorflow::internal::PickUnusedPortOrDie();
485 }
486
TFE_NewTensorHandleFromScalar(TF_DataType data_type,void * data,size_t len,TF_Status * status)487 TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
488 void* data, size_t len,
489 TF_Status* status) {
490 auto dtype = static_cast<tensorflow::DataType>(data_type);
491 DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype));
492
493 tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
494 std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
495
496 status->status = tensorflow::Status::OK();
497 return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
498 }
499
500 // Set server_def on the context, possibly updating it.
TFE_EnableCollectiveOps(TFE_Context * ctx,const void * proto,size_t proto_len,TF_Status * status)501 TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
502 const void* proto,
503 size_t proto_len,
504 TF_Status* status) {
505 tensorflow::ServerDef server_def;
506 if (!server_def.ParseFromArray(proto, proto_len)) {
507 status->status = tensorflow::errors::InvalidArgument(
508 "Invalid tensorflow.ServerDef protocol buffer");
509 return;
510 }
511 status->status = tensorflow::unwrap(ctx)->EnableCollectiveOps(server_def);
512 }
513
TFE_AbortCollectiveOps(TFE_Context * ctx,TF_Status * status)514 TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
515 TF_Status* status) {
516 tensorflow::EagerContext* context =
517 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
518 auto collective_executor_handle = context->GetCollectiveExecutorHandle();
519 collective_executor_handle->get()->StartAbort(status->status);
520 }
521
TFE_CollectiveOpsCheckPeerHealth(TFE_Context * ctx,const char * task,int64_t timeout_in_ms,TF_Status * status)522 TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
523 TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
524 TF_Status* status) {
525 tensorflow::EagerContext* context =
526 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
527 auto collective_executor_handle = context->GetCollectiveExecutorHandle();
528 tensorflow::Notification done;
529 collective_executor_handle->get()->remote_access()->CheckPeerHealth(
530 task, timeout_in_ms, [&done, status](const Status& s) {
531 status->status = s;
532 done.Notify();
533 });
534 done.WaitForNotification();
535 }
536
TF_NewShapeAndTypeList(int num_items)537 TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
538 TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
539 result->num_items = num_items;
540 result->items = (num_items == 0) ? nullptr : new TF_ShapeAndType[num_items]();
541 return result;
542 }
543
TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList * shape_list,int index,const int64_t * dims,int num_dims)544 void TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList* shape_list, int index,
545 const int64_t* dims, int num_dims) {
546 DCHECK(index >= 0 && index < shape_list->num_items);
547 TF_ShapeAndType& shape = shape_list->items[index];
548 DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
549 DCHECK(num_dims >= 0) << "Number of dimensions cannot be negative!";
550 shape.num_dims = num_dims;
551 shape.dims = new int64_t[num_dims];
552 memcpy(shape.dims, dims, sizeof(int64_t) * num_dims);
553 }
554
TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList * shape_list,int index)555 void TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList* shape_list,
556 int index) {
557 DCHECK(index >= 0 && index < shape_list->num_items);
558 TF_ShapeAndType& shape = shape_list->items[index];
559 DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
560 shape.num_dims = -1;
561 shape.dims = nullptr;
562 }
563
TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList * shape_list,int index,TF_DataType dtype)564 void TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList* shape_list, int index,
565 TF_DataType dtype) {
566 DCHECK(index >= 0 && index < shape_list->num_items);
567 TF_ShapeAndType& shape_and_type = shape_list->items[index];
568 shape_and_type.dtype = dtype;
569 }
570
TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList * shape_list)571 void TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList* shape_list) {
572 if (shape_list == nullptr) return;
573 for (size_t i = 0; i < shape_list->num_items; ++i) {
574 delete[] shape_list->items[i].dims;
575 }
576 delete[] shape_list->items;
577 delete shape_list;
578 }
579
TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList ** shape_list_array,int num_items)580 void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
581 int num_items) {
582 if (shape_list_array == nullptr) return;
583 for (int i = 0; i < num_items; ++i) {
584 TF_DeleteShapeAndTypeList(shape_list_array[i]);
585 }
586 delete[] shape_list_array;
587 }
588
589 namespace tensorflow {
590 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
591
592 // Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
593 Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
594 } // namespace tensorflow
595
TFE_InferShapes(TFE_Op * tfe_op,TF_ShapeAndTypeList * input_shapes,TF_Tensor ** input_tensors,TF_ShapeAndTypeList * input_tensors_as_shapes,TF_ShapeAndTypeList ** input_resource_shapes_and_types,TF_ShapeAndTypeList ** output_shapes,TF_ShapeAndTypeList *** output_resource_shapes_and_types,TF_Status * status)596 void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
597 TF_Tensor** input_tensors,
598 TF_ShapeAndTypeList* input_tensors_as_shapes,
599 TF_ShapeAndTypeList** input_resource_shapes_and_types,
600 TF_ShapeAndTypeList** output_shapes,
601 TF_ShapeAndTypeList*** output_resource_shapes_and_types,
602 TF_Status* status) {
603 using tensorflow::NodeDef;
604 using tensorflow::OpRegistrationData;
605 using tensorflow::Tensor;
606 using tensorflow::shape_inference::DimensionHandle;
607 using tensorflow::shape_inference::InferenceContext;
608 using tensorflow::shape_inference::ShapeAndType;
609 using tensorflow::shape_inference::ShapeHandle;
610
611 const int num_inputs = input_shapes->num_items;
612 NodeDef node_def;
613 tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op);
614 node_def.set_name(op->Name());
615 node_def.set_op(op->Name());
616 for (int i = 0; i < num_inputs; ++i) {
617 node_def.add_input("dummy_input");
618 }
619 OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr());
620
621 const tensorflow::OpRegistrationData* op_reg_data;
622 status->status =
623 tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
624 if (!status->status.ok()) return;
625
626 // Initialize a input_tensor vector with `nullptr` values.
627 std::vector<const Tensor*> input_tensors_vector(num_inputs, nullptr);
628 // A vector to keep track of newly created `tf::Tensor` objects.
629 std::vector<Tensor> all_input_tensors;
630 // Update the vector with information from `input_tensors` if provided.
631 if (input_tensors != nullptr) {
632 // Note that we take the address of the elements in `all_input_tensors`
633 // below. Allocate enough space so that no reallocation happens, which will
634 // make the pointers invalid.
635 all_input_tensors.reserve(num_inputs);
636 for (int i = 0; i < num_inputs; ++i) {
637 if (input_tensors[i] == nullptr) continue;
638 all_input_tensors.emplace_back();
639 Tensor& input_tensor = all_input_tensors.back();
640 status->status = TF_TensorToTensor(input_tensors[i], &input_tensor);
641 if (!status->status.ok()) return;
642 input_tensors_vector[i] = &input_tensor;
643 }
644 }
645
646 // Create an inference context with dummy values, which will be updated later.
647 InferenceContext c(TF_GRAPH_DEF_VERSION, node_def, op_reg_data->op_def,
648 std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
649 {},
650 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
651
652 // Set input_shapes.
653 for (int i = 0; i < num_inputs; ++i) {
654 std::vector<DimensionHandle> dims;
655 const TF_ShapeAndType& input_shape = input_shapes->items[i];
656 if (input_shape.num_dims == InferenceContext::kUnknownRank) {
657 c.SetInput(i, c.UnknownShape());
658 continue;
659 }
660 dims.reserve(input_shape.num_dims);
661 for (int j = 0; j < input_shape.num_dims; ++j) {
662 dims.push_back(c.MakeDim(input_shape.dims[j]));
663 }
664 c.SetInput(i, c.MakeShape(dims));
665 }
666
667 // TODO(bgogul): Handle input_tensors_as_shapes.
668 // TODO(bgogul): Handle input_resource_shapes_and_types.
669
670 status->status = c.construction_status();
671 if (!status->status.ok()) return;
672
673 if (op_reg_data->shape_inference_fn == nullptr) {
674 status->status =
675 InvalidArgument("No shape inference function exists for op '",
676 node_def.op(), "', did you forget to define it?");
677 return;
678 }
679
680 status->status = c.Run(op_reg_data->shape_inference_fn);
681 if (!status->status.ok()) return;
682
683 // Set output_shapes.
684 TF_ShapeAndTypeList* output_shapes_result =
685 TF_NewShapeAndTypeList(c.num_outputs());
686 for (int i = 0; i < c.num_outputs(); ++i) {
687 ShapeHandle shape_handle = c.output(i);
688 TF_ShapeAndType& shape = output_shapes_result->items[i];
689 shape.num_dims = c.Rank(shape_handle);
690 if (shape.num_dims == InferenceContext::kUnknownRank) {
691 shape.dims = nullptr;
692 continue;
693 }
694 shape.dims = new int64_t[shape.num_dims];
695 for (size_t j = 0; j < shape.num_dims; ++j) {
696 shape.dims[j] = c.Value(c.Dim(shape_handle, j));
697 }
698 }
699 if (output_shapes != nullptr) *output_shapes = output_shapes_result;
700
701 // TODO(bgogul): Set output_resource_shapes_and_types.
702 }
703
TF_ImportGraphDefOptionsSetValidateColocationConstraints(TF_ImportGraphDefOptions * opts,unsigned char enable)704 void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
705 TF_ImportGraphDefOptions* opts, unsigned char enable) {
706 opts->opts.validate_colocation_constraints = enable;
707 }
708
709 // Load a Pluggable Device library.
710 // On success, returns the handle to library in result and return OK from the
711 // function. Otherwise return nullptr in result and error Status from the
712 // function.
713 //
714 // If `library_filename` has already been loaded, we return a cached handle.
715 // Device and Kernels/Ops are registered as globals when a library is loaded
716 // for the first time.
TF_LoadPluggableDeviceLibrary(const char * library_filename,TF_Status * status)717 TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
718 TF_Status* status) {
719 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
720 status->status = tensorflow::errors::Unimplemented(
721 "PluggableDevice plugin functionality is not supported on mobile");
722 return nullptr;
723 #else
724 TF_Library* lib_handle = new TF_Library;
725 static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
726 static std::unordered_map<std::string, void*>* loaded_libs =
727 new std::unordered_map<std::string, void*>();
728 tensorflow::Env* env = tensorflow::Env::Default();
729 {
730 tensorflow::mutex_lock lock(mu);
731 auto it = loaded_libs->find(library_filename);
732 if (it != loaded_libs->end()) {
733 lib_handle->lib_handle = it->second;
734 } else {
735 status->status =
736 env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
737 if (status->status.ok()) {
738 TF_CHECK_OK(
739 tensorflow::RegisterPluggableDevicePlugin(lib_handle->lib_handle));
740 } else {
741 delete lib_handle;
742 return nullptr;
743 }
744 }
745 return lib_handle;
746 }
747 #endif
748 }
749
TF_DeletePluggableDeviceLibraryHandle(TF_Library * lib_handle)750 void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
751 delete lib_handle;
752 }
753