1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Library for controlling the Tensorflow/XLA JIT compiler.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22 23from tensorflow.core.framework import attr_value_pb2 24from tensorflow.python.framework import ops 25 26 27_XLA_SCOPE_KEY = ("__xla_scope",) 28 29 30class _XlaScope(object): 31 """Keeps track of previous XLA scope calls, and depth of current call.""" 32 33 def __init__(self, count, depth): 34 self.count = count 35 self.depth = depth 36 37 38@contextlib.contextmanager 39def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False): 40 """Enable or disable JIT compilation of operators within the scope. 41 42 NOTE: This is an experimental feature. 43 44 The compilation is a hint and only supported on a best-effort basis. 45 46 Example usage: 47 with tf.contrib.compiler.experimental_jit_scope(): 48 c = tf.matmul(a, b) # compiled 49 with tf.contrib.compiler.experimental_jit_scope(compile_ops=False): 50 d = tf.matmul(a, c) # not compiled 51 with tf.contrib.compiler.experimental_jit_scope( 52 compile_ops=lambda node_def: 'matmul' in node_def.op.lower()): 53 e = tf.matmul(a, b) + d # matmul is compiled, the addition is not. 54 55 Example of separate_compiled_gradients: 56 # In the example below, the computations for f, g and h will all be compiled 57 # in separate scopes. 58 with tf.contrib.compiler.experimental_jit_scope( 59 separate_compiled_gradients=True): 60 f = tf.matmul(a, b) 61 g = tf.gradients([f], [a, b], name='mygrads1') 62 h = tf.gradients([f], [a, b], name='mygrads2') 63 64 Args: 65 compile_ops: Whether to enable or disable compilation in the scope. 66 Either a Python bool, or a callable that accepts the parameter 67 `node_def` and returns a python bool. 68 separate_compiled_gradients: If true put each gradient subgraph into a 69 separate compilation scope. This gives fine-grained control over which 70 portions of the graph will be compiled as a single unit. Compiling 71 gradients separately may yield better performance for some graphs. 72 The scope is named based on the scope of the forward computation as well 73 as the name of the gradients. As a result, the gradients will be compiled 74 in a scope that is separate from both the forward computation, and from 75 other gradients. 76 Yields: 77 The current scope, enabling or disabling compilation. 78 79 """ 80 if callable(compile_ops): 81 def xla_compile(node_def): 82 return attr_value_pb2.AttrValue(b=compile_ops(node_def)) 83 else: 84 xla_compile = attr_value_pb2.AttrValue(b=compile_ops) 85 86 attrs = { 87 "_XlaCompile": 88 xla_compile, 89 "_XlaSeparateCompiledGradients": 90 attr_value_pb2.AttrValue(b=bool(separate_compiled_gradients)) 91 } 92 93 # Find the singleton counter for the current scoped graph. If it 94 # doesn't exist, create one. 95 xla_scope_counter = ops.get_collection(_XLA_SCOPE_KEY) 96 if not xla_scope_counter: 97 xla_scope_counter = _XlaScope(0, 0) 98 ops.add_to_collection(_XLA_SCOPE_KEY, xla_scope_counter) 99 else: 100 xla_scope_counter = xla_scope_counter[0] 101 102 if xla_scope_counter.depth == 0: 103 # If we're at the root xla scope, we can increase the counter so 104 # future calls to jit_scope use a different scope value. 105 # If we're already within a scope, we'll be fusing using the scope 106 # controlled by the parent. 107 attrs["_XlaScope"] = attr_value_pb2.AttrValue( 108 s=("jit_scope_%d" % xla_scope_counter.count).encode()) 109 xla_scope_counter.count += 1 110 111 xla_scope_counter.depth += 1 112 113 # pylint: disable=protected-access 114 with ops.get_default_graph()._attr_scope(attrs): 115 yield 116 # pylint: enable=protected-access 117 118 xla_scope_counter.depth -= 1 119