• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for JIT compilation on the CPU and GPU devices."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23import numpy as np
24
25from tensorflow.compiler.tests import test_utils
26from tensorflow.core.protobuf import config_pb2
27from tensorflow.core.protobuf import rewriter_config_pb2
28from tensorflow.python.client import session as session_lib
29from tensorflow.python.compiler.xla import jit
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import function
33from tensorflow.python.framework import ops
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import gradients_impl
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import nn_ops
39from tensorflow.python.platform import test
40
41
42jit_scope = jit.experimental_jit_scope
43
44# Disable rewrites to make sure we don't end up having to update this test
45# whenever we implement new ones.
46def NoRewriteSessionConfig():
47  rewriter_config = rewriter_config_pb2.RewriterConfig(
48      disable_model_pruning=True,
49      arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
50      dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
51      function_optimization=rewriter_config_pb2.RewriterConfig.OFF)
52  graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
53  return config_pb2.ConfigProto(graph_options=graph_options)
54
55
56def CompiledKernel(fn, *inputs, **kwargs):
57  """Execute 'fn' as a compiled XLA kernel, with 'inputs'."""
58  name = kwargs.pop("name", None)
59  noinline = kwargs.pop("noinline", None)
60
61  @function.Defun(func_name=name, noinline=noinline, compiled=True)
62  def Compiled(*args):
63    return fn(*args)
64
65  return Compiled(*inputs)
66
67
68def RunMetadataLabels(run_metadata):
69  """Returns all labels in run_metadata."""
70  labels = []
71  for dev_stats in run_metadata.step_stats.dev_stats:
72    for node_stats in dev_stats.node_stats:
73      labels.append(node_stats.timeline_label)
74  return labels
75
76
77def InLabels(labels, substr):
78  """Returns true iff one of the labels contains substr."""
79  return any(substr in x for x in labels)
80
81
82def MetadataHasXlaRunOp(run_metadata):
83  """Returns true if there are XlaRun kernels in run_metadata's timeline."""
84
85  # TODO(phawkins): find a less hacky way to test whether a kernel ran.
86  return InLabels(RunMetadataLabels(run_metadata), "_XlaRun")
87
88
89class JitLaunchTest(test.TestCase):
90
91  # Evaluates 'fn' on 'args' both directly and as a compiled XLA kernel.
92  # Verifies that the outputs match and that XLA was invoked. 'fn' must take
93  # the same number of tensors as arguments that are in 'args', and must return
94  # a tuple of output tensors.
95  #
96  # If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun
97  # node actually ran. However, it is sometimes possible for XlaCompile/XlaRun
98  # ops to be constant-folded away, so the check is optional.
99  def _compare(self,
100               fn,
101               args,
102               require_kernel_launch=True,
103               name=None,
104               noinline=None):
105    with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
106      placeholders = []
107      feeds = {}
108      for arg in args:
109        placeholder = array_ops.placeholder(
110            dtypes.as_dtype(arg.dtype), list(arg.shape))
111        placeholders.append(placeholder)
112        feeds[placeholder] = arg
113
114      compiled_op = CompiledKernel(
115          fn, *placeholders, name=name, noinline=noinline)
116      direct_op = fn(*placeholders)
117
118      run_metadata = config_pb2.RunMetadata()
119      compiled = test_utils.RunWithWarmup(
120          sess, compiled_op, feeds,
121          config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE),
122          run_metadata)
123      print("Compiled Result {}".format(compiled))
124
125      if require_kernel_launch:
126        self.assert_(MetadataHasXlaRunOp(run_metadata))
127
128        direct = sess.run(direct_op, feeds)
129        print("Direct Result {}".format(direct))
130
131        if (isinstance(compiled, (tuple, list)) and
132            (isinstance(direct, (tuple, list)))):
133          for (x, y) in zip(compiled, direct):
134            self.assertAllClose(x, y, rtol=1e-1)
135        else:
136          self.assertAllClose(compiled, direct, rtol=1e-2)
137
138  def testNoOutputs(self):
139    with session_lib.Session() as sess:
140
141      # Check that calling the result as a compiled kernel doesn't crash.
142      @function.Defun(compiled=True)
143      def KernelWithNoOutputs():
144        a = constant_op.constant(100)  # pylint: disable=unused-variable
145
146      call = KernelWithNoOutputs()  # pylint: disable=assignment-from-no-return
147      test_utils.RunWithWarmup(sess, call, {})
148
149  def testAliasing(self):
150    """Regression test for compiled functions that return an aliased buffer.
151
152       XLA returns aliased buffers if outputs are identical. Tests that
153       we handle that case.
154    """
155
156    def AddOnceReturnTwice(x):
157      y = math_ops.add(x, x)
158      return y, y
159
160    # Exercises compiling a function (say, Foo) which calls another function
161    # (say, Bar) which is not inlined. When the compiler compiles Foo, it needs
162    # to symbolically execute Bar correctly regardless of whether Bar is inlined
163    # or not.
164
165    # Tests compiled=True and noinline=True.
166    self._compare(
167        AddOnceReturnTwice, [np.array([[[0.5, -1.0]]], dtype=np.float32)],
168        name="AddOnceReturnTwice_inline",
169        noinline=True)
170
171    # Tests compiled=True and noinline=False.
172    self._compare(
173        AddOnceReturnTwice, [np.array([[[0.5, -1.0]]], dtype=np.float32)],
174        name="AddOnceReturnTwice_noinline",
175        noinline=False)
176
177  def testOneConstOutput(self):
178    """Test consisting of a single constant return value."""
179
180    def OneConstOutput():
181      return constant_op.constant([-3, 44, 99])
182
183    self._compare(OneConstOutput, [], require_kernel_launch=False)
184
185  def testConstZeroElementOutput(self):
186    """Test consisting of a constant zero element return value."""
187
188    def ConstZeroElementOutput():
189      return array_ops.fill([7, 0], 3.0)
190
191    self._compare(ConstZeroElementOutput, [], require_kernel_launch=False)
192
193  def testSomeConstOutputs(self):
194    """Test kernels that return a mixture of const and non-const outputs."""
195
196    def SomeConstOutputs(x):
197      return constant_op.constant(
198          [-2, 7]), array_ops.identity(x), constant_op.constant(3.5)
199
200    self._compare(
201        SomeConstOutputs, [np.array(
202            [[1, 2, 3], [4, 5, 6]], dtype=np.float32)])
203
204  def testInt32Input(self):
205    """Test an int32-typed input.
206
207       On a GPU, int32 tensors will be placed in host memory.
208    """
209
210    def AddToSelf(x):
211      return math_ops.add(x, x)
212
213    self._compare(AddToSelf, [np.array([7, 1, 3], dtype=np.int32)])
214
215  def testMandatoryConstantInput(self):
216    """Tests an operator that has a mandatory-constant shape input."""
217
218    def FillWithFloat(x):
219      return array_ops.fill(x, 9.5)
220
221    self._compare(FillWithFloat, [np.array([3, 2], dtype=np.int32)])
222
223  def testMnistForwardFunc(self):
224    """Compute inference function from MNIST beginners tutorial."""
225    batch_size = 16
226    image_size = 28 * 28
227    num_classes = 10
228
229    # Define a TensorFlow function to compute the forward pass.
230    def MnistForward(w, b, x):
231      return nn_ops.softmax(math_ops.matmul(x, w) + b)
232
233    w = np.random.random_sample((image_size, num_classes)).astype(np.float32)
234    b = np.random.random_sample((num_classes)).astype(np.float32)
235    x = np.random.random_sample((batch_size, image_size)).astype(np.float32)
236    self._compare(MnistForward, [w, b, x])
237
238  def testExplicitMarking(self):
239    """Test explicit marking of operators to compile."""
240    batch_size = 16
241    image_size = 28 * 28
242    num_classes = 10
243
244    with ops.Graph().as_default():
245      x = array_ops.placeholder(dtypes.float32)
246      w = array_ops.placeholder(dtypes.float32)
247      b = array_ops.placeholder(dtypes.float32)
248      with jit_scope():
249        y1 = math_ops.matmul(x, w)
250      y2 = math_ops.add(y1, b)
251      with jit_scope():
252        y = math_ops.square(y2)
253
254      dw = np.random.random_sample((image_size, num_classes)).astype(np.float32)
255      db = np.random.random_sample((num_classes)).astype(np.float32)
256      dx = np.random.random_sample((batch_size, image_size)).astype(np.float32)
257      with session_lib.Session() as sess:
258        run_metadata = config_pb2.RunMetadata()
259        output = test_utils.RunWithWarmup(
260            sess,
261            y, {
262                x: dx,
263                w: dw,
264                b: db
265            },
266            run_metadata=run_metadata,
267            options=config_pb2.RunOptions(
268                trace_level=config_pb2.RunOptions.FULL_TRACE))
269
270        # TODO(phawkins): really we would like to test that there were exactly
271        # two kernel launches. However, we have no reliable way to determine
272        # that.
273        self.assert_(MetadataHasXlaRunOp(run_metadata))
274
275        expected = np.square(np.dot(dx, dw) + db)
276        self.assertAllClose(expected, output, rtol=1e-1)
277
278
279class XlaCompilationTest(test.TestCase):
280  """Tests for auto-compilation on CPU/GPU devices."""
281
282  def testReshape(self):
283    """Tests an operator with compile-time constant and non-constant inputs."""
284
285    with self.session(config=NoRewriteSessionConfig()) as sess:
286      x = array_ops.placeholder(dtypes.float32)
287      y = array_ops.placeholder(dtypes.int32)
288      with jit_scope():
289        # Reshape's first argument is non-constant in the JIT, but its second
290        # (shape) argument will be treated as a compile-time constant for
291        # each JIT compilation.
292        # We do not use a tf.const() argument since we want to ensure the
293        # shape is still a run-time argument to the JIT, and not
294        # statically known as part of the JIT compilation's input graph.
295        z = array_ops.reshape(x, y)
296      run_metadata = config_pb2.RunMetadata()
297      out = test_utils.RunWithWarmup(
298          sess,
299          z, {
300              x: np.array([1, 2, 3, 4, 5, 6], np.float32),
301              y: [-1, 3]
302          },
303          run_metadata=run_metadata,
304          options=config_pb2.RunOptions(
305              trace_level=config_pb2.RunOptions.FULL_TRACE))
306      self.assert_(MetadataHasXlaRunOp(run_metadata))
307      self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out)
308
309  def testIgnoredArguments(self):
310    """Tests that JIT computations can ignore formal parameters."""
311
312    with self.session(config=NoRewriteSessionConfig()) as sess:
313      x = array_ops.placeholder(dtypes.int32)
314      y = array_ops.placeholder(dtypes.int32)
315      with jit_scope():
316        z = math_ops.add(x, x)
317        w = math_ops.add(y, y)
318        # Pulls 'w' into the same compilation via control dependencies.
319        with ops.control_dependencies([w]):
320          n = control_flow_ops.no_op()
321        with ops.control_dependencies([n]):
322          t = math_ops.add(z, z)
323
324      run_metadata = config_pb2.RunMetadata()
325      out = test_utils.RunWithWarmup(
326          sess,
327          t, {
328              x: np.int32(7),
329              y: np.int32(404)
330          },
331          run_metadata=run_metadata,
332          options=config_pb2.RunOptions(
333              trace_level=config_pb2.RunOptions.FULL_TRACE))
334      self.assert_(MetadataHasXlaRunOp(run_metadata))
335      self.assertAllClose(28, out)
336
337  def testLoops(self):
338    """Tests that compilation accepts computations containing loops."""
339
340    with self.session(config=NoRewriteSessionConfig()) as session:
341      x = array_ops.placeholder(dtypes.float32)
342      with jit_scope():
343        c = lambda i, _: math_ops.less(i, 5)
344        b = lambda i, x: (i + 1, x * 2.0 + 1.0)
345        _, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x))
346
347      run_metadata = config_pb2.RunMetadata()
348      result = session.run(y, {x: np.float32(2)},
349                           run_metadata=run_metadata,
350                           options=config_pb2.RunOptions(
351                               trace_level=config_pb2.RunOptions.FULL_TRACE))
352      self.assert_(MetadataHasXlaRunOp(run_metadata))
353      self.assertAllClose(result, np.float32(95), rtol=1e-1)
354
355  def testCond(self):
356    """Tests that compilation handles switch operators."""
357
358    with self.session(config=NoRewriteSessionConfig()) as session:
359      x = array_ops.placeholder(dtypes.float32)
360      y = array_ops.placeholder(dtypes.float32)
361      c = array_ops.placeholder(dtypes.bool)
362      with jit_scope():
363        z = x + 1.0
364        w = control_flow_ops.cond(c, lambda: z, lambda: y)
365        t = math_ops.add(z, w)
366
367      # If JIT compilation chooses to cluster z and t, then execution will
368      # deadlock.
369
370      run_metadata = config_pb2.RunMetadata()
371      result = test_utils.RunWithWarmup(
372          session,
373          t, {
374              x: np.float32(2),
375              y: np.float32(4),
376              c: True
377          },
378          run_metadata=run_metadata,
379          options=config_pb2.RunOptions(
380              trace_level=config_pb2.RunOptions.FULL_TRACE))
381      self.assert_(MetadataHasXlaRunOp(run_metadata))
382      self.assertAllClose(result, np.float32(6), rtol=1e-1)
383
384  def testNestedFunction(self):
385    g = ops.Graph()
386    with g.as_default():
387
388      @function.Defun(compiled=True)
389      def Bar(x, y):
390        return x + 2 * y
391
392      @function.Defun(compiled=True)
393      def Foo(x):
394        return Bar(x * x, x * x * x)
395
396      @function.Defun()
397      def Entry(x):
398        return Foo(x)
399
400      inp = array_ops.placeholder(dtypes.float32)
401      out = Entry(inp)
402
403    with self.session(
404        config=NoRewriteSessionConfig(), graph=g, use_gpu=True) as sess:
405      run_metadata = config_pb2.RunMetadata()
406      val = sess.run(out,
407                     feed_dict={inp: [2., 10.]},
408                     run_metadata=run_metadata,
409                     options=config_pb2.RunOptions(
410                         trace_level=config_pb2.RunOptions.FULL_TRACE))
411      self.assertAllClose(val, [20., 2100.])
412
413  def testLoopDeadlock(self):
414    """Regression test for bug that caused deadlocks in graphs with loops."""
415
416    with self.session(config=NoRewriteSessionConfig()) as session:
417      x = array_ops.placeholder(dtypes.float32)
418      with jit_scope():
419        y = x + 1.0
420        c = lambda i, _x, _y: math_ops.less(i, 5)
421        b = lambda i, x, _y: (i + 1, x * 2.0 + 1.0, x - 3.0)
422        _, _, w = control_flow_ops.while_loop(c, b,
423                                              (constant_op.constant(0), y, x))
424        u = w + y
425      result = session.run(u, {x: np.float32(2)})
426      self.assertAllClose(result, np.float32(63), rtol=1e-1)
427
428  def testGradient(self):
429    """Tests that the backprop function is properly compiled."""
430
431    def _Run(compiled):
432
433      @function.Defun(compiled=compiled)
434      def Forward(x):
435        return math_ops.log(x)
436
437      g = ops.Graph()
438      with g.as_default():
439        x = array_ops.placeholder(dtypes.float32)
440        y = Forward(x)
441        dx, = gradients_impl.gradients(y, [x], 1.0)
442
443      cfg = NoRewriteSessionConfig()
444      cfg.graph_options.optimizer_options.opt_level = (
445          config_pb2.OptimizerOptions.L1)
446      cfg.graph_options.optimizer_options.do_function_inlining = True
447      with session_lib.Session(graph=g, config=cfg) as sess:
448        run_metadata = config_pb2.RunMetadata()
449        dx_val = test_utils.RunWithWarmup(
450            sess,
451            dx,
452            feed_dict={x: 100.},
453            run_metadata=run_metadata,
454            options=config_pb2.RunOptions(
455                trace_level=config_pb2.RunOptions.FULL_TRACE))
456      self.assertAllClose(dx_val, 0.01)
457      return RunMetadataLabels(run_metadata)
458
459    # SymGrad[f=log(x)](x, dy) = 1/x * dy
460    #
461    # Note: we don't need to compute log(x) for dx due to graph pruning.
462
463    # Do not compile the backprop. We should see one Reciprocal and one Mul.
464    labels = _Run(compiled=False)
465    self.assertFalse(InLabels(labels, "Log"))
466    self.assertTrue(InLabels(labels, "Reciprocal"))
467    self.assertTrue(InLabels(labels, "Mul"))
468    self.assertFalse(InLabels(labels, "XlaCompile"))
469    self.assertFalse(InLabels(labels, "XlaRun"))
470
471    # Compile the backprop. One XlaCompile/XlaRun pair.
472    labels = _Run(compiled=True)
473    self.assertFalse(InLabels(labels, "Log"))
474    self.assertFalse(InLabels(labels, "Reciprocal"))
475    self.assertFalse(InLabels(labels, "Mul"))
476    self.assertTrue(InLabels(labels, "XlaCompile"))
477    self.assertTrue(InLabels(labels, "XlaRun"))
478
479
480class ElementWiseFusionTest(test.TestCase):
481
482  # Runs a simple test with the input jit_level and fusion_only flag.
483  def simpleTest(self, arg0, arg1, global_jit_level):
484    config = config_pb2.ConfigProto()
485    config.graph_options.optimizer_options.global_jit_level = global_jit_level
486
487    with session_lib.Session(config=config) as sess:
488      a1 = array_ops.placeholder(dtypes.float32, [2, 2], name="a1")
489      a2 = array_ops.placeholder(dtypes.float32, [2, 2], name="a2")
490      # Two element-wise ops. We need at least two ops since single
491      # element clusters are not passed to XLA in fusion_only mode.
492      a3 = a1 * a2
493      a4 = a3 + a1
494      # A matmul to break XLA clustering.
495      a5 = math_ops.matmul(a4, a1)
496      # Two more element-wise ops.
497      a6 = a5 - a4
498      a7 = a6 + a2
499
500      run_metadata = config_pb2.RunMetadata()
501      output = test_utils.RunWithWarmup(
502          sess,
503          a7, {
504              a1: arg0,
505              a2: arg1
506          },
507          run_metadata=run_metadata,
508          options=config_pb2.RunOptions(
509              trace_level=config_pb2.RunOptions.FULL_TRACE))
510
511      labels = RunMetadataLabels(run_metadata)
512
513      xla_compile_count = sum("XlaCompile(" in x for x in labels)
514      xla_run_count = sum("XlaRun(" in x for x in labels)
515      self.assertEqual(xla_compile_count, xla_run_count)
516
517      return output, xla_run_count
518
519
520class LazyCompilationTest(test.TestCase):
521
522  def testLazyCompilation(self):
523
524    @function.Defun(compiled=True)
525    def CompiledFunction(x):
526      return math_ops.log(x)
527
528    with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
529      x = array_ops.placeholder(dtypes.float32)
530      y = CompiledFunction(x)
531
532      # The very first run of the cluster is always compiled (non-lazily).
533      run_metadata_for_first_run = config_pb2.RunMetadata()
534      sess.run(
535          y,
536          feed_dict={x: [2., 10., 19., 77., 100.]},
537          run_metadata=run_metadata_for_first_run,
538          options=config_pb2.RunOptions(
539              trace_level=config_pb2.RunOptions.FULL_TRACE))
540      self.assertTrue(
541          InLabels(
542              RunMetadataLabels(run_metadata_for_first_run), "_XlaCompile"))
543      self.assertTrue(
544          InLabels(RunMetadataLabels(run_metadata_for_first_run), "_XlaRun"))
545
546      run_metadata_before_warmup = config_pb2.RunMetadata()
547      sess.run(
548          y,
549          feed_dict={x: [2., 10.]},
550          run_metadata=run_metadata_before_warmup,
551          options=config_pb2.RunOptions(
552              trace_level=config_pb2.RunOptions.FULL_TRACE))
553      self.assertTrue(
554          InLabels(
555              RunMetadataLabels(run_metadata_before_warmup), "_XlaCompile"))
556      self.assertFalse(
557          InLabels(RunMetadataLabels(run_metadata_before_warmup), "_XlaRun"))
558
559      # We compile when we see the same shape a second time.
560
561      run_metadata_after_warmup = config_pb2.RunMetadata()
562      sess.run(
563          y,
564          feed_dict={x: [2., 10.]},
565          run_metadata=run_metadata_after_warmup,
566          options=config_pb2.RunOptions(
567              trace_level=config_pb2.RunOptions.FULL_TRACE))
568      self.assertTrue(
569          InLabels(RunMetadataLabels(run_metadata_after_warmup), "_XlaCompile"))
570      self.assertTrue(
571          InLabels(RunMetadataLabels(run_metadata_after_warmup), "_XlaRun"))
572
573      run_metadata_for_new_shape = config_pb2.RunMetadata()
574      sess.run(
575          y,
576          feed_dict={x: [2., 10., 12.]},
577          run_metadata=run_metadata_for_new_shape,
578          options=config_pb2.RunOptions(
579              trace_level=config_pb2.RunOptions.FULL_TRACE))
580      self.assertTrue(
581          InLabels(
582              RunMetadataLabels(run_metadata_for_new_shape), "_XlaCompile"))
583      self.assertFalse(
584          InLabels(RunMetadataLabels(run_metadata_for_new_shape), "_XlaRun"))
585
586  def testIsMegamorphic(self):
587
588    @function.Defun(compiled=True)
589    def CompiledFunction(x):
590      return math_ops.log(x)
591
592    with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
593      x = array_ops.placeholder(dtypes.float32)
594      y = CompiledFunction(x)
595
596      # Make the cluster go megamorphic by running it with lots of shape
597      # signatures where the cluster is executed with each signature only a few
598      # times.  Then check that we don't compile the cluster ever again.
599
600      for shape in range(10, 50):
601        for _ in range(0, 49):
602          sess.run(y, feed_dict={x: [0.] * shape})
603
604      for _ in range(0, 50):
605        run_metadata = config_pb2.RunMetadata()
606        sess.run(
607            y,
608            feed_dict={x: [0.] * 60},
609            run_metadata=run_metadata,
610            options=config_pb2.RunOptions(
611                trace_level=config_pb2.RunOptions.FULL_TRACE))
612        self.assertTrue(
613            InLabels(RunMetadataLabels(run_metadata), "_XlaCompile"))
614        self.assertFalse(InLabels(RunMetadataLabels(run_metadata), "_XlaRun"))
615
616  def testIsNotMegamorphic(self):
617
618    @function.Defun(compiled=True)
619    def CompiledFunction(x):
620      return math_ops.log(x)
621
622    with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
623      x = array_ops.placeholder(dtypes.float32)
624      y = CompiledFunction(x)
625
626      # Run the cluster with lots of shape signatures, but in a way that it
627      # isn't megamorphic (i.e. each shape signature sees a lot of executions).
628      # Then check that the cluster has not been marked as megamorphic.
629
630      for shape in range(10, 50):
631        for _ in range(0, 1000):
632          sess.run(y, feed_dict={x: [0.] * shape})
633
634      for _ in range(0, 10):
635        sess.run(y, feed_dict={x: [0.] * 60})
636
637      run_metadata = config_pb2.RunMetadata()
638      sess.run(
639          y,
640          feed_dict={x: [0.] * 60},
641          run_metadata=run_metadata,
642          options=config_pb2.RunOptions(
643              trace_level=config_pb2.RunOptions.FULL_TRACE))
644      self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaCompile"))
645      self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaRun"))
646
647
648if __name__ == "__main__":
649  os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " +
650                                os.environ.get("TF_XLA_FLAGS", ""))
651  test.main()
652