1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Library for controlling the Tensorflow/XLA JIT compiler.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22 23from tensorflow.core.framework import attr_value_pb2 24from tensorflow.python.eager import context 25from tensorflow.python.framework import ops 26from tensorflow.python.util.tf_export import tf_export 27 28 29_XLA_SCOPE_KEY = ("__xla_scope",) 30 31 32class _XlaScope(object): 33 """Keeps track of previous XLA scope calls, and depth of current call.""" 34 35 def __init__(self, count, depth): 36 self.count = count 37 self.depth = depth 38 39 40@contextlib.contextmanager 41@tf_export("xla.experimental.jit_scope") 42def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False): 43 """Enable or disable JIT compilation of operators within the scope. 44 45 NOTE: This is an experimental feature. 46 47 The compilation is a hint and only supported on a best-effort basis. 48 49 Example usage: 50 51 ```python 52 with tf.xla.experimental.jit_scope(): 53 c = tf.matmul(a, b) # compiled 54 with tf.xla.experimental.jit_scope(compile_ops=False): 55 d = tf.matmul(a, c) # not compiled 56 with tf.xla.experimental.jit_scope( 57 compile_ops=lambda node_def: 'matmul' in node_def.op.lower()): 58 e = tf.matmul(a, b) + d # matmul is compiled, the addition is not. 59 ``` 60 61 Example of `separate_compiled_gradients`: 62 63 ```python 64 # In the example below, the computations for f, g and h will all be compiled 65 # in separate scopes. 66 with tf.xla.experimental.jit_scope( 67 separate_compiled_gradients=True): 68 f = tf.matmul(a, b) 69 g = tf.gradients([f], [a, b], name='mygrads1') 70 h = tf.gradients([f], [a, b], name='mygrads2') 71 ``` 72 73 Ops that are not in the scope may be clustered and compiled with ops in 74 the scope with `compile_ops=True`, while the ops in the scope with 75 `compile_ops=False` will never be compiled. 76 77 For example: 78 79 ```python 80 # In the example below, x and loss may be clustered and compiled together, 81 # while y will not be compiled. 82 with tf.xla.experimental.jit_scope(): 83 x = tf.matmul(a, b) 84 with tf.xla.experimental.jit_scope(compile_ops=False): 85 y = tf.matmul(c, d) 86 loss = x + y 87 ``` 88 89 If you want to only compile the ops in the scope with `compile_ops=True`, 90 consider adding an outer `jit_scope(compile_ops=False)`: 91 92 ```python 93 # In the example below, only x will be compiled. 94 with tf.xla.experimental.jit_scope(compile_ops=False): 95 with tf.xla.experimental.jit_scope(): 96 x = tf.matmul(a, b) 97 y = tf.matmul(c, d) 98 loss = x + y 99 ``` 100 101 Args: 102 compile_ops: Whether to enable or disable compilation in the scope. 103 Either a Python bool, or a callable that accepts the parameter 104 `node_def` and returns a python bool. 105 separate_compiled_gradients: If true put each gradient subgraph into a 106 separate compilation scope. This gives fine-grained control over which 107 portions of the graph will be compiled as a single unit. Compiling 108 gradients separately may yield better performance for some graphs. 109 The scope is named based on the scope of the forward computation as well 110 as the name of the gradients. As a result, the gradients will be compiled 111 in a scope that is separate from both the forward computation, and from 112 other gradients. 113 Raises: 114 RuntimeError: if called when eager execution is enabled. 115 Yields: 116 The current scope, enabling or disabling compilation. 117 """ 118 if context.executing_eagerly(): 119 raise RuntimeError("xla.experimental.jit_scope is not supported when eager " 120 "execution is enabled. Try use it inside tf.function.") 121 122 if callable(compile_ops): 123 def xla_compile(node_def): 124 return attr_value_pb2.AttrValue(b=compile_ops(node_def)) 125 else: 126 xla_compile = attr_value_pb2.AttrValue(b=compile_ops) 127 128 attrs = { 129 "_XlaCompile": 130 xla_compile, 131 "_XlaSeparateCompiledGradients": 132 attr_value_pb2.AttrValue(b=bool(separate_compiled_gradients)) 133 } 134 135 # Find the singleton counter for the current scoped graph. If it 136 # doesn't exist, create one. 137 xla_scope_counter = ops.get_collection(_XLA_SCOPE_KEY) 138 if not xla_scope_counter: 139 xla_scope_counter = _XlaScope(0, 0) 140 ops.add_to_collection(_XLA_SCOPE_KEY, xla_scope_counter) 141 else: 142 xla_scope_counter = xla_scope_counter[0] 143 144 if xla_scope_counter.depth == 0: 145 # If we're at the root xla scope, we can increase the counter so 146 # future calls to jit_scope use a different scope value. 147 # If we're already within a scope, we'll be fusing using the scope 148 # controlled by the parent. 149 attrs["_XlaScope"] = attr_value_pb2.AttrValue( 150 s=("jit_scope_%d" % xla_scope_counter.count).encode()) 151 xla_scope_counter.count += 1 152 153 xla_scope_counter.depth += 1 154 155 # pylint: disable=protected-access 156 with ops.get_default_graph()._attr_scope(attrs): 157 yield 158 # pylint: enable=protected-access 159 160 xla_scope_counter.depth -= 1 161