1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_ 18 19 #include <map> 20 #include <memory> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/compiler/xla/array2d.h" 25 #include "tensorflow/compiler/xla/service/global_device_id.h" 26 #include "tensorflow/compiler/xla/status.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 #include "tensorflow/compiler/xla/xla_data.pb.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/platform/macros.h" 31 #include "tensorflow/core/platform/mutex.h" 32 #include "tensorflow/core/platform/types.h" 33 #include "tensorflow/stream_executor/platform.h" 34 35 namespace xla { 36 37 // Class that represents the device assignment for a set of XLA replicated 38 // computations. For R replicas and C computations, R * C devices are required 39 // execute the computation in parallel. The assigned device ids can be accessed 40 // by assignment(replica, computation). 41 class DeviceAssignment : public Array2D<int> { 42 public: DeviceAssignment()43 DeviceAssignment() {} DeviceAssignment(int replica_count,int computation_count)44 DeviceAssignment(int replica_count, int computation_count) 45 : Array2D<int>(replica_count, computation_count, -1) { 46 CHECK_GT(replica_count, 0); 47 CHECK_GT(computation_count, 0); 48 } 49 replica_count()50 int replica_count() const { return height(); } computation_count()51 int computation_count() const { return width(); } 52 53 // Finds the (replica ID, computation ID) pair for the given device. 54 StatusOr<std::pair<int, int>> LogicalIdsForDevice( 55 GlobalDeviceId device_id) const; 56 // Finds the replica ID for the given device. 57 StatusOr<int> ReplicaIdForDevice(GlobalDeviceId device_id) const; 58 59 // Protocol buffer serialization and deserialization. 60 Status Serialize(DeviceAssignmentProto* proto) const; 61 62 // Return a std::unique_ptr<DeviceAssignment> instead of a DeviceAssignment 63 // directly because one of the supported TF platforms (mac) does not compile 64 // due to a StatusOr of an incomplete type (DeviceAssignment). 65 static StatusOr<std::unique_ptr<DeviceAssignment>> Deserialize( 66 const DeviceAssignmentProto& proto); 67 68 string ToString() const; 69 }; 70 71 // A generic implementation of the XLA computation placer, which assigns device 72 // ids to a set of replicated computations. 73 class ComputationPlacer { 74 public: ComputationPlacer()75 ComputationPlacer() {} ~ComputationPlacer()76 virtual ~ComputationPlacer() {} 77 78 // Returns the device id assigned to the given replica and computation 79 // instance for [replica_count x computation_count] setup. The returned device 80 // id must match the assignment from PlaceReplicatedComputation(). 81 virtual StatusOr<int> DeviceId(int replica, int computation, 82 int replica_count, int computation_count); 83 84 // Returns the device ids assigned to a set of replicated computations, given 85 // the number of replicas and the number of computations. 86 virtual StatusOr<DeviceAssignment> AssignDevices(int replica_count, 87 int computation_count); 88 89 using ComputationPlacerCreationFunction = 90 std::unique_ptr<ComputationPlacer> (*)(); 91 92 // Registers a computation placer creation function for a particular platform. 93 static void RegisterComputationPlacer( 94 se::Platform::Id platform_id, 95 ComputationPlacerCreationFunction creation_function); 96 97 // Returns the computation placer singleton pointer if it is available for the 98 // given platform, or an error status if it is not. 99 static StatusOr<ComputationPlacer*> GetForPlatform( 100 const se::Platform* platform); 101 102 private: 103 // The mutex that guards the platform-to-computation placer map. 104 static tensorflow::mutex platform_computation_placer_mutex_; 105 106 // State kept for each kind of ComputationPlacer. Registration functions set 107 // up creation_function, and then we use that to lazily create "placer" the 108 // first time GetForPlatform is invoked for a particular id. 109 struct State { 110 std::unique_ptr<ComputationPlacer> placer; 111 ComputationPlacerCreationFunction creation_function = nullptr; 112 }; 113 114 // Map from platform kind to computation placer singleton. 115 static std::map<se::Platform::Id, State>* GetPlatformComputationPlacers(); 116 117 TF_DISALLOW_COPY_AND_ASSIGN(ComputationPlacer); 118 }; 119 120 } // namespace xla 121 122 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_ 123