1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for DenseLayer JIT compilation on the CPU and GPU devices.""" 16 17import os 18 19import numpy as np 20 21from tensorflow.compiler.tests import test_utils 22from tensorflow.core.protobuf import config_pb2 23from tensorflow.python.compiler.xla import jit 24from tensorflow.python.framework import ops 25from tensorflow.python.layers import layers 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import variables 28from tensorflow.python.platform import test 29 30jit_scope = jit.experimental_jit_scope 31 32def GetRunMetadataLabels(run_metadata): 33 """Returns all labels in run_metadata.""" 34 labels = [] 35 for dev_stats in run_metadata.step_stats.dev_stats: 36 for node_stats in dev_stats.node_stats: 37 labels.append(node_stats.timeline_label) 38 return labels 39 40 41def InLabels(labels, substr): 42 """Returns true iff one of the labels contains substr.""" 43 return any(substr in x for x in labels) 44 45 46class DenseLayerTest(test.TestCase): 47 48 def countXlaOps(self, labels): 49 """Count how many XlaCompile/XlaRun labels are present.""" 50 xla_compile_count = sum("XlaCompile(" in x for x in labels) 51 xla_run_count = sum("XlaRun(" in x for x in labels) 52 self.assertEqual(xla_compile_count, xla_run_count) 53 return xla_run_count 54 55 56 def testDenseLayerAutoJit(self): 57 """Tests dense layer compilation in auto-jit mode. 58 59 Dense layer should be compiled into a single XlaCompile/XlaRun op pair in 60 auto-jit mode. 61 """ 62 63 os.environ["TF_XLA_FLAGS"] = ( 64 "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) 65 config = config_pb2.ConfigProto() 66 config.graph_options.optimizer_options.global_jit_level = ( 67 config_pb2.OptimizerOptions.ON_1) 68 69 with self.session(config=config) as sess: 70 x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) 71 y = layers.dense(x, 3) 72 73 self.evaluate(variables.global_variables_initializer()) 74 run_metadata = config_pb2.RunMetadata() 75 test_utils.RunWithWarmup( 76 sess, 77 y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, 78 run_metadata=run_metadata, 79 options=config_pb2.RunOptions( 80 trace_level=config_pb2.RunOptions.FULL_TRACE)) 81 82 labels = GetRunMetadataLabels(run_metadata) 83 self.assertEqual(1, self.countXlaOps(labels)) 84 self.assertFalse(InLabels(labels, "MatMult")) 85 86 def testDenseLayerJitScopeDefinedShape(self): 87 """Tests that the dense layer node is properly compiled in jit scope. 88 89 Dense layer with static shape input tensor should be compiled into a single 90 XlaCompile/XlaRun op pair by XLA. 91 """ 92 93 with self.session() as sess: 94 x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32) 95 with jit_scope(): 96 y = layers.dense(x, 3) 97 98 self.evaluate(variables.global_variables_initializer()) 99 run_metadata = config_pb2.RunMetadata() 100 test_utils.RunWithWarmup( 101 sess, 102 y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, 103 run_metadata=run_metadata, 104 options=config_pb2.RunOptions( 105 trace_level=config_pb2.RunOptions.FULL_TRACE)) 106 107 labels = GetRunMetadataLabels(run_metadata) 108 self.assertEqual(1, self.countXlaOps(labels)) 109 # No need to check whether ListDiff is compiled or not because ListDiff op 110 # is not used when input tensor shape is fully defined. 111 112 def testDenseLayerJitScopeUndefinedShape(self): 113 """Tests that the dense layer node is properly compiled in jit scope. 114 """ 115 116 with self.session() as sess: 117 x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) 118 with jit_scope(): 119 y = layers.dense(x, 3) 120 121 self.evaluate(variables.global_variables_initializer()) 122 run_metadata = config_pb2.RunMetadata() 123 test_utils.RunWithWarmup( 124 sess, 125 y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, 126 run_metadata=run_metadata, 127 options=config_pb2.RunOptions( 128 trace_level=config_pb2.RunOptions.FULL_TRACE)) 129 130 labels = GetRunMetadataLabels(run_metadata) 131 self.assertEqual(1, self.countXlaOps(labels)) 132 self.assertFalse(InLabels(labels, "MatMult")) 133 134 135if __name__ == "__main__": 136 os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " + 137 os.environ.get("TF_XLA_FLAGS", "")) 138 # This test is using Tensorflow sessions which are not compatible with eager 139 # mode. 140 ops.disable_eager_execution() 141 test.main() 142