1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Helper functions for TPU rewrite passes. 17 18 #ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_ 19 #define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_ 20 21 #include "tensorflow/compiler/xla/status_macros.h" 22 #include "tensorflow/core/common_runtime/device_set.h" 23 #include "tensorflow/core/framework/node_def.pb.h" 24 #include "tensorflow/core/framework/resource_mgr.h" 25 #include "tensorflow/core/util/device_name_utils.h" 26 27 namespace tensorflow { 28 29 class DistributedTPURewriteHelpers { 30 public: 31 // Given a user-assigned device string, system_spec_string, parse it into 32 // system_spec. Verify that the device type is either TPU_SYSTEM or 33 // unassigned, and in the latter case set it to TPU_SYSTEM:0. Having set the 34 // type, verify that the spec matches a unique device in device_set, and 35 // return that device in system_device. The normal use case is for 36 // system_spec_string to identify the TPU_SYSTEM on replica 0, task 0 of the 37 // job that contains the TPU hardware. 38 // TODO(b/110910013): Possibly remove the tpu system device. 39 static Status GetSystemDevice(const string& system_spec_string, 40 const DeviceSet& device_set, 41 DeviceNameUtils::ParsedName* system_spec, 42 Device** system_device); 43 44 // Given a parsed system spec (e.g., the one returned above from 45 // GetSystemDeviceName), return in host_devices the TPU_SYSTEM:0 device on 46 // every host in the spec's job. If the spec does not include an explicit job, 47 // "localhost" is used. Returns an error if system_spec matches devices from 48 // a multiple jobs or replicas. 49 static Status GetHostSystemDevices( 50 const DeviceNameUtils::ParsedName& system_spec, 51 const DeviceSet& device_set, std::vector<Device*>* host_system_devices); 52 53 // Given a parsed system spec (e.g., the one returned above from 54 // GetSystemDeviceName), sets `*tpu_devices` to a per-host vector of the TPU 55 // devices on every host in the spec's job. If the spec does not include an 56 // explicit job, "localhost" is used. Sets `*num_tpus_per_host` to the number 57 // of TPU devices in each host, and verifies that each host in the job has 58 // the same number of TPU devices. 59 // Returns an error if system_spec matches devices from a multiple jobs or 60 // replicas. 61 static Status GetTPUDevices(const DeviceNameUtils::ParsedName& system_spec, 62 const DeviceSet& device_set, 63 int* num_tpus_per_host, 64 std::vector<std::vector<Device*>>* tpu_devices); 65 66 // Perform 'action' on every node in 'graph' of type 67 // 'node_type'. This function is designed for use with configuration 68 // Ops that have no inputs or outputs. The arguments passed to 'action' are: 69 // 'configuration_node_name': the name of the node that matched 70 // 'configuration_device_name': the name of the device that the 71 // matching node is placed on 72 // 'host_devices': the set of TPU_SYSTEM devices on hosts with TPUs that are 73 // in the same system as the node that matched. 74 // 'input_dependencies': the set of nodes that have control edges to 75 // the matching node. 76 // 'output_dependencies': the set of output port, destination node, input port 77 // triples that have edges from the matching node. Input port is 78 // Graph::kControlSlot for a control edge. 79 // 'graph': the graph being mutated. 80 struct OutputDependency { 81 int src_output; 82 Node* dst; 83 int dst_input; 84 }; 85 static Status ForConfigurationNodeMatchingType( 86 const string& node_type, Graph* graph, const DeviceSet& device_set, 87 const std::function< 88 Status(const NodeDef& configuration_node_def, 89 const string& configuration_device_name, 90 const std::vector<Device*>& host_devices, 91 const std::vector<Node*>& input_dependencies, 92 const std::vector<OutputDependency>& output_dependencies, 93 Graph* graph)>& action); 94 }; 95 96 } // namespace tensorflow 97 98 #endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_ 99