"""Tests for tensor_utils.""" from absl.testing import absltest from absl.testing import parameterized import tensorflow as tf from google.protobuf import any_pb2 from fcp.artifact_building import tensor_utils class TensorUtilsTest(parameterized.TestCase, tf.test.TestCase): def test_bare_name(self): self.assertEqual(tensor_utils.bare_name('foo'), 'foo') self.assertEqual(tensor_utils.bare_name('foo:0'), 'foo') self.assertEqual(tensor_utils.bare_name('foo:1'), 'foo') self.assertEqual(tensor_utils.bare_name('^foo:1'), 'foo') self.assertEqual(tensor_utils.bare_name('^foo:output:2'), 'foo') with tf.Graph().as_default() as g: v = tf.Variable(0.0, name='foo') self.assertEqual(tensor_utils.bare_name(v), 'foo') @tf.function def foo(x): return tf.add(x, v.read_value(), 'add_op') foo(tf.constant(1.0)) # Exchange the input tensor names (the outputs of other nodes) in the graph # to ensure we can recover the original user-specified bare names. graph_def = g.as_graph_def() # Test that the graph def contains graph_def_str = str(graph_def) self.assertIn('add_op:z:0', graph_def_str) self.assertIn('Read/ReadVariableOp:value:0', graph_def_str) # Ensure that we can locate required_names = ['add_op', 'Read/ReadVariableOp'] for node in graph_def.library.function[0].node_def: for i in node.input: if tensor_utils.bare_name(i) in required_names: required_names.remove(tensor_utils.bare_name(i)) self.assertEmpty(required_names) def test_bare_name_with_scope(self): self.assertEqual(tensor_utils.bare_name('bar/foo:1'), 'bar/foo') with tf.Graph().as_default(): with tf.compat.v1.variable_scope('bar'): v = tf.Variable(0.0, name='foo') self.assertEqual(tensor_utils.bare_name(v), 'bar/foo') def test_name_or_str_with_named_variable(self): with tf.Graph().as_default(): v = tf.Variable(0.0, name='foo') self.assertEqual('foo:0', tensor_utils.name_or_str(v)) def test_name_or_str_with_unnamed_variable(self): with tf.Graph().as_default(): v = tf.Variable(0.0) self.assertEqual('Variable:0', tensor_utils.name_or_str(v)) def test_import_graph_def_from_any(self): with tf.Graph().as_default() as g: tf.constant(0.0) graph_def = g.as_graph_def() graph_def_any = any_pb2.Any() graph_def_any.Pack(graph_def) # Graph object doesn't have equality, so we check that the graph defs match. self.assertEqual( tensor_utils.import_graph_def_from_any(graph_def_any), g.as_graph_def() ) def test_save_and_restore_in_eager_mode(self): filename = tf.constant(self.create_tempfile().full_path) tensor_name = 'a' tensor = tf.constant(1.0) tensor_utils.save(filename, [tensor_name], [tensor]) restored_tensor = tensor_utils.restore(filename, tensor_name, tensor.dtype) self.assertAllEqual(tensor, restored_tensor) @parameterized.named_parameters( ('scalar_tensor', tf.constant(1.0)), ('non_scalar_tensor', tf.constant([1.0, 2.0])), ) def test_save_and_restore_with_shape_info_in_eager_mode(self, tensor): filename = tf.constant(self.create_tempfile().full_path) tensor_name = 'a' tensor_utils.save(filename, [tensor_name], [tensor]) restored_tensor = tensor_utils.restore( filename, tensor_name, tensor.dtype, tensor.shape ) self.assertAllEqual(tensor, restored_tensor) def _assert_op_in_graph(self, expected_op, graph): graph_def = graph.as_graph_def() node_ops = [node.op for node in graph_def.node] self.assertIn(expected_op, node_ops) def _get_shape_and_slices_value(self, graph): graph_def = graph.as_graph_def() node_name_to_value_dict = {node.name: node for node in graph_def.node} self.assertIn('restore/shape_and_slices', node_name_to_value_dict) return ( node_name_to_value_dict['restore/shape_and_slices'] .attr['value'] .tensor.string_val[0] ) def test_save_and_restore_in_graph_mode(self): temp_file = self.create_tempfile().full_path graph = tf.Graph() with graph.as_default(): filename = tf.constant(temp_file) tensor_name = 'a' tensor = tf.constant(1.0) save_op = tensor_utils.save(filename, [tensor_name], [tensor]) restored = tensor_utils.restore(filename, tensor_name, tensor.dtype) with tf.compat.v1.Session(graph=graph) as sess: sess.run(save_op) expected_tensor, restored_tensor = sess.run([tensor, restored]) self.assertAllEqual(expected_tensor, restored_tensor) self._assert_op_in_graph(expected_op='SaveSlices', graph=graph) self._assert_op_in_graph(expected_op='RestoreV2', graph=graph) self.assertEqual(b'', self._get_shape_and_slices_value(graph)) @parameterized.named_parameters( ('scalar_tensor', lambda: tf.constant(1.0), b''), ('non_scalar_tensor', lambda: tf.constant([1.0, 2.0]), b'2 :-'), ) def test_save_and_restore_with_shape_info_in_graph_mode( self, tensor_builder, expected_shape_and_slices_value ): temp_file = self.create_tempfile().full_path graph = tf.Graph() with graph.as_default(): filename = tf.constant(temp_file) tensor_name = 'a' tensor = tensor_builder() save_op = tensor_utils.save(filename, [tensor_name], [tensor]) restored = tensor_utils.restore( filename, tensor_name, tensor.dtype, tensor.shape ) with tf.compat.v1.Session(graph=graph) as sess: sess.run(save_op) expected_tensor, restored_tensor = sess.run([tensor, restored]) self.assertAllEqual(expected_tensor, restored_tensor) self._assert_op_in_graph(expected_op='SaveSlices', graph=graph) self._assert_op_in_graph(expected_op='RestoreV2', graph=graph) self.assertEqual( expected_shape_and_slices_value, self._get_shape_and_slices_value(graph), ) if __name__ == '__main__': absltest.main()