• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api_experimental.h"
17 
18 #include <vector>
19 
20 #include "absl/strings/match.h"
21 #include "tensorflow/c/c_api.h"
22 #include "tensorflow/c/eager/c_api_internal.h"
23 #include "tensorflow/c/eager/tfe_context_internal.h"
24 #include "tensorflow/c/eager/tfe_op_internal.h"
25 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
26 #include "tensorflow/c/tf_status_helper.h"
27 #include "tensorflow/core/common_runtime/composite_device.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
30 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
31 #include "tensorflow/core/lib/monitoring/counter.h"
32 #include "tensorflow/core/lib/monitoring/gauge.h"
33 #include "tensorflow/core/lib/monitoring/sampler.h"
34 #include "tensorflow/core/platform/casts.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/strcat.h"
38 
39 using tensorflow::string;
40 
TFE_OpReset(TFE_Op * op_to_reset,const char * op_or_function_name,const char * raw_device_name,TF_Status * status)41 void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
42                  const char* raw_device_name, TF_Status* status) {
43   if (op_to_reset) {
44     tensorflow::ImmediateExecutionOperation* op =
45         tensorflow::unwrap(op_to_reset);
46     op->Clear();
47     status->status = op->Reset(op_or_function_name, raw_device_name);
48   } else {
49     TF_SetStatus(status, TF_INVALID_ARGUMENT,
50                  "op_to_reset should not be nullptr");
51   }
52 }
53 
TFE_ContextEnableGraphCollection(TFE_Context * ctx)54 void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
55   tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
56 }
57 
TFE_ContextDisableGraphCollection(TFE_Context * ctx)58 void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
59   tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
60 }
61 
TFE_GetContextId(TFE_Context * ctx)62 uint64_t TFE_GetContextId(TFE_Context* ctx) {
63   tensorflow::EagerContext* context =
64       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
65   return context->GetContextId();
66 }
67 
TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell * cell,int64_t value)68 void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
69                                           int64_t value) {
70   cell->cell.IncrementBy(value);
71 }
72 
TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell * cell)73 int64_t TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell* cell) {
74   return cell->cell.value();
75 }
76 
TFE_MonitoringNewCounter0(const char * name,TF_Status * status,const char * description)77 TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(const char* name,
78                                                   TF_Status* status,
79                                                   const char* description) {
80   auto* result = new TFE_MonitoringCounter0({name, description});
81   Set_TF_Status_from_Status(status, result->counter->GetStatus());
82   if (!result->counter->GetStatus().ok()) {
83     delete result;
84     return nullptr;
85   }
86   return result;
87 }
88 
TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0 * counter)89 void TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0* counter) {
90   delete counter;
91 }
92 
TFE_MonitoringGetCellCounter0(TFE_MonitoringCounter0 * counter)93 TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
94     TFE_MonitoringCounter0* counter) {
95   return static_cast<TFE_MonitoringCounterCell*>(
96       static_cast<void*>(counter->counter->GetCell()));
97 }
98 
TFE_MonitoringNewCounter1(const char * name,TF_Status * status,const char * description,const char * label1)99 TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(const char* name,
100                                                   TF_Status* status,
101                                                   const char* description,
102                                                   const char* label1) {
103   auto* result = new TFE_MonitoringCounter1({name, description, label1});
104   Set_TF_Status_from_Status(status, result->counter->GetStatus());
105   if (!result->counter->GetStatus().ok()) {
106     delete result;
107     return nullptr;
108   }
109   return result;
110 }
111 
TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1 * counter)112 void TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1* counter) {
113   delete counter;
114 }
115 
TFE_MonitoringGetCellCounter1(TFE_MonitoringCounter1 * counter,const char * label1)116 TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
117     TFE_MonitoringCounter1* counter, const char* label1) {
118   return static_cast<TFE_MonitoringCounterCell*>(
119       static_cast<void*>(counter->counter->GetCell(label1)));
120 }
121 
TFE_MonitoringNewCounter2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)122 TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(const char* name,
123                                                   TF_Status* status,
124                                                   const char* description,
125                                                   const char* label1,
126                                                   const char* label2) {
127   auto* result =
128       new TFE_MonitoringCounter2({name, description, label1, label2});
129   Set_TF_Status_from_Status(status, result->counter->GetStatus());
130   if (!result->counter->GetStatus().ok()) {
131     delete result;
132     return nullptr;
133   }
134   return result;
135 }
136 
TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2 * counter)137 void TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2* counter) {
138   delete counter;
139 }
140 
TFE_MonitoringGetCellCounter2(TFE_MonitoringCounter2 * counter,const char * label1,const char * label2)141 TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
142     TFE_MonitoringCounter2* counter, const char* label1, const char* label2) {
143   return static_cast<TFE_MonitoringCounterCell*>(
144       static_cast<void*>(counter->counter->GetCell(label1, label2)));
145 }
146 
TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell * cell,int64_t value)147 void TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell* cell,
148                                    int64_t value) {
149   cell->cell.Set(value);
150 }
151 
TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell * cell)152 int64_t TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell* cell) {
153   return cell->cell.value();
154 }
155 
TFE_MonitoringNewIntGauge0(const char * name,TF_Status * status,const char * description)156 TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(const char* name,
157                                                     TF_Status* status,
158                                                     const char* description) {
159   auto* result = new TFE_MonitoringIntGauge0({name, description});
160   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
161   if (!result->gauge->GetStatus().ok()) {
162     delete result;
163     return nullptr;
164   }
165   return result;
166 }
167 
TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0 * gauge)168 void TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0* gauge) {
169   delete gauge;
170 }
171 
TFE_MonitoringGetCellIntGauge0(TFE_MonitoringIntGauge0 * gauge)172 TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge0(
173     TFE_MonitoringIntGauge0* gauge) {
174   return static_cast<TFE_MonitoringIntGaugeCell*>(
175       static_cast<void*>(gauge->gauge->GetCell()));
176 }
177 
TFE_MonitoringNewIntGauge1(const char * name,TF_Status * status,const char * description,const char * label1)178 TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(const char* name,
179                                                     TF_Status* status,
180                                                     const char* description,
181                                                     const char* label1) {
182   auto* result = new TFE_MonitoringIntGauge1({name, description, label1});
183   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
184   if (!result->gauge->GetStatus().ok()) {
185     delete result;
186     return nullptr;
187   }
188   return result;
189 }
190 
TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1 * gauge)191 void TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1* gauge) {
192   delete gauge;
193 }
194 
TFE_MonitoringGetCellIntGauge1(TFE_MonitoringIntGauge1 * gauge,const char * label1)195 TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge1(
196     TFE_MonitoringIntGauge1* gauge, const char* label1) {
197   return static_cast<TFE_MonitoringIntGaugeCell*>(
198       static_cast<void*>(gauge->gauge->GetCell(label1)));
199 }
200 
TFE_MonitoringNewIntGauge2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)201 TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(const char* name,
202                                                     TF_Status* status,
203                                                     const char* description,
204                                                     const char* label1,
205                                                     const char* label2) {
206   auto* result =
207       new TFE_MonitoringIntGauge2({name, description, label1, label2});
208   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
209   if (!result->gauge->GetStatus().ok()) {
210     delete result;
211     return nullptr;
212   }
213   return result;
214 }
215 
TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2 * gauge)216 void TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2* gauge) {
217   delete gauge;
218 }
219 
TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2 * gauge,const char * label1,const char * label2)220 TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge2(
221     TFE_MonitoringIntGauge2* gauge, const char* label1, const char* label2) {
222   return static_cast<TFE_MonitoringIntGaugeCell*>(
223       static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
224 }
225 
TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell * cell,const char * value)226 void TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell* cell,
227                                       const char* value) {
228   cell->cell.Set({value});
229 }
230 
TFE_MonitoringStringGaugeCellValue(TFE_MonitoringStringGaugeCell * cell,TF_Buffer * buf)231 const void TFE_MonitoringStringGaugeCellValue(
232     TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf) {
233   tensorflow::string value = cell->cell.value();
234   void* data = tensorflow::port::Malloc(value.length());
235   value.copy(static_cast<char*>(data), value.length(), 0);
236   buf->data = data;
237   buf->length = value.length();
238   buf->data_deallocator = [](void* data, size_t length) {
239     tensorflow::port::Free(data);
240   };
241 }
242 
TFE_MonitoringNewStringGauge0(const char * name,TF_Status * status,const char * description)243 TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
244     const char* name, TF_Status* status, const char* description) {
245   auto* result = new TFE_MonitoringStringGauge0({name, description});
246   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
247   if (!result->gauge->GetStatus().ok()) {
248     delete result;
249     return nullptr;
250   }
251   return result;
252 }
253 
TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0 * gauge)254 void TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0* gauge) {
255   delete gauge;
256 }
257 
TFE_MonitoringGetCellStringGauge0(TFE_MonitoringStringGauge0 * gauge)258 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge0(
259     TFE_MonitoringStringGauge0* gauge) {
260   return static_cast<TFE_MonitoringStringGaugeCell*>(
261       static_cast<void*>(gauge->gauge->GetCell()));
262 }
263 
TFE_MonitoringNewStringGauge1(const char * name,TF_Status * status,const char * description,const char * label1)264 TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
265     const char* name, TF_Status* status, const char* description,
266     const char* label1) {
267   auto* result = new TFE_MonitoringStringGauge1({name, description, label1});
268   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
269   if (!result->gauge->GetStatus().ok()) {
270     delete result;
271     return nullptr;
272   }
273   return result;
274 }
275 
TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1 * gauge)276 void TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1* gauge) {
277   delete gauge;
278 }
279 
TFE_MonitoringGetCellStringGauge1(TFE_MonitoringStringGauge1 * gauge,const char * label1)280 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge1(
281     TFE_MonitoringStringGauge1* gauge, const char* label1) {
282   return static_cast<TFE_MonitoringStringGaugeCell*>(
283       static_cast<void*>(gauge->gauge->GetCell(label1)));
284 }
285 
TFE_MonitoringNewStringGauge2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)286 TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
287     const char* name, TF_Status* status, const char* description,
288     const char* label1, const char* label2) {
289   auto* result =
290       new TFE_MonitoringStringGauge2({name, description, label1, label2});
291   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
292   if (!result->gauge->GetStatus().ok()) {
293     delete result;
294     return nullptr;
295   }
296   return result;
297 }
298 
TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2 * gauge)299 void TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2* gauge) {
300   delete gauge;
301 }
302 
TFE_MonitoringGetCellStringGauge2(TFE_MonitoringStringGauge2 * gauge,const char * label1,const char * label2)303 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge2(
304     TFE_MonitoringStringGauge2* gauge, const char* label1, const char* label2) {
305   return static_cast<TFE_MonitoringStringGaugeCell*>(
306       static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
307 }
308 
TFE_MonitoringNewStringGauge3(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2,const char * label3)309 TFE_MonitoringStringGauge3* TFE_MonitoringNewStringGauge3(
310     const char* name, TF_Status* status, const char* description,
311     const char* label1, const char* label2, const char* label3) {
312   auto* result = new TFE_MonitoringStringGauge3(
313       {name, description, label1, label2, label3});
314   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
315   if (!result->gauge->GetStatus().ok()) {
316     delete result;
317     return nullptr;
318   }
319   return result;
320 }
321 
TFE_MonitoringDeleteStringGauge3(TFE_MonitoringStringGauge3 * gauge)322 void TFE_MonitoringDeleteStringGauge3(TFE_MonitoringStringGauge3* gauge) {
323   delete gauge;
324 }
325 
TFE_MonitoringGetCellStringGauge3(TFE_MonitoringStringGauge3 * gauge,const char * label1,const char * label2,const char * label3)326 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge3(
327     TFE_MonitoringStringGauge3* gauge, const char* label1, const char* label2,
328     const char* label3) {
329   return static_cast<TFE_MonitoringStringGaugeCell*>(
330       static_cast<void*>(gauge->gauge->GetCell(label1, label2, label3)));
331 }
332 
TFE_MonitoringNewStringGauge4(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2,const char * label3,const char * label4)333 TFE_MonitoringStringGauge4* TFE_MonitoringNewStringGauge4(
334     const char* name, TF_Status* status, const char* description,
335     const char* label1, const char* label2, const char* label3,
336     const char* label4) {
337   auto* result = new TFE_MonitoringStringGauge4(
338       {name, description, label1, label2, label3, label4});
339   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
340   if (!result->gauge->GetStatus().ok()) {
341     delete result;
342     return nullptr;
343   }
344   return result;
345 }
346 
TFE_MonitoringDeleteStringGauge4(TFE_MonitoringStringGauge4 * gauge)347 void TFE_MonitoringDeleteStringGauge4(TFE_MonitoringStringGauge4* gauge) {
348   delete gauge;
349 }
350 
TFE_MonitoringGetCellStringGauge4(TFE_MonitoringStringGauge4 * gauge,const char * label1,const char * label2,const char * label3,const char * label4)351 TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge4(
352     TFE_MonitoringStringGauge4* gauge, const char* label1, const char* label2,
353     const char* label3, const char* label4) {
354   return static_cast<TFE_MonitoringStringGaugeCell*>(static_cast<void*>(
355       gauge->gauge->GetCell(label1, label2, label3, label4)));
356 }
357 
TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell * cell,bool value)358 void TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell* cell,
359                                     bool value) {
360   cell->cell.Set(value);
361 }
362 
TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell * cell)363 bool TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell* cell) {
364   return cell->cell.value();
365 }
366 
TFE_MonitoringNewBoolGauge0(const char * name,TF_Status * status,const char * description)367 TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(const char* name,
368                                                       TF_Status* status,
369                                                       const char* description) {
370   auto* result = new TFE_MonitoringBoolGauge0({name, description});
371   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
372   if (!result->gauge->GetStatus().ok()) {
373     delete result;
374     return nullptr;
375   }
376   return result;
377 }
378 
TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0 * gauge)379 void TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0* gauge) {
380   delete gauge;
381 }
382 
TFE_MonitoringGetCellBoolGauge0(TFE_MonitoringBoolGauge0 * gauge)383 TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge0(
384     TFE_MonitoringBoolGauge0* gauge) {
385   return static_cast<TFE_MonitoringBoolGaugeCell*>(
386       static_cast<void*>(gauge->gauge->GetCell()));
387 }
388 
TFE_MonitoringNewBoolGauge1(const char * name,TF_Status * status,const char * description,const char * label1)389 TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(const char* name,
390                                                       TF_Status* status,
391                                                       const char* description,
392                                                       const char* label1) {
393   auto* result = new TFE_MonitoringBoolGauge1({name, description, label1});
394   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
395   if (!result->gauge->GetStatus().ok()) {
396     delete result;
397     return nullptr;
398   }
399   return result;
400 }
401 
TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1 * gauge)402 void TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1* gauge) {
403   delete gauge;
404 }
405 
TFE_MonitoringGetCellBoolGauge1(TFE_MonitoringBoolGauge1 * gauge,const char * label1)406 TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge1(
407     TFE_MonitoringBoolGauge1* gauge, const char* label1) {
408   return static_cast<TFE_MonitoringBoolGaugeCell*>(
409       static_cast<void*>(gauge->gauge->GetCell(label1)));
410 }
411 
TFE_MonitoringNewBoolGauge2(const char * name,TF_Status * status,const char * description,const char * label1,const char * label2)412 TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(const char* name,
413                                                       TF_Status* status,
414                                                       const char* description,
415                                                       const char* label1,
416                                                       const char* label2) {
417   auto* result =
418       new TFE_MonitoringBoolGauge2({name, description, label1, label2});
419   Set_TF_Status_from_Status(status, result->gauge->GetStatus());
420   if (!result->gauge->GetStatus().ok()) {
421     delete result;
422     return nullptr;
423   }
424   return result;
425 }
426 
TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2 * gauge)427 void TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2* gauge) {
428   delete gauge;
429 }
430 
TFE_MonitoringGetCellBoolGauge2(TFE_MonitoringBoolGauge2 * gauge,const char * label1,const char * label2)431 TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge2(
432     TFE_MonitoringBoolGauge2* gauge, const char* label1, const char* label2) {
433   return static_cast<TFE_MonitoringBoolGaugeCell*>(
434       static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
435 }
436 
TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell * cell,double value)437 void TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell* cell,
438                                   double value) {
439   cell->cell.Add(value);
440 }
441 
TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell * cell,TF_Buffer * buf)442 void TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell* cell,
443                                     TF_Buffer* buf) {
444   string content;
445   cell->cell.value().SerializeToString(&content);
446   void* data = tensorflow::port::Malloc(content.length());
447   content.copy(static_cast<char*>(data), content.length(), 0);
448   buf->data = data;
449   buf->length = content.length();
450   buf->data_deallocator = [](void* data, size_t length) {
451     tensorflow::port::Free(data);
452   };
453 }
454 
TFE_MonitoringNewExponentialBuckets(double scale,double growth_factor,int bucket_count)455 TFE_MonitoringBuckets* TFE_MonitoringNewExponentialBuckets(double scale,
456                                                            double growth_factor,
457                                                            int bucket_count) {
458   return new TFE_MonitoringBuckets([scale, growth_factor, bucket_count]() {
459     return tensorflow::monitoring::Buckets::Exponential(scale, growth_factor,
460                                                         bucket_count);
461   });
462 }
463 
TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets * buckets)464 void TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets* buckets) {
465   delete buckets;
466 }
467 
TFE_MonitoringNewSampler0(const char * name,TFE_MonitoringBuckets * buckets,TF_Status * status,const char * description)468 TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
469     const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
470     const char* description) {
471   auto* result = new TFE_MonitoringSampler0(
472       {name, buckets->create_buckets(), description});
473   Set_TF_Status_from_Status(status, result->sampler->GetStatus());
474   if (!result->sampler->GetStatus().ok()) {
475     delete result;
476     return nullptr;
477   }
478   return result;
479 }
480 
TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0 * sampler)481 void TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0* sampler) {
482   delete sampler;
483 }
484 
TFE_MonitoringGetCellSampler0(TFE_MonitoringSampler0 * sampler)485 TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
486     TFE_MonitoringSampler0* sampler) {
487   return static_cast<TFE_MonitoringSamplerCell*>(
488       static_cast<void*>(sampler->sampler->GetCell()));
489 }
490 
TFE_MonitoringNewSampler1(const char * name,TFE_MonitoringBuckets * buckets,TF_Status * status,const char * description,const char * label1)491 TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
492     const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
493     const char* description, const char* label1) {
494   auto* result = new TFE_MonitoringSampler1(
495       {name, buckets->create_buckets(), description, label1});
496   Set_TF_Status_from_Status(status, result->sampler->GetStatus());
497   if (!result->sampler->GetStatus().ok()) {
498     delete result;
499     return nullptr;
500   }
501   return result;
502 }
503 
TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1 * sampler)504 void TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1* sampler) {
505   delete sampler;
506 }
507 
TFE_MonitoringGetCellSampler1(TFE_MonitoringSampler1 * sampler,const char * label1)508 TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
509     TFE_MonitoringSampler1* sampler, const char* label1) {
510   return static_cast<TFE_MonitoringSamplerCell*>(
511       static_cast<void*>(sampler->sampler->GetCell(label1)));
512 }
513 
TFE_MonitoringNewSampler2(const char * name,TFE_MonitoringBuckets * buckets,TF_Status * status,const char * description,const char * label1,const char * label2)514 TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
515     const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
516     const char* description, const char* label1, const char* label2) {
517   auto* result = new TFE_MonitoringSampler2(
518       {name, buckets->create_buckets(), description, label1, label2});
519   Set_TF_Status_from_Status(status, result->sampler->GetStatus());
520   if (!result->sampler->GetStatus().ok()) {
521     delete result;
522     return nullptr;
523   }
524   return result;
525 }
526 
TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2 * sampler)527 void TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2* sampler) {
528   delete sampler;
529 }
530 
TFE_MonitoringGetCellSampler2(TFE_MonitoringSampler2 * sampler,const char * label1,const char * label2)531 TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
532     TFE_MonitoringSampler2* sampler, const char* label1, const char* label2) {
533   return static_cast<TFE_MonitoringSamplerCell*>(
534       static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
535 }
536 
TFE_ContextOptionsSetTfrt(TFE_ContextOptions * options,bool use_tfrt)537 void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) {
538   options->use_tfrt = use_tfrt;
539 }
540 
TFE_ContextOptionsSetTfrtDistributedRuntime(TFE_ContextOptions * options,bool use_tfrt_distributed_runtime)541 void TFE_ContextOptionsSetTfrtDistributedRuntime(
542     TFE_ContextOptions* options, bool use_tfrt_distributed_runtime) {
543   options->use_tfrt_distributed_runtime = use_tfrt_distributed_runtime;
544 }
545 
TFE_NewCancellationManager()546 TFE_CancellationManager* TFE_NewCancellationManager() {
547   return tensorflow::wrap(new tensorflow::CancellationManager);
548 }
549 
TFE_CancellationManagerStartCancel(TFE_CancellationManager * cancellation_manager)550 void TFE_CancellationManagerStartCancel(
551     TFE_CancellationManager* cancellation_manager) {
552   tensorflow::unwrap(cancellation_manager)->StartCancel();
553 }
554 
TFE_CancellationManagerIsCancelled(TFE_CancellationManager * cancellation_manager)555 bool TFE_CancellationManagerIsCancelled(
556     TFE_CancellationManager* cancellation_manager) {
557   return tensorflow::unwrap(cancellation_manager)->IsCancelled();
558 }
559 
TFE_DeleteCancellationManager(TFE_CancellationManager * cancellation_manager)560 void TFE_DeleteCancellationManager(
561     TFE_CancellationManager* cancellation_manager) {
562   delete tensorflow::unwrap(cancellation_manager);
563 }
564 
TFE_OpSetCancellationManager(TFE_Op * op,TFE_CancellationManager * cancellation_manager,TF_Status * status)565 void TFE_OpSetCancellationManager(TFE_Op* op,
566                                   TFE_CancellationManager* cancellation_manager,
567                                   TF_Status* status) {
568   tensorflow::unwrap(op)->SetCancellationManager(
569       tensorflow::unwrap(cancellation_manager));
570   status->status = ::tensorflow::OkStatus();
571 }
572 
TFE_NewExecutor(bool is_async,bool enable_streaming_enqueue)573 TFE_Executor* TFE_NewExecutor(bool is_async, bool enable_streaming_enqueue) {
574   return new TFE_Executor(is_async, enable_streaming_enqueue);
575 }
576 
TFE_DeleteExecutor(TFE_Executor * executor)577 void TFE_DeleteExecutor(TFE_Executor* executor) { delete executor; }
578 
TFE_ExecutorIsAsync(TFE_Executor * executor)579 bool TFE_ExecutorIsAsync(TFE_Executor* executor) {
580   return executor->executor()->Async();
581 }
582 
TFE_ExecutorWaitForAllPendingNodes(TFE_Executor * executor,TF_Status * status)583 void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor* executor,
584                                         TF_Status* status) {
585   status->status = executor->executor()->WaitForAllPendingNodes();
586 }
587 
TFE_ExecutorClearError(TFE_Executor * executor)588 void TFE_ExecutorClearError(TFE_Executor* executor) {
589   executor->executor()->ClearError();
590 }
591 
TFE_ContextSetExecutorForThread(TFE_Context * ctx,TFE_Executor * executor)592 void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
593   tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
594 }
595 
TFE_ContextGetExecutorForThread(TFE_Context * ctx)596 TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
597   return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
598 }
599 
TFE_HostAddressSpace(TFE_Context * ctx,TF_Buffer * buf)600 void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
601   auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
602       tensorflow::unwrap(ctx)->HostCPUParsedName());
603   auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
604   void* data = tensorflow::port::Malloc(str.length());
605   str.copy(static_cast<char*>(data), str.length(), 0);
606   buf->data = data;
607   buf->length = str.length();
608   buf->data_deallocator = [](void* data, size_t length) {
609     tensorflow::port::Free(data);
610   };
611 }
612 
TFE_ContextGetFunctionDef(TFE_Context * ctx,const char * function_name,TF_Buffer * buf,TF_Status * status)613 void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
614                                TF_Buffer* buf, TF_Status* status) {
615   auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
616   if (function_def == nullptr) {
617     status->status = tensorflow::errors::NotFound(
618         "Unable to find FunctionDef with name: ", function_name);
619     return;
620   }
621   string str = function_def->SerializeAsString();
622   void* data = tensorflow::port::Malloc(str.length());
623   str.copy(static_cast<char*>(data), str.length(), 0);
624   buf->data = data;
625   buf->length = str.length();
626   buf->data_deallocator = [](void* data, size_t length) {
627     tensorflow::port::Free(data);
628   };
629   status->status = ::tensorflow::OkStatus();
630 }
631 
TFE_AllocateHostTensor(TFE_Context * ctx,TF_DataType dtype,const int64_t * dims,int num_dims,TF_Status * status)632 TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
633                                   const int64_t* dims, int num_dims,
634                                   TF_Status* status) {
635   std::vector<int64_t> dimvec(num_dims);
636   for (int i = 0; i < num_dims; ++i) {
637     dimvec[i] = static_cast<int64_t>(dims[i]);
638   }
639 
640   if (ctx == nullptr) {
641     status->status = tensorflow::errors::InvalidArgument("Invalid Context");
642     return nullptr;
643   }
644 
645   tensorflow::AbstractTensorInterface* t =
646       tensorflow::unwrap(ctx)->CreateTensor(
647           static_cast<tensorflow::DataType>(dtype), dimvec);
648 
649   if (t == nullptr) {
650     status->status =
651         tensorflow::errors::InvalidArgument("Unsupported dtype: ", dtype);
652     return nullptr;
653   }
654 
655   return new TF_Tensor{t};
656 }
657 
TFE_NewTensorHandleFromTensor(TFE_Context * ctx,TF_Tensor * t,TF_Status * status)658 TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
659                                                 TF_Status* status) {
660   return tensorflow::wrap(
661       tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
662 }
663 
TFE_CreatePackedTensorHandle(TFE_Context * ctx,TFE_TensorHandle ** handles,int * num_handles,TF_Status * status)664 TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
665                                                TFE_TensorHandle** handles,
666                                                int* num_handles,
667                                                TF_Status* status) {
668   std::vector<tensorflow::TensorHandle*> tensor_handles;
669   tensor_handles.reserve(*num_handles);
670   for (int i = 0; i < *num_handles; ++i) {
671     tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
672         tensorflow::unwrap(handles[i]);
673     if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
674       // One of the inputs we're trying to pack is on a custom device. We'll let
675       // the first custom device we see handle all of the packing.
676       auto* custom_device_handle =
677           tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
678               unwrapped_handle);
679       tensorflow::ImmediateExecutionTensorHandle* result;
680       status->status = custom_device_handle->device()->Pack(
681           absl::Span<tensorflow::ImmediateExecutionTensorHandle*>(
682               tensorflow::unwrap(handles), *num_handles),
683           &result);
684       return tensorflow::wrap(result);
685     }
686     tensor_handles.push_back(
687         tensorflow::TensorHandleFromInterface(unwrapped_handle));
688   }
689   tensorflow::EagerContext* context =
690       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
691   tensorflow::TensorHandle* handle = nullptr;
692   status->status = tensorflow::TensorHandle::CreatePackedHandle(
693       std::move(tensor_handles), context, &handle);
694   return tensorflow::wrap(handle);
695 }
696 
TFE_ContextSetSoftDevicePlacement(TFE_Context * ctx,unsigned char enable,TF_Status * status)697 void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
698                                        TF_Status* status) {
699   tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
700 }
701 
TFE_ContextSetLogDevicePlacement(TFE_Context * ctx,unsigned char enable,TF_Status * status)702 void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
703                                       TF_Status* status) {
704   tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
705 }
706 
TFE_ContextSetRunEagerOpAsFunction(TFE_Context * ctx,unsigned char enable,TF_Status * status)707 void TFE_ContextSetRunEagerOpAsFunction(TFE_Context* ctx, unsigned char enable,
708                                         TF_Status* status) {
709   tensorflow::unwrap(ctx)->SetRunEagerOpAsFunction(enable);
710 }
711 
TFE_ContextSetJitCompileRewrite(TFE_Context * ctx,unsigned char enable,TF_Status * status)712 void TFE_ContextSetJitCompileRewrite(TFE_Context* ctx, unsigned char enable,
713                                      TF_Status* status) {
714   tensorflow::unwrap(ctx)->SetJitCompileRewrite(enable);
715 }
716 
TFE_TensorHandleDeviceType(TFE_TensorHandle * h,TF_Status * status)717 const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) {
718   if (h == nullptr) {
719     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
720     return nullptr;
721   }
722   return tensorflow::unwrap(h)->DeviceType(&status->status);
723 }
724 
TFE_TensorHandleDeviceID(TFE_TensorHandle * h,TF_Status * status)725 int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
726   if (h == nullptr) {
727     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
728     return -1;
729   }
730   return tensorflow::unwrap(h)->DeviceId(&status->status);
731 }
732 
TFE_TensorHandleGetStatus(TFE_TensorHandle * h,TF_Status * status)733 TF_CAPI_EXPORT extern void TFE_TensorHandleGetStatus(TFE_TensorHandle* h,
734                                                      TF_Status* status) {
735   status->status = tensorflow::unwrap(h)->TensorHandleStatus();
736 }
737 
TFE_GetExecutedOpNames(TFE_Context * ctx,TF_Buffer * buf,TF_Status * status)738 void TFE_GetExecutedOpNames(TFE_Context* ctx, TF_Buffer* buf,
739                             TF_Status* status) {
740   const std::vector<std::string>& op_names =
741       tensorflow::unwrap(ctx)->GetLoggedOpsTestonly();
742 
743   std::ostringstream op_names_oss;
744   for (const auto& op : op_names) {
745     op_names_oss << op << ", ";
746   }
747   const std::string& op_names_str = op_names_oss.str();
748   void* data = tensorflow::port::Malloc(op_names_str.length());
749   op_names_str.copy(static_cast<char*>(data), op_names_str.length(), 0);
750   buf->data = data;
751   buf->length = op_names_str.length();
752   buf->data_deallocator = [](void* data, size_t length) {
753     tensorflow::port::Free(data);
754   };
755   status->status = ::tensorflow::OkStatus();
756 }
757 
TFE_SetLogicalCpuDevices(TFE_Context * ctx,int num_cpus,const char * prefix,TF_Status * status)758 void TFE_SetLogicalCpuDevices(TFE_Context* ctx, int num_cpus,
759                               const char* prefix, TF_Status* status) {
760   std::vector<std::unique_ptr<tensorflow::Device>> devices;
761 
762   if (prefix == nullptr || strlen(prefix) == 0)
763     prefix = "/job:localhost/replica:0/task:0";
764 
765   tensorflow::SessionOptions sess_options;
766   (*sess_options.config.mutable_device_count())["CPU"] = num_cpus;
767   status->status =
768       tensorflow::DeviceFactory::AddCpuDevices(sess_options, prefix, &devices);
769 
770   // Remove the device that has the host device name since host device is alreay
771   // in an initialized context.
772   for (auto d = devices.begin(); d != devices.end();) {
773     if (absl::StrContains(d->get()->name(), "CPU:0")) {
774       d = devices.erase(d);
775     } else {
776       ++d;
777     }
778   }
779 
780   status->status = tensorflow::unwrap(ctx)->AddDevices(std::move(devices));
781 }
782 
TFE_InsertConfigKeyValue(TFE_Context * ctx,const char * key,const char * value,TF_Status * status)783 void TFE_InsertConfigKeyValue(TFE_Context* ctx, const char* key,
784                               const char* value, TF_Status* status) {
785   tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
786       tensorflow::unwrap(ctx)->GetDistributedManager();
787   tensorflow::CoordinationServiceAgent* coord_agent =
788       dist_mgr->GetCoordinationServiceAgent();
789   if (coord_agent == nullptr) {
790     status->status = tensorflow::errors::FailedPrecondition(
791         "Coordination service agent is not enabled.");
792     return;
793   }
794   status->status = coord_agent->InsertKeyValue(key, value);
795 }
796 
TFE_GetConfigKeyValue(TFE_Context * ctx,const char * key,TF_Buffer * value_buf,TF_Status * status)797 void TFE_GetConfigKeyValue(TFE_Context* ctx, const char* key,
798                            TF_Buffer* value_buf, TF_Status* status) {
799   tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
800       tensorflow::unwrap(ctx)->GetDistributedManager();
801   tensorflow::CoordinationServiceAgent* coord_agent =
802       dist_mgr->GetCoordinationServiceAgent();
803   if (coord_agent == nullptr) {
804     status->status = tensorflow::errors::FailedPrecondition(
805         "Coordination service is not enabled.");
806     return;
807   }
808   auto status_or_value = coord_agent->GetKeyValue(key);
809   status->status = status_or_value.status();
810   if (!status_or_value.ok()) return;
811 
812   const std::string& value_string = status_or_value.ValueOrDie();
813   void* data = tensorflow::port::Malloc(value_string.length());
814   value_string.copy(static_cast<char*>(data), value_string.length(), 0);
815   value_buf->data = data;
816   value_buf->length = value_string.length();
817   value_buf->data_deallocator = [](void* data, size_t length) {
818     tensorflow::port::Free(data);
819   };
820 }
821 
TFE_DeleteConfigKeyValue(TFE_Context * ctx,const char * key,TF_Status * status)822 void TFE_DeleteConfigKeyValue(TFE_Context* ctx, const char* key,
823                               TF_Status* status) {
824   tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
825       tensorflow::unwrap(ctx)->GetDistributedManager();
826   tensorflow::CoordinationServiceAgent* coord_agent =
827       dist_mgr->GetCoordinationServiceAgent();
828   if (coord_agent == nullptr) {
829     status->status = tensorflow::errors::FailedPrecondition(
830         "Coordination service is not enabled.");
831     return;
832   }
833   status->status = coord_agent->DeleteKeyValue(key);
834 }
835 
TFE_ReportErrorToCluster(TFE_Context * ctx,int error_code,const char * error_message,TF_Status * status)836 void TFE_ReportErrorToCluster(TFE_Context* ctx, int error_code,
837                               const char* error_message, TF_Status* status) {
838   tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
839       tensorflow::unwrap(ctx)->GetDistributedManager();
840   tensorflow::CoordinationServiceAgent* coord_agent =
841       dist_mgr->GetCoordinationServiceAgent();
842   if (coord_agent == nullptr) {
843     status->status = tensorflow::errors::FailedPrecondition(
844         "Coordination service is not enabled.");
845     return;
846   }
847   tensorflow::Status s(static_cast<tensorflow::error::Code>(error_code),
848                        error_message);
849   status->status = coord_agent->ReportError(s);
850 }
851