1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Tests for tpu_function helpers.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import importer 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import test_util 27from tensorflow.python.layers import convolutional 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import control_flow_util 31from tensorflow.python.ops import init_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import special_math_ops 34from tensorflow.python.ops import variable_scope 35from tensorflow.python.platform import test 36from tensorflow.python.tpu import tpu 37from tensorflow.python.tpu import tpu_feed 38from tensorflow.python.tpu import training_loop 39 40 41class TPUContextTest(test.TestCase): 42 43 @test_util.deprecated_graph_mode_only 44 def testIsInContext(self): 45 """Test that control_flow_util can check that we're in a TPU context.""" 46 z1 = array_ops.identity(1) 47 pivot = control_flow_ops.no_op() 48 context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot) 49 context.Enter() 50 z2 = array_ops.identity(1) 51 context.Exit() 52 self.assertFalse(control_flow_util.IsInXLAContext(z1.op)) 53 self.assertTrue(control_flow_util.IsInXLAContext(z2.op)) 54 55 56class TPULayerRewriteTest(test.TestCase): 57 58 @test_util.deprecated_graph_mode_only 59 def testUsingInfeedQueueWithRegularizer(self): 60 """Test that Layer regularizers can reference data created in loops.""" 61 62 def make_regularizer(scale): 63 return lambda inputs: scale * math_ops.reduce_sum(math_ops.square(inputs)) 64 65 def training_step(inputs, scale): 66 outputs = convolutional.conv2d( 67 inputs, 68 filters=16, 69 kernel_size=(3, 3), 70 data_format="channels_first", 71 kernel_regularizer=make_regularizer(scale)) 72 loss = math_ops.reduce_mean(math_ops.square(outputs)) 73 return loss.op 74 75 inputs = array_ops.zeros(shape=(128, 32, 32, 16)) 76 scale = array_ops.ones(shape=()) 77 infeed = tpu_feed.InfeedQueue( 78 tuple_types=[dtypes.float32, dtypes.float32], 79 tuple_shapes=[inputs.shape, scale.shape]) 80 81 def loop(): 82 return training_loop.repeat(5, training_step, infeed_queue=infeed) 83 84 # This should not throw an error. 85 tpu.rewrite(loop) 86 87class TPUGraphPruneTest(test.TestCase): 88 89 def test_prune_unconnected_ops(self): 90 with ops.Graph().as_default(): 91 a = array_ops.placeholder(dtype=dtypes.float32, name="a") 92 b = array_ops.placeholder(dtype=dtypes.float32, name="b") 93 constant_op.constant(1.0, name="constant") 94 x = variable_scope.get_variable( 95 name="x", 96 dtype=dtypes.float32, 97 shape=[], 98 use_resource=True, 99 initializer=init_ops.constant_initializer(2.0)) 100 y = variable_scope.get_variable( 101 name="y", 102 dtype=dtypes.float32, 103 shape=[], 104 use_resource=True, 105 initializer=init_ops.constant_initializer(3.0)) 106 math_ops.add(a, b) 107 math_ops.add(x, y) 108 graph_def = ops.get_default_graph().as_graph_def() 109 110 for node in graph_def.node: 111 # Attach a TPU_REPLICATE_ATTR to each node. 112 node.attr[tpu._TPU_REPLICATE_ATTR].s = b"0" 113 # Rewire placeholder "a" and variable "y" leaving them unconnected. 114 for (input_index, node_input) in enumerate(node.input): 115 if node_input == "b": 116 node.input[input_index] = "constant" 117 if node_input == "y": 118 node.input[input_index] = "x" 119 120 with ops.Graph().as_default() as graph: 121 # Reimport the graph and prune unconnected ops. 122 importer.import_graph_def(graph_def) 123 tpu.prune_unconnected_ops_from_xla(ops.get_default_graph()) 124 125 # Verify that ops "a" and "x" still have TPU_REPLICATE_ATTR. 126 a = graph.get_operation_by_name("import/a").get_attr( 127 tpu._TPU_REPLICATE_ATTR) 128 self.assertEqual(b"0", a) 129 x = graph.get_operation_by_name("import/x").get_attr( 130 tpu._TPU_REPLICATE_ATTR) 131 self.assertEqual(b"0", x) 132 # Verify that ops "b" and "y" have TPU_REPLICATE_ATTR removed. 133 with self.assertRaisesRegexp( 134 ValueError, 135 "Operation \'import/b\' has no attr named \'_tpu_replicate\'"): 136 graph.get_operation_by_name("import/b").get_attr( 137 tpu._TPU_REPLICATE_ATTR) 138 with self.assertRaisesRegexp( 139 ValueError, 140 "Operation \'import/y\' has no attr named \'_tpu_replicate\'"): 141 graph.get_operation_by_name("import/y").get_attr( 142 tpu._TPU_REPLICATE_ATTR) 143 144def do_einsum(): 145 a = array_ops.placeholder(dtype=dtypes.float32, name="a", shape=[2, 3, 4]) 146 b = array_ops.placeholder(dtype=dtypes.float32, name="b", shape=[2, 4, 5]) 147 return special_math_ops.einsum("abc,acd->abd", a, b) 148 149 150def find_einsum(g): 151 graph_def = g.as_graph_def() 152 for node in graph_def.node: 153 if node.op == "Einsum": 154 return True 155 return False 156 157 158def find_xla_einsum(g): 159 graph_def = g.as_graph_def() 160 for node in graph_def.node: 161 if node.op == "XlaEinsum": 162 return True 163 return False 164 165 166class TPUXlaEinsumTest(test.TestCase): 167 168 def test_tpu_rewrite_uses_xla_einsum(self): 169 with ops.Graph().as_default() as g: 170 tpu.rewrite(do_einsum) 171 self.assertTrue(find_einsum(g) or find_xla_einsum(g)) 172 173 def test_default_does_not_use_xla_einsum(self): 174 with ops.Graph().as_default() as g: 175 do_einsum() 176 self.assertFalse(find_xla_einsum(g)) 177 178 179if __name__ == "__main__": 180 test.main() 181