• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Helper functions for TPU rewrite passes.
17 
18 #ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_
19 #define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_
20 
21 #include "tensorflow/compiler/xla/status_macros.h"
22 #include "tensorflow/core/common_runtime/device_set.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/resource_mgr.h"
25 #include "tensorflow/core/util/device_name_utils.h"
26 
27 namespace tensorflow {
28 
29 class DistributedTPURewriteHelpers {
30  public:
31   // Given a user-assigned device string, system_spec_string, parse it into
32   // system_spec. Verify that the device type is either TPU_SYSTEM or
33   // unassigned, and in the latter case set it to TPU_SYSTEM:0. Having set the
34   // type, verify that the spec matches a unique device in device_set, and
35   // return that device in system_device. The normal use case is for
36   // system_spec_string to identify the TPU_SYSTEM on replica 0, task 0 of the
37   // job that contains the TPU hardware.
38   // TODO(b/110910013): Possibly remove the tpu system device.
39   static Status GetSystemDevice(const string& system_spec_string,
40                                 const DeviceSet& device_set,
41                                 DeviceNameUtils::ParsedName* system_spec,
42                                 Device** system_device);
43 
44   // Given a parsed system spec (e.g., the one returned above from
45   // GetSystemDeviceName), return in host_devices the TPU_SYSTEM:0 device on
46   // every host in the spec's job. If the spec does not include an explicit job,
47   // "localhost" is used.  Returns an error if system_spec matches devices from
48   // a multiple jobs or replicas.
49   static Status GetHostSystemDevices(
50       const DeviceNameUtils::ParsedName& system_spec,
51       const DeviceSet& device_set, std::vector<Device*>* host_system_devices);
52 
53   // Given a parsed system spec (e.g., the one returned above from
54   // GetSystemDeviceName), sets `*tpu_devices` to a per-host vector of the TPU
55   // devices on every host in the spec's job. If the spec does not include an
56   // explicit job, "localhost" is used. Sets `*num_tpus_per_host` to the number
57   // of TPU devices in each host, and verifies that each host in the job has
58   // the same number of TPU devices.
59   // Returns an error if system_spec matches devices from a multiple jobs or
60   // replicas.
61   static Status GetTPUDevices(const DeviceNameUtils::ParsedName& system_spec,
62                               const DeviceSet& device_set,
63                               int* num_tpus_per_host,
64                               std::vector<std::vector<Device*>>* tpu_devices);
65 
66   // Perform 'action' on every node in 'graph' of type
67   // 'node_type'. This function is designed for use with configuration
68   // Ops that have no inputs or outputs. The arguments passed to 'action' are:
69   // 'configuration_node_name': the name of the node that matched
70   // 'configuration_device_name': the name of the device that the
71   // matching node is placed on
72   // 'host_devices': the set of TPU_SYSTEM devices on hosts with TPUs that are
73   // in the same system as the node that matched.
74   // 'input_dependencies': the set of nodes that have control edges to
75   // the matching node.
76   // 'output_dependencies': the set of output port, destination node, input port
77   // triples that have edges from the matching node. Input port is
78   // Graph::kControlSlot for a control edge.
79   // 'graph': the graph being mutated.
80   struct OutputDependency {
81     int src_output;
82     Node* dst;
83     int dst_input;
84   };
85   static Status ForConfigurationNodeMatchingType(
86       const string& node_type, Graph* graph, const DeviceSet& device_set,
87       const std::function<
88           Status(const NodeDef& configuration_node_def,
89                  const string& configuration_device_name,
90                  const std::vector<Device*>& host_devices,
91                  const std::vector<Node*>& input_dependencies,
92                  const std::vector<OutputDependency>& output_dependencies,
93                  Graph* graph)>& action);
94 };
95 
96 }  // namespace tensorflow
97 
98 #endif  // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_
99