• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for device function for replicated training."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import test_util
23from tensorflow.python.ops import resource_variable_ops
24from tensorflow.python.ops import variables
25from tensorflow.python.platform import test
26from tensorflow.python.training import device_setter
27from tensorflow.python.training import server_lib
28
29
30class DeviceSetterTest(test.TestCase):
31
32  _cluster_spec = server_lib.ClusterSpec({
33      "ps": ["ps0:2222", "ps1:2222"],
34      "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
35  })
36
37  @test_util.run_deprecated_v1
38  def testCPUOverride(self):
39    with ops.device(
40        device_setter.replica_device_setter(cluster=self._cluster_spec)):
41      with ops.device("/cpu:0"):
42        v = variables.Variable([1, 2])
43      w = variables.Variable([2, 1])
44      with ops.device("/cpu:0"):
45        a = v + w
46      self.assertDeviceEqual("/job:ps/task:0/cpu:0", v.device)
47      self.assertDeviceEqual("/job:ps/task:0/cpu:0", v.initializer.device)
48      self.assertDeviceEqual("/job:ps/task:1", w.device)
49      self.assertDeviceEqual("/job:ps/task:1", w.initializer.device)
50      self.assertDeviceEqual("/job:worker/cpu:0", a.device)
51
52  @test_util.run_deprecated_v1
53  def testResource(self):
54    with ops.device(
55        device_setter.replica_device_setter(cluster=self._cluster_spec)):
56      v = resource_variable_ops.ResourceVariable([1, 2])
57      self.assertDeviceEqual("/job:ps/task:0", v.device)
58
59  @test_util.run_deprecated_v1
60  def testPS2TasksWithClusterSpecClass(self):
61    with ops.device(
62        device_setter.replica_device_setter(cluster=self._cluster_spec)):
63      v = variables.Variable([1, 2])
64      w = variables.Variable([2, 1])
65      a = v + w
66      self.assertDeviceEqual("/job:ps/task:0", v.device)
67      self.assertDeviceEqual("/job:ps/task:0", v.initializer.device)
68      self.assertDeviceEqual("/job:ps/task:1", w.device)
69      self.assertDeviceEqual("/job:ps/task:1", w.initializer.device)
70      self.assertDeviceEqual("/job:worker", a.device)
71
72  @test_util.run_deprecated_v1
73  def testPS2TasksPinVariableToJob(self):
74    with ops.device(
75        device_setter.replica_device_setter(cluster=self._cluster_spec)):
76      v = variables.Variable([1, 2])
77      with ops.device("/job:moon"):
78        w = variables.Variable([2, 1])
79        with ops.device("/job:ps"):  # Explicit PS job will get task set.
80          x = variables.Variable([0, 1])
81      a = v + w + x
82      self.assertDeviceEqual("/job:ps/task:0", v.device)
83      self.assertDeviceEqual("/job:ps/task:0", v.initializer.device)
84      self.assertDeviceEqual("/job:moon", w.device)
85      self.assertDeviceEqual("/job:moon", w.initializer.device)
86      self.assertDeviceEqual("/job:ps/task:1", x.device)
87      self.assertDeviceEqual("/job:ps/task:1", x.initializer.device)
88      self.assertDeviceEqual("/job:worker", a.device)
89
90  @test_util.run_deprecated_v1
91  def testPS2TasksUseCpuForPS(self):
92    with ops.device(
93        device_setter.replica_device_setter(ps_tasks=1, ps_device="/cpu:0")):
94      v = variables.Variable([1, 2])
95      with ops.device("/job:moon"):
96        w = variables.Variable([2, 1])
97      a = v + w
98      self.assertDeviceEqual("/cpu:0", v.device)
99      self.assertDeviceEqual("/cpu:0", v.initializer.device)
100      self.assertDeviceEqual("/job:moon/cpu:0", w.device)
101      self.assertDeviceEqual("/job:moon/cpu:0", w.initializer.device)
102      self.assertDeviceEqual("/job:worker", a.device)
103
104  @test_util.run_deprecated_v1
105  def testPS2TasksNoMerging(self):
106    with ops.device(
107        device_setter.replica_device_setter(
108            cluster=self._cluster_spec, merge_devices=False)):
109      v = variables.Variable([1, 2])
110      with ops.device("/job:ps"):  # Won't assign task when merge_devices=False.
111        w = variables.Variable([2, 1])
112      a = v + w
113      self.assertDeviceEqual("/job:ps/task:0", v.device)
114      self.assertDeviceEqual("/job:ps/task:0", v.initializer.device)
115      self.assertDeviceEqual("/job:ps", w.device)
116      self.assertDeviceEqual("/job:ps", w.initializer.device)
117      self.assertDeviceEqual("/job:worker", a.device)
118
119  @test_util.run_deprecated_v1
120  def testPS2TasksWithClusterSpecDict(self):
121    with ops.device(
122        device_setter.replica_device_setter(cluster=self._cluster_spec.as_dict(
123        ))):
124      v = variables.Variable([1, 2])
125      w = variables.Variable([2, 1])
126      a = v + w
127      self.assertDeviceEqual("/job:ps/task:0", v.device)
128      self.assertDeviceEqual("/job:ps/task:0", v.initializer.device)
129      self.assertDeviceEqual("/job:ps/task:1", w.device)
130      self.assertDeviceEqual("/job:ps/task:1", w.initializer.device)
131      self.assertDeviceEqual("/job:worker", a.device)
132
133  @test_util.run_deprecated_v1
134  def testPS2TasksWithClusterDef(self):
135    with ops.device(
136        device_setter.replica_device_setter(
137            cluster=self._cluster_spec.as_cluster_def())):
138      v = variables.Variable([1, 2])
139      w = variables.Variable([2, 1])
140      a = v + w
141      self.assertDeviceEqual("/job:ps/task:0", v.device)
142      self.assertDeviceEqual("/job:ps/task:0", v.initializer.device)
143      self.assertDeviceEqual("/job:ps/task:1", w.device)
144      self.assertDeviceEqual("/job:ps/task:1", w.initializer.device)
145      self.assertDeviceEqual("/job:worker", a.device)
146
147  @test_util.run_deprecated_v1
148  def testPS2TasksWithDevice(self):
149    cluster_spec = server_lib.ClusterSpec({
150        "sun": ["sun0:2222", "sun1:2222", "sun2:2222"],
151        "moon": ["moon0:2222", "moon1:2222"]
152    })
153
154    with ops.device(
155        device_setter.replica_device_setter(
156            ps_device="/job:moon",
157            worker_device="/job:sun",
158            cluster=cluster_spec.as_cluster_def())):
159      v = variables.Variable([1, 2])
160      w = variables.Variable([2, 1])
161      a = v + w
162      self.assertDeviceEqual("/job:moon/task:0", v.device)
163      self.assertDeviceEqual("/job:moon/task:0", v.initializer.device)
164      self.assertDeviceEqual("/job:moon/task:1", w.device)
165      self.assertDeviceEqual("/job:moon/task:1", w.initializer.device)
166      self.assertDeviceEqual("/job:sun", a.device)
167
168  @test_util.run_deprecated_v1
169  def testPS2TasksWithCPUConstraint(self):
170    cluster_spec = server_lib.ClusterSpec({
171        "sun": ["sun0:2222", "sun1:2222", "sun2:2222"],
172        "moon": ["moon0:2222", "moon1:2222"]
173    })
174
175    with ops.device(
176        device_setter.replica_device_setter(
177            ps_device="/job:moon/cpu:0",
178            worker_device="/job:sun",
179            cluster=cluster_spec.as_cluster_def())):
180      v = variables.Variable([1, 2])
181      w = variables.Variable([2, 1])
182      a = v + w
183      self.assertDeviceEqual("/job:moon/task:0/cpu:0", v.device)
184      self.assertDeviceEqual("/job:moon/task:0/cpu:0", v.initializer.device)
185      self.assertDeviceEqual("/job:moon/task:1/cpu:0", w.device)
186      self.assertDeviceEqual("/job:moon/task:1/cpu:0", w.initializer.device)
187      self.assertDeviceEqual("/job:sun", a.device)
188
189
190if __name__ == "__main__":
191  test.main()
192