• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_
16 #define TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_
17 
18 #include "tensorflow/c/c_api.h"
19 #include "tensorflow/c/eager/c_api.h"
20 
21 #ifdef __cplusplus
22 extern "C" {
23 #endif
24 
25 // Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This
26 // is for performance optimization by reusing an exiting unused op rather than
27 // creating a new op every time. If `raw_device_name` is `NULL` or empty, it
28 // does not set the device name. If it's not `NULL`, then it attempts to parse
29 // and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
30 // than separately calling it because if the existing op has the same
31 // `raw_device_name`, it skips parsing and just leave as it is.
32 TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
33                                        const char* op_or_function_name,
34                                        const char* raw_device_name,
35                                        TF_Status* status);
36 
37 // Enables only graph collection in RunMetadata on the functions executed from
38 // this context.
39 TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
40 
41 // Disables only graph collection in RunMetadata on the functions executed from
42 // this context.
43 TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx);
44 
45 // TODO(fishx): Move these monitoring APIs into a separate file.
46 // -----------------------------------------------------------------------------
47 // Monitoring Counter APIs.
48 // These APIs de-templated monitoring Counter for swig.
49 
50 typedef struct TFE_MonitoringCounterCell TFE_MonitoringCounterCell;
51 
52 // Atomically increments the value of the cell. The value must be non-negative.
53 TF_CAPI_EXPORT extern void TFE_MonitoringCounterCellIncrementBy(
54     TFE_MonitoringCounterCell* cell, int64_t value);
55 
56 // Retrieves the current value of the cell.
57 TF_CAPI_EXPORT extern int64_t TFE_MonitoringCounterCellValue(
58     TFE_MonitoringCounterCell* cell);
59 
60 // APIs for Counter without label.
61 typedef struct TFE_MonitoringCounter0 TFE_MonitoringCounter0;
62 // Returns a new Counter metric object. The caller should manage lifetime of
63 // the object. Using duplicate metric name will crash the program with fatal
64 // error.
65 TF_CAPI_EXPORT extern TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(
66     const char* name, TF_Status* status, const char* description);
67 // Deletes the Counter object.
68 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter0(
69     TFE_MonitoringCounter0* counter);
70 // Retrieves the cell from the Counter object. The Counter object will manage
71 // lifetime of the cell.
72 TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
73     TFE_MonitoringCounter0* counter);
74 
75 // APIs for Counter with 1 label.
76 typedef struct TFE_MonitoringCounter1 TFE_MonitoringCounter1;
77 TF_CAPI_EXPORT extern TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(
78     const char* name, TF_Status* status, const char* description,
79     const char* label1);
80 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter1(
81     TFE_MonitoringCounter1* counter);
82 TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
83     TFE_MonitoringCounter1* counter, const char* label1);
84 
85 // APIs for Counter with 2 labels.
86 typedef struct TFE_MonitoringCounter2 TFE_MonitoringCounter2;
87 TF_CAPI_EXPORT extern TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(
88     const char* name, TF_Status* status, const char* description,
89     const char* label1, const char* label2);
90 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter2(
91     TFE_MonitoringCounter2* counter);
92 TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
93     TFE_MonitoringCounter2* counter, const char* label1, const char* label2);
94 
95 // -----------------------------------------------------------------------------
96 // Monitoring Gauge APIs.
97 // These APIs de-templated monitoring Gauge for swig.
98 
99 typedef struct TFE_MonitoringIntGaugeCell TFE_MonitoringIntGaugeCell;
100 
101 // Atomically set the value of the cell.
102 TF_CAPI_EXPORT extern void TFE_MonitoringIntGaugeCellSet(
103     TFE_MonitoringIntGaugeCell* cell, int64_t value);
104 
105 // Retrieves the current value of the cell.
106 TF_CAPI_EXPORT extern int64_t TFE_MonitoringIntGaugeCellValue(
107     TFE_MonitoringIntGaugeCell* cell);
108 
109 // APIs for Int Gauge without label.
110 typedef struct TFE_MonitoringIntGauge0 TFE_MonitoringIntGauge0;
111 TF_CAPI_EXPORT extern TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(
112     const char* name, TF_Status* out_status, const char* description);
113 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge0(
114     TFE_MonitoringIntGauge0* gauge);
115 TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
116 TFE_MonitoringGetCellIntGauge0(TFE_MonitoringIntGauge0* gauge);
117 
118 // APIs for Int Gauge with 1 label.
119 typedef struct TFE_MonitoringIntGauge1 TFE_MonitoringIntGauge1;
120 TF_CAPI_EXPORT extern TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(
121     const char* name, TF_Status* out_status, const char* description,
122     const char* label1);
123 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge1(
124     TFE_MonitoringIntGauge1* gauge);
125 TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
126 TFE_MonitoringGetCellIntGauge1(TFE_MonitoringIntGauge1* gauge,
127                                const char* label1);
128 
129 // APIs for Int Gauge with 2 label.
130 typedef struct TFE_MonitoringIntGauge2 TFE_MonitoringIntGauge2;
131 TF_CAPI_EXPORT extern TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(
132     const char* name, TF_Status* out_status, const char* description,
133     const char* label1, const char* label2);
134 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge2(
135     TFE_MonitoringIntGauge2* gauge);
136 TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
137 TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2* gauge,
138                                const char* label1, const char* label2);
139 
140 typedef struct TFE_MonitoringStringGaugeCell TFE_MonitoringStringGaugeCell;
141 TF_CAPI_EXPORT extern void TFE_MonitoringStringGaugeCellSet(
142     TFE_MonitoringStringGaugeCell* cell, const char* value);
143 // Retrieves the string value and saves it in buffer.
144 TF_CAPI_EXPORT extern const void TFE_MonitoringStringGaugeCellValue(
145     TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf);
146 
147 // APIs for String Gauge without label.
148 typedef struct TFE_MonitoringStringGauge0 TFE_MonitoringStringGauge0;
149 TF_CAPI_EXPORT extern TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
150     const char* name, TF_Status* out_status, const char* description);
151 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge0(
152     TFE_MonitoringStringGauge0* gauge);
153 TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
154 TFE_MonitoringGetCellStringGauge0(TFE_MonitoringStringGauge0* gauge);
155 
156 // APIs for String Gauge with 1 label.
157 typedef struct TFE_MonitoringStringGauge1 TFE_MonitoringStringGauge1;
158 TF_CAPI_EXPORT extern TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
159     const char* name, TF_Status* out_status, const char* description,
160     const char* label1);
161 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge1(
162     TFE_MonitoringStringGauge1* gauge);
163 TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
164 TFE_MonitoringGetCellStringGauge1(TFE_MonitoringStringGauge1* gauge,
165                                   const char* label1);
166 
167 // APIs for String Gauge with 2 label.
168 typedef struct TFE_MonitoringStringGauge2 TFE_MonitoringStringGauge2;
169 TF_CAPI_EXPORT extern TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
170     const char* name, TF_Status* out_status, const char* description,
171     const char* label1, const char* label2);
172 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge2(
173     TFE_MonitoringStringGauge2* gauge);
174 TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
175 TFE_MonitoringGetCellStringGauge2(TFE_MonitoringStringGauge2* gauge,
176                                   const char* label1, const char* label2);
177 
178 typedef struct TFE_MonitoringBoolGaugeCell TFE_MonitoringBoolGaugeCell;
179 TF_CAPI_EXPORT extern void TFE_MonitoringBoolGaugeCellSet(
180     TFE_MonitoringBoolGaugeCell* cell, bool value);
181 TF_CAPI_EXPORT extern bool TFE_MonitoringBoolGaugeCellValue(
182     TFE_MonitoringBoolGaugeCell* cell);
183 
184 // APIs for Bool Gauge without label.
185 typedef struct TFE_MonitoringBoolGauge0 TFE_MonitoringBoolGauge0;
186 TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(
187     const char* name, TF_Status* out_status, const char* description);
188 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge0(
189     TFE_MonitoringBoolGauge0* gauge);
190 TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
191 TFE_MonitoringGetCellBoolGauge0(TFE_MonitoringBoolGauge0* gauge);
192 
193 // APIs for Bool Gauge with 1 label.
194 typedef struct TFE_MonitoringBoolGauge1 TFE_MonitoringBoolGauge1;
195 TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(
196     const char* name, TF_Status* out_status, const char* description,
197     const char* label1);
198 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge1(
199     TFE_MonitoringBoolGauge1* gauge);
200 TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
201 TFE_MonitoringGetCellBoolGauge1(TFE_MonitoringBoolGauge1* gauge,
202                                 const char* label1);
203 
204 // APIs for Bool Gauge with 2 label.
205 typedef struct TFE_MonitoringBoolGauge2 TFE_MonitoringBoolGauge2;
206 TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(
207     const char* name, TF_Status* out_status, const char* description,
208     const char* label1, const char* label2);
209 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge2(
210     TFE_MonitoringBoolGauge2* gauge);
211 TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
212 TFE_MonitoringGetCellBoolGauge2(TFE_MonitoringBoolGauge2* gauge,
213                                 const char* label1, const char* label2);
214 
215 // -----------------------------------------------------------------------------
216 // Monitoring Sampler APIs.
217 // These APIs de-templated monitoring Sampler for swig.
218 
219 typedef struct TFE_MonitoringSamplerCell TFE_MonitoringSamplerCell;
220 
221 // Atomically add the value of the cell.
222 TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellAdd(
223     TFE_MonitoringSamplerCell* cell, double value);
224 
225 // Retrieves the current value of the cell. The return value is a HistogramProto
226 // saved in buffer.
227 TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellValue(
228     TFE_MonitoringSamplerCell* cell, TF_Buffer* buf);
229 
230 // APIs for sampler buckets
231 typedef struct TFE_MonitoringBuckets TFE_MonitoringBuckets;
232 TF_CAPI_EXPORT extern TFE_MonitoringBuckets*
233 TFE_MonitoringNewExponentialBuckets(double scale, double growth_factor,
234                                     int bucket_count);
235 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBuckets(
236     TFE_MonitoringBuckets* buckets);
237 
238 // APIs for Sampler without label.
239 typedef struct TFE_MonitoringSampler0 TFE_MonitoringSampler0;
240 TF_CAPI_EXPORT extern TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
241     const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
242     const char* description);
243 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler0(
244     TFE_MonitoringSampler0* sampler);
245 TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
246     TFE_MonitoringSampler0* sampler);
247 
248 // APIs for Sampler with 1 label.
249 typedef struct TFE_MonitoringSampler1 TFE_MonitoringSampler1;
250 TF_CAPI_EXPORT extern TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
251     const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
252     const char* description, const char* label1);
253 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler1(
254     TFE_MonitoringSampler1* sampler);
255 TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
256     TFE_MonitoringSampler1* sampler, const char* label1);
257 
258 // APIs for Sampler with 2 label.
259 typedef struct TFE_MonitoringSampler2 TFE_MonitoringSampler2;
260 TF_CAPI_EXPORT extern TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
261     const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
262     const char* description, const char* label1, const char* label2);
263 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2(
264     TFE_MonitoringSampler2* sampler);
265 TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
266     TFE_MonitoringSampler2* sampler, const char* label1, const char* label2);
267 
268 // Sets whether to use TFRT
269 TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*,
270                                                      bool use_tfrt);
271 
272 // Returns the context_id from the EagerContext which is used by the
273 // EagerService to maintain consistency between client and worker. The
274 // context_id is initialized with a dummy value and is later set when the worker
275 // is initialized (either locally or remotely). The context_id can change during
276 // the process lifetime although this should cause the worker to be
277 // reinitialized (e.g. cleared caches) as well.
278 TF_CAPI_EXPORT extern uint64_t TFE_GetContextId(TFE_Context* ctx);
279 
280 // -----------------------------------------------------------------------------
281 // Cancellation APIs.
282 
283 typedef struct TFE_CancellationManager TFE_CancellationManager;
284 TF_CAPI_EXPORT extern TFE_CancellationManager* TFE_NewCancellationManager();
285 TF_CAPI_EXPORT extern bool TFE_CancellationManagerIsCancelled(
286     TFE_CancellationManager*);
287 TF_CAPI_EXPORT extern void TFE_CancellationManagerStartCancel(
288     TFE_CancellationManager*);
289 TF_CAPI_EXPORT extern void TFE_DeleteCancellationManager(
290     TFE_CancellationManager*);
291 
292 // Associates the given `cancellation_manager` with `op`, so that invoking
293 // `TFE_CancellationManagerStartCancel(cancellation_manager)` will cancel the
294 // execution of `op`.
295 typedef struct TFE_CancellationManager TFE_CancellationManager;
296 TF_CAPI_EXPORT extern void TFE_OpSetCancellationManager(
297     TFE_Op* op, TFE_CancellationManager* cancellation_manager,
298     TF_Status* status);
299 
300 // -----------------------------------------------------------------------------
301 // Eager Executor APIs.
302 typedef struct TFE_Executor TFE_Executor;
303 
304 // Creates a new eager Executor. Nodes in one executor are guaranteed to be
305 // executed in sequence. Assigning nodes to different executors allows executing
306 // nodes in parallel.
307 TF_CAPI_EXPORT extern TFE_Executor* TFE_NewExecutor(bool is_async);
308 
309 // Deletes the eager Executor without waiting for enqueued nodes. Please call
310 // TFE_ExecutorWaitForAllPendingNodes before calling this API if you want to
311 // make sure all nodes are finished.
312 TF_CAPI_EXPORT extern void TFE_DeleteExecutor(TFE_Executor*);
313 
314 // Returns true if the executor is in async mode.
315 TF_CAPI_EXPORT extern bool TFE_ExecutorIsAsync(TFE_Executor*);
316 
317 // Causes the calling thread to block till all ops dispatched in this executor
318 // have been executed. Note that "execution" here refers to kernel execution /
319 // scheduling of copies, etc. Similar to sync execution, it doesn't guarantee
320 // that lower level device queues (like GPU streams) have been flushed.
321 //
322 // This call may not block for execution of ops enqueued concurrently with this
323 // call.
324 TF_CAPI_EXPORT extern void TFE_ExecutorWaitForAllPendingNodes(
325     TFE_Executor*, TF_Status* status);
326 
327 // When an error happens, any pending operations are discarded and newly issued
328 // ops return an error. This call clears the error state and re-enables
329 // execution of newly issued ops.
330 //
331 // Note that outputs of discarded ops remain in a corrupt state and should not
332 // be used for future calls.
333 // TODO(agarwal): mark the affected handles and raise errors if they are used.
334 TF_CAPI_EXPORT extern void TFE_ExecutorClearError(TFE_Executor*);
335 
336 // Sets a custom Executor for current thread. All nodes created by this thread
337 // will be added to this Executor. It will override current executor.
338 TF_CAPI_EXPORT extern void TFE_ContextSetExecutorForThread(TFE_Context*,
339                                                            TFE_Executor*);
340 
341 // Returns the Executor for current thread.
342 TF_CAPI_EXPORT extern TFE_Executor* TFE_ContextGetExecutorForThread(
343     TFE_Context*);
344 
345 // -----------------------------------------------------------------------------
346 // Dynamic cluster API.
347 
348 // Update an existing context with a new set of servers defined in a ServerDef
349 // proto. Servers can be added to and removed from the list of remote workers
350 // in the context. New set of servers identified by the ServerDef must be up
351 // when the context is updated.
352 //
353 // This API is for experimental usage and may be subject to change.
354 TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
355                                                       int keep_alive_secs,
356                                                       const void* proto,
357                                                       size_t proto_len,
358                                                       TF_Status* status);
359 
360 // Checks whether a remote worker is alive or not. This will return true even if
361 // the context doesn't exist on the remote worker.
362 TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
363                                                  const char* worker_name,
364                                                  TF_Status* status);
365 
366 // Sync pending nodes in local executors (including the context default executor
367 // and thread executors) and streaming requests to remote executors, and get the
368 // combined status.
369 TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
370                                                 TF_Status* status);
371 
372 // This function will block till the operation that produces `h` has
373 // completed. This is only valid on local TFE_TensorHandles. The pointer
374 // returned will be on the device in which the TFE_TensorHandle resides (so e.g.
375 // for a GPU tensor this will return a pointer to GPU memory). The pointer is
376 // only guaranteed to be valid until TFE_DeleteTensorHandle is called on this
377 // TensorHandle. Only supports POD data types.
378 TF_CAPI_EXPORT extern void* TFE_TensorHandleDevicePointer(TFE_TensorHandle*,
379                                                           TF_Status*);
380 
381 // This function will block till the operation that produces `h` has
382 // completed. This is only valid on local TFE_TensorHandles. Returns the size in
383 // bytes of the memory pointed to by the device pointer returned above.
384 TF_CAPI_EXPORT extern size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle*,
385                                                               TF_Status*);
386 
387 // Creates a new TensorHandle from memory residing in device_name. Takes
388 // ownership of the memory, and will call deleter to release it after TF
389 // no longer needs it or in case of error.
390 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
391     TFE_Context* ctx, const char* device_name, TF_DataType, const int64_t* dims,
392     int num_dims, void* data, size_t len,
393     void (*deallocator)(void* data, size_t len, void* arg),
394     void* deallocator_arg, TF_Status* status);
395 
396 // Retrieves the address space (i.e. job, replia, task) of the local host and
397 // saves it in the buffer.
398 TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
399                                                 TF_Buffer* buf);
400 
401 // APIs for generically dealing with op attributes (e.g. when forwarding them
402 // through custom device implementations).
403 //
404 // TODO(allenl): Currently these are black boxes, but we should have some way to
405 // inspect values. This would let people e.g. copy over most attributes and then
406 // modify some based on their values.
407 
408 // A reference to an op's name -> attribute mapping
409 typedef struct TFE_OpAttrs TFE_OpAttrs;
410 
411 // Fetch a reference to `op`'s attributes. The returned reference is only valid
412 // while `op` is alive.
413 TF_CAPI_EXPORT extern const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op);
414 // Add attributes in `attrs` to `op`.
415 //
416 // Does not overwrite or update existing attributes, but adds new ones.
417 TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
418 
419 // Serialize `attrs` as a tensorflow::NameAttrList protocol buffer (into `buf`),
420 // containing the op name and a map of its attributes.
421 TF_CAPI_EXPORT extern void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs,
422                                                 TF_Buffer* buf,
423                                                 TF_Status* status);
424 
425 // Set an op's attribute from a serialized AttrValue protocol buffer.
426 //
427 // Analogous to TF_SetAttrValueProto for building graph operations.
428 TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
429                                                    const char* attr_name,
430                                                    const void* proto,
431                                                    size_t proto_len,
432                                                    TF_Status* status);
433 
434 // TODO(b/166642410): It would be nice, for custom devices and for other users,
435 // to have a non-string representation of devices (TF_Device) extracted from
436 // tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc.
437 
438 #define TFE_CUSTOM_DEVICE_VERSION 4
439 
440 // Struct to be filled in. Functions are required except where indicated.
441 typedef struct TFE_CustomDevice {
442   int version = TFE_CUSTOM_DEVICE_VERSION;
443   // Method to copy a tensor to the custom device.
444   TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context,
445                                              TFE_TensorHandle* tensor,
446                                              TF_Status* status,
447                                              void* device_info);
448 
449   // Method to copy a tensor from the custom device to a target device.
450   TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context,
451                                                TFE_TensorHandle* tensor,
452                                                const char* target_device_name,
453                                                TF_Status* status,
454                                                void* device_info);
455 
456   // Method to execute an operation.
457   //
458   // Arguments provide enough information to reconstruct the original `TFE_Op`,
459   // or construct a transformed version, by inspecting the passed `op`.
460   //
461   // TFE_OpGetDevice(op) records the original placement of the operation. It may
462   // be an empty string if no device was explicitly requested, but will
463   // otherwise be the name of this custom device. Ops are placed onto a custom
464   // device if any of their inputs are on that custom device, but custom devices
465   // are free to set a bad status in order to require explicit placement.
466   void (*execute)(const TFE_Op* op, int* num_outputs,
467                   TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
468 
469   // Method to delete a device.
470   void (*delete_device)(void* device_info);
471 
472   // Implements TFE_CreatePackedTensorHandle when one of `handles` is on this
473   // custom device.
474   //
475   // Many devices will want to simply return an "unimplemented" status
476   // here. This is the default behavior if `pack` is null when passed to
477   // TFE_RegisterCustomDevice.
478   TFE_TensorHandle* (*pack)(TFE_Context* context, TFE_TensorHandle** handles,
479                             int num_handles, TF_Status* s,
480                             void* device_info) = nullptr;
481 } TFE_CustomDevice;
482 
483 // Registers a custom device for use with eager execution.
484 //
485 // Eager operations may be placed on this device, e.g.  `with
486 // tf.device("CUSTOM"):` from Python if `device_name` for this call is
487 // "/job:localhost/replica:0/task:0/device:CUSTOM:0".
488 //
489 // The custom device defines copy operations for moving TensorHandles on and
490 // off, and an execution operation for named operations. Often execution will
491 // simply wrap op execution on one or more physical devices.
492 //
493 // device_info is an opaque caller-defined type stored with the custom device
494 // which is passed to the functions referenced in the TFE_CustomDevice struct
495 // `device` (execute, delete_device, etc.). It can for example contain the
496 // names of wrapped devices.
497 //
498 // There are currently no graph semantics implemented for registered custom
499 // devices, so executing tf.functions which contain operations placed on custom
500 // devices will fail.
501 //
502 // `device_name` must not name an existing physical or custom device. It must
503 // follow the format:
504 //
505 //    /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
506 //
507 // If the device is successfully registered, `status` is set to TF_OK. Otherwise
508 // the device is not usable. In case of a bad status, `device.delete_device` is
509 // still called on `device_info` (i.e. the caller does not retain ownership).
510 //
511 // This API is highly experimental, and in particular is expected to change when
512 // it starts supporting operations with attributes and when tf.function support
513 // is added.
514 TF_CAPI_EXPORT extern void TFE_RegisterCustomDevice(TFE_Context* ctx,
515                                                     TFE_CustomDevice device,
516                                                     const char* device_name,
517                                                     void* device_info,
518                                                     TF_Status* status);
519 
520 // Creates a new TensorHandle from memory residing in a custom device. Takes
521 // ownership of the memory, and will call `deallocator` to release it after TF
522 // no longer needs it or in case of error.
523 //
524 // `num_dims_callback` is a callback computing the rank of the tensor, and
525 // `dim_callback` computes the axis length at `dim_index`. Shapes are specified
526 // via callbacks because retrieving the shape of a tensor is a blocking
527 // operation for async eager; custom devices should avoid retrieving shapes of
528 // tensors they wrap until the custom device tensor's shape is explicitly
529 // requested where possible.
530 //
531 // `arg` is passed to the callbacks unmodified for any extra information the
532 // caller needs to provide them.
533 //
534 // This call is similar to `TFE_NewTensorHandleFromDeviceMemory`, but does not
535 // require blocking waiting for exact shapes.
536 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
537     TFE_Context* ctx, const char* device_name, TF_DataType, void* data,
538     int (*num_dims_callback)(void* data, void* arg, TF_Status* status),
539     int64_t (*dim_callback)(void* data, int dim_index, void* arg,
540                             TF_Status* status),
541     void (*deallocator)(void* data, void* arg), void* arg, TF_Status* status);
542 
543 TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
544                                                      const char* function_name,
545                                                      TF_Buffer* buf,
546                                                      TF_Status* status);
547 
548 // Allocate and return a new Tensor on the host.
549 //
550 // The caller must set the Tensor values by writing them to the pointer returned
551 // by TF_TensorData with length TF_TensorByteSize.
552 TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx,
553                                                         TF_DataType dtype,
554                                                         const int64_t* dims,
555                                                         int num_dims,
556                                                         TF_Status* status);
557 
558 // Given a Tensor, wrap it with a TensorHandle
559 //
560 // Similar to TFE_NewTensorHandle, but includes a pointer to the TFE_Context.
561 // The context should be identical to that of the Tensor.
562 TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
563     TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
564 
565 // Create a packed TensorHandle with the given list of TensorHandles.
566 // If `handles` are on the same device, assign the same device to the packed
567 // handle; if `handles` are on different deivces, assign a CompositeDevice to
568 // it.
569 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
570     TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
571     TF_Status* status);
572 
573 // Configure soft device placement policy for the eager executor. Note this
574 // policy is applied to any subsequent op executions.
575 TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx,
576                                                       unsigned char enable,
577                                                       TF_Status* status);
578 
579 // Configure device placement policy logging for the eager executor. Note this
580 // policy is applied to any subsequent op executions.
581 TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
582                                                      unsigned char enable,
583                                                      TF_Status* status);
584 
585 // Returns the device type of the operation that produced `h`.
586 TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType(
587     TFE_TensorHandle* h, TF_Status* status);
588 
589 // Returns the device ID of the operation that produced `h`.
590 TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h,
591                                                    TF_Status* status);
592 
593 // Get a comma-separated list of op names executed in graph functions dispatched
594 // to `ctx`. This feature is currently only enabled for TFRT debug builds, for
595 // performance and simplicity reasons.
596 TF_CAPI_EXPORT extern void TFE_GetExecutedOpNames(TFE_Context* ctx,
597                                                   TF_Buffer* buf,
598                                                   TF_Status* status);
599 
600 #ifdef __cplusplus
601 } /* end extern "C" */
602 #endif
603 
604 #endif  // TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_
605