# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for device utilities.""" from absl.testing import parameterized from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python.distribute import combinations from tensorflow.python.distribute import device_util from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.platform import test from tensorflow.python.training import server_lib class DeviceUtilTest(test.TestCase, parameterized.TestCase): def setUp(self): super(DeviceUtilTest, self).setUp() context._reset_context() # pylint: disable=protected-access @combinations.generate( combinations.combine(mode="graph") ) def testCurrentDeviceWithGlobalGraph(self): with ops.device("/cpu:0"): self.assertEqual(device_util.current(), "/device:CPU:0") with ops.device("/job:worker"): with ops.device("/cpu:0"): self.assertEqual(device_util.current(), "/job:worker/device:CPU:0") with ops.device("/cpu:0"): with ops.device("/gpu:0"): self.assertEqual(device_util.current(), "/device:GPU:0") def testCurrentDeviceWithNonGlobalGraph(self): with ops.Graph().as_default(): with ops.device("/cpu:0"): self.assertEqual(device_util.current(), "/device:CPU:0") def testCurrentDeviceWithEager(self): with context.eager_mode(): with ops.device("/cpu:0"): self.assertEqual(device_util.current(), "/job:localhost/replica:0/task:0/device:CPU:0") @combinations.generate(combinations.combine(mode=["graph", "eager"])) def testCanonicalizeWithoutDefaultDevice(self, mode): if mode == "graph": self.assertEqual( device_util.canonicalize("/cpu:0"), "/replica:0/task:0/device:CPU:0") else: self.assertEqual( device_util.canonicalize("/cpu:0"), "/job:localhost/replica:0/task:0/device:CPU:0") self.assertEqual( device_util.canonicalize("/job:worker/cpu:0"), "/job:worker/replica:0/task:0/device:CPU:0") self.assertEqual( device_util.canonicalize("/job:worker/task:1/cpu:0"), "/job:worker/replica:0/task:1/device:CPU:0") @combinations.generate(combinations.combine(mode=["eager"])) def testCanonicalizeWithoutDefaultDeviceCollectiveEnabled(self): cluster_spec = server_lib.ClusterSpec( multi_worker_test_base.create_cluster_spec( has_chief=False, num_workers=1, num_ps=0, has_eval=False)) server_def = tensorflow_server_pb2.ServerDef( cluster=cluster_spec.as_cluster_def(), job_name="worker", task_index=0, protocol="grpc", port=0) context.context().enable_collective_ops(server_def) self.assertEqual( device_util.canonicalize("/cpu:0"), "/job:worker/replica:0/task:0/device:CPU:0") def testCanonicalizeWithDefaultDevice(self): self.assertEqual( device_util.canonicalize("/job:worker/task:1/cpu:0", default="/gpu:0"), "/job:worker/replica:0/task:1/device:CPU:0") self.assertEqual( device_util.canonicalize("/job:worker/task:1", default="/gpu:0"), "/job:worker/replica:0/task:1/device:GPU:0") self.assertEqual( device_util.canonicalize("/cpu:0", default="/job:worker"), "/job:worker/replica:0/task:0/device:CPU:0") self.assertEqual( device_util.canonicalize( "/job:worker/replica:0/task:1/device:CPU:0", default="/job:chief/replica:0/task:1/device:CPU:0"), "/job:worker/replica:0/task:1/device:CPU:0") def testResolveWithDeviceScope(self): with ops.device("/gpu:0"): self.assertEqual( device_util.resolve("/job:worker/task:1/cpu:0"), "/job:worker/replica:0/task:1/device:CPU:0") self.assertEqual( device_util.resolve("/job:worker/task:1"), "/job:worker/replica:0/task:1/device:GPU:0") with ops.device("/job:worker"): self.assertEqual( device_util.resolve("/cpu:0"), "/job:worker/replica:0/task:0/device:CPU:0") if __name__ == "__main__": test.main()