• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for DenseLayer JIT compilation on the CPU and GPU devices."""
16
17import os
18
19import numpy as np
20
21from tensorflow.compiler.tests import test_utils
22from tensorflow.core.protobuf import config_pb2
23from tensorflow.python.compiler.xla import jit
24from tensorflow.python.framework import ops
25from tensorflow.python.layers import layers
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import variables
28from tensorflow.python.platform import test
29
30jit_scope = jit.experimental_jit_scope
31
32def GetRunMetadataLabels(run_metadata):
33  """Returns all labels in run_metadata."""
34  labels = []
35  for dev_stats in run_metadata.step_stats.dev_stats:
36    for node_stats in dev_stats.node_stats:
37      labels.append(node_stats.timeline_label)
38  return labels
39
40
41def InLabels(labels, substr):
42  """Returns true iff one of the labels contains substr."""
43  return any(substr in x for x in labels)
44
45
46class DenseLayerTest(test.TestCase):
47
48  def countXlaOps(self, labels):
49    """Count how many XlaCompile/XlaRun labels are present."""
50    xla_compile_count = sum("XlaCompile(" in x for x in labels)
51    xla_run_count = sum("XlaRun(" in x for x in labels)
52    self.assertEqual(xla_compile_count, xla_run_count)
53    return xla_run_count
54
55
56  def testDenseLayerAutoJit(self):
57    """Tests dense layer compilation in auto-jit mode.
58
59    Dense layer should be compiled into a single XlaCompile/XlaRun op pair in
60    auto-jit mode.
61    """
62
63    os.environ["TF_XLA_FLAGS"] = (
64        "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
65    config = config_pb2.ConfigProto()
66    config.graph_options.optimizer_options.global_jit_level = (
67        config_pb2.OptimizerOptions.ON_1)
68
69    with self.session(config=config) as sess:
70      x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
71      y = layers.dense(x, 3)
72
73      self.evaluate(variables.global_variables_initializer())
74      run_metadata = config_pb2.RunMetadata()
75      test_utils.RunWithWarmup(
76          sess,
77          y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])},
78          run_metadata=run_metadata,
79          options=config_pb2.RunOptions(
80              trace_level=config_pb2.RunOptions.FULL_TRACE))
81
82    labels = GetRunMetadataLabels(run_metadata)
83    self.assertEqual(1, self.countXlaOps(labels))
84    self.assertFalse(InLabels(labels, "MatMult"))
85
86  def testDenseLayerJitScopeDefinedShape(self):
87    """Tests that the dense layer node is properly compiled in jit scope.
88
89    Dense layer with static shape input tensor should be compiled into a single
90    XlaCompile/XlaRun op pair by XLA.
91    """
92
93    with self.session() as sess:
94      x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32)
95      with jit_scope():
96        y = layers.dense(x, 3)
97
98      self.evaluate(variables.global_variables_initializer())
99      run_metadata = config_pb2.RunMetadata()
100      test_utils.RunWithWarmup(
101          sess,
102          y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])},
103          run_metadata=run_metadata,
104          options=config_pb2.RunOptions(
105              trace_level=config_pb2.RunOptions.FULL_TRACE))
106
107    labels = GetRunMetadataLabels(run_metadata)
108    self.assertEqual(1, self.countXlaOps(labels))
109    # No need to check whether ListDiff is compiled or not because ListDiff op
110    # is not used when input tensor shape is fully defined.
111
112  def testDenseLayerJitScopeUndefinedShape(self):
113    """Tests that the dense layer node is properly compiled in jit scope.
114    """
115
116    with self.session() as sess:
117      x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
118      with jit_scope():
119        y = layers.dense(x, 3)
120
121      self.evaluate(variables.global_variables_initializer())
122      run_metadata = config_pb2.RunMetadata()
123      test_utils.RunWithWarmup(
124          sess,
125          y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])},
126          run_metadata=run_metadata,
127          options=config_pb2.RunOptions(
128              trace_level=config_pb2.RunOptions.FULL_TRACE))
129
130    labels = GetRunMetadataLabels(run_metadata)
131    self.assertEqual(1, self.countXlaOps(labels))
132    self.assertFalse(InLabels(labels, "MatMult"))
133
134
135if __name__ == "__main__":
136  os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " +
137                                os.environ.get("TF_XLA_FLAGS", ""))
138  # This test is using Tensorflow sessions which are not compatible with eager
139  # mode.
140  ops.disable_eager_execution()
141  test.main()
142