• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Helper functions for TPU rewrite passes.
17 
18 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
19 
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/device_set.h"
23 #include "tensorflow/core/tpu/tpu_defs.h"
24 #include "tensorflow/core/util/device_name_utils.h"
25 
26 namespace tensorflow {
27 
28 // LINT.IfChange
GetSystemDevice(const string & system_spec_string,const DeviceSet & device_set,DeviceNameUtils::ParsedName * system_spec,Device ** system_device)29 Status DistributedTPURewriteHelpers::GetSystemDevice(
30     const string& system_spec_string, const DeviceSet& device_set,
31     DeviceNameUtils::ParsedName* system_spec, Device** system_device) {
32   if (!DeviceNameUtils::ParseFullName(system_spec_string, system_spec)) {
33     system_spec->Clear();
34   }
35 
36   // Callers may have relied on an Op only being registered on TPU_SYSTEM
37   // devices to ensure the Op is placed there. Augment the device spec to make
38   // the device type explicit.
39   if (!system_spec->has_type || system_spec->type != DEVICE_TPU_SYSTEM) {
40     system_spec->type = DEVICE_TPU_SYSTEM;
41     system_spec->has_type = true;
42     system_spec->id = 0;
43     system_spec->has_id = true;
44   }
45 
46   std::vector<Device*> system_devices;
47   device_set.FindMatchingDevices(*system_spec, &system_devices);
48   if (system_devices.empty()) {
49     if (system_spec_string.empty()) {
50       return errors::InvalidArgument(
51           "No TPU_SYSTEM device found. Please ensure that you're connected to "
52           "a host with a TPU_SYSTEM device.");
53     }
54     return errors::InvalidArgument("No matching devices found for '",
55                                    system_spec_string, "'");
56   } else if (system_devices.size() > 1) {
57     // Validate that all system devices are part of the same job.
58     std::unordered_set<string> job_names;
59     for (auto device : system_devices) {
60       const auto& parsed_name = device->parsed_name();
61       TF_RET_CHECK(parsed_name.has_job);
62       job_names.insert(parsed_name.job);
63     }
64     if (job_names.size() > 1) {
65       return errors::InvalidArgument(
66           "System devices cannot be part "
67           "of multiple different jobs.  Found: ",
68           str_util::Join(job_names, ","));
69     }
70 
71     // Identify the lexicographically first device from the list of
72     // valid TPU SYSTEM devices, so that every process in the same
73     // 'cluster' definition uses the same system device.
74     std::sort(system_devices.begin(), system_devices.end(),
75               [](Device* i, Device* j) {
76                 auto i_name = i->parsed_name();
77                 auto j_name = j->parsed_name();
78                 if (i_name.replica != j_name.replica) {
79                   return i_name.replica < j_name.replica;
80                 }
81                 return i_name.task < j_name.task;
82               });
83   }
84 
85   *system_device = system_devices[0];
86   if (!DeviceNameUtils::ParseFullName((*system_device)->name(), system_spec)) {
87     return errors::InvalidArgument("Unable to re-parse system device name ",
88                                    (*system_device)->name(),
89                                    " as a device spec.");
90   }
91   return OkStatus();
92 }
93 // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
94 
95 // LINT.IfChange
GetHostSystemDevices(const DeviceNameUtils::ParsedName & system_spec,const DeviceSet & device_set,std::vector<Device * > * host_system_devices)96 Status DistributedTPURewriteHelpers::GetHostSystemDevices(
97     const DeviceNameUtils::ParsedName& system_spec, const DeviceSet& device_set,
98     std::vector<Device*>* host_system_devices) {
99   DeviceNameUtils::ParsedName host_spec;
100   if (system_spec.has_job) {
101     // The system Op has been explicitly assigned to a job, so we want
102     // all the hosts in that job.
103     CHECK(DeviceNameUtils::ParseFullName(
104         strings::StrCat("/job:", system_spec.job, "/device:", DEVICE_TPU_SYSTEM,
105                         ":0"),
106         &host_spec));
107   } else {
108     // The system Op has not been explicitly assigned to a
109     // job, so take all hosts in the system. There will be a runtime
110     // error if some of those hosts don't contain TPU devices.
111     CHECK(DeviceNameUtils::ParseFullName(
112         strings::StrCat("/device:", DEVICE_TPU_SYSTEM, ":0"), &host_spec));
113   }
114   device_set.FindMatchingDevices(host_spec, host_system_devices);
115 
116   TF_RET_CHECK(!host_system_devices->empty())
117       << "No hosts found matching device spec "
118       << DeviceNameUtils::ParsedNameToString(host_spec);
119 
120   // Check that all the devices belong to the same job.
121   TF_RET_CHECK((*host_system_devices)[0]->parsed_name().has_job);
122   const string& job_name = (*host_system_devices)[0]->parsed_name().job;
123   int replica = (*host_system_devices)[0]->parsed_name().replica;
124   for (const auto host_device : *host_system_devices) {
125     const auto& parsed_name = host_device->parsed_name();
126     TF_RET_CHECK(parsed_name.has_job);
127     if (parsed_name.job != job_name) {
128       return errors::InvalidArgument(
129           "All TPU host devices must be in the same job");
130     }
131     TF_RET_CHECK(parsed_name.has_replica);
132     if (parsed_name.replica != replica) {
133       return errors::InvalidArgument(
134           "All TPU host devices must be in the same replica");
135     }
136   }
137 
138   // Sort the devices by replica and then task.
139   std::sort(host_system_devices->begin(), host_system_devices->end(),
140             [](Device* i, Device* j) {
141               auto i_name = i->parsed_name();
142               auto j_name = j->parsed_name();
143               return i_name.task < j_name.task;
144             });
145   return OkStatus();
146 }
147 // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
148 
149 // LINT.IfChange
GetTPUDevices(const DeviceNameUtils::ParsedName & system_spec,const DeviceSet & device_set,int * num_tpus_per_host,std::vector<std::vector<Device * >> * tpu_devices)150 Status DistributedTPURewriteHelpers::GetTPUDevices(
151     const DeviceNameUtils::ParsedName& system_spec, const DeviceSet& device_set,
152     int* num_tpus_per_host, std::vector<std::vector<Device*>>* tpu_devices) {
153   // GetHostSystemDevices returns the CPU device on each host that is
154   // going to be used for executing TPU code.
155   std::vector<Device*> host_system_devices;
156   TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetHostSystemDevices(
157       system_spec, device_set, &host_system_devices));
158 
159   // Enumerate all the physical devices. Enumerate devices on task 0,
160   // then task 1, etc.
161   std::sort(host_system_devices.begin(), host_system_devices.end(),
162             [](Device* i, Device* j) {
163               return i->parsed_name().task < j->parsed_name().task;
164             });
165 
166   *num_tpus_per_host = 0;
167   tpu_devices->clear();
168   tpu_devices->reserve(host_system_devices.size());
169   for (const auto device : host_system_devices) {
170     // Make a copy of the parsed name because we are going to change it.
171     DeviceNameUtils::ParsedName device_spec = device->parsed_name();
172     device_spec.has_type = true;
173     device_spec.type = "TPU";
174     // Enumerate all the available TPUs.
175     device_spec.has_id = false;
176     std::vector<Device*> host_tpu_devices;
177     device_set.FindMatchingDevices(device_spec, &host_tpu_devices);
178     // Sort the devices by device id.
179     std::sort(host_tpu_devices.begin(), host_tpu_devices.end(),
180               [](Device* i, Device* j) {
181                 return i->parsed_name().id < j->parsed_name().id;
182               });
183     if (tpu_devices->empty()) {
184       // First iteration: set *num_tpus_per_host to the number of TPUs on the
185       // first host.
186       *num_tpus_per_host = host_tpu_devices.size();
187     } else if (*num_tpus_per_host != host_tpu_devices.size()) {
188       // Subsequent iterations: check the number of TPUs match the number on
189       // the first host.
190       return errors::InvalidArgument(
191           "Mismatched number of TPU devices in cluster ", *num_tpus_per_host,
192           " vs. ", host_tpu_devices.size());
193     }
194     tpu_devices->push_back(std::move(host_tpu_devices));
195   }
196   return OkStatus();
197 }
198 // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
199 
ForConfigurationNodeMatchingType(const string & node_type,Graph * graph,const DeviceSet & device_set,const std::function<Status (const NodeDef & configuration_node_def,const string & configuration_device_name,const std::vector<Device * > & host_devices,const std::vector<Node * > & input_dependencies,const std::vector<OutputDependency> & output_dependencies,Graph * graph)> & action)200 Status DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
201     const string& node_type, Graph* graph, const DeviceSet& device_set,
202     const std::function<
203         Status(const NodeDef& configuration_node_def,
204                const string& configuration_device_name,
205                const std::vector<Device*>& host_devices,
206                const std::vector<Node*>& input_dependencies,
207                const std::vector<OutputDependency>& output_dependencies,
208                Graph* graph)>& action) {
209   // Find all the matching nodes before mutating the graph.
210   std::vector<Node*> nodes;
211   for (Node* node : graph->nodes()) {
212     if (node->type_string() == node_type) {
213       nodes.push_back(node);
214     }
215   }
216 
217   for (Node* node : nodes) {
218     string spec_string = node->requested_device();
219     DeviceNameUtils::ParsedName spec;
220     Device* device;
221     TF_RETURN_IF_ERROR(
222         GetSystemDevice(spec_string, device_set, &spec, &device));
223     const string& device_name = device->name();
224 
225     std::vector<Device*> host_devices;
226     TF_RETURN_IF_ERROR(GetHostSystemDevices(spec, device_set, &host_devices));
227 
228     std::vector<Node*> input_dependencies;
229     for (const Edge* edge : node->in_edges()) {
230       // Config ops have no inputs, so all edges must be control edges.
231       CHECK(edge->IsControlEdge());
232       input_dependencies.push_back(edge->src());
233     }
234     std::vector<OutputDependency> output_dependencies;
235     for (const Edge* edge : node->out_edges()) {
236       OutputDependency dep;
237       dep.src_output = edge->src_output();
238       dep.dst = edge->dst();
239       dep.dst_input = edge->dst_input();
240       output_dependencies.push_back(dep);
241     }
242     NodeDef node_def = node->def();
243 
244     // Remove the node now so we can insert a new node with the same
245     // name inside the action.
246     graph->RemoveNode(node);
247 
248     TF_RETURN_IF_ERROR(action(node_def, device_name, host_devices,
249                               input_dependencies, output_dependencies, graph));
250   }
251 
252   return OkStatus();
253 }
254 
255 }  // namespace tensorflow
256