1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for JIT compilation on the CPU and GPU devices.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23import numpy as np 24 25from tensorflow.compiler.tests import test_utils 26from tensorflow.core.protobuf import config_pb2 27from tensorflow.core.protobuf import rewriter_config_pb2 28from tensorflow.python.client import session as session_lib 29from tensorflow.python.compiler.xla import jit 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import function 33from tensorflow.python.framework import ops 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import gradients_impl 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import nn_ops 39from tensorflow.python.platform import test 40 41 42jit_scope = jit.experimental_jit_scope 43 44# Disable rewrites to make sure we don't end up having to update this test 45# whenever we implement new ones. 46def NoRewriteSessionConfig(): 47 rewriter_config = rewriter_config_pb2.RewriterConfig( 48 disable_model_pruning=True, 49 arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, 50 dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, 51 function_optimization=rewriter_config_pb2.RewriterConfig.OFF) 52 graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) 53 return config_pb2.ConfigProto(graph_options=graph_options) 54 55 56def CompiledKernel(fn, *inputs, **kwargs): 57 """Execute 'fn' as a compiled XLA kernel, with 'inputs'.""" 58 name = kwargs.pop("name", None) 59 noinline = kwargs.pop("noinline", None) 60 61 @function.Defun(func_name=name, noinline=noinline, compiled=True) 62 def Compiled(*args): 63 return fn(*args) 64 65 return Compiled(*inputs) 66 67 68def RunMetadataLabels(run_metadata): 69 """Returns all labels in run_metadata.""" 70 labels = [] 71 for dev_stats in run_metadata.step_stats.dev_stats: 72 for node_stats in dev_stats.node_stats: 73 labels.append(node_stats.timeline_label) 74 return labels 75 76 77def InLabels(labels, substr): 78 """Returns true iff one of the labels contains substr.""" 79 return any(substr in x for x in labels) 80 81 82def MetadataHasXlaRunOp(run_metadata): 83 """Returns true if there are XlaRun kernels in run_metadata's timeline.""" 84 85 # TODO(phawkins): find a less hacky way to test whether a kernel ran. 86 return InLabels(RunMetadataLabels(run_metadata), "_XlaRun") 87 88 89class JitLaunchTest(test.TestCase): 90 91 # Evaluates 'fn' on 'args' both directly and as a compiled XLA kernel. 92 # Verifies that the outputs match and that XLA was invoked. 'fn' must take 93 # the same number of tensors as arguments that are in 'args', and must return 94 # a tuple of output tensors. 95 # 96 # If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun 97 # node actually ran. However, it is sometimes possible for XlaCompile/XlaRun 98 # ops to be constant-folded away, so the check is optional. 99 def _compare(self, 100 fn, 101 args, 102 require_kernel_launch=True, 103 name=None, 104 noinline=None): 105 with session_lib.Session(config=NoRewriteSessionConfig()) as sess: 106 placeholders = [] 107 feeds = {} 108 for arg in args: 109 placeholder = array_ops.placeholder( 110 dtypes.as_dtype(arg.dtype), list(arg.shape)) 111 placeholders.append(placeholder) 112 feeds[placeholder] = arg 113 114 compiled_op = CompiledKernel( 115 fn, *placeholders, name=name, noinline=noinline) 116 direct_op = fn(*placeholders) 117 118 run_metadata = config_pb2.RunMetadata() 119 compiled = test_utils.RunWithWarmup( 120 sess, compiled_op, feeds, 121 config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE), 122 run_metadata) 123 print("Compiled Result {}".format(compiled)) 124 125 if require_kernel_launch: 126 self.assert_(MetadataHasXlaRunOp(run_metadata)) 127 128 direct = sess.run(direct_op, feeds) 129 print("Direct Result {}".format(direct)) 130 131 if (isinstance(compiled, (tuple, list)) and 132 (isinstance(direct, (tuple, list)))): 133 for (x, y) in zip(compiled, direct): 134 self.assertAllClose(x, y, rtol=1e-1) 135 else: 136 self.assertAllClose(compiled, direct, rtol=1e-2) 137 138 def testNoOutputs(self): 139 with session_lib.Session() as sess: 140 141 # Check that calling the result as a compiled kernel doesn't crash. 142 @function.Defun(compiled=True) 143 def KernelWithNoOutputs(): 144 a = constant_op.constant(100) # pylint: disable=unused-variable 145 146 call = KernelWithNoOutputs() # pylint: disable=assignment-from-no-return 147 test_utils.RunWithWarmup(sess, call, {}) 148 149 def testAliasing(self): 150 """Regression test for compiled functions that return an aliased buffer. 151 152 XLA returns aliased buffers if outputs are identical. Tests that 153 we handle that case. 154 """ 155 156 def AddOnceReturnTwice(x): 157 y = math_ops.add(x, x) 158 return y, y 159 160 # Exercises compiling a function (say, Foo) which calls another function 161 # (say, Bar) which is not inlined. When the compiler compiles Foo, it needs 162 # to symbolically execute Bar correctly regardless of whether Bar is inlined 163 # or not. 164 165 # Tests compiled=True and noinline=True. 166 self._compare( 167 AddOnceReturnTwice, [np.array([[[0.5, -1.0]]], dtype=np.float32)], 168 name="AddOnceReturnTwice_inline", 169 noinline=True) 170 171 # Tests compiled=True and noinline=False. 172 self._compare( 173 AddOnceReturnTwice, [np.array([[[0.5, -1.0]]], dtype=np.float32)], 174 name="AddOnceReturnTwice_noinline", 175 noinline=False) 176 177 def testOneConstOutput(self): 178 """Test consisting of a single constant return value.""" 179 180 def OneConstOutput(): 181 return constant_op.constant([-3, 44, 99]) 182 183 self._compare(OneConstOutput, [], require_kernel_launch=False) 184 185 def testConstZeroElementOutput(self): 186 """Test consisting of a constant zero element return value.""" 187 188 def ConstZeroElementOutput(): 189 return array_ops.fill([7, 0], 3.0) 190 191 self._compare(ConstZeroElementOutput, [], require_kernel_launch=False) 192 193 def testSomeConstOutputs(self): 194 """Test kernels that return a mixture of const and non-const outputs.""" 195 196 def SomeConstOutputs(x): 197 return constant_op.constant( 198 [-2, 7]), array_ops.identity(x), constant_op.constant(3.5) 199 200 self._compare( 201 SomeConstOutputs, [np.array( 202 [[1, 2, 3], [4, 5, 6]], dtype=np.float32)]) 203 204 def testInt32Input(self): 205 """Test an int32-typed input. 206 207 On a GPU, int32 tensors will be placed in host memory. 208 """ 209 210 def AddToSelf(x): 211 return math_ops.add(x, x) 212 213 self._compare(AddToSelf, [np.array([7, 1, 3], dtype=np.int32)]) 214 215 def testMandatoryConstantInput(self): 216 """Tests an operator that has a mandatory-constant shape input.""" 217 218 def FillWithFloat(x): 219 return array_ops.fill(x, 9.5) 220 221 self._compare(FillWithFloat, [np.array([3, 2], dtype=np.int32)]) 222 223 def testMnistForwardFunc(self): 224 """Compute inference function from MNIST beginners tutorial.""" 225 batch_size = 16 226 image_size = 28 * 28 227 num_classes = 10 228 229 # Define a TensorFlow function to compute the forward pass. 230 def MnistForward(w, b, x): 231 return nn_ops.softmax(math_ops.matmul(x, w) + b) 232 233 w = np.random.random_sample((image_size, num_classes)).astype(np.float32) 234 b = np.random.random_sample((num_classes)).astype(np.float32) 235 x = np.random.random_sample((batch_size, image_size)).astype(np.float32) 236 self._compare(MnistForward, [w, b, x]) 237 238 def testExplicitMarking(self): 239 """Test explicit marking of operators to compile.""" 240 batch_size = 16 241 image_size = 28 * 28 242 num_classes = 10 243 244 with ops.Graph().as_default(): 245 x = array_ops.placeholder(dtypes.float32) 246 w = array_ops.placeholder(dtypes.float32) 247 b = array_ops.placeholder(dtypes.float32) 248 with jit_scope(): 249 y1 = math_ops.matmul(x, w) 250 y2 = math_ops.add(y1, b) 251 with jit_scope(): 252 y = math_ops.square(y2) 253 254 dw = np.random.random_sample((image_size, num_classes)).astype(np.float32) 255 db = np.random.random_sample((num_classes)).astype(np.float32) 256 dx = np.random.random_sample((batch_size, image_size)).astype(np.float32) 257 with session_lib.Session() as sess: 258 run_metadata = config_pb2.RunMetadata() 259 output = test_utils.RunWithWarmup( 260 sess, 261 y, { 262 x: dx, 263 w: dw, 264 b: db 265 }, 266 run_metadata=run_metadata, 267 options=config_pb2.RunOptions( 268 trace_level=config_pb2.RunOptions.FULL_TRACE)) 269 270 # TODO(phawkins): really we would like to test that there were exactly 271 # two kernel launches. However, we have no reliable way to determine 272 # that. 273 self.assert_(MetadataHasXlaRunOp(run_metadata)) 274 275 expected = np.square(np.dot(dx, dw) + db) 276 self.assertAllClose(expected, output, rtol=1e-1) 277 278 279class XlaCompilationTest(test.TestCase): 280 """Tests for auto-compilation on CPU/GPU devices.""" 281 282 def testReshape(self): 283 """Tests an operator with compile-time constant and non-constant inputs.""" 284 285 with self.session(config=NoRewriteSessionConfig()) as sess: 286 x = array_ops.placeholder(dtypes.float32) 287 y = array_ops.placeholder(dtypes.int32) 288 with jit_scope(): 289 # Reshape's first argument is non-constant in the JIT, but its second 290 # (shape) argument will be treated as a compile-time constant for 291 # each JIT compilation. 292 # We do not use a tf.const() argument since we want to ensure the 293 # shape is still a run-time argument to the JIT, and not 294 # statically known as part of the JIT compilation's input graph. 295 z = array_ops.reshape(x, y) 296 run_metadata = config_pb2.RunMetadata() 297 out = test_utils.RunWithWarmup( 298 sess, 299 z, { 300 x: np.array([1, 2, 3, 4, 5, 6], np.float32), 301 y: [-1, 3] 302 }, 303 run_metadata=run_metadata, 304 options=config_pb2.RunOptions( 305 trace_level=config_pb2.RunOptions.FULL_TRACE)) 306 self.assert_(MetadataHasXlaRunOp(run_metadata)) 307 self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out) 308 309 def testIgnoredArguments(self): 310 """Tests that JIT computations can ignore formal parameters.""" 311 312 with self.session(config=NoRewriteSessionConfig()) as sess: 313 x = array_ops.placeholder(dtypes.int32) 314 y = array_ops.placeholder(dtypes.int32) 315 with jit_scope(): 316 z = math_ops.add(x, x) 317 w = math_ops.add(y, y) 318 # Pulls 'w' into the same compilation via control dependencies. 319 with ops.control_dependencies([w]): 320 n = control_flow_ops.no_op() 321 with ops.control_dependencies([n]): 322 t = math_ops.add(z, z) 323 324 run_metadata = config_pb2.RunMetadata() 325 out = test_utils.RunWithWarmup( 326 sess, 327 t, { 328 x: np.int32(7), 329 y: np.int32(404) 330 }, 331 run_metadata=run_metadata, 332 options=config_pb2.RunOptions( 333 trace_level=config_pb2.RunOptions.FULL_TRACE)) 334 self.assert_(MetadataHasXlaRunOp(run_metadata)) 335 self.assertAllClose(28, out) 336 337 def testLoops(self): 338 """Tests that compilation accepts computations containing loops.""" 339 340 with self.session(config=NoRewriteSessionConfig()) as session: 341 x = array_ops.placeholder(dtypes.float32) 342 with jit_scope(): 343 c = lambda i, _: math_ops.less(i, 5) 344 b = lambda i, x: (i + 1, x * 2.0 + 1.0) 345 _, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x)) 346 347 run_metadata = config_pb2.RunMetadata() 348 result = session.run(y, {x: np.float32(2)}, 349 run_metadata=run_metadata, 350 options=config_pb2.RunOptions( 351 trace_level=config_pb2.RunOptions.FULL_TRACE)) 352 self.assert_(MetadataHasXlaRunOp(run_metadata)) 353 self.assertAllClose(result, np.float32(95), rtol=1e-1) 354 355 def testCond(self): 356 """Tests that compilation handles switch operators.""" 357 358 with self.session(config=NoRewriteSessionConfig()) as session: 359 x = array_ops.placeholder(dtypes.float32) 360 y = array_ops.placeholder(dtypes.float32) 361 c = array_ops.placeholder(dtypes.bool) 362 with jit_scope(): 363 z = x + 1.0 364 w = control_flow_ops.cond(c, lambda: z, lambda: y) 365 t = math_ops.add(z, w) 366 367 # If JIT compilation chooses to cluster z and t, then execution will 368 # deadlock. 369 370 run_metadata = config_pb2.RunMetadata() 371 result = test_utils.RunWithWarmup( 372 session, 373 t, { 374 x: np.float32(2), 375 y: np.float32(4), 376 c: True 377 }, 378 run_metadata=run_metadata, 379 options=config_pb2.RunOptions( 380 trace_level=config_pb2.RunOptions.FULL_TRACE)) 381 self.assert_(MetadataHasXlaRunOp(run_metadata)) 382 self.assertAllClose(result, np.float32(6), rtol=1e-1) 383 384 def testNestedFunction(self): 385 g = ops.Graph() 386 with g.as_default(): 387 388 @function.Defun(compiled=True) 389 def Bar(x, y): 390 return x + 2 * y 391 392 @function.Defun(compiled=True) 393 def Foo(x): 394 return Bar(x * x, x * x * x) 395 396 @function.Defun() 397 def Entry(x): 398 return Foo(x) 399 400 inp = array_ops.placeholder(dtypes.float32) 401 out = Entry(inp) 402 403 with self.session( 404 config=NoRewriteSessionConfig(), graph=g, use_gpu=True) as sess: 405 run_metadata = config_pb2.RunMetadata() 406 val = sess.run(out, 407 feed_dict={inp: [2., 10.]}, 408 run_metadata=run_metadata, 409 options=config_pb2.RunOptions( 410 trace_level=config_pb2.RunOptions.FULL_TRACE)) 411 self.assertAllClose(val, [20., 2100.]) 412 413 def testLoopDeadlock(self): 414 """Regression test for bug that caused deadlocks in graphs with loops.""" 415 416 with self.session(config=NoRewriteSessionConfig()) as session: 417 x = array_ops.placeholder(dtypes.float32) 418 with jit_scope(): 419 y = x + 1.0 420 c = lambda i, _x, _y: math_ops.less(i, 5) 421 b = lambda i, x, _y: (i + 1, x * 2.0 + 1.0, x - 3.0) 422 _, _, w = control_flow_ops.while_loop(c, b, 423 (constant_op.constant(0), y, x)) 424 u = w + y 425 result = session.run(u, {x: np.float32(2)}) 426 self.assertAllClose(result, np.float32(63), rtol=1e-1) 427 428 def testGradient(self): 429 """Tests that the backprop function is properly compiled.""" 430 431 def _Run(compiled): 432 433 @function.Defun(compiled=compiled) 434 def Forward(x): 435 return math_ops.log(x) 436 437 g = ops.Graph() 438 with g.as_default(): 439 x = array_ops.placeholder(dtypes.float32) 440 y = Forward(x) 441 dx, = gradients_impl.gradients(y, [x], 1.0) 442 443 cfg = NoRewriteSessionConfig() 444 cfg.graph_options.optimizer_options.opt_level = ( 445 config_pb2.OptimizerOptions.L1) 446 cfg.graph_options.optimizer_options.do_function_inlining = True 447 with session_lib.Session(graph=g, config=cfg) as sess: 448 run_metadata = config_pb2.RunMetadata() 449 dx_val = test_utils.RunWithWarmup( 450 sess, 451 dx, 452 feed_dict={x: 100.}, 453 run_metadata=run_metadata, 454 options=config_pb2.RunOptions( 455 trace_level=config_pb2.RunOptions.FULL_TRACE)) 456 self.assertAllClose(dx_val, 0.01) 457 return RunMetadataLabels(run_metadata) 458 459 # SymGrad[f=log(x)](x, dy) = 1/x * dy 460 # 461 # Note: we don't need to compute log(x) for dx due to graph pruning. 462 463 # Do not compile the backprop. We should see one Reciprocal and one Mul. 464 labels = _Run(compiled=False) 465 self.assertFalse(InLabels(labels, "Log")) 466 self.assertTrue(InLabels(labels, "Reciprocal")) 467 self.assertTrue(InLabels(labels, "Mul")) 468 self.assertFalse(InLabels(labels, "XlaCompile")) 469 self.assertFalse(InLabels(labels, "XlaRun")) 470 471 # Compile the backprop. One XlaCompile/XlaRun pair. 472 labels = _Run(compiled=True) 473 self.assertFalse(InLabels(labels, "Log")) 474 self.assertFalse(InLabels(labels, "Reciprocal")) 475 self.assertFalse(InLabels(labels, "Mul")) 476 self.assertTrue(InLabels(labels, "XlaCompile")) 477 self.assertTrue(InLabels(labels, "XlaRun")) 478 479 480class ElementWiseFusionTest(test.TestCase): 481 482 # Runs a simple test with the input jit_level and fusion_only flag. 483 def simpleTest(self, arg0, arg1, global_jit_level): 484 config = config_pb2.ConfigProto() 485 config.graph_options.optimizer_options.global_jit_level = global_jit_level 486 487 with session_lib.Session(config=config) as sess: 488 a1 = array_ops.placeholder(dtypes.float32, [2, 2], name="a1") 489 a2 = array_ops.placeholder(dtypes.float32, [2, 2], name="a2") 490 # Two element-wise ops. We need at least two ops since single 491 # element clusters are not passed to XLA in fusion_only mode. 492 a3 = a1 * a2 493 a4 = a3 + a1 494 # A matmul to break XLA clustering. 495 a5 = math_ops.matmul(a4, a1) 496 # Two more element-wise ops. 497 a6 = a5 - a4 498 a7 = a6 + a2 499 500 run_metadata = config_pb2.RunMetadata() 501 output = test_utils.RunWithWarmup( 502 sess, 503 a7, { 504 a1: arg0, 505 a2: arg1 506 }, 507 run_metadata=run_metadata, 508 options=config_pb2.RunOptions( 509 trace_level=config_pb2.RunOptions.FULL_TRACE)) 510 511 labels = RunMetadataLabels(run_metadata) 512 513 xla_compile_count = sum("XlaCompile(" in x for x in labels) 514 xla_run_count = sum("XlaRun(" in x for x in labels) 515 self.assertEqual(xla_compile_count, xla_run_count) 516 517 return output, xla_run_count 518 519 520class LazyCompilationTest(test.TestCase): 521 522 def testLazyCompilation(self): 523 524 @function.Defun(compiled=True) 525 def CompiledFunction(x): 526 return math_ops.log(x) 527 528 with session_lib.Session(config=NoRewriteSessionConfig()) as sess: 529 x = array_ops.placeholder(dtypes.float32) 530 y = CompiledFunction(x) 531 532 # The very first run of the cluster is always compiled (non-lazily). 533 run_metadata_for_first_run = config_pb2.RunMetadata() 534 sess.run( 535 y, 536 feed_dict={x: [2., 10., 19., 77., 100.]}, 537 run_metadata=run_metadata_for_first_run, 538 options=config_pb2.RunOptions( 539 trace_level=config_pb2.RunOptions.FULL_TRACE)) 540 self.assertTrue( 541 InLabels( 542 RunMetadataLabels(run_metadata_for_first_run), "_XlaCompile")) 543 self.assertTrue( 544 InLabels(RunMetadataLabels(run_metadata_for_first_run), "_XlaRun")) 545 546 run_metadata_before_warmup = config_pb2.RunMetadata() 547 sess.run( 548 y, 549 feed_dict={x: [2., 10.]}, 550 run_metadata=run_metadata_before_warmup, 551 options=config_pb2.RunOptions( 552 trace_level=config_pb2.RunOptions.FULL_TRACE)) 553 self.assertTrue( 554 InLabels( 555 RunMetadataLabels(run_metadata_before_warmup), "_XlaCompile")) 556 self.assertFalse( 557 InLabels(RunMetadataLabels(run_metadata_before_warmup), "_XlaRun")) 558 559 # We compile when we see the same shape a second time. 560 561 run_metadata_after_warmup = config_pb2.RunMetadata() 562 sess.run( 563 y, 564 feed_dict={x: [2., 10.]}, 565 run_metadata=run_metadata_after_warmup, 566 options=config_pb2.RunOptions( 567 trace_level=config_pb2.RunOptions.FULL_TRACE)) 568 self.assertTrue( 569 InLabels(RunMetadataLabels(run_metadata_after_warmup), "_XlaCompile")) 570 self.assertTrue( 571 InLabels(RunMetadataLabels(run_metadata_after_warmup), "_XlaRun")) 572 573 run_metadata_for_new_shape = config_pb2.RunMetadata() 574 sess.run( 575 y, 576 feed_dict={x: [2., 10., 12.]}, 577 run_metadata=run_metadata_for_new_shape, 578 options=config_pb2.RunOptions( 579 trace_level=config_pb2.RunOptions.FULL_TRACE)) 580 self.assertTrue( 581 InLabels( 582 RunMetadataLabels(run_metadata_for_new_shape), "_XlaCompile")) 583 self.assertFalse( 584 InLabels(RunMetadataLabels(run_metadata_for_new_shape), "_XlaRun")) 585 586 def testIsMegamorphic(self): 587 588 @function.Defun(compiled=True) 589 def CompiledFunction(x): 590 return math_ops.log(x) 591 592 with session_lib.Session(config=NoRewriteSessionConfig()) as sess: 593 x = array_ops.placeholder(dtypes.float32) 594 y = CompiledFunction(x) 595 596 # Make the cluster go megamorphic by running it with lots of shape 597 # signatures where the cluster is executed with each signature only a few 598 # times. Then check that we don't compile the cluster ever again. 599 600 for shape in range(10, 50): 601 for _ in range(0, 49): 602 sess.run(y, feed_dict={x: [0.] * shape}) 603 604 for _ in range(0, 50): 605 run_metadata = config_pb2.RunMetadata() 606 sess.run( 607 y, 608 feed_dict={x: [0.] * 60}, 609 run_metadata=run_metadata, 610 options=config_pb2.RunOptions( 611 trace_level=config_pb2.RunOptions.FULL_TRACE)) 612 self.assertTrue( 613 InLabels(RunMetadataLabels(run_metadata), "_XlaCompile")) 614 self.assertFalse(InLabels(RunMetadataLabels(run_metadata), "_XlaRun")) 615 616 def testIsNotMegamorphic(self): 617 618 @function.Defun(compiled=True) 619 def CompiledFunction(x): 620 return math_ops.log(x) 621 622 with session_lib.Session(config=NoRewriteSessionConfig()) as sess: 623 x = array_ops.placeholder(dtypes.float32) 624 y = CompiledFunction(x) 625 626 # Run the cluster with lots of shape signatures, but in a way that it 627 # isn't megamorphic (i.e. each shape signature sees a lot of executions). 628 # Then check that the cluster has not been marked as megamorphic. 629 630 for shape in range(10, 50): 631 for _ in range(0, 1000): 632 sess.run(y, feed_dict={x: [0.] * shape}) 633 634 for _ in range(0, 10): 635 sess.run(y, feed_dict={x: [0.] * 60}) 636 637 run_metadata = config_pb2.RunMetadata() 638 sess.run( 639 y, 640 feed_dict={x: [0.] * 60}, 641 run_metadata=run_metadata, 642 options=config_pb2.RunOptions( 643 trace_level=config_pb2.RunOptions.FULL_TRACE)) 644 self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaCompile")) 645 self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaRun")) 646 647 648if __name__ == "__main__": 649 os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " + 650 os.environ.get("TF_XLA_FLAGS", "")) 651 test.main() 652