• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api_experimental.h"
17 
18 #include <vector>
19 
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/eager/c_api_internal.h"
22 #include "tensorflow/c/eager/tfe_context_internal.h"
23 #include "tensorflow/c/eager/tfe_op_internal.h"
24 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
25 #include "tensorflow/c/tf_status_helper.h"
26 #include "tensorflow/core/common_runtime/composite_device.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
29 #include "tensorflow/core/lib/monitoring/counter.h"
30 #include "tensorflow/core/lib/monitoring/gauge.h"
31 #include "tensorflow/core/lib/monitoring/sampler.h"
32 #include "tensorflow/core/platform/casts.h"
33 #include "tensorflow/core/platform/mutex.h"
34 #include "tensorflow/core/platform/strcat.h"
35 
36 using tensorflow::string;
37 
TFE_OpReset(TFE_Op * op_to_reset,const char * op_or_function_name,const char * raw_device_name,TF_Status * status)38 void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
39                  const char* raw_device_name, TF_Status* status) {
40   if (op_to_reset) {
41     tensorflow::ImmediateExecutionOperation* op =
42         tensorflow::unwrap(op_to_reset);
43     op->Clear();
44     status->status = op->Reset(op_or_function_name, raw_device_name);
45   } else {
46     TF_SetStatus(status, TF_INVALID_ARGUMENT,
47                  "op_to_reset should not be nullptr");
48   }
49 }
50 
TFE_ContextEnableGraphCollection(TFE_Context * ctx)51 void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
52   tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
53 }
54 
TFE_ContextDisableGraphCollection(TFE_Context * ctx)55 void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
56   tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
57 }
58 
TFE_GetContextId(TFE_Context * ctx)59 uint64_t TFE_GetContextId(TFE_Context* ctx) {
60   tensorflow::EagerContext* context =
61       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
62   return context->GetContextId();
63 }
64 
TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell * cell,int64_t value)65 void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
66                                           int64_t value) {
67   cell->cell.IncrementBy(value);
68 }
69 
TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell * cell)70 int64_t TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell* cell) {
71   return cell->cell.value();
72 }
73 
TFE_MonitoringNewCounter0(const char * name,TF_Status * status,const char * description)74 TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(const char* name,
75                                                   TF_Status* status,
76                                                   const char* description) {
77   auto* result = new TFE_MonitoringCounter0({name, description});
78   Set_TF_Status_from_Status(status, result->counter->GetStatus());
79   if (!result->counter->GetStatus().ok()) {
80     delete result;
81     return nullptr;
82   }
83   return result;
84 }
85 
TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0 * counter)86 void TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0* counter) {
87   delete counter;
88 }
89 
TFE_MonitoringGetCellCounter0(TFE_MonitoringCounter0 * counter)90 TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
91     TFE_MonitoringCounter0* counter) {
92   return static_cast<TFE_MonitoringCounterCell*>(
93       static_cast<void*>(counter->counter->GetCell()));
94 }
95 
TFE_MonitoringNewCounter1(const char * name,TF_Status * status,const char * description,const char * label1)96 TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(const char* name,
97                                                   TF_Status* status,
98                                                   const char* description,
99                                                   const char* label1) {
100   auto* result = new TFE_MonitoringCounter1({name, description, label1});
101   Set_TF_Status_from_Status(status, result->counter->GetStatus());
102   if (!result->counter->GetStatus().ok()) {
103     delete result;
104     return nullptr;
105   }
106   return result;
107 }
108 
TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1 * counter)109 void TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1* counter) {
110   delete counter;
111 }
112 
TFE_MonitoringGetCellCounter1(TFE_MonitoringCounter1 * counter,const char * label1)113 TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
114     TFE_MonitoringCounter1* counter, const char* label1) {
115   return static_cast<TFE_MonitoringCounterCell*>(
116       static_cast<void*>(counter->counter->GetCell(label1)));
117 }
118 
TFE_MonitoringNewCounter2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)119 TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(const char* name,
120                                                   TF_Status* status,
121                                                   const char* description,
122                                                   const char* label1,
123                                                   const char* label2) {
124   auto* result =
125       new TFE_MonitoringCounter2({name, description, label1, label2});
126   Set_TF_Status_from_Status(status, result->counter->GetStatus());
127   if (!result->counter->GetStatus().ok()) {
128     delete result;
129     return nullptr;
130   }
131   return result;
132 }
133 
TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2 * counter)134 void TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2* counter) {
135   delete counter;
136 }
137 
TFE_MonitoringGetCellCounter2(TFE_MonitoringCounter2 * counter,const char * label1,const char * label2)138 TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
139     TFE_MonitoringCounter2* counter, const char* label1, const char* label2) {
140   return static_cast<TFE_MonitoringCounterCell*>(
141       static_cast<void*>(counter->counter->GetCell(label1, label2)));
142 }
143 
TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell * cell,int64_t value)144 void TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell* cell,
145                                    int64_t value) {
146   cell->cell.Set(value);
147 }
148 
TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell * cell)149 int64_t TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell* cell) {
150   return cell->cell.value();
151 }
152 
TFE_MonitoringNewIntGauge0(const char * name,TF_Status * status,const char * description)153 TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(const char* name,
154                                                     TF_Status* status,
155                                                     const char* description) {
156   auto* result = new TFE_MonitoringIntGauge0({name, description});
157   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
158   if (!result->gauge->GetStatus().ok()) {
159     delete result;
160     return nullptr;
161   }
162   return result;
163 }
164 
TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0 * gauge)165 void TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0* gauge) {
166   delete gauge;
167 }
168 
TFE_MonitoringGetCellIntGauge0(TFE_MonitoringIntGauge0 * gauge)169 TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge0(
170     TFE_MonitoringIntGauge0* gauge) {
171   return static_cast<TFE_MonitoringIntGaugeCell*>(
172       static_cast<void*>(gauge->gauge->GetCell()));
173 }
174 
TFE_MonitoringNewIntGauge1(const char * name,TF_Status * status,const char * description,const char * label1)175 TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(const char* name,
176                                                     TF_Status* status,
177                                                     const char* description,
178                                                     const char* label1) {
179   auto* result = new TFE_MonitoringIntGauge1({name, description, label1});
180   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
181   if (!result->gauge->GetStatus().ok()) {
182     delete result;
183     return nullptr;
184   }
185   return result;
186 }
187 
TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1 * gauge)188 void TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1* gauge) {
189   delete gauge;
190 }
191 
TFE_MonitoringGetCellIntGauge1(TFE_MonitoringIntGauge1 * gauge,const char * label1)192 TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge1(
193     TFE_MonitoringIntGauge1* gauge, const char* label1) {
194   return static_cast<TFE_MonitoringIntGaugeCell*>(
195       static_cast<void*>(gauge->gauge->GetCell(label1)));
196 }
197 
TFE_MonitoringNewIntGauge2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)198 TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(const char* name,
199                                                     TF_Status* status,
200                                                     const char* description,
201                                                     const char* label1,
202                                                     const char* label2) {
203   auto* result =
204       new TFE_MonitoringIntGauge2({name, description, label1, label2});
205   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
206   if (!result->gauge->GetStatus().ok()) {
207     delete result;
208     return nullptr;
209   }
210   return result;
211 }
212 
TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2 * gauge)213 void TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2* gauge) {
214   delete gauge;
215 }
216 
TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2 * gauge,const char * label1,const char * label2)217 TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge2(
218     TFE_MonitoringIntGauge2* gauge, const char* label1, const char* label2) {
219   return static_cast<TFE_MonitoringIntGaugeCell*>(
220       static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
221 }
222 
TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell * cell,const char * value)223 void TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell* cell,
224                                       const char* value) {
225   cell->cell.Set({value});
226 }
227 
TFE_MonitoringStringGaugeCellValue(TFE_MonitoringStringGaugeCell * cell,TF_Buffer * buf)228 const void TFE_MonitoringStringGaugeCellValue(
229     TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf) {
230   tensorflow::string value = cell->cell.value();
231   void* data = tensorflow::port::Malloc(value.length());
232   value.copy(static_cast<char*>(data), value.length(), 0);
233   buf->data = data;
234   buf->length = value.length();
235   buf->data_deallocator = [](void* data, size_t length) {
236     tensorflow::port::Free(data);
237   };
238 }
239 
TFE_MonitoringNewStringGauge0(const char * name,TF_Status * status,const char * description)240 TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
241     const char* name, TF_Status* status, const char* description) {
242   auto* result = new TFE_MonitoringStringGauge0({name, description});
243   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
244   if (!result->gauge->GetStatus().ok()) {
245     delete result;
246     return nullptr;
247   }
248   return result;
249 }
250 
TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0 * gauge)251 void TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0* gauge) {
252   delete gauge;
253 }
254 
TFE_MonitoringGetCellStringGauge0(TFE_MonitoringStringGauge0 * gauge)255 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge0(
256     TFE_MonitoringStringGauge0* gauge) {
257   return static_cast<TFE_MonitoringStringGaugeCell*>(
258       static_cast<void*>(gauge->gauge->GetCell()));
259 }
260 
TFE_MonitoringNewStringGauge1(const char * name,TF_Status * status,const char * description,const char * label1)261 TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
262     const char* name, TF_Status* status, const char* description,
263     const char* label1) {
264   auto* result = new TFE_MonitoringStringGauge1({name, description, label1});
265   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
266   if (!result->gauge->GetStatus().ok()) {
267     delete result;
268     return nullptr;
269   }
270   return result;
271 }
272 
TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1 * gauge)273 void TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1* gauge) {
274   delete gauge;
275 }
276 
TFE_MonitoringGetCellStringGauge1(TFE_MonitoringStringGauge1 * gauge,const char * label1)277 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge1(
278     TFE_MonitoringStringGauge1* gauge, const char* label1) {
279   return static_cast<TFE_MonitoringStringGaugeCell*>(
280       static_cast<void*>(gauge->gauge->GetCell(label1)));
281 }
282 
TFE_MonitoringNewStringGauge2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)283 TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
284     const char* name, TF_Status* status, const char* description,
285     const char* label1, const char* label2) {
286   auto* result =
287       new TFE_MonitoringStringGauge2({name, description, label1, label2});
288   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
289   if (!result->gauge->GetStatus().ok()) {
290     delete result;
291     return nullptr;
292   }
293   return result;
294 }
295 
TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2 * gauge)296 void TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2* gauge) {
297   delete gauge;
298 }
299 
TFE_MonitoringGetCellStringGauge2(TFE_MonitoringStringGauge2 * gauge,const char * label1,const char * label2)300 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge2(
301     TFE_MonitoringStringGauge2* gauge, const char* label1, const char* label2) {
302   return static_cast<TFE_MonitoringStringGaugeCell*>(
303       static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
304 }
305 
TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell * cell,bool value)306 void TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell* cell,
307                                     bool value) {
308   cell->cell.Set(value);
309 }
310 
TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell * cell)311 bool TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell* cell) {
312   return cell->cell.value();
313 }
314 
TFE_MonitoringNewBoolGauge0(const char * name,TF_Status * status,const char * description)315 TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(const char* name,
316                                                       TF_Status* status,
317                                                       const char* description) {
318   auto* result = new TFE_MonitoringBoolGauge0({name, description});
319   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
320   if (!result->gauge->GetStatus().ok()) {
321     delete result;
322     return nullptr;
323   }
324   return result;
325 }
326 
TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0 * gauge)327 void TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0* gauge) {
328   delete gauge;
329 }
330 
TFE_MonitoringGetCellBoolGauge0(TFE_MonitoringBoolGauge0 * gauge)331 TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge0(
332     TFE_MonitoringBoolGauge0* gauge) {
333   return static_cast<TFE_MonitoringBoolGaugeCell*>(
334       static_cast<void*>(gauge->gauge->GetCell()));
335 }
336 
TFE_MonitoringNewBoolGauge1(const char * name,TF_Status * status,const char * description,const char * label1)337 TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(const char* name,
338                                                       TF_Status* status,
339                                                       const char* description,
340                                                       const char* label1) {
341   auto* result = new TFE_MonitoringBoolGauge1({name, description, label1});
342   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
343   if (!result->gauge->GetStatus().ok()) {
344     delete result;
345     return nullptr;
346   }
347   return result;
348 }
349 
TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1 * gauge)350 void TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1* gauge) {
351   delete gauge;
352 }
353 
TFE_MonitoringGetCellBoolGauge1(TFE_MonitoringBoolGauge1 * gauge,const char * label1)354 TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge1(
355     TFE_MonitoringBoolGauge1* gauge, const char* label1) {
356   return static_cast<TFE_MonitoringBoolGaugeCell*>(
357       static_cast<void*>(gauge->gauge->GetCell(label1)));
358 }
359 
TFE_MonitoringNewBoolGauge2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)360 TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(const char* name,
361                                                       TF_Status* status,
362                                                       const char* description,
363                                                       const char* label1,
364                                                       const char* label2) {
365   auto* result =
366       new TFE_MonitoringBoolGauge2({name, description, label1, label2});
367   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
368   if (!result->gauge->GetStatus().ok()) {
369     delete result;
370     return nullptr;
371   }
372   return result;
373 }
374 
TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2 * gauge)375 void TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2* gauge) {
376   delete gauge;
377 }
378 
TFE_MonitoringGetCellBoolGauge2(TFE_MonitoringBoolGauge2 * gauge,const char * label1,const char * label2)379 TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge2(
380     TFE_MonitoringBoolGauge2* gauge, const char* label1, const char* label2) {
381   return static_cast<TFE_MonitoringBoolGaugeCell*>(
382       static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
383 }
384 
TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell * cell,double value)385 void TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell* cell,
386                                   double value) {
387   cell->cell.Add(value);
388 }
389 
TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell * cell,TF_Buffer * buf)390 void TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell* cell,
391                                     TF_Buffer* buf) {
392   string content;
393   cell->cell.value().SerializeToString(&content);
394   void* data = tensorflow::port::Malloc(content.length());
395   content.copy(static_cast<char*>(data), content.length(), 0);
396   buf->data = data;
397   buf->length = content.length();
398   buf->data_deallocator = [](void* data, size_t length) {
399     tensorflow::port::Free(data);
400   };
401 }
402 
TFE_MonitoringNewExponentialBuckets(double scale,double growth_factor,int bucket_count)403 TFE_MonitoringBuckets* TFE_MonitoringNewExponentialBuckets(double scale,
404                                                            double growth_factor,
405                                                            int bucket_count) {
406   return new TFE_MonitoringBuckets([scale, growth_factor, bucket_count]() {
407     return tensorflow::monitoring::Buckets::Exponential(scale, growth_factor,
408                                                         bucket_count);
409   });
410 }
411 
TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets * buckets)412 void TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets* buckets) {
413   delete buckets;
414 }
415 
TFE_MonitoringNewSampler0(const char * name,TFE_MonitoringBuckets * buckets,TF_Status * status,const char * description)416 TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
417     const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
418     const char* description) {
419   auto* result = new TFE_MonitoringSampler0(
420       {name, buckets->create_buckets(), description});
421   Set_TF_Status_from_Status(status, result->sampler->GetStatus());
422   if (!result->sampler->GetStatus().ok()) {
423     delete result;
424     return nullptr;
425   }
426   return result;
427 }
428 
TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0 * sampler)429 void TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0* sampler) {
430   delete sampler;
431 }
432 
TFE_MonitoringGetCellSampler0(TFE_MonitoringSampler0 * sampler)433 TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
434     TFE_MonitoringSampler0* sampler) {
435   return static_cast<TFE_MonitoringSamplerCell*>(
436       static_cast<void*>(sampler->sampler->GetCell()));
437 }
438 
TFE_MonitoringNewSampler1(const char * name,TFE_MonitoringBuckets * buckets,TF_Status * status,const char * description,const char * label1)439 TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
440     const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
441     const char* description, const char* label1) {
442   auto* result = new TFE_MonitoringSampler1(
443       {name, buckets->create_buckets(), description, label1});
444   Set_TF_Status_from_Status(status, result->sampler->GetStatus());
445   if (!result->sampler->GetStatus().ok()) {
446     delete result;
447     return nullptr;
448   }
449   return result;
450 }
451 
TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1 * sampler)452 void TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1* sampler) {
453   delete sampler;
454 }
455 
TFE_MonitoringGetCellSampler1(TFE_MonitoringSampler1 * sampler,const char * label1)456 TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
457     TFE_MonitoringSampler1* sampler, const char* label1) {
458   return static_cast<TFE_MonitoringSamplerCell*>(
459       static_cast<void*>(sampler->sampler->GetCell(label1)));
460 }
461 
TFE_MonitoringNewSampler2(const char * name,TFE_MonitoringBuckets * buckets,TF_Status * status,const char * description,const char * label1,const char * label2)462 TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
463     const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
464     const char* description, const char* label1, const char* label2) {
465   auto* result = new TFE_MonitoringSampler2(
466       {name, buckets->create_buckets(), description, label1, label2});
467   Set_TF_Status_from_Status(status, result->sampler->GetStatus());
468   if (!result->sampler->GetStatus().ok()) {
469     delete result;
470     return nullptr;
471   }
472   return result;
473 }
474 
TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2 * sampler)475 void TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2* sampler) {
476   delete sampler;
477 }
478 
TFE_MonitoringGetCellSampler2(TFE_MonitoringSampler2 * sampler,const char * label1,const char * label2)479 TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
480     TFE_MonitoringSampler2* sampler, const char* label1, const char* label2) {
481   return static_cast<TFE_MonitoringSamplerCell*>(
482       static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
483 }
484 
TFE_ContextOptionsSetTfrt(TFE_ContextOptions * options,bool use_tfrt)485 void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) {
486   options->use_tfrt = use_tfrt;
487 }
488 
TFE_NewCancellationManager()489 TFE_CancellationManager* TFE_NewCancellationManager() {
490   return tensorflow::wrap(new tensorflow::CancellationManager);
491 }
492 
TFE_CancellationManagerStartCancel(TFE_CancellationManager * cancellation_manager)493 void TFE_CancellationManagerStartCancel(
494     TFE_CancellationManager* cancellation_manager) {
495   tensorflow::unwrap(cancellation_manager)->StartCancel();
496 }
497 
TFE_CancellationManagerIsCancelled(TFE_CancellationManager * cancellation_manager)498 bool TFE_CancellationManagerIsCancelled(
499     TFE_CancellationManager* cancellation_manager) {
500   return tensorflow::unwrap(cancellation_manager)->IsCancelled();
501 }
502 
TFE_DeleteCancellationManager(TFE_CancellationManager * cancellation_manager)503 void TFE_DeleteCancellationManager(
504     TFE_CancellationManager* cancellation_manager) {
505   delete tensorflow::unwrap(cancellation_manager);
506 }
507 
TFE_OpSetCancellationManager(TFE_Op * op,TFE_CancellationManager * cancellation_manager,TF_Status * status)508 void TFE_OpSetCancellationManager(TFE_Op* op,
509                                   TFE_CancellationManager* cancellation_manager,
510                                   TF_Status* status) {
511   tensorflow::EagerOperation* operation =
512       tensorflow::OperationFromInterface(tensorflow::unwrap(op));
513   operation->SetCancellationManager(tensorflow::unwrap(cancellation_manager));
514   status->status = tensorflow::Status::OK();
515 }
516 
TFE_NewExecutor(bool is_async)517 TFE_Executor* TFE_NewExecutor(bool is_async) {
518   return new TFE_Executor(is_async);
519 }
520 
TFE_DeleteExecutor(TFE_Executor * executor)521 void TFE_DeleteExecutor(TFE_Executor* executor) { delete executor; }
522 
TFE_ExecutorIsAsync(TFE_Executor * executor)523 bool TFE_ExecutorIsAsync(TFE_Executor* executor) {
524   return executor->executor()->Async();
525 }
526 
TFE_ExecutorWaitForAllPendingNodes(TFE_Executor * executor,TF_Status * status)527 void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor* executor,
528                                         TF_Status* status) {
529   status->status = executor->executor()->WaitForAllPendingNodes();
530 }
531 
TFE_ExecutorClearError(TFE_Executor * executor)532 void TFE_ExecutorClearError(TFE_Executor* executor) {
533   executor->executor()->ClearError();
534 }
535 
TFE_ContextSetExecutorForThread(TFE_Context * ctx,TFE_Executor * executor)536 void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
537   tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
538 }
539 
TFE_ContextGetExecutorForThread(TFE_Context * ctx)540 TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
541   return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
542 }
543 
TFE_HostAddressSpace(TFE_Context * ctx,TF_Buffer * buf)544 void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
545   auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
546       tensorflow::unwrap(ctx)->HostCPUParsedName());
547   auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
548   void* data = tensorflow::port::Malloc(str.length());
549   str.copy(static_cast<char*>(data), str.length(), 0);
550   buf->data = data;
551   buf->length = str.length();
552   buf->data_deallocator = [](void* data, size_t length) {
553     tensorflow::port::Free(data);
554   };
555 }
556 
TFE_ContextGetFunctionDef(TFE_Context * ctx,const char * function_name,TF_Buffer * buf,TF_Status * status)557 void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
558                                TF_Buffer* buf, TF_Status* status) {
559   auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
560   if (function_def == nullptr) {
561     status->status = tensorflow::errors::NotFound(
562         "Unable to find FunctionDef with name: ", function_name);
563     return;
564   }
565   string str = function_def->SerializeAsString();
566   void* data = tensorflow::port::Malloc(str.length());
567   str.copy(static_cast<char*>(data), str.length(), 0);
568   buf->data = data;
569   buf->length = str.length();
570   buf->data_deallocator = [](void* data, size_t length) {
571     tensorflow::port::Free(data);
572   };
573   status->status = tensorflow::Status::OK();
574 }
575 
TFE_AllocateHostTensor(TFE_Context * ctx,TF_DataType dtype,const int64_t * dims,int num_dims,TF_Status * status)576 TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
577                                   const int64_t* dims, int num_dims,
578                                   TF_Status* status) {
579   std::vector<tensorflow::int64> dimvec(num_dims);
580   for (int i = 0; i < num_dims; ++i) {
581     dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
582   }
583 
584   if (ctx == nullptr) {
585     status->status = tensorflow::errors::InvalidArgument("Invalid Context");
586     return nullptr;
587   }
588 
589   tensorflow::AbstractTensorInterface* t =
590       tensorflow::unwrap(ctx)->CreateTensor(
591           static_cast<tensorflow::DataType>(dtype), dimvec);
592 
593   if (t == nullptr) {
594     status->status =
595         tensorflow::errors::InvalidArgument("Unsupported dtype: ", dtype);
596     return nullptr;
597   }
598 
599   return new TF_Tensor{t};
600 }
601 
TFE_NewTensorHandleFromTensor(TFE_Context * ctx,TF_Tensor * t,TF_Status * status)602 TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
603                                                 TF_Status* status) {
604   return tensorflow::wrap(
605       tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
606 }
607 
TFE_CreatePackedTensorHandle(TFE_Context * ctx,TFE_TensorHandle ** handles,int * num_handles,TF_Status * status)608 TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
609                                                TFE_TensorHandle** handles,
610                                                int* num_handles,
611                                                TF_Status* status) {
612   std::vector<tensorflow::TensorHandle*> tensor_handles;
613   tensor_handles.reserve(*num_handles);
614   for (int i = 0; i < *num_handles; ++i) {
615     tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
616         tensorflow::unwrap(handles[i]);
617     if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
618       // One of the inputs we're trying to pack is on a custom device. We'll let
619       // the first custom device we see handle all of the packing.
620       auto* custom_device_handle =
621           tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
622               unwrapped_handle);
623       tensorflow::ImmediateExecutionTensorHandle* result;
624       status->status = custom_device_handle->device()->Pack(
625           absl::Span<tensorflow::ImmediateExecutionTensorHandle*>(
626               tensorflow::unwrap(handles), *num_handles),
627           &result);
628       return tensorflow::wrap(result);
629     }
630     tensor_handles.push_back(
631         tensorflow::TensorHandleFromInterface(unwrapped_handle));
632   }
633   tensorflow::EagerContext* context =
634       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
635   tensorflow::TensorHandle* handle = nullptr;
636   status->status = tensorflow::TensorHandle::CreatePackedHandle(
637       std::move(tensor_handles), context, &handle);
638   return tensorflow::wrap(handle);
639 }
640 
TFE_ContextSetSoftDevicePlacement(TFE_Context * ctx,unsigned char enable,TF_Status * status)641 void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
642                                        TF_Status* status) {
643   tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
644 }
645 
TFE_ContextSetLogDevicePlacement(TFE_Context * ctx,unsigned char enable,TF_Status * status)646 void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
647                                       TF_Status* status) {
648   tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
649 }
650 
TFE_TensorHandleDeviceType(TFE_TensorHandle * h,TF_Status * status)651 const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) {
652   if (h == nullptr) {
653     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
654     return nullptr;
655   }
656   return tensorflow::unwrap(h)->DeviceType(&status->status);
657 }
658 
TFE_TensorHandleDeviceID(TFE_TensorHandle * h,TF_Status * status)659 int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
660   if (h == nullptr) {
661     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
662     return -1;
663   }
664   return tensorflow::unwrap(h)->DeviceId(&status->status);
665 }
666 
TFE_GetExecutedOpNames(TFE_Context * ctx,TF_Buffer * buf,TF_Status * status)667 void TFE_GetExecutedOpNames(TFE_Context* ctx, TF_Buffer* buf,
668                             TF_Status* status) {
669   const std::vector<std::string>& op_names =
670       tensorflow::unwrap(ctx)->GetLoggedOpsTestonly();
671 
672   std::ostringstream op_names_oss;
673   for (const auto& op : op_names) {
674     op_names_oss << op << ", ";
675   }
676   const std::string& op_names_str = op_names_oss.str();
677   void* data = tensorflow::port::Malloc(op_names_str.length());
678   op_names_str.copy(static_cast<char*>(data), op_names_str.length(), 0);
679   buf->data = data;
680   buf->length = op_names_str.length();
681   buf->data_deallocator = [](void* data, size_t length) {
682     tensorflow::port::Free(data);
683   };
684   status->status = tensorflow::Status::OK();
685 }
686