1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ 16 #define TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ 17 18 #include "tensorflow/c/c_api.h" 19 #include "tensorflow/c/eager/c_api.h" 20 21 #ifdef __cplusplus 22 extern "C" { 23 #endif 24 25 // Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This 26 // is for performance optimization by reusing an exiting unused op rather than 27 // creating a new op every time. If `raw_device_name` is `NULL` or empty, it 28 // does not set the device name. If it's not `NULL`, then it attempts to parse 29 // and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster 30 // than separately calling it because if the existing op has the same 31 // `raw_device_name`, it skips parsing and just leave as it is. 32 TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset, 33 const char* op_or_function_name, 34 const char* raw_device_name, 35 TF_Status* status); 36 37 // Enables only graph collection in RunMetadata on the functions executed from 38 // this context. 39 TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx); 40 41 // Disables only graph collection in RunMetadata on the functions executed from 42 // this context. 43 TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx); 44 45 // TODO(fishx): Move these monitoring APIs into a separate file. 46 // ----------------------------------------------------------------------------- 47 // Monitoring Counter APIs. 48 // These APIs de-templated monitoring Counter for swig. 49 50 typedef struct TFE_MonitoringCounterCell TFE_MonitoringCounterCell; 51 52 // Atomically increments the value of the cell. The value must be non-negative. 53 TF_CAPI_EXPORT extern void TFE_MonitoringCounterCellIncrementBy( 54 TFE_MonitoringCounterCell* cell, int64_t value); 55 56 // Retrieves the current value of the cell. 57 TF_CAPI_EXPORT extern int64_t TFE_MonitoringCounterCellValue( 58 TFE_MonitoringCounterCell* cell); 59 60 // APIs for Counter without label. 61 typedef struct TFE_MonitoringCounter0 TFE_MonitoringCounter0; 62 // Returns a new Counter metric object. The caller should manage lifetime of 63 // the object. Using duplicate metric name will crash the program with fatal 64 // error. 65 TF_CAPI_EXPORT extern TFE_MonitoringCounter0* TFE_MonitoringNewCounter0( 66 const char* name, TF_Status* status, const char* description); 67 // Deletes the Counter object. 68 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter0( 69 TFE_MonitoringCounter0* counter); 70 // Retrieves the cell from the Counter object. The Counter object will manage 71 // lifetime of the cell. 72 TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0( 73 TFE_MonitoringCounter0* counter); 74 75 // APIs for Counter with 1 label. 76 typedef struct TFE_MonitoringCounter1 TFE_MonitoringCounter1; 77 TF_CAPI_EXPORT extern TFE_MonitoringCounter1* TFE_MonitoringNewCounter1( 78 const char* name, TF_Status* status, const char* description, 79 const char* label1); 80 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter1( 81 TFE_MonitoringCounter1* counter); 82 TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1( 83 TFE_MonitoringCounter1* counter, const char* label1); 84 85 // APIs for Counter with 2 labels. 86 typedef struct TFE_MonitoringCounter2 TFE_MonitoringCounter2; 87 TF_CAPI_EXPORT extern TFE_MonitoringCounter2* TFE_MonitoringNewCounter2( 88 const char* name, TF_Status* status, const char* description, 89 const char* label1, const char* label2); 90 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter2( 91 TFE_MonitoringCounter2* counter); 92 TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2( 93 TFE_MonitoringCounter2* counter, const char* label1, const char* label2); 94 95 // ----------------------------------------------------------------------------- 96 // Monitoring Gauge APIs. 97 // These APIs de-templated monitoring Gauge for swig. 98 99 typedef struct TFE_MonitoringIntGaugeCell TFE_MonitoringIntGaugeCell; 100 101 // Atomically set the value of the cell. 102 TF_CAPI_EXPORT extern void TFE_MonitoringIntGaugeCellSet( 103 TFE_MonitoringIntGaugeCell* cell, int64_t value); 104 105 // Retrieves the current value of the cell. 106 TF_CAPI_EXPORT extern int64_t TFE_MonitoringIntGaugeCellValue( 107 TFE_MonitoringIntGaugeCell* cell); 108 109 // APIs for Int Gauge without label. 110 typedef struct TFE_MonitoringIntGauge0 TFE_MonitoringIntGauge0; 111 TF_CAPI_EXPORT extern TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0( 112 const char* name, TF_Status* out_status, const char* description); 113 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge0( 114 TFE_MonitoringIntGauge0* gauge); 115 TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell* 116 TFE_MonitoringGetCellIntGauge0(TFE_MonitoringIntGauge0* gauge); 117 118 // APIs for Int Gauge with 1 label. 119 typedef struct TFE_MonitoringIntGauge1 TFE_MonitoringIntGauge1; 120 TF_CAPI_EXPORT extern TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1( 121 const char* name, TF_Status* out_status, const char* description, 122 const char* label1); 123 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge1( 124 TFE_MonitoringIntGauge1* gauge); 125 TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell* 126 TFE_MonitoringGetCellIntGauge1(TFE_MonitoringIntGauge1* gauge, 127 const char* label1); 128 129 // APIs for Int Gauge with 2 label. 130 typedef struct TFE_MonitoringIntGauge2 TFE_MonitoringIntGauge2; 131 TF_CAPI_EXPORT extern TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2( 132 const char* name, TF_Status* out_status, const char* description, 133 const char* label1, const char* label2); 134 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge2( 135 TFE_MonitoringIntGauge2* gauge); 136 TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell* 137 TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2* gauge, 138 const char* label1, const char* label2); 139 140 typedef struct TFE_MonitoringStringGaugeCell TFE_MonitoringStringGaugeCell; 141 TF_CAPI_EXPORT extern void TFE_MonitoringStringGaugeCellSet( 142 TFE_MonitoringStringGaugeCell* cell, const char* value); 143 // Retrieves the string value and saves it in buffer. 144 TF_CAPI_EXPORT extern const void TFE_MonitoringStringGaugeCellValue( 145 TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf); 146 147 // APIs for String Gauge without label. 148 typedef struct TFE_MonitoringStringGauge0 TFE_MonitoringStringGauge0; 149 TF_CAPI_EXPORT extern TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0( 150 const char* name, TF_Status* out_status, const char* description); 151 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge0( 152 TFE_MonitoringStringGauge0* gauge); 153 TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell* 154 TFE_MonitoringGetCellStringGauge0(TFE_MonitoringStringGauge0* gauge); 155 156 // APIs for String Gauge with 1 label. 157 typedef struct TFE_MonitoringStringGauge1 TFE_MonitoringStringGauge1; 158 TF_CAPI_EXPORT extern TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1( 159 const char* name, TF_Status* out_status, const char* description, 160 const char* label1); 161 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge1( 162 TFE_MonitoringStringGauge1* gauge); 163 TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell* 164 TFE_MonitoringGetCellStringGauge1(TFE_MonitoringStringGauge1* gauge, 165 const char* label1); 166 167 // APIs for String Gauge with 2 label. 168 typedef struct TFE_MonitoringStringGauge2 TFE_MonitoringStringGauge2; 169 TF_CAPI_EXPORT extern TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2( 170 const char* name, TF_Status* out_status, const char* description, 171 const char* label1, const char* label2); 172 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge2( 173 TFE_MonitoringStringGauge2* gauge); 174 TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell* 175 TFE_MonitoringGetCellStringGauge2(TFE_MonitoringStringGauge2* gauge, 176 const char* label1, const char* label2); 177 178 typedef struct TFE_MonitoringBoolGaugeCell TFE_MonitoringBoolGaugeCell; 179 TF_CAPI_EXPORT extern void TFE_MonitoringBoolGaugeCellSet( 180 TFE_MonitoringBoolGaugeCell* cell, bool value); 181 TF_CAPI_EXPORT extern bool TFE_MonitoringBoolGaugeCellValue( 182 TFE_MonitoringBoolGaugeCell* cell); 183 184 // APIs for Bool Gauge without label. 185 typedef struct TFE_MonitoringBoolGauge0 TFE_MonitoringBoolGauge0; 186 TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0( 187 const char* name, TF_Status* out_status, const char* description); 188 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge0( 189 TFE_MonitoringBoolGauge0* gauge); 190 TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell* 191 TFE_MonitoringGetCellBoolGauge0(TFE_MonitoringBoolGauge0* gauge); 192 193 // APIs for Bool Gauge with 1 label. 194 typedef struct TFE_MonitoringBoolGauge1 TFE_MonitoringBoolGauge1; 195 TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1( 196 const char* name, TF_Status* out_status, const char* description, 197 const char* label1); 198 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge1( 199 TFE_MonitoringBoolGauge1* gauge); 200 TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell* 201 TFE_MonitoringGetCellBoolGauge1(TFE_MonitoringBoolGauge1* gauge, 202 const char* label1); 203 204 // APIs for Bool Gauge with 2 label. 205 typedef struct TFE_MonitoringBoolGauge2 TFE_MonitoringBoolGauge2; 206 TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2( 207 const char* name, TF_Status* out_status, const char* description, 208 const char* label1, const char* label2); 209 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge2( 210 TFE_MonitoringBoolGauge2* gauge); 211 TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell* 212 TFE_MonitoringGetCellBoolGauge2(TFE_MonitoringBoolGauge2* gauge, 213 const char* label1, const char* label2); 214 215 // ----------------------------------------------------------------------------- 216 // Monitoring Sampler APIs. 217 // These APIs de-templated monitoring Sampler for swig. 218 219 typedef struct TFE_MonitoringSamplerCell TFE_MonitoringSamplerCell; 220 221 // Atomically add the value of the cell. 222 TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellAdd( 223 TFE_MonitoringSamplerCell* cell, double value); 224 225 // Retrieves the current value of the cell. The return value is a HistogramProto 226 // saved in buffer. 227 TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellValue( 228 TFE_MonitoringSamplerCell* cell, TF_Buffer* buf); 229 230 // APIs for sampler buckets 231 typedef struct TFE_MonitoringBuckets TFE_MonitoringBuckets; 232 TF_CAPI_EXPORT extern TFE_MonitoringBuckets* 233 TFE_MonitoringNewExponentialBuckets(double scale, double growth_factor, 234 int bucket_count); 235 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBuckets( 236 TFE_MonitoringBuckets* buckets); 237 238 // APIs for Sampler without label. 239 typedef struct TFE_MonitoringSampler0 TFE_MonitoringSampler0; 240 TF_CAPI_EXPORT extern TFE_MonitoringSampler0* TFE_MonitoringNewSampler0( 241 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status, 242 const char* description); 243 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler0( 244 TFE_MonitoringSampler0* sampler); 245 TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0( 246 TFE_MonitoringSampler0* sampler); 247 248 // APIs for Sampler with 1 label. 249 typedef struct TFE_MonitoringSampler1 TFE_MonitoringSampler1; 250 TF_CAPI_EXPORT extern TFE_MonitoringSampler1* TFE_MonitoringNewSampler1( 251 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status, 252 const char* description, const char* label1); 253 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler1( 254 TFE_MonitoringSampler1* sampler); 255 TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1( 256 TFE_MonitoringSampler1* sampler, const char* label1); 257 258 // APIs for Sampler with 2 label. 259 typedef struct TFE_MonitoringSampler2 TFE_MonitoringSampler2; 260 TF_CAPI_EXPORT extern TFE_MonitoringSampler2* TFE_MonitoringNewSampler2( 261 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status, 262 const char* description, const char* label1, const char* label2); 263 TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2( 264 TFE_MonitoringSampler2* sampler); 265 TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2( 266 TFE_MonitoringSampler2* sampler, const char* label1, const char* label2); 267 268 // Sets whether to use TFRT 269 TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*, 270 bool use_tfrt); 271 272 // Returns the context_id from the EagerContext which is used by the 273 // EagerService to maintain consistency between client and worker. The 274 // context_id is initialized with a dummy value and is later set when the worker 275 // is initialized (either locally or remotely). The context_id can change during 276 // the process lifetime although this should cause the worker to be 277 // reinitialized (e.g. cleared caches) as well. 278 TF_CAPI_EXPORT extern uint64_t TFE_GetContextId(TFE_Context* ctx); 279 280 // ----------------------------------------------------------------------------- 281 // Cancellation APIs. 282 283 typedef struct TFE_CancellationManager TFE_CancellationManager; 284 TF_CAPI_EXPORT extern TFE_CancellationManager* TFE_NewCancellationManager(); 285 TF_CAPI_EXPORT extern bool TFE_CancellationManagerIsCancelled( 286 TFE_CancellationManager*); 287 TF_CAPI_EXPORT extern void TFE_CancellationManagerStartCancel( 288 TFE_CancellationManager*); 289 TF_CAPI_EXPORT extern void TFE_DeleteCancellationManager( 290 TFE_CancellationManager*); 291 292 // Associates the given `cancellation_manager` with `op`, so that invoking 293 // `TFE_CancellationManagerStartCancel(cancellation_manager)` will cancel the 294 // execution of `op`. 295 typedef struct TFE_CancellationManager TFE_CancellationManager; 296 TF_CAPI_EXPORT extern void TFE_OpSetCancellationManager( 297 TFE_Op* op, TFE_CancellationManager* cancellation_manager, 298 TF_Status* status); 299 300 // ----------------------------------------------------------------------------- 301 // Eager Executor APIs. 302 typedef struct TFE_Executor TFE_Executor; 303 304 // Creates a new eager Executor. Nodes in one executor are guaranteed to be 305 // executed in sequence. Assigning nodes to different executors allows executing 306 // nodes in parallel. 307 TF_CAPI_EXPORT extern TFE_Executor* TFE_NewExecutor(bool is_async); 308 309 // Deletes the eager Executor without waiting for enqueued nodes. Please call 310 // TFE_ExecutorWaitForAllPendingNodes before calling this API if you want to 311 // make sure all nodes are finished. 312 TF_CAPI_EXPORT extern void TFE_DeleteExecutor(TFE_Executor*); 313 314 // Returns true if the executor is in async mode. 315 TF_CAPI_EXPORT extern bool TFE_ExecutorIsAsync(TFE_Executor*); 316 317 // Causes the calling thread to block till all ops dispatched in this executor 318 // have been executed. Note that "execution" here refers to kernel execution / 319 // scheduling of copies, etc. Similar to sync execution, it doesn't guarantee 320 // that lower level device queues (like GPU streams) have been flushed. 321 // 322 // This call may not block for execution of ops enqueued concurrently with this 323 // call. 324 TF_CAPI_EXPORT extern void TFE_ExecutorWaitForAllPendingNodes( 325 TFE_Executor*, TF_Status* status); 326 327 // When an error happens, any pending operations are discarded and newly issued 328 // ops return an error. This call clears the error state and re-enables 329 // execution of newly issued ops. 330 // 331 // Note that outputs of discarded ops remain in a corrupt state and should not 332 // be used for future calls. 333 // TODO(agarwal): mark the affected handles and raise errors if they are used. 334 TF_CAPI_EXPORT extern void TFE_ExecutorClearError(TFE_Executor*); 335 336 // Sets a custom Executor for current thread. All nodes created by this thread 337 // will be added to this Executor. It will override current executor. 338 TF_CAPI_EXPORT extern void TFE_ContextSetExecutorForThread(TFE_Context*, 339 TFE_Executor*); 340 341 // Returns the Executor for current thread. 342 TF_CAPI_EXPORT extern TFE_Executor* TFE_ContextGetExecutorForThread( 343 TFE_Context*); 344 345 // ----------------------------------------------------------------------------- 346 // Dynamic cluster API. 347 348 // Update an existing context with a new set of servers defined in a ServerDef 349 // proto. Servers can be added to and removed from the list of remote workers 350 // in the context. New set of servers identified by the ServerDef must be up 351 // when the context is updated. 352 // 353 // This API is for experimental usage and may be subject to change. 354 TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx, 355 int keep_alive_secs, 356 const void* proto, 357 size_t proto_len, 358 TF_Status* status); 359 360 // Checks whether a remote worker is alive or not. This will return true even if 361 // the context doesn't exist on the remote worker. 362 TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, 363 const char* worker_name, 364 TF_Status* status); 365 366 // Sync pending nodes in local executors (including the context default executor 367 // and thread executors) and streaming requests to remote executors, and get the 368 // combined status. 369 TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, 370 TF_Status* status); 371 372 // This function will block till the operation that produces `h` has 373 // completed. This is only valid on local TFE_TensorHandles. The pointer 374 // returned will be on the device in which the TFE_TensorHandle resides (so e.g. 375 // for a GPU tensor this will return a pointer to GPU memory). The pointer is 376 // only guaranteed to be valid until TFE_DeleteTensorHandle is called on this 377 // TensorHandle. Only supports POD data types. 378 TF_CAPI_EXPORT extern void* TFE_TensorHandleDevicePointer(TFE_TensorHandle*, 379 TF_Status*); 380 381 // This function will block till the operation that produces `h` has 382 // completed. This is only valid on local TFE_TensorHandles. Returns the size in 383 // bytes of the memory pointed to by the device pointer returned above. 384 TF_CAPI_EXPORT extern size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle*, 385 TF_Status*); 386 387 // Creates a new TensorHandle from memory residing in device_name. Takes 388 // ownership of the memory, and will call deleter to release it after TF 389 // no longer needs it or in case of error. 390 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( 391 TFE_Context* ctx, const char* device_name, TF_DataType, const int64_t* dims, 392 int num_dims, void* data, size_t len, 393 void (*deallocator)(void* data, size_t len, void* arg), 394 void* deallocator_arg, TF_Status* status); 395 396 // Retrieves the address space (i.e. job, replia, task) of the local host and 397 // saves it in the buffer. 398 TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx, 399 TF_Buffer* buf); 400 401 // APIs for generically dealing with op attributes (e.g. when forwarding them 402 // through custom device implementations). 403 // 404 // TODO(allenl): Currently these are black boxes, but we should have some way to 405 // inspect values. This would let people e.g. copy over most attributes and then 406 // modify some based on their values. 407 408 // A reference to an op's name -> attribute mapping 409 typedef struct TFE_OpAttrs TFE_OpAttrs; 410 411 // Fetch a reference to `op`'s attributes. The returned reference is only valid 412 // while `op` is alive. 413 TF_CAPI_EXPORT extern const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op); 414 // Add attributes in `attrs` to `op`. 415 // 416 // Does not overwrite or update existing attributes, but adds new ones. 417 TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs); 418 419 // Serialize `attrs` as a tensorflow::NameAttrList protocol buffer (into `buf`), 420 // containing the op name and a map of its attributes. 421 TF_CAPI_EXPORT extern void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, 422 TF_Buffer* buf, 423 TF_Status* status); 424 425 // Set an op's attribute from a serialized AttrValue protocol buffer. 426 // 427 // Analogous to TF_SetAttrValueProto for building graph operations. 428 TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op, 429 const char* attr_name, 430 const void* proto, 431 size_t proto_len, 432 TF_Status* status); 433 434 // TODO(b/166642410): It would be nice, for custom devices and for other users, 435 // to have a non-string representation of devices (TF_Device) extracted from 436 // tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc. 437 438 #define TFE_CUSTOM_DEVICE_VERSION 4 439 440 // Struct to be filled in. Functions are required except where indicated. 441 typedef struct TFE_CustomDevice { 442 int version = TFE_CUSTOM_DEVICE_VERSION; 443 // Method to copy a tensor to the custom device. 444 TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context, 445 TFE_TensorHandle* tensor, 446 TF_Status* status, 447 void* device_info); 448 449 // Method to copy a tensor from the custom device to a target device. 450 TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context, 451 TFE_TensorHandle* tensor, 452 const char* target_device_name, 453 TF_Status* status, 454 void* device_info); 455 456 // Method to execute an operation. 457 // 458 // Arguments provide enough information to reconstruct the original `TFE_Op`, 459 // or construct a transformed version, by inspecting the passed `op`. 460 // 461 // TFE_OpGetDevice(op) records the original placement of the operation. It may 462 // be an empty string if no device was explicitly requested, but will 463 // otherwise be the name of this custom device. Ops are placed onto a custom 464 // device if any of their inputs are on that custom device, but custom devices 465 // are free to set a bad status in order to require explicit placement. 466 void (*execute)(const TFE_Op* op, int* num_outputs, 467 TFE_TensorHandle** outputs, TF_Status* s, void* device_info); 468 469 // Method to delete a device. 470 void (*delete_device)(void* device_info); 471 472 // Implements TFE_CreatePackedTensorHandle when one of `handles` is on this 473 // custom device. 474 // 475 // Many devices will want to simply return an "unimplemented" status 476 // here. This is the default behavior if `pack` is null when passed to 477 // TFE_RegisterCustomDevice. 478 TFE_TensorHandle* (*pack)(TFE_Context* context, TFE_TensorHandle** handles, 479 int num_handles, TF_Status* s, 480 void* device_info) = nullptr; 481 } TFE_CustomDevice; 482 483 // Registers a custom device for use with eager execution. 484 // 485 // Eager operations may be placed on this device, e.g. `with 486 // tf.device("CUSTOM"):` from Python if `device_name` for this call is 487 // "/job:localhost/replica:0/task:0/device:CUSTOM:0". 488 // 489 // The custom device defines copy operations for moving TensorHandles on and 490 // off, and an execution operation for named operations. Often execution will 491 // simply wrap op execution on one or more physical devices. 492 // 493 // device_info is an opaque caller-defined type stored with the custom device 494 // which is passed to the functions referenced in the TFE_CustomDevice struct 495 // `device` (execute, delete_device, etc.). It can for example contain the 496 // names of wrapped devices. 497 // 498 // There are currently no graph semantics implemented for registered custom 499 // devices, so executing tf.functions which contain operations placed on custom 500 // devices will fail. 501 // 502 // `device_name` must not name an existing physical or custom device. It must 503 // follow the format: 504 // 505 // /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> 506 // 507 // If the device is successfully registered, `status` is set to TF_OK. Otherwise 508 // the device is not usable. In case of a bad status, `device.delete_device` is 509 // still called on `device_info` (i.e. the caller does not retain ownership). 510 // 511 // This API is highly experimental, and in particular is expected to change when 512 // it starts supporting operations with attributes and when tf.function support 513 // is added. 514 TF_CAPI_EXPORT extern void TFE_RegisterCustomDevice(TFE_Context* ctx, 515 TFE_CustomDevice device, 516 const char* device_name, 517 void* device_info, 518 TF_Status* status); 519 520 // Creates a new TensorHandle from memory residing in a custom device. Takes 521 // ownership of the memory, and will call `deallocator` to release it after TF 522 // no longer needs it or in case of error. 523 // 524 // `num_dims_callback` is a callback computing the rank of the tensor, and 525 // `dim_callback` computes the axis length at `dim_index`. Shapes are specified 526 // via callbacks because retrieving the shape of a tensor is a blocking 527 // operation for async eager; custom devices should avoid retrieving shapes of 528 // tensors they wrap until the custom device tensor's shape is explicitly 529 // requested where possible. 530 // 531 // `arg` is passed to the callbacks unmodified for any extra information the 532 // caller needs to provide them. 533 // 534 // This call is similar to `TFE_NewTensorHandleFromDeviceMemory`, but does not 535 // require blocking waiting for exact shapes. 536 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle( 537 TFE_Context* ctx, const char* device_name, TF_DataType, void* data, 538 int (*num_dims_callback)(void* data, void* arg, TF_Status* status), 539 int64_t (*dim_callback)(void* data, int dim_index, void* arg, 540 TF_Status* status), 541 void (*deallocator)(void* data, void* arg), void* arg, TF_Status* status); 542 543 TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx, 544 const char* function_name, 545 TF_Buffer* buf, 546 TF_Status* status); 547 548 // Allocate and return a new Tensor on the host. 549 // 550 // The caller must set the Tensor values by writing them to the pointer returned 551 // by TF_TensorData with length TF_TensorByteSize. 552 TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, 553 TF_DataType dtype, 554 const int64_t* dims, 555 int num_dims, 556 TF_Status* status); 557 558 // Given a Tensor, wrap it with a TensorHandle 559 // 560 // Similar to TFE_NewTensorHandle, but includes a pointer to the TFE_Context. 561 // The context should be identical to that of the Tensor. 562 TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor( 563 TFE_Context* ctx, TF_Tensor* t, TF_Status* status); 564 565 // Create a packed TensorHandle with the given list of TensorHandles. 566 // If `handles` are on the same device, assign the same device to the packed 567 // handle; if `handles` are on different deivces, assign a CompositeDevice to 568 // it. 569 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle( 570 TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles, 571 TF_Status* status); 572 573 // Configure soft device placement policy for the eager executor. Note this 574 // policy is applied to any subsequent op executions. 575 TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, 576 unsigned char enable, 577 TF_Status* status); 578 579 // Configure device placement policy logging for the eager executor. Note this 580 // policy is applied to any subsequent op executions. 581 TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, 582 unsigned char enable, 583 TF_Status* status); 584 585 // Returns the device type of the operation that produced `h`. 586 TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType( 587 TFE_TensorHandle* h, TF_Status* status); 588 589 // Returns the device ID of the operation that produced `h`. 590 TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, 591 TF_Status* status); 592 593 // Get a comma-separated list of op names executed in graph functions dispatched 594 // to `ctx`. This feature is currently only enabled for TFRT debug builds, for 595 // performance and simplicity reasons. 596 TF_CAPI_EXPORT extern void TFE_GetExecutedOpNames(TFE_Context* ctx, 597 TF_Buffer* buf, 598 TF_Status* status); 599 600 #ifdef __cplusplus 601 } /* end extern "C" */ 602 #endif 603 604 #endif // TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ 605