• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library for controlling the Tensorflow/XLA JIT compiler."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22
23from tensorflow.core.framework import attr_value_pb2
24from tensorflow.python.eager import context
25from tensorflow.python.framework import ops
26from tensorflow.python.util.tf_export import tf_export
27
28
29_XLA_SCOPE_KEY = ("__xla_scope",)
30
31
32class _XlaScope(object):
33  """Keeps track of previous XLA scope calls, and depth of current call."""
34
35  def __init__(self, count, depth):
36    self.count = count
37    self.depth = depth
38
39
40@contextlib.contextmanager
41@tf_export("xla.experimental.jit_scope")
42def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False):
43  """Enable or disable JIT compilation of operators within the scope.
44
45  NOTE: This is an experimental feature.
46
47  The compilation is a hint and only supported on a best-effort basis.
48
49  Example usage:
50
51    ```python
52    with tf.xla.experimental.jit_scope():
53      c = tf.matmul(a, b)  # compiled
54    with tf.xla.experimental.jit_scope(compile_ops=False):
55      d = tf.matmul(a, c)  # not compiled
56    with tf.xla.experimental.jit_scope(
57        compile_ops=lambda node_def: 'matmul' in node_def.op.lower()):
58      e = tf.matmul(a, b) + d  # matmul is compiled, the addition is not.
59    ```
60
61  Example of `separate_compiled_gradients`:
62
63    ```python
64    # In the example below, the computations for f, g and h will all be compiled
65    # in separate scopes.
66    with tf.xla.experimental.jit_scope(
67        separate_compiled_gradients=True):
68      f = tf.matmul(a, b)
69    g = tf.gradients([f], [a, b], name='mygrads1')
70    h = tf.gradients([f], [a, b], name='mygrads2')
71    ```
72
73  Ops that are not in the scope may be clustered and compiled with ops in
74  the scope with `compile_ops=True`, while the ops in the scope with
75  `compile_ops=False` will never be compiled.
76
77  For example:
78
79    ```python
80    # In the example below, x and loss may be clustered and compiled together,
81    # while y will not be compiled.
82    with tf.xla.experimental.jit_scope():
83      x = tf.matmul(a, b)
84    with tf.xla.experimental.jit_scope(compile_ops=False):
85      y = tf.matmul(c, d)
86    loss = x + y
87    ```
88
89  If you want to only compile the ops in the scope with `compile_ops=True`,
90  consider adding an outer `jit_scope(compile_ops=False)`:
91
92    ```python
93    # In the example below, only x will be compiled.
94    with tf.xla.experimental.jit_scope(compile_ops=False):
95      with tf.xla.experimental.jit_scope():
96        x = tf.matmul(a, b)
97      y = tf.matmul(c, d)
98      loss = x + y
99    ```
100
101  Args:
102    compile_ops: Whether to enable or disable compilation in the scope.
103      Either a Python bool, or a callable that accepts the parameter
104      `node_def` and returns a python bool.
105    separate_compiled_gradients: If true put each gradient subgraph into a
106      separate compilation scope. This gives fine-grained control over which
107      portions of the graph will be compiled as a single unit. Compiling
108      gradients separately may yield better performance for some graphs.
109      The scope is named based on the scope of the forward computation as well
110      as the name of the gradients. As a result, the gradients will be compiled
111      in a scope that is separate from both the forward computation, and from
112      other gradients.
113  Raises:
114    RuntimeError: if called when eager execution is enabled.
115  Yields:
116    The current scope, enabling or disabling compilation.
117  """
118  if context.executing_eagerly():
119    raise RuntimeError("xla.experimental.jit_scope is not supported when eager "
120                       "execution is enabled. Try use it inside tf.function.")
121
122  if callable(compile_ops):
123    def xla_compile(node_def):
124      return attr_value_pb2.AttrValue(b=compile_ops(node_def))
125  else:
126    xla_compile = attr_value_pb2.AttrValue(b=compile_ops)
127
128  attrs = {
129      "_XlaCompile":
130          xla_compile,
131      "_XlaSeparateCompiledGradients":
132          attr_value_pb2.AttrValue(b=bool(separate_compiled_gradients))
133  }
134
135  # Find the singleton counter for the current scoped graph.  If it
136  # doesn't exist, create one.
137  xla_scope_counter = ops.get_collection(_XLA_SCOPE_KEY)
138  if not xla_scope_counter:
139    xla_scope_counter = _XlaScope(0, 0)
140    ops.add_to_collection(_XLA_SCOPE_KEY, xla_scope_counter)
141  else:
142    xla_scope_counter = xla_scope_counter[0]
143
144  if xla_scope_counter.depth == 0:
145    # If we're at the root xla scope, we can increase the counter so
146    # future calls to jit_scope use a different scope value.
147    # If we're already within a scope, we'll be fusing using the scope
148    # controlled by the parent.
149    attrs["_XlaScope"] = attr_value_pb2.AttrValue(
150        s=("jit_scope_%d" % xla_scope_counter.count).encode())
151    xla_scope_counter.count += 1
152
153  xla_scope_counter.depth += 1
154
155  # pylint: disable=protected-access
156  with ops.get_default_graph()._attr_scope(attrs):
157    yield
158  # pylint: enable=protected-access
159
160  xla_scope_counter.depth -= 1
161