1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Experimental API for TensorFlow's "Eager" mode of execution.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import contextlib 23import copy 24import random 25import threading 26 27from tensorflow.core.protobuf import config_pb2 28from tensorflow.python import pywrap_tensorflow 29from tensorflow.python.framework import c_api_util 30from tensorflow.python.framework import device as pydev 31from tensorflow.python.framework import errors 32from tensorflow.python.util import compat 33from tensorflow.python.util import is_in_graph_mode 34from tensorflow.python.util import tf_contextlib 35 36GRAPH_MODE = 0 37EAGER_MODE = 1 38 39# Default execution mode. 40_default_mode = GRAPH_MODE 41 42# Cache from (old_device_name, partial_new_device_name) -> (new_device_name, 43# new_device_spec). 44# Note that we do not protect this with a lock and instead rely on python's GIL 45# and the idempotent nature of writes to provide thread safety. 46_device_parsing_cache = {} 47 48_MAXINT32 = 2**31 - 1 49 50DEVICE_PLACEMENT_EXPLICIT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_EXPLICIT 51DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN 52DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT 53DEVICE_PLACEMENT_SILENT_FOR_INT32 = ( 54 pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) 55 56 57# TODO(agarwal): better name ? 58class _EagerContext(threading.local): 59 """Thread local eager context.""" 60 61 def __init__(self): 62 super(_EagerContext, self).__init__() 63 self.device_spec = pydev.DeviceSpec.from_string( 64 "/job:localhost/replica:0/task:0/device:CPU:0") 65 self.device_name = self.device_spec.to_string() 66 self.mode = _default_mode 67 self.scope_name = "" 68 self.recording_summaries = False 69 self.summary_writer_resource = None 70 self.scalar_cache = {} 71 72 73ContextStackEntry = collections.namedtuple( 74 "ContextStackEntry", ["is_building_function", "enter_context_fn"]) 75 76 77class ContextStack(threading.local): 78 """A thread-local stack of context switches.""" 79 80 def __init__(self): 81 super(ContextStack, self).__init__() 82 self.stack = [] 83 84 def push(self, is_building_function, enter_context_fn): 85 """Push metadata about a context switch onto the stack. 86 87 A context switch can take one of two forms: installing a graph as the 88 default graph, or entering the eager context. 89 90 Args: 91 is_building_function: (bool.) Whether the context is building a function. 92 enter_context_fn: (function.) A callable that executes the context switch. 93 For example, `graph.as_default` or `eager_mode`. 94 """ 95 96 self.stack.append( 97 ContextStackEntry(is_building_function, enter_context_fn)) 98 99 def pop(self): 100 """Pop the stack.""" 101 102 self.stack.pop() 103 104 105context_stack = ContextStack() 106 107 108# TODO(agarwal): rename to EagerContext / EagerRuntime ? 109# TODO(agarwal): consider keeping the corresponding Graph here. 110class Context(object): 111 """Environment in which eager operations execute.""" 112 113 def __init__(self, config=None, device_policy=None): 114 """Creates a new Context. 115 116 Args: 117 config: (Optional.) A `ConfigProto` protocol buffer with configuration 118 options for the Context. Note that a lot of these options may be 119 currently unimplemented or irrelevant when eager execution is enabled. 120 device_policy: (Optional.) What policy to use when trying to run an 121 operation on a device with inputs which are not on that device. 122 Valid values: 123 tfe.DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not 124 correct. 125 tfe.DEVICE_PLACEMENT_WARN: copies the tensors which are not on the 126 right device but raises a warning. 127 tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might 128 hide performance problems. 129 tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors, 130 raising errors on the other ones. 131 """ 132 self._eager_context = _EagerContext() 133 self._context_handle = None 134 self._context_devices = None 135 self._post_execution_callbacks = [] 136 self._config = config 137 self._seed = None 138 self._initialize_lock = threading.Lock() 139 self._device_policy = device_policy 140 141 def _set_global_seed(self, seed): 142 """Set a global eager mode seed for random ops.""" 143 self._seed = seed 144 self._rng = random.Random(self._seed) 145 # Also clear the kernel cache, to reset any existing seeds 146 if self._context_handle is not None: 147 pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle) 148 149 def _internal_operation_seed(self): 150 """Returns a fake operation seed. 151 152 In eager mode, user shouldn't set or depend on operation seed. 153 Here, we generate a random seed based on global seed to make 154 operation's randomness different and depend on the global seed. 155 156 Returns: 157 A fake operation seed based on global seed. 158 """ 159 return self._rng.randint(0, _MAXINT32) 160 161 def _initialize_handle_and_devices(self): 162 """Initialize handle and devices.""" 163 with self._initialize_lock: 164 if self._context_handle is not None: 165 return 166 assert self._context_devices is None 167 opts = pywrap_tensorflow.TFE_NewContextOptions() 168 try: 169 with errors.raise_exception_on_not_ok_status() as status: 170 if self._config is not None: 171 config_str = self._config.SerializeToString() 172 pywrap_tensorflow.TFE_ContextOptionsSetConfig( 173 opts, config_str, len(config_str), status) 174 if self._device_policy is not None: 175 pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy( 176 opts, self._device_policy) 177 self._context_handle = pywrap_tensorflow.TFE_NewContext(opts, status) 178 finally: 179 pywrap_tensorflow.TFE_DeleteContextOptions(opts) 180 # Store list of devices 181 self._context_devices = [] 182 with errors.raise_exception_on_not_ok_status() as status: 183 device_list = pywrap_tensorflow.TFE_ContextListDevices( 184 self._context_handle, status) 185 try: 186 self._num_gpus = 0 187 for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)): 188 with errors.raise_exception_on_not_ok_status() as status: 189 dev_name = pywrap_tensorflow.TF_DeviceListName( 190 device_list, i, status) 191 self._context_devices.append(pydev.canonical_name(dev_name)) 192 with errors.raise_exception_on_not_ok_status() as status: 193 dev_type = pywrap_tensorflow.TF_DeviceListType( 194 device_list, i, status) 195 if dev_type == "GPU": 196 self._num_gpus += 1 197 198 finally: 199 pywrap_tensorflow.TF_DeleteDeviceList(device_list) 200 201 @property 202 def _handle(self): 203 ctx = self._context_handle 204 if ctx is None: 205 self._initialize_handle_and_devices() 206 return self._context_handle 207 else: 208 return ctx 209 210 @property 211 def _devices(self): 212 devices = self._context_devices 213 if devices is None: 214 self._initialize_handle_and_devices() 215 return self._context_devices 216 else: 217 return devices 218 219 def __str__(self): 220 if self._context_handle is None: 221 return "Eager TensorFlow Context. Devices currently uninitialized." 222 else: 223 devices = self._devices 224 lines = ["Eager TensorFlow Context with %d devices" % (len(devices))] 225 for i, d in enumerate(devices): 226 lines.append(" Device %d: %s" % (i, d)) 227 return "\n".join(lines) 228 229 @tf_contextlib.contextmanager 230 def _mode(self, mode): 231 ctx = self._eager_context 232 old_mode = ctx.mode 233 ctx.mode = mode 234 if mode == EAGER_MODE: 235 context_stack.push(False, eager_mode) 236 try: 237 yield 238 finally: 239 ctx.mode = old_mode 240 if mode == EAGER_MODE: 241 context_stack.pop() 242 243 def in_graph_mode(self): 244 """Returns True if current thread is in GRAPH mode.""" 245 return self._eager_context.mode == GRAPH_MODE 246 247 def in_eager_mode(self): 248 """Returns True if current thread is in EAGER mode.""" 249 return self._eager_context.mode == EAGER_MODE 250 251 def scalar_cache(self): 252 """Per-device cache for scalars.""" 253 return self._eager_context.scalar_cache 254 255 @property 256 def scope_name(self): 257 """Returns scope name for the current thread.""" 258 return self._eager_context.scope_name 259 260 @scope_name.setter 261 def scope_name(self, s): 262 """Sets scope name for the current thread.""" 263 self._eager_context.scope_name = s 264 265 @property 266 def summary_writer_resource(self): 267 """Returns summary writer resource.""" 268 return self._eager_context.summary_writer_resource 269 270 @summary_writer_resource.setter 271 def summary_writer_resource(self, resource): 272 """Sets summary writer resource.""" 273 self._eager_context.summary_writer_resource = resource 274 275 @property 276 def device_name(self): 277 """Returns the device name for the current thread.""" 278 return self._eager_context.device_name 279 280 @property 281 def device_spec(self): 282 """Returns the device spec for the current thread.""" 283 return self._eager_context.device_spec 284 285 @tf_contextlib.contextmanager 286 def device(self, name): 287 """Context-manager to force placement of operations and Tensors on a device. 288 289 Args: 290 name: Name of the device or None to get default placement. 291 292 Yields: 293 Nothing. 294 295 Raises: 296 ValueError: If name is not a string or is an invalid device name. 297 """ 298 eager_context = self._eager_context 299 old_device_name = eager_context.device_name 300 old_device_spec = eager_context.device_spec 301 cache_key = (old_device_name, name) 302 try: 303 new_device_name, new_device_spec = _device_parsing_cache[cache_key] 304 except TypeError: 305 # Error while trying to compute the cache key. 306 raise ValueError("Expecting a string device name. Got %s(%s)" % 307 (type(name), name)) 308 except KeyError: 309 # Handle a cache miss. 310 if name is not None: 311 if not isinstance(name, str): 312 raise ValueError("Expecting a string device name. Got %s(%s)" % 313 (type(name), name)) 314 device_spec = pydev.DeviceSpec.from_string(name) 315 if old_device_name: 316 new_device_spec = copy.copy(old_device_spec) 317 else: 318 new_device_spec = pydev.DeviceSpec.from_string( 319 "/job:localhost/replica:0/task:0/device:CPU:0") 320 new_device_spec.merge_from(device_spec) 321 else: 322 new_device_spec = pydev.DeviceSpec.from_string("") 323 new_device_name = new_device_spec.to_string() 324 _device_parsing_cache[cache_key] = (new_device_name, new_device_spec) 325 326 try: 327 eager_context.device_name = new_device_name 328 eager_context.device_spec = new_device_spec 329 yield 330 finally: 331 eager_context.device_name = old_device_name 332 eager_context.device_spec = old_device_spec 333 334 def devices(self): 335 """List of the names of devices available to execute operations.""" 336 return self._devices 337 338 def num_gpus(self): 339 """The number of GPUs available to execute operations.""" 340 self._initialize_handle_and_devices() 341 return self._num_gpus 342 343 def add_function(self, fn): 344 """Add a function definition to the context. 345 346 Once added, the function (identified by its name) can be executed like any 347 other operation. 348 349 Args: 350 fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper). 351 """ 352 with errors.raise_exception_on_not_ok_status() as status: 353 pywrap_tensorflow.TFE_ContextAddFunction( 354 self._handle, # pylint: disable=protected-access 355 fn, 356 status) 357 358 def add_function_def(self, fdef): 359 """Add a function definition to the context. 360 361 Once added, the function (identified by its name) can be executed like any 362 other operation. 363 364 Args: 365 fdef: A FunctionDef protocol buffer message. 366 """ 367 fdef_string = fdef.SerializeToString() 368 with errors.raise_exception_on_not_ok_status() as status: 369 pywrap_tensorflow.TFE_ContextAddFunctionDef( 370 self._handle, # pylint: disable=protected-access 371 fdef_string, 372 len(fdef_string), 373 status) 374 375 def add_post_execution_callback(self, callback): 376 """Add a post-execution callback to the context. 377 378 A post-execution callback is invoked immediately after an eager operation or 379 function has finished execution, providing access to the op's type, name 380 input and output tensors. Multiple execution callbacks can be added, in 381 which case the callbacks will be invoked in the order in which they are 382 added. 383 384 Args: 385 callback: a callable of the signature 386 `f(op_type, op_name, attrs, inputs, outputs)`. 387 `op_type` is the type of the operation that was just executed (e.g., 388 `MatMul`). 389 `op_name` is the name of the operation that has was just executed. This 390 name is set by the client who created the operation and can be `None` if 391 it is unset. 392 `attrs` contains the attributes of the operation as a `tuple` of 393 alternating attribute names and attribute values. 394 `inputs` is the `list` of input `Tensor`(s) to the op. 395 `outputs` is the `list` of output `Tensor`(s) from the op. 396 Return value(s) from the callback are ignored. 397 """ 398 # TODO(cais): (b/64674139) Allow access to function-internal operations. 399 self._post_execution_callbacks.append(callback) 400 401 def clear_post_execution_callbacks(self): 402 """Clear all post-execution callbacks added to the context.""" 403 del self._post_execution_callbacks[:] 404 405 @property 406 def post_execution_callbacks(self): 407 """Get the list of post-execution callbacks added to the context.""" 408 return self._post_execution_callbacks 409 410 def enable_run_metadata(self): 411 """Enables tracing of op execution via RunMetadata. 412 413 To retrieve the accumulated metadata call context.export_run_metadata() 414 and to stop tracing call context.disable_run_metadata(). 415 """ 416 if not self._context_handle: 417 self._initialize_handle_and_devices() 418 pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._context_handle) 419 420 @tf_contextlib.contextmanager 421 def device_policy(self, policy): 422 if not self._context_handle: 423 self._initialize_handle_and_devices() 424 old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy( 425 self._context_handle) 426 pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy( 427 self._handle, policy) 428 try: 429 yield 430 finally: 431 pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy( 432 self._handle, old) 433 434 def disable_run_metadata(self): 435 """Disables tracing of op execution via RunMetadata.""" 436 if not self._context_handle: 437 return 438 pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle) 439 440 def export_run_metadata(self): 441 """Returns a RunMetadata proto with accumulated information. 442 443 The returned protocol buffer contains information since the most recent call 444 to either enable_run_metadata or export_run_metadata. 445 446 Returns: 447 A RunMetadata protocol buffer. Or None if not enabled. 448 """ 449 if not self._context_handle: 450 return None 451 with c_api_util.tf_buffer() as buffer_: 452 with errors.raise_exception_on_not_ok_status() as status: 453 pywrap_tensorflow.TFE_ContextExportRunMetadata( 454 self._context_handle, buffer_, status) 455 proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) 456 run_metadata = config_pb2.RunMetadata() 457 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 458 return run_metadata 459 460_context = None 461_context_lock = threading.Lock() 462 463 464def _initialize_context(): 465 global _context 466 with _context_lock: 467 if _context is None: 468 _context = Context() 469 470 471def context(): 472 """Returns a singleton context object.""" 473 if _context is None: 474 _initialize_context() 475 return _context 476 477 478# TODO(agarwal): remove this. 479def get_default_context(): 480 """Same as context.""" 481 if _context is None: 482 _initialize_context() 483 return _context 484 485 486def set_global_seed(seed): 487 """Sets the eager mode seed.""" 488 context()._set_global_seed(seed) # pylint: disable=protected-access 489 490 491def global_seed(): 492 """Returns the eager mode seed.""" 493 return context()._seed # pylint: disable=protected-access 494 495 496def internal_operation_seed(): 497 """Returns the operation seed generated based on global seed.""" 498 return context()._internal_operation_seed() # pylint: disable=protected-access 499 500 501def in_graph_mode(): 502 """Returns True if current thread is in GRAPH mode for default context.""" 503 return context().in_graph_mode() 504 505 506def in_eager_mode(): 507 """Returns True if current thread is in EAGER mode for default context.""" 508 return context().in_eager_mode() 509 510 511def graph_mode(): 512 """Context-manager to enable GRAPH mode for current thread.""" 513 return context()._mode(GRAPH_MODE) # pylint: disable=protected-access 514 515 516def eager_mode(): 517 """Context-manager to enable EAGER mode for current thread.""" 518 return context()._mode(EAGER_MODE) # pylint: disable=protected-access 519 520 521# TODO(agarwal): get rid of this and use ops.name_scope instead. 522@contextlib.contextmanager 523def namescope(name): 524 """ContextManager for creating hierarchical name scopes.""" 525 ctx = context() 526 old_name = ctx.scope_name 527 ctx.scope_name = "%s/%s" % (old_name, name) if old_name else name 528 try: 529 yield 530 finally: 531 ctx.scope_name = old_name 532 533 534def scope_name(): 535 """Name of the current scope.""" 536 return context().scope_name 537 538 539def device(name): 540 """Context-manager to force placement of operations and Tensors on a device. 541 542 Example: 543 ```python 544 with tfe.device('gpu:0'): 545 with tfe.device('cpu:0'): 546 shape = tf.constant([], dtype=tf.int32) 547 x = tf.truncated_normal(shape, tf.float32) 548 ``` 549 will ensure that the `shape` Tensor is on CPU but the `truncated_normal` 550 operation runs on GPU 0. 551 552 Args: 553 name: Name of the device (see context().devices()), or None to 554 perform automatic placement. 555 556 Returns: 557 Context manager for setting the device. 558 """ 559 return context().device(name) 560 561 562def list_devices(): 563 """List the names of the available devices. 564 565 Returns: 566 Names of the available devices, as a `list`. 567 """ 568 return context().devices() 569 570 571def num_gpus(): 572 """Get the number of available GPU devices. 573 574 Returns: 575 The number of available GPU devices. 576 """ 577 return context().num_gpus() 578 579 580def enable_run_metadata(): 581 """Enables tracing of op execution via RunMetadata. 582 583 To retrieve the accumulated metadata call context.export_run_metadata() 584 and to stop tracing call context.disable_run_metadata(). 585 """ 586 context().enable_run_metadata() 587 588 589def disable_run_metadata(): 590 """Disables tracing of op execution via RunMetadata.""" 591 context().disable_run_metadata() 592 593 594def export_run_metadata(): 595 """Returns a RunMetadata proto with accumulated information. 596 597 The returned protocol buffer contains information since the most recent call 598 to either enable_run_metadata or export_run_metadata. 599 600 Returns: 601 A RunMetadata protocol buffer. 602 """ 603 return context().export_run_metadata() 604 605 606# Not every user creates a Context via context.context() 607# (for example, enable_eager_execution in python/framework/ops.py), 608# but they do all import this file. Note that IS_IN_GRAPH_MODE and 609# in_graph_mode are both parameterless functions. 610is_in_graph_mode.IS_IN_GRAPH_MODE = in_graph_mode 611