1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Helper functions for TPU rewrite passes.
17
18 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
19
20 #include <vector>
21
22 #include "tensorflow/core/common_runtime/device_set.h"
23 #include "tensorflow/core/tpu/tpu_defs.h"
24 #include "tensorflow/core/util/device_name_utils.h"
25
26 namespace tensorflow {
27
28 // LINT.IfChange
GetSystemDevice(const string & system_spec_string,const DeviceSet & device_set,DeviceNameUtils::ParsedName * system_spec,Device ** system_device)29 Status DistributedTPURewriteHelpers::GetSystemDevice(
30 const string& system_spec_string, const DeviceSet& device_set,
31 DeviceNameUtils::ParsedName* system_spec, Device** system_device) {
32 if (!DeviceNameUtils::ParseFullName(system_spec_string, system_spec)) {
33 system_spec->Clear();
34 }
35
36 // Callers may have relied on an Op only being registered on TPU_SYSTEM
37 // devices to ensure the Op is placed there. Augment the device spec to make
38 // the device type explicit.
39 if (!system_spec->has_type || system_spec->type != DEVICE_TPU_SYSTEM) {
40 system_spec->type = DEVICE_TPU_SYSTEM;
41 system_spec->has_type = true;
42 system_spec->id = 0;
43 system_spec->has_id = true;
44 }
45
46 std::vector<Device*> system_devices;
47 device_set.FindMatchingDevices(*system_spec, &system_devices);
48 if (system_devices.empty()) {
49 if (system_spec_string.empty()) {
50 return errors::InvalidArgument(
51 "No TPU_SYSTEM device found. Please ensure that you're connected to "
52 "a host with a TPU_SYSTEM device.");
53 }
54 return errors::InvalidArgument("No matching devices found for '",
55 system_spec_string, "'");
56 } else if (system_devices.size() > 1) {
57 // Validate that all system devices are part of the same job.
58 std::unordered_set<string> job_names;
59 for (auto device : system_devices) {
60 const auto& parsed_name = device->parsed_name();
61 TF_RET_CHECK(parsed_name.has_job);
62 job_names.insert(parsed_name.job);
63 }
64 if (job_names.size() > 1) {
65 return errors::InvalidArgument(
66 "System devices cannot be part "
67 "of multiple different jobs. Found: ",
68 str_util::Join(job_names, ","));
69 }
70
71 // Identify the lexicographically first device from the list of
72 // valid TPU SYSTEM devices, so that every process in the same
73 // 'cluster' definition uses the same system device.
74 std::sort(system_devices.begin(), system_devices.end(),
75 [](Device* i, Device* j) {
76 auto i_name = i->parsed_name();
77 auto j_name = j->parsed_name();
78 if (i_name.replica != j_name.replica) {
79 return i_name.replica < j_name.replica;
80 }
81 return i_name.task < j_name.task;
82 });
83 }
84
85 *system_device = system_devices[0];
86 if (!DeviceNameUtils::ParseFullName((*system_device)->name(), system_spec)) {
87 return errors::InvalidArgument("Unable to re-parse system device name ",
88 (*system_device)->name(),
89 " as a device spec.");
90 }
91 return Status::OK();
92 }
93 // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
94
95 // LINT.IfChange
GetHostSystemDevices(const DeviceNameUtils::ParsedName & system_spec,const DeviceSet & device_set,std::vector<Device * > * host_system_devices)96 Status DistributedTPURewriteHelpers::GetHostSystemDevices(
97 const DeviceNameUtils::ParsedName& system_spec, const DeviceSet& device_set,
98 std::vector<Device*>* host_system_devices) {
99 DeviceNameUtils::ParsedName host_spec;
100 if (system_spec.has_job) {
101 // The system Op has been explicitly assigned to a job, so we want
102 // all the hosts in that job.
103 CHECK(DeviceNameUtils::ParseFullName(
104 strings::StrCat("/job:", system_spec.job, "/device:", DEVICE_TPU_SYSTEM,
105 ":0"),
106 &host_spec));
107 } else {
108 // The system Op has not been explicitly assigned to a
109 // job, so take all hosts in the system. There will be a runtime
110 // error if some of those hosts don't contain TPU devices.
111 CHECK(DeviceNameUtils::ParseFullName(
112 strings::StrCat("/device:", DEVICE_TPU_SYSTEM, ":0"), &host_spec));
113 }
114 device_set.FindMatchingDevices(host_spec, host_system_devices);
115
116 TF_RET_CHECK(!host_system_devices->empty())
117 << "No hosts found matching device spec "
118 << DeviceNameUtils::ParsedNameToString(host_spec);
119
120 // Check that all the devices belong to the same job.
121 TF_RET_CHECK((*host_system_devices)[0]->parsed_name().has_job);
122 const string& job_name = (*host_system_devices)[0]->parsed_name().job;
123 int replica = (*host_system_devices)[0]->parsed_name().replica;
124 for (const auto host_device : *host_system_devices) {
125 const auto& parsed_name = host_device->parsed_name();
126 TF_RET_CHECK(parsed_name.has_job);
127 if (parsed_name.job != job_name) {
128 return errors::InvalidArgument(
129 "All TPU host devices must be in the same job");
130 }
131 TF_RET_CHECK(parsed_name.has_replica);
132 if (parsed_name.replica != replica) {
133 return errors::InvalidArgument(
134 "All TPU host devices must be in the same replica");
135 }
136 }
137
138 // Sort the devices by replica and then task.
139 std::sort(host_system_devices->begin(), host_system_devices->end(),
140 [](Device* i, Device* j) {
141 auto i_name = i->parsed_name();
142 auto j_name = j->parsed_name();
143 return i_name.task < j_name.task;
144 });
145 return Status::OK();
146 }
147 // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
148
149 // LINT.IfChange
GetTPUDevices(const DeviceNameUtils::ParsedName & system_spec,const DeviceSet & device_set,int * num_tpus_per_host,std::vector<std::vector<Device * >> * tpu_devices)150 Status DistributedTPURewriteHelpers::GetTPUDevices(
151 const DeviceNameUtils::ParsedName& system_spec, const DeviceSet& device_set,
152 int* num_tpus_per_host, std::vector<std::vector<Device*>>* tpu_devices) {
153 // GetHostSystemDevices returns the CPU device on each host that is
154 // going to be used for executing TPU code.
155 std::vector<Device*> host_system_devices;
156 TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetHostSystemDevices(
157 system_spec, device_set, &host_system_devices));
158
159 // Enumerate all the physical devices. Enumerate devices on task 0,
160 // then task 1, etc.
161 std::sort(host_system_devices.begin(), host_system_devices.end(),
162 [](Device* i, Device* j) {
163 return i->parsed_name().task < j->parsed_name().task;
164 });
165
166 *num_tpus_per_host = 0;
167 tpu_devices->clear();
168 tpu_devices->reserve(host_system_devices.size());
169 for (const auto device : host_system_devices) {
170 // Make a copy of the parsed name because we are going to change it.
171 DeviceNameUtils::ParsedName device_spec = device->parsed_name();
172 device_spec.has_type = true;
173 device_spec.type = "TPU";
174 // Enumerate all the available TPUs.
175 device_spec.has_id = false;
176 std::vector<Device*> host_tpu_devices;
177 device_set.FindMatchingDevices(device_spec, &host_tpu_devices);
178 // Sort the devices by device id.
179 std::sort(host_tpu_devices.begin(), host_tpu_devices.end(),
180 [](Device* i, Device* j) {
181 return i->parsed_name().id < j->parsed_name().id;
182 });
183 if (tpu_devices->empty()) {
184 // First iteration: set *num_tpus_per_host to the number of TPUs on the
185 // first host.
186 *num_tpus_per_host = host_tpu_devices.size();
187 } else if (*num_tpus_per_host != host_tpu_devices.size()) {
188 // Subsequent iterations: check the number of TPUs match the number on
189 // the first host.
190 return errors::InvalidArgument(
191 "Mismatched number of TPU devices in cluster ", *num_tpus_per_host,
192 " vs. ", host_tpu_devices.size());
193 }
194 tpu_devices->push_back(std::move(host_tpu_devices));
195 }
196 return Status::OK();
197 }
198 // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
199
ForConfigurationNodeMatchingType(const string & node_type,Graph * graph,const DeviceSet & device_set,const std::function<Status (const NodeDef & configuration_node_def,const string & configuration_device_name,const std::vector<Device * > & host_devices,const std::vector<Node * > & input_dependencies,const std::vector<OutputDependency> & output_dependencies,Graph * graph)> & action)200 Status DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
201 const string& node_type, Graph* graph, const DeviceSet& device_set,
202 const std::function<
203 Status(const NodeDef& configuration_node_def,
204 const string& configuration_device_name,
205 const std::vector<Device*>& host_devices,
206 const std::vector<Node*>& input_dependencies,
207 const std::vector<OutputDependency>& output_dependencies,
208 Graph* graph)>& action) {
209 // Find all the matching nodes before mutating the graph.
210 std::vector<Node*> nodes;
211 for (Node* node : graph->nodes()) {
212 if (node->type_string() == node_type) {
213 nodes.push_back(node);
214 }
215 }
216
217 for (Node* node : nodes) {
218 string spec_string = node->requested_device();
219 DeviceNameUtils::ParsedName spec;
220 Device* device;
221 TF_RETURN_IF_ERROR(
222 GetSystemDevice(spec_string, device_set, &spec, &device));
223 const string& device_name = device->name();
224
225 std::vector<Device*> host_devices;
226 TF_RETURN_IF_ERROR(GetHostSystemDevices(spec, device_set, &host_devices));
227
228 std::vector<Node*> input_dependencies;
229 for (const Edge* edge : node->in_edges()) {
230 // Config ops have no inputs, so all edges must be control edges.
231 CHECK(edge->IsControlEdge());
232 input_dependencies.push_back(edge->src());
233 }
234 std::vector<OutputDependency> output_dependencies;
235 for (const Edge* edge : node->out_edges()) {
236 OutputDependency dep;
237 dep.src_output = edge->src_output();
238 dep.dst = edge->dst();
239 dep.dst_input = edge->dst_input();
240 output_dependencies.push_back(dep);
241 }
242 NodeDef node_def = node->def();
243
244 // Remove the node now so we can insert a new node with the same
245 // name inside the action.
246 graph->RemoveNode(node);
247
248 TF_RETURN_IF_ERROR(action(node_def, device_name, host_devices,
249 input_dependencies, output_dependencies, graph));
250 }
251
252 return Status::OK();
253 }
254
255 } // namespace tensorflow
256