1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for device function for replicated training.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import test_util 23from tensorflow.python.ops import resource_variable_ops 24from tensorflow.python.ops import variables 25from tensorflow.python.platform import test 26from tensorflow.python.training import device_setter 27from tensorflow.python.training import server_lib 28 29 30class DeviceSetterTest(test.TestCase): 31 32 _cluster_spec = server_lib.ClusterSpec({ 33 "ps": ["ps0:2222", "ps1:2222"], 34 "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] 35 }) 36 37 @test_util.run_deprecated_v1 38 def testCPUOverride(self): 39 with ops.device( 40 device_setter.replica_device_setter(cluster=self._cluster_spec)): 41 with ops.device("/cpu:0"): 42 v = variables.Variable([1, 2]) 43 w = variables.Variable([2, 1]) 44 with ops.device("/cpu:0"): 45 a = v + w 46 self.assertDeviceEqual("/job:ps/task:0/cpu:0", v.device) 47 self.assertDeviceEqual("/job:ps/task:0/cpu:0", v.initializer.device) 48 self.assertDeviceEqual("/job:ps/task:1", w.device) 49 self.assertDeviceEqual("/job:ps/task:1", w.initializer.device) 50 self.assertDeviceEqual("/job:worker/cpu:0", a.device) 51 52 @test_util.run_deprecated_v1 53 def testResource(self): 54 with ops.device( 55 device_setter.replica_device_setter(cluster=self._cluster_spec)): 56 v = resource_variable_ops.ResourceVariable([1, 2]) 57 self.assertDeviceEqual("/job:ps/task:0", v.device) 58 59 @test_util.run_deprecated_v1 60 def testPS2TasksWithClusterSpecClass(self): 61 with ops.device( 62 device_setter.replica_device_setter(cluster=self._cluster_spec)): 63 v = variables.Variable([1, 2]) 64 w = variables.Variable([2, 1]) 65 a = v + w 66 self.assertDeviceEqual("/job:ps/task:0", v.device) 67 self.assertDeviceEqual("/job:ps/task:0", v.initializer.device) 68 self.assertDeviceEqual("/job:ps/task:1", w.device) 69 self.assertDeviceEqual("/job:ps/task:1", w.initializer.device) 70 self.assertDeviceEqual("/job:worker", a.device) 71 72 @test_util.run_deprecated_v1 73 def testPS2TasksPinVariableToJob(self): 74 with ops.device( 75 device_setter.replica_device_setter(cluster=self._cluster_spec)): 76 v = variables.Variable([1, 2]) 77 with ops.device("/job:moon"): 78 w = variables.Variable([2, 1]) 79 with ops.device("/job:ps"): # Explicit PS job will get task set. 80 x = variables.Variable([0, 1]) 81 a = v + w + x 82 self.assertDeviceEqual("/job:ps/task:0", v.device) 83 self.assertDeviceEqual("/job:ps/task:0", v.initializer.device) 84 self.assertDeviceEqual("/job:moon", w.device) 85 self.assertDeviceEqual("/job:moon", w.initializer.device) 86 self.assertDeviceEqual("/job:ps/task:1", x.device) 87 self.assertDeviceEqual("/job:ps/task:1", x.initializer.device) 88 self.assertDeviceEqual("/job:worker", a.device) 89 90 @test_util.run_deprecated_v1 91 def testPS2TasksUseCpuForPS(self): 92 with ops.device( 93 device_setter.replica_device_setter(ps_tasks=1, ps_device="/cpu:0")): 94 v = variables.Variable([1, 2]) 95 with ops.device("/job:moon"): 96 w = variables.Variable([2, 1]) 97 a = v + w 98 self.assertDeviceEqual("/cpu:0", v.device) 99 self.assertDeviceEqual("/cpu:0", v.initializer.device) 100 self.assertDeviceEqual("/job:moon/cpu:0", w.device) 101 self.assertDeviceEqual("/job:moon/cpu:0", w.initializer.device) 102 self.assertDeviceEqual("/job:worker", a.device) 103 104 @test_util.run_deprecated_v1 105 def testPS2TasksNoMerging(self): 106 with ops.device( 107 device_setter.replica_device_setter( 108 cluster=self._cluster_spec, merge_devices=False)): 109 v = variables.Variable([1, 2]) 110 with ops.device("/job:ps"): # Won't assign task when merge_devices=False. 111 w = variables.Variable([2, 1]) 112 a = v + w 113 self.assertDeviceEqual("/job:ps/task:0", v.device) 114 self.assertDeviceEqual("/job:ps/task:0", v.initializer.device) 115 self.assertDeviceEqual("/job:ps", w.device) 116 self.assertDeviceEqual("/job:ps", w.initializer.device) 117 self.assertDeviceEqual("/job:worker", a.device) 118 119 @test_util.run_deprecated_v1 120 def testPS2TasksWithClusterSpecDict(self): 121 with ops.device( 122 device_setter.replica_device_setter(cluster=self._cluster_spec.as_dict( 123 ))): 124 v = variables.Variable([1, 2]) 125 w = variables.Variable([2, 1]) 126 a = v + w 127 self.assertDeviceEqual("/job:ps/task:0", v.device) 128 self.assertDeviceEqual("/job:ps/task:0", v.initializer.device) 129 self.assertDeviceEqual("/job:ps/task:1", w.device) 130 self.assertDeviceEqual("/job:ps/task:1", w.initializer.device) 131 self.assertDeviceEqual("/job:worker", a.device) 132 133 @test_util.run_deprecated_v1 134 def testPS2TasksWithClusterDef(self): 135 with ops.device( 136 device_setter.replica_device_setter( 137 cluster=self._cluster_spec.as_cluster_def())): 138 v = variables.Variable([1, 2]) 139 w = variables.Variable([2, 1]) 140 a = v + w 141 self.assertDeviceEqual("/job:ps/task:0", v.device) 142 self.assertDeviceEqual("/job:ps/task:0", v.initializer.device) 143 self.assertDeviceEqual("/job:ps/task:1", w.device) 144 self.assertDeviceEqual("/job:ps/task:1", w.initializer.device) 145 self.assertDeviceEqual("/job:worker", a.device) 146 147 @test_util.run_deprecated_v1 148 def testPS2TasksWithDevice(self): 149 cluster_spec = server_lib.ClusterSpec({ 150 "sun": ["sun0:2222", "sun1:2222", "sun2:2222"], 151 "moon": ["moon0:2222", "moon1:2222"] 152 }) 153 154 with ops.device( 155 device_setter.replica_device_setter( 156 ps_device="/job:moon", 157 worker_device="/job:sun", 158 cluster=cluster_spec.as_cluster_def())): 159 v = variables.Variable([1, 2]) 160 w = variables.Variable([2, 1]) 161 a = v + w 162 self.assertDeviceEqual("/job:moon/task:0", v.device) 163 self.assertDeviceEqual("/job:moon/task:0", v.initializer.device) 164 self.assertDeviceEqual("/job:moon/task:1", w.device) 165 self.assertDeviceEqual("/job:moon/task:1", w.initializer.device) 166 self.assertDeviceEqual("/job:sun", a.device) 167 168 @test_util.run_deprecated_v1 169 def testPS2TasksWithCPUConstraint(self): 170 cluster_spec = server_lib.ClusterSpec({ 171 "sun": ["sun0:2222", "sun1:2222", "sun2:2222"], 172 "moon": ["moon0:2222", "moon1:2222"] 173 }) 174 175 with ops.device( 176 device_setter.replica_device_setter( 177 ps_device="/job:moon/cpu:0", 178 worker_device="/job:sun", 179 cluster=cluster_spec.as_cluster_def())): 180 v = variables.Variable([1, 2]) 181 w = variables.Variable([2, 1]) 182 a = v + w 183 self.assertDeviceEqual("/job:moon/task:0/cpu:0", v.device) 184 self.assertDeviceEqual("/job:moon/task:0/cpu:0", v.initializer.device) 185 self.assertDeviceEqual("/job:moon/task:1/cpu:0", w.device) 186 self.assertDeviceEqual("/job:moon/task:1/cpu:0", w.initializer.device) 187 self.assertDeviceEqual("/job:sun", a.device) 188 189 190if __name__ == "__main__": 191 test.main() 192