1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api_experimental.h"
17
18 #include <vector>
19
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/eager/c_api_internal.h"
22 #include "tensorflow/c/eager/tfe_context_internal.h"
23 #include "tensorflow/c/eager/tfe_op_internal.h"
24 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
25 #include "tensorflow/c/tf_status_helper.h"
26 #include "tensorflow/core/common_runtime/composite_device.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
29 #include "tensorflow/core/lib/monitoring/counter.h"
30 #include "tensorflow/core/lib/monitoring/gauge.h"
31 #include "tensorflow/core/lib/monitoring/sampler.h"
32 #include "tensorflow/core/platform/casts.h"
33 #include "tensorflow/core/platform/mutex.h"
34 #include "tensorflow/core/platform/strcat.h"
35
36 using tensorflow::string;
37
TFE_OpReset(TFE_Op * op_to_reset,const char * op_or_function_name,const char * raw_device_name,TF_Status * status)38 void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
39 const char* raw_device_name, TF_Status* status) {
40 if (op_to_reset) {
41 tensorflow::ImmediateExecutionOperation* op =
42 tensorflow::unwrap(op_to_reset);
43 op->Clear();
44 status->status = op->Reset(op_or_function_name, raw_device_name);
45 } else {
46 TF_SetStatus(status, TF_INVALID_ARGUMENT,
47 "op_to_reset should not be nullptr");
48 }
49 }
50
TFE_ContextEnableGraphCollection(TFE_Context * ctx)51 void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
52 tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
53 }
54
TFE_ContextDisableGraphCollection(TFE_Context * ctx)55 void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
56 tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
57 }
58
TFE_GetContextId(TFE_Context * ctx)59 uint64_t TFE_GetContextId(TFE_Context* ctx) {
60 tensorflow::EagerContext* context =
61 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
62 return context->GetContextId();
63 }
64
TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell * cell,int64_t value)65 void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
66 int64_t value) {
67 cell->cell.IncrementBy(value);
68 }
69
TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell * cell)70 int64_t TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell* cell) {
71 return cell->cell.value();
72 }
73
TFE_MonitoringNewCounter0(const char * name,TF_Status * status,const char * description)74 TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(const char* name,
75 TF_Status* status,
76 const char* description) {
77 auto* result = new TFE_MonitoringCounter0({name, description});
78 Set_TF_Status_from_Status(status, result->counter->GetStatus());
79 if (!result->counter->GetStatus().ok()) {
80 delete result;
81 return nullptr;
82 }
83 return result;
84 }
85
TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0 * counter)86 void TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0* counter) {
87 delete counter;
88 }
89
TFE_MonitoringGetCellCounter0(TFE_MonitoringCounter0 * counter)90 TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
91 TFE_MonitoringCounter0* counter) {
92 return static_cast<TFE_MonitoringCounterCell*>(
93 static_cast<void*>(counter->counter->GetCell()));
94 }
95
TFE_MonitoringNewCounter1(const char * name,TF_Status * status,const char * description,const char * label1)96 TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(const char* name,
97 TF_Status* status,
98 const char* description,
99 const char* label1) {
100 auto* result = new TFE_MonitoringCounter1({name, description, label1});
101 Set_TF_Status_from_Status(status, result->counter->GetStatus());
102 if (!result->counter->GetStatus().ok()) {
103 delete result;
104 return nullptr;
105 }
106 return result;
107 }
108
TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1 * counter)109 void TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1* counter) {
110 delete counter;
111 }
112
TFE_MonitoringGetCellCounter1(TFE_MonitoringCounter1 * counter,const char * label1)113 TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
114 TFE_MonitoringCounter1* counter, const char* label1) {
115 return static_cast<TFE_MonitoringCounterCell*>(
116 static_cast<void*>(counter->counter->GetCell(label1)));
117 }
118
TFE_MonitoringNewCounter2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)119 TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(const char* name,
120 TF_Status* status,
121 const char* description,
122 const char* label1,
123 const char* label2) {
124 auto* result =
125 new TFE_MonitoringCounter2({name, description, label1, label2});
126 Set_TF_Status_from_Status(status, result->counter->GetStatus());
127 if (!result->counter->GetStatus().ok()) {
128 delete result;
129 return nullptr;
130 }
131 return result;
132 }
133
TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2 * counter)134 void TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2* counter) {
135 delete counter;
136 }
137
TFE_MonitoringGetCellCounter2(TFE_MonitoringCounter2 * counter,const char * label1,const char * label2)138 TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
139 TFE_MonitoringCounter2* counter, const char* label1, const char* label2) {
140 return static_cast<TFE_MonitoringCounterCell*>(
141 static_cast<void*>(counter->counter->GetCell(label1, label2)));
142 }
143
TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell * cell,int64_t value)144 void TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell* cell,
145 int64_t value) {
146 cell->cell.Set(value);
147 }
148
TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell * cell)149 int64_t TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell* cell) {
150 return cell->cell.value();
151 }
152
TFE_MonitoringNewIntGauge0(const char * name,TF_Status * status,const char * description)153 TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(const char* name,
154 TF_Status* status,
155 const char* description) {
156 auto* result = new TFE_MonitoringIntGauge0({name, description});
157 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
158 if (!result->gauge->GetStatus().ok()) {
159 delete result;
160 return nullptr;
161 }
162 return result;
163 }
164
TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0 * gauge)165 void TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0* gauge) {
166 delete gauge;
167 }
168
TFE_MonitoringGetCellIntGauge0(TFE_MonitoringIntGauge0 * gauge)169 TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge0(
170 TFE_MonitoringIntGauge0* gauge) {
171 return static_cast<TFE_MonitoringIntGaugeCell*>(
172 static_cast<void*>(gauge->gauge->GetCell()));
173 }
174
TFE_MonitoringNewIntGauge1(const char * name,TF_Status * status,const char * description,const char * label1)175 TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(const char* name,
176 TF_Status* status,
177 const char* description,
178 const char* label1) {
179 auto* result = new TFE_MonitoringIntGauge1({name, description, label1});
180 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
181 if (!result->gauge->GetStatus().ok()) {
182 delete result;
183 return nullptr;
184 }
185 return result;
186 }
187
TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1 * gauge)188 void TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1* gauge) {
189 delete gauge;
190 }
191
TFE_MonitoringGetCellIntGauge1(TFE_MonitoringIntGauge1 * gauge,const char * label1)192 TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge1(
193 TFE_MonitoringIntGauge1* gauge, const char* label1) {
194 return static_cast<TFE_MonitoringIntGaugeCell*>(
195 static_cast<void*>(gauge->gauge->GetCell(label1)));
196 }
197
TFE_MonitoringNewIntGauge2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)198 TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(const char* name,
199 TF_Status* status,
200 const char* description,
201 const char* label1,
202 const char* label2) {
203 auto* result =
204 new TFE_MonitoringIntGauge2({name, description, label1, label2});
205 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
206 if (!result->gauge->GetStatus().ok()) {
207 delete result;
208 return nullptr;
209 }
210 return result;
211 }
212
TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2 * gauge)213 void TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2* gauge) {
214 delete gauge;
215 }
216
TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2 * gauge,const char * label1,const char * label2)217 TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge2(
218 TFE_MonitoringIntGauge2* gauge, const char* label1, const char* label2) {
219 return static_cast<TFE_MonitoringIntGaugeCell*>(
220 static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
221 }
222
TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell * cell,const char * value)223 void TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell* cell,
224 const char* value) {
225 cell->cell.Set({value});
226 }
227
TFE_MonitoringStringGaugeCellValue(TFE_MonitoringStringGaugeCell * cell,TF_Buffer * buf)228 const void TFE_MonitoringStringGaugeCellValue(
229 TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf) {
230 tensorflow::string value = cell->cell.value();
231 void* data = tensorflow::port::Malloc(value.length());
232 value.copy(static_cast<char*>(data), value.length(), 0);
233 buf->data = data;
234 buf->length = value.length();
235 buf->data_deallocator = [](void* data, size_t length) {
236 tensorflow::port::Free(data);
237 };
238 }
239
TFE_MonitoringNewStringGauge0(const char * name,TF_Status * status,const char * description)240 TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
241 const char* name, TF_Status* status, const char* description) {
242 auto* result = new TFE_MonitoringStringGauge0({name, description});
243 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
244 if (!result->gauge->GetStatus().ok()) {
245 delete result;
246 return nullptr;
247 }
248 return result;
249 }
250
TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0 * gauge)251 void TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0* gauge) {
252 delete gauge;
253 }
254
TFE_MonitoringGetCellStringGauge0(TFE_MonitoringStringGauge0 * gauge)255 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge0(
256 TFE_MonitoringStringGauge0* gauge) {
257 return static_cast<TFE_MonitoringStringGaugeCell*>(
258 static_cast<void*>(gauge->gauge->GetCell()));
259 }
260
TFE_MonitoringNewStringGauge1(const char * name,TF_Status * status,const char * description,const char * label1)261 TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
262 const char* name, TF_Status* status, const char* description,
263 const char* label1) {
264 auto* result = new TFE_MonitoringStringGauge1({name, description, label1});
265 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
266 if (!result->gauge->GetStatus().ok()) {
267 delete result;
268 return nullptr;
269 }
270 return result;
271 }
272
TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1 * gauge)273 void TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1* gauge) {
274 delete gauge;
275 }
276
TFE_MonitoringGetCellStringGauge1(TFE_MonitoringStringGauge1 * gauge,const char * label1)277 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge1(
278 TFE_MonitoringStringGauge1* gauge, const char* label1) {
279 return static_cast<TFE_MonitoringStringGaugeCell*>(
280 static_cast<void*>(gauge->gauge->GetCell(label1)));
281 }
282
TFE_MonitoringNewStringGauge2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)283 TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
284 const char* name, TF_Status* status, const char* description,
285 const char* label1, const char* label2) {
286 auto* result =
287 new TFE_MonitoringStringGauge2({name, description, label1, label2});
288 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
289 if (!result->gauge->GetStatus().ok()) {
290 delete result;
291 return nullptr;
292 }
293 return result;
294 }
295
TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2 * gauge)296 void TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2* gauge) {
297 delete gauge;
298 }
299
TFE_MonitoringGetCellStringGauge2(TFE_MonitoringStringGauge2 * gauge,const char * label1,const char * label2)300 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge2(
301 TFE_MonitoringStringGauge2* gauge, const char* label1, const char* label2) {
302 return static_cast<TFE_MonitoringStringGaugeCell*>(
303 static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
304 }
305
TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell * cell,bool value)306 void TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell* cell,
307 bool value) {
308 cell->cell.Set(value);
309 }
310
TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell * cell)311 bool TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell* cell) {
312 return cell->cell.value();
313 }
314
TFE_MonitoringNewBoolGauge0(const char * name,TF_Status * status,const char * description)315 TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(const char* name,
316 TF_Status* status,
317 const char* description) {
318 auto* result = new TFE_MonitoringBoolGauge0({name, description});
319 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
320 if (!result->gauge->GetStatus().ok()) {
321 delete result;
322 return nullptr;
323 }
324 return result;
325 }
326
TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0 * gauge)327 void TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0* gauge) {
328 delete gauge;
329 }
330
TFE_MonitoringGetCellBoolGauge0(TFE_MonitoringBoolGauge0 * gauge)331 TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge0(
332 TFE_MonitoringBoolGauge0* gauge) {
333 return static_cast<TFE_MonitoringBoolGaugeCell*>(
334 static_cast<void*>(gauge->gauge->GetCell()));
335 }
336
TFE_MonitoringNewBoolGauge1(const char * name,TF_Status * status,const char * description,const char * label1)337 TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(const char* name,
338 TF_Status* status,
339 const char* description,
340 const char* label1) {
341 auto* result = new TFE_MonitoringBoolGauge1({name, description, label1});
342 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
343 if (!result->gauge->GetStatus().ok()) {
344 delete result;
345 return nullptr;
346 }
347 return result;
348 }
349
TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1 * gauge)350 void TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1* gauge) {
351 delete gauge;
352 }
353
TFE_MonitoringGetCellBoolGauge1(TFE_MonitoringBoolGauge1 * gauge,const char * label1)354 TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge1(
355 TFE_MonitoringBoolGauge1* gauge, const char* label1) {
356 return static_cast<TFE_MonitoringBoolGaugeCell*>(
357 static_cast<void*>(gauge->gauge->GetCell(label1)));
358 }
359
TFE_MonitoringNewBoolGauge2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)360 TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(const char* name,
361 TF_Status* status,
362 const char* description,
363 const char* label1,
364 const char* label2) {
365 auto* result =
366 new TFE_MonitoringBoolGauge2({name, description, label1, label2});
367 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
368 if (!result->gauge->GetStatus().ok()) {
369 delete result;
370 return nullptr;
371 }
372 return result;
373 }
374
TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2 * gauge)375 void TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2* gauge) {
376 delete gauge;
377 }
378
TFE_MonitoringGetCellBoolGauge2(TFE_MonitoringBoolGauge2 * gauge,const char * label1,const char * label2)379 TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge2(
380 TFE_MonitoringBoolGauge2* gauge, const char* label1, const char* label2) {
381 return static_cast<TFE_MonitoringBoolGaugeCell*>(
382 static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
383 }
384
TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell * cell,double value)385 void TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell* cell,
386 double value) {
387 cell->cell.Add(value);
388 }
389
TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell * cell,TF_Buffer * buf)390 void TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell* cell,
391 TF_Buffer* buf) {
392 string content;
393 cell->cell.value().SerializeToString(&content);
394 void* data = tensorflow::port::Malloc(content.length());
395 content.copy(static_cast<char*>(data), content.length(), 0);
396 buf->data = data;
397 buf->length = content.length();
398 buf->data_deallocator = [](void* data, size_t length) {
399 tensorflow::port::Free(data);
400 };
401 }
402
TFE_MonitoringNewExponentialBuckets(double scale,double growth_factor,int bucket_count)403 TFE_MonitoringBuckets* TFE_MonitoringNewExponentialBuckets(double scale,
404 double growth_factor,
405 int bucket_count) {
406 return new TFE_MonitoringBuckets([scale, growth_factor, bucket_count]() {
407 return tensorflow::monitoring::Buckets::Exponential(scale, growth_factor,
408 bucket_count);
409 });
410 }
411
TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets * buckets)412 void TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets* buckets) {
413 delete buckets;
414 }
415
TFE_MonitoringNewSampler0(const char * name,TFE_MonitoringBuckets * buckets,TF_Status * status,const char * description)416 TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
417 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
418 const char* description) {
419 auto* result = new TFE_MonitoringSampler0(
420 {name, buckets->create_buckets(), description});
421 Set_TF_Status_from_Status(status, result->sampler->GetStatus());
422 if (!result->sampler->GetStatus().ok()) {
423 delete result;
424 return nullptr;
425 }
426 return result;
427 }
428
TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0 * sampler)429 void TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0* sampler) {
430 delete sampler;
431 }
432
TFE_MonitoringGetCellSampler0(TFE_MonitoringSampler0 * sampler)433 TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
434 TFE_MonitoringSampler0* sampler) {
435 return static_cast<TFE_MonitoringSamplerCell*>(
436 static_cast<void*>(sampler->sampler->GetCell()));
437 }
438
TFE_MonitoringNewSampler1(const char * name,TFE_MonitoringBuckets * buckets,TF_Status * status,const char * description,const char * label1)439 TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
440 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
441 const char* description, const char* label1) {
442 auto* result = new TFE_MonitoringSampler1(
443 {name, buckets->create_buckets(), description, label1});
444 Set_TF_Status_from_Status(status, result->sampler->GetStatus());
445 if (!result->sampler->GetStatus().ok()) {
446 delete result;
447 return nullptr;
448 }
449 return result;
450 }
451
TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1 * sampler)452 void TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1* sampler) {
453 delete sampler;
454 }
455
TFE_MonitoringGetCellSampler1(TFE_MonitoringSampler1 * sampler,const char * label1)456 TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
457 TFE_MonitoringSampler1* sampler, const char* label1) {
458 return static_cast<TFE_MonitoringSamplerCell*>(
459 static_cast<void*>(sampler->sampler->GetCell(label1)));
460 }
461
TFE_MonitoringNewSampler2(const char * name,TFE_MonitoringBuckets * buckets,TF_Status * status,const char * description,const char * label1,const char * label2)462 TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
463 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
464 const char* description, const char* label1, const char* label2) {
465 auto* result = new TFE_MonitoringSampler2(
466 {name, buckets->create_buckets(), description, label1, label2});
467 Set_TF_Status_from_Status(status, result->sampler->GetStatus());
468 if (!result->sampler->GetStatus().ok()) {
469 delete result;
470 return nullptr;
471 }
472 return result;
473 }
474
TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2 * sampler)475 void TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2* sampler) {
476 delete sampler;
477 }
478
TFE_MonitoringGetCellSampler2(TFE_MonitoringSampler2 * sampler,const char * label1,const char * label2)479 TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
480 TFE_MonitoringSampler2* sampler, const char* label1, const char* label2) {
481 return static_cast<TFE_MonitoringSamplerCell*>(
482 static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
483 }
484
TFE_ContextOptionsSetTfrt(TFE_ContextOptions * options,bool use_tfrt)485 void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) {
486 options->use_tfrt = use_tfrt;
487 }
488
TFE_NewCancellationManager()489 TFE_CancellationManager* TFE_NewCancellationManager() {
490 return tensorflow::wrap(new tensorflow::CancellationManager);
491 }
492
TFE_CancellationManagerStartCancel(TFE_CancellationManager * cancellation_manager)493 void TFE_CancellationManagerStartCancel(
494 TFE_CancellationManager* cancellation_manager) {
495 tensorflow::unwrap(cancellation_manager)->StartCancel();
496 }
497
TFE_CancellationManagerIsCancelled(TFE_CancellationManager * cancellation_manager)498 bool TFE_CancellationManagerIsCancelled(
499 TFE_CancellationManager* cancellation_manager) {
500 return tensorflow::unwrap(cancellation_manager)->IsCancelled();
501 }
502
TFE_DeleteCancellationManager(TFE_CancellationManager * cancellation_manager)503 void TFE_DeleteCancellationManager(
504 TFE_CancellationManager* cancellation_manager) {
505 delete tensorflow::unwrap(cancellation_manager);
506 }
507
TFE_OpSetCancellationManager(TFE_Op * op,TFE_CancellationManager * cancellation_manager,TF_Status * status)508 void TFE_OpSetCancellationManager(TFE_Op* op,
509 TFE_CancellationManager* cancellation_manager,
510 TF_Status* status) {
511 tensorflow::EagerOperation* operation =
512 tensorflow::OperationFromInterface(tensorflow::unwrap(op));
513 operation->SetCancellationManager(tensorflow::unwrap(cancellation_manager));
514 status->status = tensorflow::Status::OK();
515 }
516
TFE_NewExecutor(bool is_async)517 TFE_Executor* TFE_NewExecutor(bool is_async) {
518 return new TFE_Executor(is_async);
519 }
520
TFE_DeleteExecutor(TFE_Executor * executor)521 void TFE_DeleteExecutor(TFE_Executor* executor) { delete executor; }
522
TFE_ExecutorIsAsync(TFE_Executor * executor)523 bool TFE_ExecutorIsAsync(TFE_Executor* executor) {
524 return executor->executor()->Async();
525 }
526
TFE_ExecutorWaitForAllPendingNodes(TFE_Executor * executor,TF_Status * status)527 void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor* executor,
528 TF_Status* status) {
529 status->status = executor->executor()->WaitForAllPendingNodes();
530 }
531
TFE_ExecutorClearError(TFE_Executor * executor)532 void TFE_ExecutorClearError(TFE_Executor* executor) {
533 executor->executor()->ClearError();
534 }
535
TFE_ContextSetExecutorForThread(TFE_Context * ctx,TFE_Executor * executor)536 void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
537 tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
538 }
539
TFE_ContextGetExecutorForThread(TFE_Context * ctx)540 TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
541 return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
542 }
543
TFE_HostAddressSpace(TFE_Context * ctx,TF_Buffer * buf)544 void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
545 auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
546 tensorflow::unwrap(ctx)->HostCPUParsedName());
547 auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
548 void* data = tensorflow::port::Malloc(str.length());
549 str.copy(static_cast<char*>(data), str.length(), 0);
550 buf->data = data;
551 buf->length = str.length();
552 buf->data_deallocator = [](void* data, size_t length) {
553 tensorflow::port::Free(data);
554 };
555 }
556
TFE_ContextGetFunctionDef(TFE_Context * ctx,const char * function_name,TF_Buffer * buf,TF_Status * status)557 void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
558 TF_Buffer* buf, TF_Status* status) {
559 auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
560 if (function_def == nullptr) {
561 status->status = tensorflow::errors::NotFound(
562 "Unable to find FunctionDef with name: ", function_name);
563 return;
564 }
565 string str = function_def->SerializeAsString();
566 void* data = tensorflow::port::Malloc(str.length());
567 str.copy(static_cast<char*>(data), str.length(), 0);
568 buf->data = data;
569 buf->length = str.length();
570 buf->data_deallocator = [](void* data, size_t length) {
571 tensorflow::port::Free(data);
572 };
573 status->status = tensorflow::Status::OK();
574 }
575
TFE_AllocateHostTensor(TFE_Context * ctx,TF_DataType dtype,const int64_t * dims,int num_dims,TF_Status * status)576 TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
577 const int64_t* dims, int num_dims,
578 TF_Status* status) {
579 std::vector<tensorflow::int64> dimvec(num_dims);
580 for (int i = 0; i < num_dims; ++i) {
581 dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
582 }
583
584 if (ctx == nullptr) {
585 status->status = tensorflow::errors::InvalidArgument("Invalid Context");
586 return nullptr;
587 }
588
589 tensorflow::AbstractTensorInterface* t =
590 tensorflow::unwrap(ctx)->CreateTensor(
591 static_cast<tensorflow::DataType>(dtype), dimvec);
592
593 if (t == nullptr) {
594 status->status =
595 tensorflow::errors::InvalidArgument("Unsupported dtype: ", dtype);
596 return nullptr;
597 }
598
599 return new TF_Tensor{t};
600 }
601
TFE_NewTensorHandleFromTensor(TFE_Context * ctx,TF_Tensor * t,TF_Status * status)602 TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
603 TF_Status* status) {
604 return tensorflow::wrap(
605 tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
606 }
607
TFE_CreatePackedTensorHandle(TFE_Context * ctx,TFE_TensorHandle ** handles,int * num_handles,TF_Status * status)608 TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
609 TFE_TensorHandle** handles,
610 int* num_handles,
611 TF_Status* status) {
612 std::vector<tensorflow::TensorHandle*> tensor_handles;
613 tensor_handles.reserve(*num_handles);
614 for (int i = 0; i < *num_handles; ++i) {
615 tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
616 tensorflow::unwrap(handles[i]);
617 if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
618 // One of the inputs we're trying to pack is on a custom device. We'll let
619 // the first custom device we see handle all of the packing.
620 auto* custom_device_handle =
621 tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
622 unwrapped_handle);
623 tensorflow::ImmediateExecutionTensorHandle* result;
624 status->status = custom_device_handle->device()->Pack(
625 absl::Span<tensorflow::ImmediateExecutionTensorHandle*>(
626 tensorflow::unwrap(handles), *num_handles),
627 &result);
628 return tensorflow::wrap(result);
629 }
630 tensor_handles.push_back(
631 tensorflow::TensorHandleFromInterface(unwrapped_handle));
632 }
633 tensorflow::EagerContext* context =
634 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
635 tensorflow::TensorHandle* handle = nullptr;
636 status->status = tensorflow::TensorHandle::CreatePackedHandle(
637 std::move(tensor_handles), context, &handle);
638 return tensorflow::wrap(handle);
639 }
640
TFE_ContextSetSoftDevicePlacement(TFE_Context * ctx,unsigned char enable,TF_Status * status)641 void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
642 TF_Status* status) {
643 tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
644 }
645
TFE_ContextSetLogDevicePlacement(TFE_Context * ctx,unsigned char enable,TF_Status * status)646 void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
647 TF_Status* status) {
648 tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
649 }
650
TFE_TensorHandleDeviceType(TFE_TensorHandle * h,TF_Status * status)651 const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) {
652 if (h == nullptr) {
653 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
654 return nullptr;
655 }
656 return tensorflow::unwrap(h)->DeviceType(&status->status);
657 }
658
TFE_TensorHandleDeviceID(TFE_TensorHandle * h,TF_Status * status)659 int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
660 if (h == nullptr) {
661 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
662 return -1;
663 }
664 return tensorflow::unwrap(h)->DeviceId(&status->status);
665 }
666
TFE_GetExecutedOpNames(TFE_Context * ctx,TF_Buffer * buf,TF_Status * status)667 void TFE_GetExecutedOpNames(TFE_Context* ctx, TF_Buffer* buf,
668 TF_Status* status) {
669 const std::vector<std::string>& op_names =
670 tensorflow::unwrap(ctx)->GetLoggedOpsTestonly();
671
672 std::ostringstream op_names_oss;
673 for (const auto& op : op_names) {
674 op_names_oss << op << ", ";
675 }
676 const std::string& op_names_str = op_names_oss.str();
677 void* data = tensorflow::port::Malloc(op_names_str.length());
678 op_names_str.copy(static_cast<char*>(data), op_names_str.length(), 0);
679 buf->data = data;
680 buf->length = op_names_str.length();
681 buf->data_deallocator = [](void* data, size_t length) {
682 tensorflow::port::Free(data);
683 };
684 status->status = tensorflow::Status::OK();
685 }
686