1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api_experimental.h"
17
18 #include <vector>
19
20 #include "absl/strings/match.h"
21 #include "tensorflow/c/c_api.h"
22 #include "tensorflow/c/eager/c_api_internal.h"
23 #include "tensorflow/c/eager/tfe_context_internal.h"
24 #include "tensorflow/c/eager/tfe_op_internal.h"
25 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
26 #include "tensorflow/c/tf_status_helper.h"
27 #include "tensorflow/core/common_runtime/composite_device.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
30 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
31 #include "tensorflow/core/lib/monitoring/counter.h"
32 #include "tensorflow/core/lib/monitoring/gauge.h"
33 #include "tensorflow/core/lib/monitoring/sampler.h"
34 #include "tensorflow/core/platform/casts.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/strcat.h"
38
39 using tensorflow::string;
40
TFE_OpReset(TFE_Op * op_to_reset,const char * op_or_function_name,const char * raw_device_name,TF_Status * status)41 void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
42 const char* raw_device_name, TF_Status* status) {
43 if (op_to_reset) {
44 tensorflow::ImmediateExecutionOperation* op =
45 tensorflow::unwrap(op_to_reset);
46 op->Clear();
47 status->status = op->Reset(op_or_function_name, raw_device_name);
48 } else {
49 TF_SetStatus(status, TF_INVALID_ARGUMENT,
50 "op_to_reset should not be nullptr");
51 }
52 }
53
TFE_ContextEnableGraphCollection(TFE_Context * ctx)54 void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
55 tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
56 }
57
TFE_ContextDisableGraphCollection(TFE_Context * ctx)58 void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
59 tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
60 }
61
TFE_GetContextId(TFE_Context * ctx)62 uint64_t TFE_GetContextId(TFE_Context* ctx) {
63 tensorflow::EagerContext* context =
64 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
65 return context->GetContextId();
66 }
67
TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell * cell,int64_t value)68 void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
69 int64_t value) {
70 cell->cell.IncrementBy(value);
71 }
72
TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell * cell)73 int64_t TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell* cell) {
74 return cell->cell.value();
75 }
76
TFE_MonitoringNewCounter0(const char * name,TF_Status * status,const char * description)77 TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(const char* name,
78 TF_Status* status,
79 const char* description) {
80 auto* result = new TFE_MonitoringCounter0({name, description});
81 Set_TF_Status_from_Status(status, result->counter->GetStatus());
82 if (!result->counter->GetStatus().ok()) {
83 delete result;
84 return nullptr;
85 }
86 return result;
87 }
88
TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0 * counter)89 void TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0* counter) {
90 delete counter;
91 }
92
TFE_MonitoringGetCellCounter0(TFE_MonitoringCounter0 * counter)93 TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
94 TFE_MonitoringCounter0* counter) {
95 return static_cast<TFE_MonitoringCounterCell*>(
96 static_cast<void*>(counter->counter->GetCell()));
97 }
98
TFE_MonitoringNewCounter1(const char * name,TF_Status * status,const char * description,const char * label1)99 TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(const char* name,
100 TF_Status* status,
101 const char* description,
102 const char* label1) {
103 auto* result = new TFE_MonitoringCounter1({name, description, label1});
104 Set_TF_Status_from_Status(status, result->counter->GetStatus());
105 if (!result->counter->GetStatus().ok()) {
106 delete result;
107 return nullptr;
108 }
109 return result;
110 }
111
TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1 * counter)112 void TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1* counter) {
113 delete counter;
114 }
115
TFE_MonitoringGetCellCounter1(TFE_MonitoringCounter1 * counter,const char * label1)116 TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
117 TFE_MonitoringCounter1* counter, const char* label1) {
118 return static_cast<TFE_MonitoringCounterCell*>(
119 static_cast<void*>(counter->counter->GetCell(label1)));
120 }
121
TFE_MonitoringNewCounter2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)122 TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(const char* name,
123 TF_Status* status,
124 const char* description,
125 const char* label1,
126 const char* label2) {
127 auto* result =
128 new TFE_MonitoringCounter2({name, description, label1, label2});
129 Set_TF_Status_from_Status(status, result->counter->GetStatus());
130 if (!result->counter->GetStatus().ok()) {
131 delete result;
132 return nullptr;
133 }
134 return result;
135 }
136
TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2 * counter)137 void TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2* counter) {
138 delete counter;
139 }
140
TFE_MonitoringGetCellCounter2(TFE_MonitoringCounter2 * counter,const char * label1,const char * label2)141 TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
142 TFE_MonitoringCounter2* counter, const char* label1, const char* label2) {
143 return static_cast<TFE_MonitoringCounterCell*>(
144 static_cast<void*>(counter->counter->GetCell(label1, label2)));
145 }
146
TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell * cell,int64_t value)147 void TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell* cell,
148 int64_t value) {
149 cell->cell.Set(value);
150 }
151
TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell * cell)152 int64_t TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell* cell) {
153 return cell->cell.value();
154 }
155
TFE_MonitoringNewIntGauge0(const char * name,TF_Status * status,const char * description)156 TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(const char* name,
157 TF_Status* status,
158 const char* description) {
159 auto* result = new TFE_MonitoringIntGauge0({name, description});
160 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
161 if (!result->gauge->GetStatus().ok()) {
162 delete result;
163 return nullptr;
164 }
165 return result;
166 }
167
TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0 * gauge)168 void TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0* gauge) {
169 delete gauge;
170 }
171
TFE_MonitoringGetCellIntGauge0(TFE_MonitoringIntGauge0 * gauge)172 TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge0(
173 TFE_MonitoringIntGauge0* gauge) {
174 return static_cast<TFE_MonitoringIntGaugeCell*>(
175 static_cast<void*>(gauge->gauge->GetCell()));
176 }
177
TFE_MonitoringNewIntGauge1(const char * name,TF_Status * status,const char * description,const char * label1)178 TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(const char* name,
179 TF_Status* status,
180 const char* description,
181 const char* label1) {
182 auto* result = new TFE_MonitoringIntGauge1({name, description, label1});
183 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
184 if (!result->gauge->GetStatus().ok()) {
185 delete result;
186 return nullptr;
187 }
188 return result;
189 }
190
TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1 * gauge)191 void TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1* gauge) {
192 delete gauge;
193 }
194
TFE_MonitoringGetCellIntGauge1(TFE_MonitoringIntGauge1 * gauge,const char * label1)195 TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge1(
196 TFE_MonitoringIntGauge1* gauge, const char* label1) {
197 return static_cast<TFE_MonitoringIntGaugeCell*>(
198 static_cast<void*>(gauge->gauge->GetCell(label1)));
199 }
200
TFE_MonitoringNewIntGauge2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)201 TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(const char* name,
202 TF_Status* status,
203 const char* description,
204 const char* label1,
205 const char* label2) {
206 auto* result =
207 new TFE_MonitoringIntGauge2({name, description, label1, label2});
208 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
209 if (!result->gauge->GetStatus().ok()) {
210 delete result;
211 return nullptr;
212 }
213 return result;
214 }
215
TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2 * gauge)216 void TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2* gauge) {
217 delete gauge;
218 }
219
TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2 * gauge,const char * label1,const char * label2)220 TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge2(
221 TFE_MonitoringIntGauge2* gauge, const char* label1, const char* label2) {
222 return static_cast<TFE_MonitoringIntGaugeCell*>(
223 static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
224 }
225
TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell * cell,const char * value)226 void TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell* cell,
227 const char* value) {
228 cell->cell.Set({value});
229 }
230
TFE_MonitoringStringGaugeCellValue(TFE_MonitoringStringGaugeCell * cell,TF_Buffer * buf)231 const void TFE_MonitoringStringGaugeCellValue(
232 TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf) {
233 tensorflow::string value = cell->cell.value();
234 void* data = tensorflow::port::Malloc(value.length());
235 value.copy(static_cast<char*>(data), value.length(), 0);
236 buf->data = data;
237 buf->length = value.length();
238 buf->data_deallocator = [](void* data, size_t length) {
239 tensorflow::port::Free(data);
240 };
241 }
242
TFE_MonitoringNewStringGauge0(const char * name,TF_Status * status,const char * description)243 TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
244 const char* name, TF_Status* status, const char* description) {
245 auto* result = new TFE_MonitoringStringGauge0({name, description});
246 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
247 if (!result->gauge->GetStatus().ok()) {
248 delete result;
249 return nullptr;
250 }
251 return result;
252 }
253
TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0 * gauge)254 void TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0* gauge) {
255 delete gauge;
256 }
257
TFE_MonitoringGetCellStringGauge0(TFE_MonitoringStringGauge0 * gauge)258 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge0(
259 TFE_MonitoringStringGauge0* gauge) {
260 return static_cast<TFE_MonitoringStringGaugeCell*>(
261 static_cast<void*>(gauge->gauge->GetCell()));
262 }
263
TFE_MonitoringNewStringGauge1(const char * name,TF_Status * status,const char * description,const char * label1)264 TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
265 const char* name, TF_Status* status, const char* description,
266 const char* label1) {
267 auto* result = new TFE_MonitoringStringGauge1({name, description, label1});
268 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
269 if (!result->gauge->GetStatus().ok()) {
270 delete result;
271 return nullptr;
272 }
273 return result;
274 }
275
TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1 * gauge)276 void TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1* gauge) {
277 delete gauge;
278 }
279
TFE_MonitoringGetCellStringGauge1(TFE_MonitoringStringGauge1 * gauge,const char * label1)280 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge1(
281 TFE_MonitoringStringGauge1* gauge, const char* label1) {
282 return static_cast<TFE_MonitoringStringGaugeCell*>(
283 static_cast<void*>(gauge->gauge->GetCell(label1)));
284 }
285
TFE_MonitoringNewStringGauge2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)286 TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
287 const char* name, TF_Status* status, const char* description,
288 const char* label1, const char* label2) {
289 auto* result =
290 new TFE_MonitoringStringGauge2({name, description, label1, label2});
291 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
292 if (!result->gauge->GetStatus().ok()) {
293 delete result;
294 return nullptr;
295 }
296 return result;
297 }
298
TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2 * gauge)299 void TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2* gauge) {
300 delete gauge;
301 }
302
TFE_MonitoringGetCellStringGauge2(TFE_MonitoringStringGauge2 * gauge,const char * label1,const char * label2)303 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge2(
304 TFE_MonitoringStringGauge2* gauge, const char* label1, const char* label2) {
305 return static_cast<TFE_MonitoringStringGaugeCell*>(
306 static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
307 }
308
TFE_MonitoringNewStringGauge3(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2,const char * label3)309 TFE_MonitoringStringGauge3* TFE_MonitoringNewStringGauge3(
310 const char* name, TF_Status* status, const char* description,
311 const char* label1, const char* label2, const char* label3) {
312 auto* result = new TFE_MonitoringStringGauge3(
313 {name, description, label1, label2, label3});
314 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
315 if (!result->gauge->GetStatus().ok()) {
316 delete result;
317 return nullptr;
318 }
319 return result;
320 }
321
TFE_MonitoringDeleteStringGauge3(TFE_MonitoringStringGauge3 * gauge)322 void TFE_MonitoringDeleteStringGauge3(TFE_MonitoringStringGauge3* gauge) {
323 delete gauge;
324 }
325
TFE_MonitoringGetCellStringGauge3(TFE_MonitoringStringGauge3 * gauge,const char * label1,const char * label2,const char * label3)326 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge3(
327 TFE_MonitoringStringGauge3* gauge, const char* label1, const char* label2,
328 const char* label3) {
329 return static_cast<TFE_MonitoringStringGaugeCell*>(
330 static_cast<void*>(gauge->gauge->GetCell(label1, label2, label3)));
331 }
332
TFE_MonitoringNewStringGauge4(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2,const char * label3,const char * label4)333 TFE_MonitoringStringGauge4* TFE_MonitoringNewStringGauge4(
334 const char* name, TF_Status* status, const char* description,
335 const char* label1, const char* label2, const char* label3,
336 const char* label4) {
337 auto* result = new TFE_MonitoringStringGauge4(
338 {name, description, label1, label2, label3, label4});
339 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
340 if (!result->gauge->GetStatus().ok()) {
341 delete result;
342 return nullptr;
343 }
344 return result;
345 }
346
TFE_MonitoringDeleteStringGauge4(TFE_MonitoringStringGauge4 * gauge)347 void TFE_MonitoringDeleteStringGauge4(TFE_MonitoringStringGauge4* gauge) {
348 delete gauge;
349 }
350
TFE_MonitoringGetCellStringGauge4(TFE_MonitoringStringGauge4 * gauge,const char * label1,const char * label2,const char * label3,const char * label4)351 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge4(
352 TFE_MonitoringStringGauge4* gauge, const char* label1, const char* label2,
353 const char* label3, const char* label4) {
354 return static_cast<TFE_MonitoringStringGaugeCell*>(static_cast<void*>(
355 gauge->gauge->GetCell(label1, label2, label3, label4)));
356 }
357
TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell * cell,bool value)358 void TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell* cell,
359 bool value) {
360 cell->cell.Set(value);
361 }
362
TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell * cell)363 bool TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell* cell) {
364 return cell->cell.value();
365 }
366
TFE_MonitoringNewBoolGauge0(const char * name,TF_Status * status,const char * description)367 TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(const char* name,
368 TF_Status* status,
369 const char* description) {
370 auto* result = new TFE_MonitoringBoolGauge0({name, description});
371 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
372 if (!result->gauge->GetStatus().ok()) {
373 delete result;
374 return nullptr;
375 }
376 return result;
377 }
378
TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0 * gauge)379 void TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0* gauge) {
380 delete gauge;
381 }
382
TFE_MonitoringGetCellBoolGauge0(TFE_MonitoringBoolGauge0 * gauge)383 TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge0(
384 TFE_MonitoringBoolGauge0* gauge) {
385 return static_cast<TFE_MonitoringBoolGaugeCell*>(
386 static_cast<void*>(gauge->gauge->GetCell()));
387 }
388
TFE_MonitoringNewBoolGauge1(const char * name,TF_Status * status,const char * description,const char * label1)389 TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(const char* name,
390 TF_Status* status,
391 const char* description,
392 const char* label1) {
393 auto* result = new TFE_MonitoringBoolGauge1({name, description, label1});
394 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
395 if (!result->gauge->GetStatus().ok()) {
396 delete result;
397 return nullptr;
398 }
399 return result;
400 }
401
TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1 * gauge)402 void TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1* gauge) {
403 delete gauge;
404 }
405
TFE_MonitoringGetCellBoolGauge1(TFE_MonitoringBoolGauge1 * gauge,const char * label1)406 TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge1(
407 TFE_MonitoringBoolGauge1* gauge, const char* label1) {
408 return static_cast<TFE_MonitoringBoolGaugeCell*>(
409 static_cast<void*>(gauge->gauge->GetCell(label1)));
410 }
411
TFE_MonitoringNewBoolGauge2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)412 TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(const char* name,
413 TF_Status* status,
414 const char* description,
415 const char* label1,
416 const char* label2) {
417 auto* result =
418 new TFE_MonitoringBoolGauge2({name, description, label1, label2});
419 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
420 if (!result->gauge->GetStatus().ok()) {
421 delete result;
422 return nullptr;
423 }
424 return result;
425 }
426
TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2 * gauge)427 void TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2* gauge) {
428 delete gauge;
429 }
430
TFE_MonitoringGetCellBoolGauge2(TFE_MonitoringBoolGauge2 * gauge,const char * label1,const char * label2)431 TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge2(
432 TFE_MonitoringBoolGauge2* gauge, const char* label1, const char* label2) {
433 return static_cast<TFE_MonitoringBoolGaugeCell*>(
434 static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
435 }
436
TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell * cell,double value)437 void TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell* cell,
438 double value) {
439 cell->cell.Add(value);
440 }
441
TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell * cell,TF_Buffer * buf)442 void TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell* cell,
443 TF_Buffer* buf) {
444 string content;
445 cell->cell.value().SerializeToString(&content);
446 void* data = tensorflow::port::Malloc(content.length());
447 content.copy(static_cast<char*>(data), content.length(), 0);
448 buf->data = data;
449 buf->length = content.length();
450 buf->data_deallocator = [](void* data, size_t length) {
451 tensorflow::port::Free(data);
452 };
453 }
454
TFE_MonitoringNewExponentialBuckets(double scale,double growth_factor,int bucket_count)455 TFE_MonitoringBuckets* TFE_MonitoringNewExponentialBuckets(double scale,
456 double growth_factor,
457 int bucket_count) {
458 return new TFE_MonitoringBuckets([scale, growth_factor, bucket_count]() {
459 return tensorflow::monitoring::Buckets::Exponential(scale, growth_factor,
460 bucket_count);
461 });
462 }
463
TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets * buckets)464 void TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets* buckets) {
465 delete buckets;
466 }
467
TFE_MonitoringNewSampler0(const char * name,TFE_MonitoringBuckets * buckets,TF_Status * status,const char * description)468 TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
469 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
470 const char* description) {
471 auto* result = new TFE_MonitoringSampler0(
472 {name, buckets->create_buckets(), description});
473 Set_TF_Status_from_Status(status, result->sampler->GetStatus());
474 if (!result->sampler->GetStatus().ok()) {
475 delete result;
476 return nullptr;
477 }
478 return result;
479 }
480
TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0 * sampler)481 void TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0* sampler) {
482 delete sampler;
483 }
484
TFE_MonitoringGetCellSampler0(TFE_MonitoringSampler0 * sampler)485 TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
486 TFE_MonitoringSampler0* sampler) {
487 return static_cast<TFE_MonitoringSamplerCell*>(
488 static_cast<void*>(sampler->sampler->GetCell()));
489 }
490
TFE_MonitoringNewSampler1(const char * name,TFE_MonitoringBuckets * buckets,TF_Status * status,const char * description,const char * label1)491 TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
492 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
493 const char* description, const char* label1) {
494 auto* result = new TFE_MonitoringSampler1(
495 {name, buckets->create_buckets(), description, label1});
496 Set_TF_Status_from_Status(status, result->sampler->GetStatus());
497 if (!result->sampler->GetStatus().ok()) {
498 delete result;
499 return nullptr;
500 }
501 return result;
502 }
503
TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1 * sampler)504 void TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1* sampler) {
505 delete sampler;
506 }
507
TFE_MonitoringGetCellSampler1(TFE_MonitoringSampler1 * sampler,const char * label1)508 TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
509 TFE_MonitoringSampler1* sampler, const char* label1) {
510 return static_cast<TFE_MonitoringSamplerCell*>(
511 static_cast<void*>(sampler->sampler->GetCell(label1)));
512 }
513
TFE_MonitoringNewSampler2(const char * name,TFE_MonitoringBuckets * buckets,TF_Status * status,const char * description,const char * label1,const char * label2)514 TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
515 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
516 const char* description, const char* label1, const char* label2) {
517 auto* result = new TFE_MonitoringSampler2(
518 {name, buckets->create_buckets(), description, label1, label2});
519 Set_TF_Status_from_Status(status, result->sampler->GetStatus());
520 if (!result->sampler->GetStatus().ok()) {
521 delete result;
522 return nullptr;
523 }
524 return result;
525 }
526
TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2 * sampler)527 void TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2* sampler) {
528 delete sampler;
529 }
530
TFE_MonitoringGetCellSampler2(TFE_MonitoringSampler2 * sampler,const char * label1,const char * label2)531 TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
532 TFE_MonitoringSampler2* sampler, const char* label1, const char* label2) {
533 return static_cast<TFE_MonitoringSamplerCell*>(
534 static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
535 }
536
TFE_ContextOptionsSetTfrt(TFE_ContextOptions * options,bool use_tfrt)537 void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) {
538 options->use_tfrt = use_tfrt;
539 }
540
TFE_ContextOptionsSetTfrtDistributedRuntime(TFE_ContextOptions * options,bool use_tfrt_distributed_runtime)541 void TFE_ContextOptionsSetTfrtDistributedRuntime(
542 TFE_ContextOptions* options, bool use_tfrt_distributed_runtime) {
543 options->use_tfrt_distributed_runtime = use_tfrt_distributed_runtime;
544 }
545
TFE_NewCancellationManager()546 TFE_CancellationManager* TFE_NewCancellationManager() {
547 return tensorflow::wrap(new tensorflow::CancellationManager);
548 }
549
TFE_CancellationManagerStartCancel(TFE_CancellationManager * cancellation_manager)550 void TFE_CancellationManagerStartCancel(
551 TFE_CancellationManager* cancellation_manager) {
552 tensorflow::unwrap(cancellation_manager)->StartCancel();
553 }
554
TFE_CancellationManagerIsCancelled(TFE_CancellationManager * cancellation_manager)555 bool TFE_CancellationManagerIsCancelled(
556 TFE_CancellationManager* cancellation_manager) {
557 return tensorflow::unwrap(cancellation_manager)->IsCancelled();
558 }
559
TFE_DeleteCancellationManager(TFE_CancellationManager * cancellation_manager)560 void TFE_DeleteCancellationManager(
561 TFE_CancellationManager* cancellation_manager) {
562 delete tensorflow::unwrap(cancellation_manager);
563 }
564
TFE_OpSetCancellationManager(TFE_Op * op,TFE_CancellationManager * cancellation_manager,TF_Status * status)565 void TFE_OpSetCancellationManager(TFE_Op* op,
566 TFE_CancellationManager* cancellation_manager,
567 TF_Status* status) {
568 tensorflow::unwrap(op)->SetCancellationManager(
569 tensorflow::unwrap(cancellation_manager));
570 status->status = ::tensorflow::OkStatus();
571 }
572
TFE_NewExecutor(bool is_async,bool enable_streaming_enqueue)573 TFE_Executor* TFE_NewExecutor(bool is_async, bool enable_streaming_enqueue) {
574 return new TFE_Executor(is_async, enable_streaming_enqueue);
575 }
576
TFE_DeleteExecutor(TFE_Executor * executor)577 void TFE_DeleteExecutor(TFE_Executor* executor) { delete executor; }
578
TFE_ExecutorIsAsync(TFE_Executor * executor)579 bool TFE_ExecutorIsAsync(TFE_Executor* executor) {
580 return executor->executor()->Async();
581 }
582
TFE_ExecutorWaitForAllPendingNodes(TFE_Executor * executor,TF_Status * status)583 void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor* executor,
584 TF_Status* status) {
585 status->status = executor->executor()->WaitForAllPendingNodes();
586 }
587
TFE_ExecutorClearError(TFE_Executor * executor)588 void TFE_ExecutorClearError(TFE_Executor* executor) {
589 executor->executor()->ClearError();
590 }
591
TFE_ContextSetExecutorForThread(TFE_Context * ctx,TFE_Executor * executor)592 void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
593 tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
594 }
595
TFE_ContextGetExecutorForThread(TFE_Context * ctx)596 TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
597 return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
598 }
599
TFE_HostAddressSpace(TFE_Context * ctx,TF_Buffer * buf)600 void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
601 auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
602 tensorflow::unwrap(ctx)->HostCPUParsedName());
603 auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
604 void* data = tensorflow::port::Malloc(str.length());
605 str.copy(static_cast<char*>(data), str.length(), 0);
606 buf->data = data;
607 buf->length = str.length();
608 buf->data_deallocator = [](void* data, size_t length) {
609 tensorflow::port::Free(data);
610 };
611 }
612
TFE_ContextGetFunctionDef(TFE_Context * ctx,const char * function_name,TF_Buffer * buf,TF_Status * status)613 void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
614 TF_Buffer* buf, TF_Status* status) {
615 auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
616 if (function_def == nullptr) {
617 status->status = tensorflow::errors::NotFound(
618 "Unable to find FunctionDef with name: ", function_name);
619 return;
620 }
621 string str = function_def->SerializeAsString();
622 void* data = tensorflow::port::Malloc(str.length());
623 str.copy(static_cast<char*>(data), str.length(), 0);
624 buf->data = data;
625 buf->length = str.length();
626 buf->data_deallocator = [](void* data, size_t length) {
627 tensorflow::port::Free(data);
628 };
629 status->status = ::tensorflow::OkStatus();
630 }
631
TFE_AllocateHostTensor(TFE_Context * ctx,TF_DataType dtype,const int64_t * dims,int num_dims,TF_Status * status)632 TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
633 const int64_t* dims, int num_dims,
634 TF_Status* status) {
635 std::vector<int64_t> dimvec(num_dims);
636 for (int i = 0; i < num_dims; ++i) {
637 dimvec[i] = static_cast<int64_t>(dims[i]);
638 }
639
640 if (ctx == nullptr) {
641 status->status = tensorflow::errors::InvalidArgument("Invalid Context");
642 return nullptr;
643 }
644
645 tensorflow::AbstractTensorInterface* t =
646 tensorflow::unwrap(ctx)->CreateTensor(
647 static_cast<tensorflow::DataType>(dtype), dimvec);
648
649 if (t == nullptr) {
650 status->status =
651 tensorflow::errors::InvalidArgument("Unsupported dtype: ", dtype);
652 return nullptr;
653 }
654
655 return new TF_Tensor{t};
656 }
657
TFE_NewTensorHandleFromTensor(TFE_Context * ctx,TF_Tensor * t,TF_Status * status)658 TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
659 TF_Status* status) {
660 return tensorflow::wrap(
661 tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
662 }
663
TFE_CreatePackedTensorHandle(TFE_Context * ctx,TFE_TensorHandle ** handles,int * num_handles,TF_Status * status)664 TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
665 TFE_TensorHandle** handles,
666 int* num_handles,
667 TF_Status* status) {
668 std::vector<tensorflow::TensorHandle*> tensor_handles;
669 tensor_handles.reserve(*num_handles);
670 for (int i = 0; i < *num_handles; ++i) {
671 tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
672 tensorflow::unwrap(handles[i]);
673 if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
674 // One of the inputs we're trying to pack is on a custom device. We'll let
675 // the first custom device we see handle all of the packing.
676 auto* custom_device_handle =
677 tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
678 unwrapped_handle);
679 tensorflow::ImmediateExecutionTensorHandle* result;
680 status->status = custom_device_handle->device()->Pack(
681 absl::Span<tensorflow::ImmediateExecutionTensorHandle*>(
682 tensorflow::unwrap(handles), *num_handles),
683 &result);
684 return tensorflow::wrap(result);
685 }
686 tensor_handles.push_back(
687 tensorflow::TensorHandleFromInterface(unwrapped_handle));
688 }
689 tensorflow::EagerContext* context =
690 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
691 tensorflow::TensorHandle* handle = nullptr;
692 status->status = tensorflow::TensorHandle::CreatePackedHandle(
693 std::move(tensor_handles), context, &handle);
694 return tensorflow::wrap(handle);
695 }
696
TFE_ContextSetSoftDevicePlacement(TFE_Context * ctx,unsigned char enable,TF_Status * status)697 void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
698 TF_Status* status) {
699 tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
700 }
701
TFE_ContextSetLogDevicePlacement(TFE_Context * ctx,unsigned char enable,TF_Status * status)702 void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
703 TF_Status* status) {
704 tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
705 }
706
TFE_ContextSetRunEagerOpAsFunction(TFE_Context * ctx,unsigned char enable,TF_Status * status)707 void TFE_ContextSetRunEagerOpAsFunction(TFE_Context* ctx, unsigned char enable,
708 TF_Status* status) {
709 tensorflow::unwrap(ctx)->SetRunEagerOpAsFunction(enable);
710 }
711
TFE_ContextSetJitCompileRewrite(TFE_Context * ctx,unsigned char enable,TF_Status * status)712 void TFE_ContextSetJitCompileRewrite(TFE_Context* ctx, unsigned char enable,
713 TF_Status* status) {
714 tensorflow::unwrap(ctx)->SetJitCompileRewrite(enable);
715 }
716
TFE_TensorHandleDeviceType(TFE_TensorHandle * h,TF_Status * status)717 const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) {
718 if (h == nullptr) {
719 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
720 return nullptr;
721 }
722 return tensorflow::unwrap(h)->DeviceType(&status->status);
723 }
724
TFE_TensorHandleDeviceID(TFE_TensorHandle * h,TF_Status * status)725 int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
726 if (h == nullptr) {
727 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
728 return -1;
729 }
730 return tensorflow::unwrap(h)->DeviceId(&status->status);
731 }
732
TFE_TensorHandleGetStatus(TFE_TensorHandle * h,TF_Status * status)733 TF_CAPI_EXPORT extern void TFE_TensorHandleGetStatus(TFE_TensorHandle* h,
734 TF_Status* status) {
735 status->status = tensorflow::unwrap(h)->TensorHandleStatus();
736 }
737
TFE_GetExecutedOpNames(TFE_Context * ctx,TF_Buffer * buf,TF_Status * status)738 void TFE_GetExecutedOpNames(TFE_Context* ctx, TF_Buffer* buf,
739 TF_Status* status) {
740 const std::vector<std::string>& op_names =
741 tensorflow::unwrap(ctx)->GetLoggedOpsTestonly();
742
743 std::ostringstream op_names_oss;
744 for (const auto& op : op_names) {
745 op_names_oss << op << ", ";
746 }
747 const std::string& op_names_str = op_names_oss.str();
748 void* data = tensorflow::port::Malloc(op_names_str.length());
749 op_names_str.copy(static_cast<char*>(data), op_names_str.length(), 0);
750 buf->data = data;
751 buf->length = op_names_str.length();
752 buf->data_deallocator = [](void* data, size_t length) {
753 tensorflow::port::Free(data);
754 };
755 status->status = ::tensorflow::OkStatus();
756 }
757
TFE_SetLogicalCpuDevices(TFE_Context * ctx,int num_cpus,const char * prefix,TF_Status * status)758 void TFE_SetLogicalCpuDevices(TFE_Context* ctx, int num_cpus,
759 const char* prefix, TF_Status* status) {
760 std::vector<std::unique_ptr<tensorflow::Device>> devices;
761
762 if (prefix == nullptr || strlen(prefix) == 0)
763 prefix = "/job:localhost/replica:0/task:0";
764
765 tensorflow::SessionOptions sess_options;
766 (*sess_options.config.mutable_device_count())["CPU"] = num_cpus;
767 status->status =
768 tensorflow::DeviceFactory::AddCpuDevices(sess_options, prefix, &devices);
769
770 // Remove the device that has the host device name since host device is alreay
771 // in an initialized context.
772 for (auto d = devices.begin(); d != devices.end();) {
773 if (absl::StrContains(d->get()->name(), "CPU:0")) {
774 d = devices.erase(d);
775 } else {
776 ++d;
777 }
778 }
779
780 status->status = tensorflow::unwrap(ctx)->AddDevices(std::move(devices));
781 }
782
TFE_InsertConfigKeyValue(TFE_Context * ctx,const char * key,const char * value,TF_Status * status)783 void TFE_InsertConfigKeyValue(TFE_Context* ctx, const char* key,
784 const char* value, TF_Status* status) {
785 tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
786 tensorflow::unwrap(ctx)->GetDistributedManager();
787 tensorflow::CoordinationServiceAgent* coord_agent =
788 dist_mgr->GetCoordinationServiceAgent();
789 if (coord_agent == nullptr) {
790 status->status = tensorflow::errors::FailedPrecondition(
791 "Coordination service agent is not enabled.");
792 return;
793 }
794 status->status = coord_agent->InsertKeyValue(key, value);
795 }
796
TFE_GetConfigKeyValue(TFE_Context * ctx,const char * key,TF_Buffer * value_buf,TF_Status * status)797 void TFE_GetConfigKeyValue(TFE_Context* ctx, const char* key,
798 TF_Buffer* value_buf, TF_Status* status) {
799 tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
800 tensorflow::unwrap(ctx)->GetDistributedManager();
801 tensorflow::CoordinationServiceAgent* coord_agent =
802 dist_mgr->GetCoordinationServiceAgent();
803 if (coord_agent == nullptr) {
804 status->status = tensorflow::errors::FailedPrecondition(
805 "Coordination service is not enabled.");
806 return;
807 }
808 auto status_or_value = coord_agent->GetKeyValue(key);
809 status->status = status_or_value.status();
810 if (!status_or_value.ok()) return;
811
812 const std::string& value_string = status_or_value.ValueOrDie();
813 void* data = tensorflow::port::Malloc(value_string.length());
814 value_string.copy(static_cast<char*>(data), value_string.length(), 0);
815 value_buf->data = data;
816 value_buf->length = value_string.length();
817 value_buf->data_deallocator = [](void* data, size_t length) {
818 tensorflow::port::Free(data);
819 };
820 }
821
TFE_DeleteConfigKeyValue(TFE_Context * ctx,const char * key,TF_Status * status)822 void TFE_DeleteConfigKeyValue(TFE_Context* ctx, const char* key,
823 TF_Status* status) {
824 tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
825 tensorflow::unwrap(ctx)->GetDistributedManager();
826 tensorflow::CoordinationServiceAgent* coord_agent =
827 dist_mgr->GetCoordinationServiceAgent();
828 if (coord_agent == nullptr) {
829 status->status = tensorflow::errors::FailedPrecondition(
830 "Coordination service is not enabled.");
831 return;
832 }
833 status->status = coord_agent->DeleteKeyValue(key);
834 }
835
TFE_ReportErrorToCluster(TFE_Context * ctx,int error_code,const char * error_message,TF_Status * status)836 void TFE_ReportErrorToCluster(TFE_Context* ctx, int error_code,
837 const char* error_message, TF_Status* status) {
838 tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
839 tensorflow::unwrap(ctx)->GetDistributedManager();
840 tensorflow::CoordinationServiceAgent* coord_agent =
841 dist_mgr->GetCoordinationServiceAgent();
842 if (coord_agent == nullptr) {
843 status->status = tensorflow::errors::FailedPrecondition(
844 "Coordination service is not enabled.");
845 return;
846 }
847 tensorflow::Status s(static_cast<tensorflow::error::Code>(error_code),
848 error_message);
849 status->status = coord_agent->ReportError(s);
850 }
851