1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api_experimental.h"
17
18 #include "absl/strings/substitute.h"
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/c/checkpoint_reader.h"
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/c_api_internal.h"
24 #include "tensorflow/c/eager/tfe_context_internal.h"
25 #include "tensorflow/c/eager/tfe_op_internal.h"
26 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
27 #include "tensorflow/compiler/jit/flags.h"
28 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
29 #include "tensorflow/core/common_runtime/eager/context.h"
30 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
31 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
32 #include "tensorflow/core/framework/collective.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/shape_inference.h"
35 #include "tensorflow/core/framework/tensor.pb.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/graph/node_builder.h"
38 #include "tensorflow/core/platform/blocking_counter.h"
39 #include "tensorflow/core/platform/casts.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/init_main.h"
42 #include "tensorflow/core/platform/mutex.h"
43 #include "tensorflow/core/platform/net.h"
44 #include "tensorflow/core/platform/platform.h"
45 #include "tensorflow/core/platform/strcat.h"
46 #include "tensorflow/core/protobuf/config.pb.h"
47 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
48
49 using tensorflow::FunctionDef;
50 using tensorflow::Node;
51 using tensorflow::NodeBuilder;
52 using tensorflow::Status;
53 using tensorflow::errors::InvalidArgument;
54
55 namespace {
56 typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)>
57 UniqueFuncPtr;
58 }
59
60 // struct TF_Operation { tensorflow::Node node; };
ToTF_Operation(Node * node)61 static TF_Operation* ToTF_Operation(Node* node) {
62 return static_cast<TF_Operation*>(static_cast<void*>(node));
63 }
64
TF_EnableXLACompilation(TF_SessionOptions * options,unsigned char enable)65 void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
66 tensorflow::ConfigProto& config = options->options.config;
67 auto* optimizer_options =
68 config.mutable_graph_options()->mutable_optimizer_options();
69 if (enable) {
70 optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
71
72 // These XLA flags are needed to trigger XLA properly from C (more generally
73 // non-Python) clients. If this API is called again with `enable` set to
74 // false, it is safe to keep these flag values as is.
75 tensorflow::MarkForCompilationPassFlags* flags =
76 tensorflow::GetMarkForCompilationPassFlags();
77 flags->tf_xla_cpu_global_jit = true;
78 flags->tf_xla_min_cluster_size = 1;
79 } else {
80 optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
81 }
82 }
83
TF_SetXlaEnableLazyCompilation(unsigned char enable)84 unsigned char TF_SetXlaEnableLazyCompilation(unsigned char enable) {
85 tensorflow::BuildXlaOpsPassFlags* flags =
86 tensorflow::GetBuildXlaOpsPassFlags();
87 bool original = flags->tf_xla_enable_lazy_compilation;
88 flags->tf_xla_enable_lazy_compilation = enable;
89 return original;
90 }
91
TF_SetTfXlaCpuGlobalJit(unsigned char enable)92 unsigned char TF_SetTfXlaCpuGlobalJit(unsigned char enable) {
93 tensorflow::MarkForCompilationPassFlags* flags =
94 tensorflow::GetMarkForCompilationPassFlags();
95 bool original = flags->tf_xla_cpu_global_jit;
96 flags->tf_xla_cpu_global_jit = static_cast<bool>(enable);
97 return static_cast<unsigned char>(original);
98 }
99
TF_SetXlaAutoJitMode(const char * mode)100 void TF_SetXlaAutoJitMode(const char* mode) {
101 tensorflow::SetXlaAutoJitFlagFromFlagString(mode);
102 }
103
TF_GetXlaConstantFoldingDisabled()104 unsigned char TF_GetXlaConstantFoldingDisabled() {
105 return static_cast<unsigned char>(
106 tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding);
107 }
108
TF_SetXlaConstantFoldingDisabled(unsigned char should_enable)109 void TF_SetXlaConstantFoldingDisabled(unsigned char should_enable) {
110 tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding =
111 static_cast<bool>(should_enable);
112 }
113
TF_SetXlaMinClusterSize(int size)114 void TF_SetXlaMinClusterSize(int size) {
115 tensorflow::MarkForCompilationPassFlags* flags =
116 tensorflow::GetMarkForCompilationPassFlags();
117 flags->tf_xla_min_cluster_size = size;
118 }
119
TF_CreateConfig(unsigned char enable_xla_compilation,unsigned char gpu_memory_allow_growth,unsigned int num_cpu_devices)120 TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
121 unsigned char gpu_memory_allow_growth,
122 unsigned int num_cpu_devices) {
123 tensorflow::ConfigProto config;
124 auto* optimizer_options =
125 config.mutable_graph_options()->mutable_optimizer_options();
126 if (enable_xla_compilation) {
127 optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
128
129 // These XLA flags are needed to trigger XLA properly from C (more generally
130 // non-Python) clients. If this API is called again with `enable` set to
131 // false, it is safe to keep these flag values as is.
132 tensorflow::MarkForCompilationPassFlags* flags =
133 tensorflow::GetMarkForCompilationPassFlags();
134 flags->tf_xla_cpu_global_jit = true;
135 flags->tf_xla_min_cluster_size = 1;
136 } else {
137 optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
138 }
139
140 auto* gpu_options = config.mutable_gpu_options();
141 gpu_options->set_allow_growth(gpu_memory_allow_growth);
142
143 (*config.mutable_device_count())["CPU"] = num_cpu_devices;
144
145 // TODO(b/113217601): This is needed for EagerContext::runner_ to use a
146 // threadpool, so that we avoid the possibility of running the runner_ in the
147 // threadpool of GPU event mgr, as that can trigger more callbacks to be
148 // scheduled on that same threadpool, causing a deadlock in cases where the
149 // caller of event_mgr->ThenExecute() blocks on the completion of the callback
150 // (as in the case of ConstOp kernel creation on GPU, which involves copying a
151 // CPU tensor to GPU).
152 // Setting a larger thread pool does not help with the Swift caller, as we use
153 // a different TFE context for each thread of execution (for running graph
154 // functions, and their send/recvs corountines).
155 config.set_inter_op_parallelism_threads(1);
156
157 TF_Buffer* ret = TF_NewBuffer();
158 TF_CHECK_OK(MessageToBuffer(config, ret));
159 return ret;
160 }
161
TF_CreateRunOptions(unsigned char enable_full_trace)162 TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) {
163 tensorflow::RunOptions options;
164 if (enable_full_trace) {
165 options.set_trace_level(tensorflow::RunOptions::FULL_TRACE);
166 } else {
167 options.set_trace_level(tensorflow::RunOptions::NO_TRACE);
168 }
169 TF_Buffer* ret = TF_NewBuffer();
170 TF_CHECK_OK(MessageToBuffer(options, ret));
171 return ret;
172 }
173
TF_GraphDebugString(TF_Graph * graph,size_t * len)174 const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
175 tensorflow::mutex_lock c(graph->mu);
176 const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString();
177 *len = debug_str.size();
178 char* ret = static_cast<char*>(malloc(*len + 1));
179 memcpy(ret, debug_str.c_str(), *len + 1);
180 return ret;
181 }
182
TF_FunctionDebugString(TF_Function * func,size_t * len)183 char* TF_FunctionDebugString(TF_Function* func, size_t* len) {
184 const auto& debug_str = DebugString(func->fdef);
185 *len = debug_str.size();
186 char* ret = static_cast<char*>(malloc(*len + 1));
187 memcpy(ret, debug_str.c_str(), *len + 1);
188 return ret;
189 }
190
191 // On success, returns a set of TF_Function instances from `text_proto` of
192 // GraphDef type. These functions must be deleted by calling TF_DeleteFunction.
193 //
194 // If `mutate_proto_func` is non-NULL, run it over each FunctionDef proto,
195 // before creating a TF_Function out of the possibly mutated proto.
CreateFunctionsFromTextProto(const char * text_proto,std::function<void (FunctionDef *)> * mutate_proto_func,TF_Status * status)196 static std::vector<UniqueFuncPtr> CreateFunctionsFromTextProto(
197 const char* text_proto,
198 std::function<void(FunctionDef*)>* mutate_proto_func, TF_Status* status) {
199 tensorflow::GraphDef gdef;
200 if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &gdef)) {
201 status->status = tensorflow::errors::Internal(
202 "Invalid text proto for GraphDef: ", text_proto);
203 return {};
204 }
205 const auto& fdef_lib = gdef.library();
206 if (fdef_lib.gradient_size() > 0) {
207 status->status = tensorflow::errors::Internal(
208 "GradientDef is not supported in reading Dataset related functions: ",
209 text_proto);
210 return {};
211 }
212 std::vector<UniqueFuncPtr> ret;
213 for (const FunctionDef& fdef : fdef_lib.function()) {
214 // Make a copy so that we can mutate it.
215 FunctionDef fdef_to_load = fdef;
216 if (mutate_proto_func) {
217 (*mutate_proto_func)(&fdef_to_load);
218 }
219 VLOG(1) << "Adding func to graph: " << fdef_to_load.DebugString();
220 std::vector<char> binary_proto_buf(fdef_to_load.ByteSizeLong());
221 fdef_to_load.SerializeToArray(binary_proto_buf.data(),
222 binary_proto_buf.size());
223 TF_Function* func = TF_FunctionImportFunctionDef(
224 binary_proto_buf.data(), binary_proto_buf.size(), status);
225 if (!status->status.ok()) return {};
226 ret.push_back(UniqueFuncPtr(func, TF_DeleteFunction));
227 }
228 return ret;
229 }
230
TF_DequeueNamedTensor(TF_Session * session,int tensor_id,TF_Status * status)231 TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
232 TF_Status* status) {
233 assert(session);
234 {
235 tensorflow::mutex_lock c(session->graph->mu);
236 VLOG(1) << "Dequeuing named tensor with id " << tensor_id
237 << ", with input graph: "
238 << session->graph->graph.ToGraphDefDebug().DebugString();
239 }
240
241 TF_Operation* dequeue_op = TF_GraphOperationByName(
242 session->graph,
243 tensorflow::strings::StrCat("fifo_queue_dequeue_", tensor_id).c_str());
244 if (dequeue_op == nullptr) {
245 status->status = tensorflow::errors::Internal(
246 "Unable to find the dequeue node in the TF graph.");
247 return nullptr;
248 }
249
250 VLOG(1) << "Running the dequeue op";
251 TF_Output output{dequeue_op, 0};
252 TF_Tensor* ret;
253 TF_SessionRun(session, /*run_options*/ nullptr,
254 // input related parameters
255 /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0,
256 // output related parameters
257 /*outputs*/ &output, /*output_values*/ &ret,
258 /*noutputs*/ 1,
259 /*targets*/ nullptr, /*ntargets*/ 0,
260 /*run_metadata*/ nullptr, status);
261 if (VLOG_IS_ON(1) && status->status.ok()) {
262 tensorflow::Tensor tensor;
263 if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) {
264 VLOG(1) << "Dequeued tensor content: " << tensor.DebugString();
265 }
266 }
267 return ret;
268 }
269
TF_EnqueueNamedTensor(TF_Session * session,int tensor_id,TF_Tensor * tensor,TF_Status * status)270 void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
271 TF_Tensor* tensor, TF_Status* status) {
272 assert(session);
273 {
274 tensorflow::mutex_lock c(session->graph->mu);
275 if (VLOG_IS_ON(1)) {
276 VLOG(1) << "Enqueuing named tensor with id " << tensor_id
277 << ", with input graph: "
278 << session->graph->graph.ToGraphDefDebug().DebugString();
279 tensorflow::Tensor internal_tensor;
280 if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) {
281 VLOG(1) << "Enqueu'ing tensor content: "
282 << internal_tensor.DebugString();
283 }
284 }
285 }
286
287 TF_Operation* enqueue_op = TF_GraphOperationByName(
288 session->graph,
289 tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str());
290 if (enqueue_op == nullptr) {
291 status->status = tensorflow::errors::Internal(
292 "Unable to find the enqueue node in the TF graph.");
293 return;
294 }
295
296 TF_Operation* placeholder_op = TF_GraphOperationByName(
297 session->graph,
298 tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str());
299 if (placeholder_op == nullptr) {
300 status->status = tensorflow::errors::Internal(
301 "Unable to find the placeholder node as input to enqueue in the TF "
302 "graph.");
303 return;
304 }
305
306 VLOG(1) << "Running the enqueue op";
307 TF_Output input{placeholder_op, 0};
308 TF_SessionRun(session, /*run_options*/ nullptr,
309 // input related parameters
310 /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1,
311 // output related parameters
312 /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0,
313 /*targets*/ &enqueue_op, /*ntargets*/ 1,
314 /*run_metadata*/ nullptr, status);
315 VLOG(1) << "Enqueuing is done.";
316 }
317
TFE_GetServerDef(const char * text_proto,TF_Status * status)318 TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
319 tensorflow::ServerDef server_def;
320 if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto,
321 &server_def)) {
322 status->status = tensorflow::errors::Internal(
323 "Invalid text proto for ServerDef: ", text_proto);
324 return nullptr;
325 }
326 status->status = tensorflow::Status();
327 TF_Buffer* ret = TF_NewBuffer();
328 TF_CHECK_OK(MessageToBuffer(server_def, ret));
329 return ret;
330 }
331
TF_MakeInternalErrorStatus(TF_Status * status,const char * errMsg)332 void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
333 status->status = tensorflow::errors::Internal(errMsg);
334 }
335
336 struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader {
337 using tensorflow::checkpoint::CheckpointReader::CheckpointReader;
338 std::vector<std::string> variable_list;
339 };
340
TF_NewCheckpointReader(const char * filename,TF_Status * status)341 TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
342 TF_Status* status) {
343 TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
344 if (!status->status.ok()) {
345 TF_DeleteCheckpointReader(reader);
346 return nullptr;
347 }
348 const auto& m = reader->GetVariableToDataTypeMap();
349 for (auto it = m.begin(); it != m.end(); ++it)
350 reader->variable_list.push_back(it->first);
351 std::sort(reader->variable_list.begin(), reader->variable_list.end());
352 return reader;
353 }
354
TF_DeleteCheckpointReader(TF_CheckpointReader * reader)355 void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; }
356
TF_CheckpointReaderHasTensor(TF_CheckpointReader * reader,const char * name)357 int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader,
358 const char* name) {
359 return reader->HasTensor(name);
360 }
361
TF_CheckpointReaderGetVariable(TF_CheckpointReader * reader,int index)362 const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader,
363 int index) {
364 return reader->variable_list[index].c_str();
365 }
366
TF_CheckpointReaderSize(TF_CheckpointReader * reader)367 int TF_CheckpointReaderSize(TF_CheckpointReader* reader) {
368 return reader->variable_list.size();
369 }
370
TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader * reader,const char * name)371 TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader,
372 const char* name) {
373 const auto& m = reader->GetVariableToDataTypeMap();
374 return static_cast<TF_DataType>(m.at(name));
375 }
376
TF_CheckpointReaderGetTensor(TF_CheckpointReader * reader,const char * name,TF_Status * status)377 TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
378 const char* name, TF_Status* status) {
379 std::unique_ptr<tensorflow::Tensor> tensor;
380 reader->GetTensor(name, &tensor, status);
381 if (!status->status.ok()) return nullptr;
382 return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
383 }
384
TF_CheckpointReaderGetVariableShape(TF_CheckpointReader * reader,const char * name,int64_t * dims,int num_dims,TF_Status * status)385 void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
386 const char* name, int64_t* dims,
387 int num_dims, TF_Status* status) {
388 const auto& shape = reader->GetVariableToShapeMap().at(name);
389 int rank = shape.dims();
390 if (num_dims != rank) {
391 status->status = InvalidArgument("Expected rank is ", num_dims,
392 " but actual rank is ", rank);
393 return;
394 }
395 for (int i = 0; i < num_dims; i++) {
396 dims[i] = shape.dim_size(i);
397 }
398 }
399
TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader * reader,const char * name)400 int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader,
401 const char* name) {
402 const auto& m = reader->GetVariableToShapeMap();
403 return m.at(name).dims();
404 }
405
406 // This builder is used in the eager API to build a NodeDef.
407 struct TF_AttrBuilder : public tensorflow::AttrBuilder {
408 using tensorflow::AttrBuilder::AttrBuilder;
409 // The string buffers to make sure that any `attr_name` we pass into
410 // `builder->Set()` will outlive the subsequent
411 // `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`.
412 std::set<std::string> attr_names;
413 };
414
TF_NewAttrBuilder(const char * op_name)415 TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) {
416 return new TF_AttrBuilder(op_name);
417 }
418
TF_DeleteAttrBuilder(TF_AttrBuilder * builder)419 void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; }
420
TF_AttrBuilderSetType(TF_AttrBuilder * builder,const char * attr_name,TF_DataType value)421 void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
422 TF_DataType value) {
423 auto iter = builder->attr_names.insert(attr_name).first;
424 builder->Set(*iter, static_cast<tensorflow::DataType>(value));
425 }
426
TF_AttrBuilderSetTypeList(TF_AttrBuilder * builder,const char * attr_name,const TF_DataType * values,int num_values)427 void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
428 const TF_DataType* values, int num_values) {
429 auto iter = builder->attr_names.insert(attr_name).first;
430 builder->Set(*iter, tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
431 reinterpret_cast<const tensorflow::DataType*>(values),
432 num_values));
433 }
434
TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder * builder,const char * device_type,TF_Status * status)435 void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
436 const char* device_type,
437 TF_Status* status) {
438 status->status = tensorflow::FindKernelDef(
439 tensorflow::DeviceType(device_type), builder->BuildNodeDef(),
440 /* def = */ nullptr, /* kernel_class_name = */ nullptr);
441 }
442
TF_GetNumberAttrForOpListInput(const char * op_name,int input_index,TF_Status * status)443 const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index,
444 TF_Status* status) {
445 const tensorflow::OpDef* op_def = nullptr;
446 status->status =
447 tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def);
448 if (!status->status.ok()) return nullptr;
449
450 if (input_index >= op_def->input_arg_size() || input_index < 0) {
451 status->status = tensorflow::errors::InvalidArgument(
452 input_index, " out of range for ", op_name);
453 return nullptr;
454 }
455
456 const tensorflow::OpDef_ArgDef& input_arg = op_def->input_arg()[input_index];
457
458 if (input_arg.number_attr().empty()) {
459 status->status = tensorflow::errors::NotFound(
460 op_name, " does not have number_attr() defined.");
461 return nullptr;
462 }
463
464 // The returned string is owned by OpRegistry, so liveness is not a concern.
465 return input_arg.number_attr().c_str();
466 }
467
TF_OpIsStateful(const char * op_type,TF_Status * status)468 int TF_OpIsStateful(const char* op_type, TF_Status* status) {
469 const tensorflow::OpRegistrationData* op_reg_data;
470 status->status =
471 tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data);
472 if (!status->status.ok()) {
473 return 0;
474 }
475 return op_reg_data->op_def.is_stateful();
476 }
477
TF_InitMain(const char * usage,int * argc,char *** argv)478 void TF_InitMain(const char* usage, int* argc, char*** argv) {
479 tensorflow::port::InitMain(usage, argc, argv);
480 }
481
TF_PickUnusedPortOrDie()482 int TF_PickUnusedPortOrDie() {
483 return tensorflow::internal::PickUnusedPortOrDie();
484 }
485
TFE_NewTensorHandleFromScalar(TF_DataType data_type,void * data,size_t len,TF_Status * status)486 TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
487 void* data, size_t len,
488 TF_Status* status) {
489 auto dtype = static_cast<tensorflow::DataType>(data_type);
490 DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype));
491
492 tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
493 std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
494
495 status->status = tensorflow::Status::OK();
496 return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
497 }
498
499 // Set server_def on the context, possibly updating it.
TFE_EnableCollectiveOps(TFE_Context * ctx,const void * proto,size_t proto_len,TF_Status * status)500 TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
501 const void* proto,
502 size_t proto_len,
503 TF_Status* status) {
504 tensorflow::ServerDef server_def;
505 if (!server_def.ParseFromArray(proto, proto_len)) {
506 status->status = tensorflow::errors::InvalidArgument(
507 "Invalid tensorflow.ServerDef protocol buffer");
508 return;
509 }
510 status->status =
511 tensorflow::unwrap(ctx)->GetDistributedManager()->EnableCollectiveOps(
512 server_def);
513 }
514
TFE_AbortCollectiveOps(TFE_Context * ctx,TF_Status * status)515 TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
516 TF_Status* status) {
517 tensorflow::EagerContext* context =
518 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
519 auto collective_executor_handle = context->GetCollectiveExecutorHandle();
520 collective_executor_handle->get()->StartAbort(status->status);
521 }
522
TFE_CollectiveOpsCheckPeerHealth(TFE_Context * ctx,const char * task,int64_t timeout_in_ms,TF_Status * status)523 TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
524 TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
525 TF_Status* status) {
526 tensorflow::EagerContext* context =
527 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
528 auto collective_executor_handle = context->GetCollectiveExecutorHandle();
529 tensorflow::Notification done;
530 collective_executor_handle->get()->remote_access()->CheckPeerHealth(
531 task, timeout_in_ms, [&done, status](const Status& s) {
532 status->status = s;
533 done.Notify();
534 });
535 done.WaitForNotification();
536 }
537
TF_NewShapeAndTypeList(int num_items)538 TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
539 TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
540 result->num_items = num_items;
541 result->items = (num_items == 0) ? nullptr : new TF_ShapeAndType[num_items]();
542 return result;
543 }
544
TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList * shape_list,int index,const int64_t * dims,int num_dims)545 void TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList* shape_list, int index,
546 const int64_t* dims, int num_dims) {
547 DCHECK(index >= 0 && index < shape_list->num_items);
548 TF_ShapeAndType& shape = shape_list->items[index];
549 DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
550 DCHECK(num_dims >= 0) << "Number of dimensions cannot be negative!";
551 shape.num_dims = num_dims;
552 shape.dims = new int64_t[num_dims];
553 memcpy(shape.dims, dims, sizeof(int64_t) * num_dims);
554 }
555
TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList * shape_list,int index)556 void TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList* shape_list,
557 int index) {
558 DCHECK(index >= 0 && index < shape_list->num_items);
559 TF_ShapeAndType& shape = shape_list->items[index];
560 DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
561 shape.num_dims = -1;
562 shape.dims = nullptr;
563 }
564
TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList * shape_list,int index,TF_DataType dtype)565 void TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList* shape_list, int index,
566 TF_DataType dtype) {
567 DCHECK(index >= 0 && index < shape_list->num_items);
568 TF_ShapeAndType& shape_and_type = shape_list->items[index];
569 shape_and_type.dtype = dtype;
570 }
571
TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList * shape_list)572 void TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList* shape_list) {
573 if (shape_list == nullptr) return;
574 for (size_t i = 0; i < shape_list->num_items; ++i) {
575 delete[] shape_list->items[i].dims;
576 }
577 delete[] shape_list->items;
578 delete shape_list;
579 }
580
TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList ** shape_list_array,int num_items)581 void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
582 int num_items) {
583 if (shape_list_array == nullptr) return;
584 for (int i = 0; i < num_items; ++i) {
585 TF_DeleteShapeAndTypeList(shape_list_array[i]);
586 }
587 delete[] shape_list_array;
588 }
589
590 namespace tensorflow {
591 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
592
593 // Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
594 Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
595 } // namespace tensorflow
596
TFE_InferShapes(TFE_Op * tfe_op,TF_ShapeAndTypeList * input_shapes,TF_Tensor ** input_tensors,TF_ShapeAndTypeList * input_tensors_as_shapes,TF_ShapeAndTypeList ** input_resource_shapes_and_types,TF_ShapeAndTypeList ** output_shapes,TF_ShapeAndTypeList *** output_resource_shapes_and_types,TF_Status * status)597 void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
598 TF_Tensor** input_tensors,
599 TF_ShapeAndTypeList* input_tensors_as_shapes,
600 TF_ShapeAndTypeList** input_resource_shapes_and_types,
601 TF_ShapeAndTypeList** output_shapes,
602 TF_ShapeAndTypeList*** output_resource_shapes_and_types,
603 TF_Status* status) {
604 using tensorflow::NodeDef;
605 using tensorflow::OpRegistrationData;
606 using tensorflow::Tensor;
607 using tensorflow::shape_inference::DimensionHandle;
608 using tensorflow::shape_inference::InferenceContext;
609 using tensorflow::shape_inference::ShapeAndType;
610 using tensorflow::shape_inference::ShapeHandle;
611
612 const int num_inputs = input_shapes->num_items;
613 NodeDef node_def;
614 tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op);
615 node_def.set_name(op->Name());
616 node_def.set_op(op->Name());
617 for (int i = 0; i < num_inputs; ++i) {
618 node_def.add_input("dummy_input");
619 }
620 OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr());
621
622 const tensorflow::OpRegistrationData* op_reg_data;
623 status->status =
624 tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
625 if (!status->status.ok()) return;
626
627 // Initialize a input_tensor vector with `nullptr` values.
628 std::vector<const Tensor*> input_tensors_vector(num_inputs, nullptr);
629 // A vector to keep track of newly created `tf::Tensor` objects.
630 std::vector<Tensor> all_input_tensors;
631 // Update the vector with information from `input_tensors` if provided.
632 if (input_tensors != nullptr) {
633 // Note that we take the address of the elements in `all_input_tensors`
634 // below. Allocate enough space so that no reallocation happens, which will
635 // make the pointers invalid.
636 all_input_tensors.reserve(num_inputs);
637 for (int i = 0; i < num_inputs; ++i) {
638 if (input_tensors[i] == nullptr) continue;
639 all_input_tensors.emplace_back();
640 Tensor& input_tensor = all_input_tensors.back();
641 status->status = TF_TensorToTensor(input_tensors[i], &input_tensor);
642 if (!status->status.ok()) return;
643 input_tensors_vector[i] = &input_tensor;
644 }
645 }
646
647 // Create an inference context with dummy values, which will be updated later.
648 InferenceContext c(TF_GRAPH_DEF_VERSION, node_def, op_reg_data->op_def,
649 std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
650 {},
651 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
652
653 // Set input_shapes.
654 for (int i = 0; i < num_inputs; ++i) {
655 std::vector<DimensionHandle> dims;
656 const TF_ShapeAndType& input_shape = input_shapes->items[i];
657 if (input_shape.num_dims == InferenceContext::kUnknownRank) {
658 c.SetInput(i, c.UnknownShape());
659 continue;
660 }
661 dims.reserve(input_shape.num_dims);
662 for (int j = 0; j < input_shape.num_dims; ++j) {
663 dims.push_back(c.MakeDim(input_shape.dims[j]));
664 }
665 c.SetInput(i, c.MakeShape(dims));
666 }
667
668 // TODO(bgogul): Handle input_tensors_as_shapes.
669 // TODO(bgogul): Handle input_resource_shapes_and_types.
670
671 status->status = c.construction_status();
672 if (!status->status.ok()) return;
673
674 if (op_reg_data->shape_inference_fn == nullptr) {
675 status->status =
676 InvalidArgument("No shape inference function exists for op '",
677 node_def.op(), "', did you forget to define it?");
678 return;
679 }
680
681 status->status = c.Run(op_reg_data->shape_inference_fn);
682 if (!status->status.ok()) return;
683
684 // Set output_shapes.
685 TF_ShapeAndTypeList* output_shapes_result =
686 TF_NewShapeAndTypeList(c.num_outputs());
687 for (int i = 0; i < c.num_outputs(); ++i) {
688 ShapeHandle shape_handle = c.output(i);
689 TF_ShapeAndType& shape = output_shapes_result->items[i];
690 shape.num_dims = c.Rank(shape_handle);
691 if (shape.num_dims == InferenceContext::kUnknownRank) {
692 shape.dims = nullptr;
693 continue;
694 }
695 shape.dims = new int64_t[shape.num_dims];
696 for (size_t j = 0; j < shape.num_dims; ++j) {
697 shape.dims[j] = c.Value(c.Dim(shape_handle, j));
698 }
699 }
700 if (output_shapes != nullptr) *output_shapes = output_shapes_result;
701
702 // TODO(bgogul): Set output_resource_shapes_and_types.
703 }
704
TF_ImportGraphDefOptionsSetValidateColocationConstraints(TF_ImportGraphDefOptions * opts,unsigned char enable)705 void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
706 TF_ImportGraphDefOptions* opts, unsigned char enable) {
707 opts->opts.validate_colocation_constraints = enable;
708 }
709
710 // Load a Pluggable Device library.
711 // On success, returns the handle to library in result and return OK from the
712 // function. Otherwise return nullptr in result and error Status from the
713 // function.
714 //
715 // If `library_filename` has already been loaded, we return a cached handle.
716 // Device and Kernels/Ops are registered as globals when a library is loaded
717 // for the first time.
TF_LoadPluggableDeviceLibrary(const char * library_filename,TF_Status * status)718 TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
719 TF_Status* status) {
720 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
721 status->status = tensorflow::errors::Unimplemented(
722 "PluggableDevice plugin functionality is not supported on mobile");
723 return nullptr;
724 #else
725 TF_Library* lib_handle = new TF_Library;
726 static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
727 static std::unordered_map<std::string, void*>* loaded_libs =
728 new std::unordered_map<std::string, void*>();
729 tensorflow::Env* env = tensorflow::Env::Default();
730 {
731 tensorflow::mutex_lock lock(mu);
732 auto it = loaded_libs->find(library_filename);
733 if (it != loaded_libs->end()) {
734 lib_handle->lib_handle = it->second;
735 } else {
736 status->status =
737 env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
738 if (!status->status.ok()) {
739 delete lib_handle;
740 return nullptr;
741 }
742 }
743 return lib_handle;
744 }
745 #endif
746 }
747
TF_DeletePluggableDeviceLibraryHandle(TF_Library * lib_handle)748 void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
749 delete lib_handle;
750 }
751