1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api_experimental.h"
17
18 #include "tensorflow/c/c_api.h"
19 #include "tensorflow/c/eager/c_api_internal.h"
20 #include "tensorflow/c/tf_status_helper.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/lib/monitoring/counter.h"
23 #include "tensorflow/core/lib/monitoring/gauge.h"
24 #include "tensorflow/core/lib/monitoring/sampler.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/platform/casts.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/profiler/rpc/client/capture_profile.h"
29 #include "tensorflow/core/profiler/rpc/profiler_server.h"
30
31 using tensorflow::string;
32
TFE_OpReset(TFE_Op * op_to_reset,const char * op_or_function_name,const char * raw_device_name,TF_Status * status)33 void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
34 const char* raw_device_name, TF_Status* status) {
35 if (op_to_reset) {
36 status->status = op_to_reset->operation.Reset(
37 op_or_function_name, raw_device_name, false, nullptr);
38 } else {
39 TF_SetStatus(status, TF_INVALID_ARGUMENT,
40 "op_to_reset should not be nullptr");
41 }
42 }
43
TFE_OpConsumeInput(TFE_Op * op,TFE_TensorHandle * h,TF_Status * status)44 void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
45 op->operation.ConsumeInput(
46 tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
47 ->Handle());
48 }
49
TFE_NewProfiler()50 TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
51
TFE_ProfilerIsOk(TFE_Profiler * profiler)52 bool TFE_ProfilerIsOk(TFE_Profiler* profiler) {
53 return profiler->profiler->Status().ok();
54 }
55
TFE_DeleteProfiler(TFE_Profiler * profiler)56 void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; }
57
TFE_ProfilerSerializeToString(TFE_Profiler * profiler,TF_Buffer * buf,TF_Status * status)58 void TFE_ProfilerSerializeToString(TFE_Profiler* profiler, TF_Buffer* buf,
59 TF_Status* status) {
60 string content;
61 status->status = profiler->profiler->SerializeToString(&content);
62 void* data = tensorflow::port::Malloc(content.length());
63 content.copy(static_cast<char*>(data), content.length(), 0);
64 buf->data = data;
65 buf->length = content.length();
66 buf->data_deallocator = [](void* data, size_t length) {
67 tensorflow::port::Free(data);
68 };
69 }
70
TFE_StartProfilerServer(int port)71 void TFE_StartProfilerServer(int port) {
72 // Release child thread intentionally. The child thread can be terminated by
73 // terminating the main thread.
74 tensorflow::StartProfilerServer(port).release();
75 }
76
TFE_ContextEnableGraphCollection(TFE_Context * ctx)77 void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
78 ctx->context->SetShouldStoreGraphs(true);
79 }
80
TFE_ContextDisableGraphCollection(TFE_Context * ctx)81 void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
82 ctx->context->SetShouldStoreGraphs(false);
83 }
84
TFE_ProfilerClientStartTracing(const char * service_addr,const char * logdir,const char * worker_list,bool include_dataset_ops,int duration_ms,int num_tracing_attempts,TF_Status * status)85 bool TFE_ProfilerClientStartTracing(const char* service_addr,
86 const char* logdir, const char* worker_list,
87 bool include_dataset_ops, int duration_ms,
88 int num_tracing_attempts,
89 TF_Status* status) {
90 tensorflow::Status s =
91 tensorflow::profiler::ValidateHostPortPair(service_addr);
92 if (!s.ok()) {
93 Set_TF_Status_from_Status(status, s);
94 return false;
95 }
96 s = tensorflow::profiler::Trace(service_addr, logdir, worker_list,
97 include_dataset_ops, duration_ms,
98 num_tracing_attempts);
99 tensorflow::Set_TF_Status_from_Status(status, s);
100 return s.ok();
101 }
102
TFE_ProfilerClientMonitor(const char * service_addr,int duration_ms,int monitoring_level,bool display_timestamp,TF_Buffer * result,TF_Status * status)103 void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
104 int monitoring_level, bool display_timestamp,
105 TF_Buffer* result, TF_Status* status) {
106 tensorflow::Status s =
107 tensorflow::profiler::ValidateHostPortPair(service_addr);
108 if (!s.ok()) {
109 Set_TF_Status_from_Status(status, s);
110 return;
111 }
112 string content;
113 s = tensorflow::profiler::Monitor(service_addr, duration_ms, monitoring_level,
114 display_timestamp, &content);
115 void* data = tensorflow::port::Malloc(content.length());
116 content.copy(static_cast<char*>(data), content.length(), 0);
117 result->data = data;
118 result->length = content.length();
119 result->data_deallocator = [](void* data, size_t length) {
120 tensorflow::port::Free(data);
121 };
122 tensorflow::Set_TF_Status_from_Status(status, s);
123 }
124
TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell * cell,int64_t value)125 void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
126 int64_t value) {
127 cell->cell.IncrementBy(value);
128 }
129
TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell * cell)130 int64_t TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell* cell) {
131 return cell->cell.value();
132 }
133
TFE_MonitoringNewCounter0(const char * name,TF_Status * status,const char * description)134 TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(const char* name,
135 TF_Status* status,
136 const char* description) {
137 auto* result = new TFE_MonitoringCounter0({name, description});
138 Set_TF_Status_from_Status(status, result->counter->GetStatus());
139 if (!result->counter->GetStatus().ok()) {
140 delete result;
141 return nullptr;
142 }
143 return result;
144 }
145
TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0 * counter)146 void TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0* counter) {
147 delete counter;
148 }
149
TFE_MonitoringGetCellCounter0(TFE_MonitoringCounter0 * counter)150 TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
151 TFE_MonitoringCounter0* counter) {
152 return static_cast<TFE_MonitoringCounterCell*>(
153 static_cast<void*>(counter->counter->GetCell()));
154 }
155
TFE_MonitoringNewCounter1(const char * name,TF_Status * status,const char * description,const char * label1)156 TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(const char* name,
157 TF_Status* status,
158 const char* description,
159 const char* label1) {
160 auto* result = new TFE_MonitoringCounter1({name, description, label1});
161 Set_TF_Status_from_Status(status, result->counter->GetStatus());
162 if (!result->counter->GetStatus().ok()) {
163 delete result;
164 return nullptr;
165 }
166 return result;
167 }
168
TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1 * counter)169 void TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1* counter) {
170 delete counter;
171 }
172
TFE_MonitoringGetCellCounter1(TFE_MonitoringCounter1 * counter,const char * label1)173 TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
174 TFE_MonitoringCounter1* counter, const char* label1) {
175 return static_cast<TFE_MonitoringCounterCell*>(
176 static_cast<void*>(counter->counter->GetCell(label1)));
177 }
178
TFE_MonitoringNewCounter2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)179 TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(const char* name,
180 TF_Status* status,
181 const char* description,
182 const char* label1,
183 const char* label2) {
184 auto* result =
185 new TFE_MonitoringCounter2({name, description, label1, label2});
186 Set_TF_Status_from_Status(status, result->counter->GetStatus());
187 if (!result->counter->GetStatus().ok()) {
188 delete result;
189 return nullptr;
190 }
191 return result;
192 }
193
TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2 * counter)194 void TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2* counter) {
195 delete counter;
196 }
197
TFE_MonitoringGetCellCounter2(TFE_MonitoringCounter2 * counter,const char * label1,const char * label2)198 TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
199 TFE_MonitoringCounter2* counter, const char* label1, const char* label2) {
200 return static_cast<TFE_MonitoringCounterCell*>(
201 static_cast<void*>(counter->counter->GetCell(label1, label2)));
202 }
203
TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell * cell,int64_t value)204 void TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell* cell,
205 int64_t value) {
206 cell->cell.Set(value);
207 }
208
TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell * cell)209 int64_t TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell* cell) {
210 return cell->cell.value();
211 }
212
TFE_MonitoringNewIntGauge0(const char * name,TF_Status * status,const char * description)213 TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(const char* name,
214 TF_Status* status,
215 const char* description) {
216 auto* result = new TFE_MonitoringIntGauge0({name, description});
217 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
218 if (!result->gauge->GetStatus().ok()) {
219 delete result;
220 return nullptr;
221 }
222 return result;
223 }
224
TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0 * gauge)225 void TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0* gauge) {
226 delete gauge;
227 }
228
TFE_MonitoringGetCellIntGauge0(TFE_MonitoringIntGauge0 * gauge)229 TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge0(
230 TFE_MonitoringIntGauge0* gauge) {
231 return static_cast<TFE_MonitoringIntGaugeCell*>(
232 static_cast<void*>(gauge->gauge->GetCell()));
233 }
234
TFE_MonitoringNewIntGauge1(const char * name,TF_Status * status,const char * description,const char * label1)235 TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(const char* name,
236 TF_Status* status,
237 const char* description,
238 const char* label1) {
239 auto* result = new TFE_MonitoringIntGauge1({name, description, label1});
240 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
241 if (!result->gauge->GetStatus().ok()) {
242 delete result;
243 return nullptr;
244 }
245 return result;
246 }
247
TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1 * gauge)248 void TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1* gauge) {
249 delete gauge;
250 }
251
TFE_MonitoringGetCellIntGauge1(TFE_MonitoringIntGauge1 * gauge,const char * label1)252 TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge1(
253 TFE_MonitoringIntGauge1* gauge, const char* label1) {
254 return static_cast<TFE_MonitoringIntGaugeCell*>(
255 static_cast<void*>(gauge->gauge->GetCell(label1)));
256 }
257
TFE_MonitoringNewIntGauge2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)258 TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(const char* name,
259 TF_Status* status,
260 const char* description,
261 const char* label1,
262 const char* label2) {
263 auto* result =
264 new TFE_MonitoringIntGauge2({name, description, label1, label2});
265 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
266 if (!result->gauge->GetStatus().ok()) {
267 delete result;
268 return nullptr;
269 }
270 return result;
271 }
272
TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2 * gauge)273 void TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2* gauge) {
274 delete gauge;
275 }
276
TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2 * gauge,const char * label1,const char * label2)277 TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge2(
278 TFE_MonitoringIntGauge2* gauge, const char* label1, const char* label2) {
279 return static_cast<TFE_MonitoringIntGaugeCell*>(
280 static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
281 }
282
TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell * cell,const char * value)283 void TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell* cell,
284 const char* value) {
285 cell->cell.Set({value});
286 }
287
TFE_MonitoringStringGaugeCellValue(TFE_MonitoringStringGaugeCell * cell,TF_Buffer * buf)288 const void TFE_MonitoringStringGaugeCellValue(
289 TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf) {
290 tensorflow::string value = cell->cell.value();
291 void* data = tensorflow::port::Malloc(value.length());
292 value.copy(static_cast<char*>(data), value.length(), 0);
293 buf->data = data;
294 buf->length = value.length();
295 buf->data_deallocator = [](void* data, size_t length) {
296 tensorflow::port::Free(data);
297 };
298 }
299
TFE_MonitoringNewStringGauge0(const char * name,TF_Status * status,const char * description)300 TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
301 const char* name, TF_Status* status, const char* description) {
302 auto* result = new TFE_MonitoringStringGauge0({name, description});
303 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
304 if (!result->gauge->GetStatus().ok()) {
305 delete result;
306 return nullptr;
307 }
308 return result;
309 }
310
TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0 * gauge)311 void TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0* gauge) {
312 delete gauge;
313 }
314
TFE_MonitoringGetCellStringGauge0(TFE_MonitoringStringGauge0 * gauge)315 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge0(
316 TFE_MonitoringStringGauge0* gauge) {
317 return static_cast<TFE_MonitoringStringGaugeCell*>(
318 static_cast<void*>(gauge->gauge->GetCell()));
319 }
320
TFE_MonitoringNewStringGauge1(const char * name,TF_Status * status,const char * description,const char * label1)321 TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
322 const char* name, TF_Status* status, const char* description,
323 const char* label1) {
324 auto* result = new TFE_MonitoringStringGauge1({name, description, label1});
325 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
326 if (!result->gauge->GetStatus().ok()) {
327 delete result;
328 return nullptr;
329 }
330 return result;
331 }
332
TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1 * gauge)333 void TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1* gauge) {
334 delete gauge;
335 }
336
TFE_MonitoringGetCellStringGauge1(TFE_MonitoringStringGauge1 * gauge,const char * label1)337 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge1(
338 TFE_MonitoringStringGauge1* gauge, const char* label1) {
339 return static_cast<TFE_MonitoringStringGaugeCell*>(
340 static_cast<void*>(gauge->gauge->GetCell(label1)));
341 }
342
TFE_MonitoringNewStringGauge2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)343 TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
344 const char* name, TF_Status* status, const char* description,
345 const char* label1, const char* label2) {
346 auto* result =
347 new TFE_MonitoringStringGauge2({name, description, label1, label2});
348 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
349 if (!result->gauge->GetStatus().ok()) {
350 delete result;
351 return nullptr;
352 }
353 return result;
354 }
355
TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2 * gauge)356 void TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2* gauge) {
357 delete gauge;
358 }
359
TFE_MonitoringGetCellStringGauge2(TFE_MonitoringStringGauge2 * gauge,const char * label1,const char * label2)360 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge2(
361 TFE_MonitoringStringGauge2* gauge, const char* label1, const char* label2) {
362 return static_cast<TFE_MonitoringStringGaugeCell*>(
363 static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
364 }
365
TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell * cell,bool value)366 void TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell* cell,
367 bool value) {
368 cell->cell.Set(value);
369 }
370
TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell * cell)371 bool TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell* cell) {
372 return cell->cell.value();
373 }
374
TFE_MonitoringNewBoolGauge0(const char * name,TF_Status * status,const char * description)375 TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(const char* name,
376 TF_Status* status,
377 const char* description) {
378 auto* result = new TFE_MonitoringBoolGauge0({name, description});
379 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
380 if (!result->gauge->GetStatus().ok()) {
381 delete result;
382 return nullptr;
383 }
384 return result;
385 }
386
TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0 * gauge)387 void TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0* gauge) {
388 delete gauge;
389 }
390
TFE_MonitoringGetCellBoolGauge0(TFE_MonitoringBoolGauge0 * gauge)391 TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge0(
392 TFE_MonitoringBoolGauge0* gauge) {
393 return static_cast<TFE_MonitoringBoolGaugeCell*>(
394 static_cast<void*>(gauge->gauge->GetCell()));
395 }
396
TFE_MonitoringNewBoolGauge1(const char * name,TF_Status * status,const char * description,const char * label1)397 TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(const char* name,
398 TF_Status* status,
399 const char* description,
400 const char* label1) {
401 auto* result = new TFE_MonitoringBoolGauge1({name, description, label1});
402 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
403 if (!result->gauge->GetStatus().ok()) {
404 delete result;
405 return nullptr;
406 }
407 return result;
408 }
409
TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1 * gauge)410 void TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1* gauge) {
411 delete gauge;
412 }
413
TFE_MonitoringGetCellBoolGauge1(TFE_MonitoringBoolGauge1 * gauge,const char * label1)414 TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge1(
415 TFE_MonitoringBoolGauge1* gauge, const char* label1) {
416 return static_cast<TFE_MonitoringBoolGaugeCell*>(
417 static_cast<void*>(gauge->gauge->GetCell(label1)));
418 }
419
TFE_MonitoringNewBoolGauge2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)420 TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(const char* name,
421 TF_Status* status,
422 const char* description,
423 const char* label1,
424 const char* label2) {
425 auto* result =
426 new TFE_MonitoringBoolGauge2({name, description, label1, label2});
427 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
428 if (!result->gauge->GetStatus().ok()) {
429 delete result;
430 return nullptr;
431 }
432 return result;
433 }
434
TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2 * gauge)435 void TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2* gauge) {
436 delete gauge;
437 }
438
TFE_MonitoringGetCellBoolGauge2(TFE_MonitoringBoolGauge2 * gauge,const char * label1,const char * label2)439 TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge2(
440 TFE_MonitoringBoolGauge2* gauge, const char* label1, const char* label2) {
441 return static_cast<TFE_MonitoringBoolGaugeCell*>(
442 static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
443 }
444
TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell * cell,double value)445 void TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell* cell,
446 double value) {
447 cell->cell.Add(value);
448 }
449
TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell * cell,TF_Buffer * buf)450 void TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell* cell,
451 TF_Buffer* buf) {
452 string content;
453 cell->cell.value().SerializeToString(&content);
454 void* data = tensorflow::port::Malloc(content.length());
455 content.copy(static_cast<char*>(data), content.length(), 0);
456 buf->data = data;
457 buf->length = content.length();
458 buf->data_deallocator = [](void* data, size_t length) {
459 tensorflow::port::Free(data);
460 };
461 }
462
TFE_MonitoringNewExponentialBuckets(double scale,double growth_factor,int bucket_count)463 TFE_MonitoringBuckets* TFE_MonitoringNewExponentialBuckets(double scale,
464 double growth_factor,
465 int bucket_count) {
466 return new TFE_MonitoringBuckets([scale, growth_factor, bucket_count]() {
467 return tensorflow::monitoring::Buckets::Exponential(scale, growth_factor,
468 bucket_count);
469 });
470 }
471
TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets * buckets)472 void TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets* buckets) {
473 delete buckets;
474 }
475
TFE_MonitoringNewSampler0(const char * name,TFE_MonitoringBuckets * buckets,TF_Status * status,const char * description)476 TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
477 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
478 const char* description) {
479 auto* result = new TFE_MonitoringSampler0(
480 {name, buckets->create_buckets(), description});
481 Set_TF_Status_from_Status(status, result->sampler->GetStatus());
482 if (!result->sampler->GetStatus().ok()) {
483 delete result;
484 return nullptr;
485 }
486 return result;
487 }
488
TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0 * sampler)489 void TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0* sampler) {
490 delete sampler;
491 }
492
TFE_MonitoringGetCellSampler0(TFE_MonitoringSampler0 * sampler)493 TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
494 TFE_MonitoringSampler0* sampler) {
495 return static_cast<TFE_MonitoringSamplerCell*>(
496 static_cast<void*>(sampler->sampler->GetCell()));
497 }
498
TFE_MonitoringNewSampler1(const char * name,TFE_MonitoringBuckets * buckets,TF_Status * status,const char * description,const char * label1)499 TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
500 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
501 const char* description, const char* label1) {
502 auto* result = new TFE_MonitoringSampler1(
503 {name, buckets->create_buckets(), description, label1});
504 Set_TF_Status_from_Status(status, result->sampler->GetStatus());
505 if (!result->sampler->GetStatus().ok()) {
506 delete result;
507 return nullptr;
508 }
509 return result;
510 }
511
TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1 * sampler)512 void TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1* sampler) {
513 delete sampler;
514 }
515
TFE_MonitoringGetCellSampler1(TFE_MonitoringSampler1 * sampler,const char * label1)516 TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
517 TFE_MonitoringSampler1* sampler, const char* label1) {
518 return static_cast<TFE_MonitoringSamplerCell*>(
519 static_cast<void*>(sampler->sampler->GetCell(label1)));
520 }
521
TFE_MonitoringNewSampler2(const char * name,TFE_MonitoringBuckets * buckets,TF_Status * status,const char * description,const char * label1,const char * label2)522 TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
523 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
524 const char* description, const char* label1, const char* label2) {
525 auto* result = new TFE_MonitoringSampler2(
526 {name, buckets->create_buckets(), description, label1, label2});
527 Set_TF_Status_from_Status(status, result->sampler->GetStatus());
528 if (!result->sampler->GetStatus().ok()) {
529 delete result;
530 return nullptr;
531 }
532 return result;
533 }
534
TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2 * sampler)535 void TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2* sampler) {
536 delete sampler;
537 }
538
TFE_MonitoringGetCellSampler2(TFE_MonitoringSampler2 * sampler,const char * label1,const char * label2)539 TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
540 TFE_MonitoringSampler2* sampler, const char* label1, const char* label2) {
541 return static_cast<TFE_MonitoringSamplerCell*>(
542 static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
543 }
544
TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions * options,TFE_ContextMirroringPolicy policy)545 void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
546 TFE_ContextMirroringPolicy policy) {
547 options->mirroring_policy = policy;
548 }
549
TFE_ContextSetThreadLocalMirroringPolicy(TFE_Context * ctx,TFE_ContextMirroringPolicy policy)550 void TFE_ContextSetThreadLocalMirroringPolicy(
551 TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
552 ctx->context->SetThreadLocalMirroringPolicy(
553 static_cast<tensorflow::ContextMirroringPolicy>(policy));
554 }
555
556 // Note: this function looks up a thread local policy. So it should be called in
557 // the appropriate client thread. In particular, in async mode, it may not be
558 // safe to call this function from the async EagerExecutor threads.
TFE_ContextGetMirroringPolicy(TFE_Context * ctx)559 extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
560 TFE_Context* ctx) {
561 return static_cast<TFE_ContextMirroringPolicy>(
562 ctx->context->GetMirroringPolicy());
563 }
564
TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions * options,bool lazy_copy)565 void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
566 bool lazy_copy) {
567 options->lazy_remote_inputs_copy = lazy_copy;
568 }
569
TFE_NewCancellationManager()570 TFE_CancellationManager* TFE_NewCancellationManager() {
571 return new TFE_CancellationManager;
572 }
573
TFE_CancellationManagerStartCancel(TFE_CancellationManager * cancellation_manager)574 void TFE_CancellationManagerStartCancel(
575 TFE_CancellationManager* cancellation_manager) {
576 cancellation_manager->cancellation_manager.StartCancel();
577 }
578
TFE_CancellationManagerIsCancelled(TFE_CancellationManager * cancellation_manager)579 bool TFE_CancellationManagerIsCancelled(
580 TFE_CancellationManager* cancellation_manager) {
581 return cancellation_manager->cancellation_manager.IsCancelled();
582 }
583
TFE_DeleteCancellationManager(TFE_CancellationManager * cancellation_manager)584 void TFE_DeleteCancellationManager(
585 TFE_CancellationManager* cancellation_manager) {
586 delete cancellation_manager;
587 }
588
TFE_OpSetCancellationManager(TFE_Op * op,TFE_CancellationManager * cancellation_manager,TF_Status * status)589 void TFE_OpSetCancellationManager(TFE_Op* op,
590 TFE_CancellationManager* cancellation_manager,
591 TF_Status* status) {
592 op->operation.SetCancellationManager(
593 &cancellation_manager->cancellation_manager);
594 }
595
TFE_NewExecutor(bool is_async)596 TFE_Executor* TFE_NewExecutor(bool is_async) {
597 return new TFE_Executor(is_async);
598 }
599
TFE_DeleteExecutor(TFE_Executor * executor)600 void TFE_DeleteExecutor(TFE_Executor* executor) { delete executor; }
601
TFE_ExecutorIsAsync(TFE_Executor * executor)602 bool TFE_ExecutorIsAsync(TFE_Executor* executor) {
603 return executor->executor()->Async();
604 }
605
TFE_ExecutorWaitForAllPendingNodes(TFE_Executor * executor,TF_Status * status)606 void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor* executor,
607 TF_Status* status) {
608 status->status = executor->executor()->WaitForAllPendingNodes();
609 }
610
TFE_ExecutorClearError(TFE_Executor * executor)611 void TFE_ExecutorClearError(TFE_Executor* executor) {
612 executor->executor()->ClearError();
613 }
614
TFE_ContextSetExecutorForThread(TFE_Context * ctx,TFE_Executor * executor)615 void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
616 ctx->context->SetExecutorForThread(executor->executor());
617 }
618
TFE_ContextGetExecutorForThread(TFE_Context * ctx)619 TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
620 return new TFE_Executor(&ctx->context->Executor());
621 }
622
TFE_HostAddressSpace(TFE_Context * ctx,TF_Buffer * buf)623 void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
624 auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
625 ctx->context->HostCPU()->parsed_name());
626 auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
627 void* data = tensorflow::port::Malloc(str.length());
628 str.copy(static_cast<char*>(data), str.length(), 0);
629 buf->data = data;
630 buf->length = str.length();
631 buf->data_deallocator = [](void* data, size_t length) {
632 tensorflow::port::Free(data);
633 };
634 }
635