1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Defines types and declares functions for identifying and extracting 17 // information about the types of platforms and supporting libraries for which 18 // StreamExecutor implementations exist. 19 #ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_H_ 20 #define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_H_ 21 22 #include <map> 23 24 #include "tensorflow/stream_executor/device_description.h" 25 #include "tensorflow/stream_executor/device_options.h" 26 #include "tensorflow/stream_executor/lib/status.h" 27 #include "tensorflow/stream_executor/lib/status_macros.h" 28 #include "tensorflow/stream_executor/lib/statusor.h" 29 #include "tensorflow/stream_executor/platform/port.h" 30 #include "tensorflow/stream_executor/plugin.h" 31 #include "tensorflow/stream_executor/trace_listener.h" 32 33 namespace stream_executor { 34 35 class StreamExecutor; 36 class DeviceDescription; 37 38 // Describes the platform for a StreamExecutor instantiation to act upon. 39 // 40 // Implementors: if you add a value here be sure to update PlatformKindString 41 // and CheckPlatformKindIsValid. 42 enum class PlatformKind { 43 kInvalid, 44 kCuda, 45 kROCm, 46 kOpenCL, 47 kHost, 48 kMock, 49 kSize, 50 }; 51 52 // Returns true if kind represents a valid platform capable of enqueuing items 53 // on a stream, but not necessarily on an accelerator device. 54 // Returns false for kMock and any invalid PlatformKind values. 55 bool PlatformIsRunnable(PlatformKind kind); 56 57 // Returns true if kind represents a valid platform capable of running kernels 58 // on an accelerator device. Returns false for kHost*, kMock and any invalid 59 // PlatformKind values. 60 bool PlatformIsRunnableOnDevice(PlatformKind kind); 61 62 // Returns a printable description of a PlatformKind. 63 std::string PlatformKindString(PlatformKind kind); 64 65 // Returns the PlatformKind corresponding to the input string; returns kInvalid 66 // in the case of no match. 67 PlatformKind PlatformKindFromString(std::string platform_string); 68 69 // Checks that kind takes on a valid value. 70 void CheckPlatformKindIsValid(PlatformKind kind); 71 72 // StreamExecutorConfig encapsulates the set of options for constructing a 73 // StreamExecutor for a given platform. 74 struct StreamExecutorConfig { 75 // Sets members to defaults: -1 for ordinal (must be changed), and default 76 // PluginConfig and DeviceOptions. 77 StreamExecutorConfig(); 78 79 // Simple ordinal-setting constructor. 80 explicit StreamExecutorConfig(int ordinal); 81 82 // The ordinal of the device to be managed by the returned StreamExecutor. 83 int ordinal; 84 85 // The PluginConfig for the returned StreamExecutor. 86 PluginConfig plugin_config; 87 88 // The DeviceOptions for the returned StreamExecutor. 89 DeviceOptions device_options; 90 }; 91 92 // Abstract base class for a platform registered with the MultiPlatformManager. 93 class Platform { 94 public: 95 virtual ~Platform(); 96 97 // A platform ID is a unique identifier for each registered platform type - 98 // each platform is required to expose an ID to ensure unique registration and 99 // as a target against which plugins can register. 100 // 101 // The macro below is provided to help generate a [process-unique] identifier. 102 using Id = void*; 103 104 // Helper macro to define a plugin ID. To be used only inside plugin 105 // implementation files. Works by "reserving" an address/value (guaranteed to be 106 // unique) inside a process space. 107 #define PLATFORM_DEFINE_ID(ID_VAR_NAME) \ 108 namespace { \ 109 int plugin_id_value; \ 110 } \ 111 const ::stream_executor::Platform::Id ID_VAR_NAME = &plugin_id_value; 112 113 // Returns a key uniquely identifying this platform. 114 virtual Id id() const = 0; 115 116 // Name of this platform. 117 virtual const std::string& Name() const = 0; 118 119 // Returns the number of devices accessible on this platform. 120 // 121 // Note that, though these devices are visible, if there is only one userspace 122 // context allowed for the device at a time and another process is using this 123 // device, a call to ExecutorForDevice may return an error status. 124 virtual int VisibleDeviceCount() const = 0; 125 126 // Returns true iff the platform has been initialized. 127 virtual bool Initialized() const; 128 129 // Initializes the platform with a custom set of options. The platform must be 130 // initialized before obtaining StreamExecutor objects. The interpretation of 131 // the platform_options argument is implementation specific. This method may 132 // return an error if unrecognized options are provided. If using 133 // MultiPlatformManager, this method will be called automatically by 134 // InitializePlatformWithId/InitializePlatformWithName. 135 virtual port::Status Initialize( 136 const std::map<std::string, std::string>& platform_options); 137 138 // Returns a populated DeviceDescription for the device at the given ordinal. 139 // This should not require device initialization. Note that not all platforms 140 // may support acquiring the DeviceDescription indirectly. 141 // 142 // Alternatively callers may call GetDeviceDescription() on the StreamExecutor 143 // which returns a cached instance specific to the initialized StreamExecutor. 144 virtual port::StatusOr<std::unique_ptr<DeviceDescription>> 145 DescriptionForDevice(int ordinal) const = 0; 146 147 // Returns a device with the given ordinal on this platform with a default 148 // plugin configuration or, if none can be found with the given ordinal or 149 // there is an error in opening a context to communicate with the device, an 150 // error status is returned. 151 // 152 // Ownership of the executor is NOT transferred to the caller -- 153 // the Platform owns the executors in a singleton-like fashion. 154 virtual port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) = 0; 155 156 // Returns a device or error, as above, with the specified plugins. 157 // 158 // Ownership of the executor is NOT transferred to the caller. 159 virtual port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig( 160 int ordinal, const PluginConfig& plugin_config) = 0; 161 162 // Returns a device constructed with the options specified in "config". 163 // Ownership of the executor is NOT transferred to the caller. 164 virtual port::StatusOr<StreamExecutor*> GetExecutor( 165 const StreamExecutorConfig& config) = 0; 166 167 // Returns a device constructed with the options specified in "config" without 168 // looking in or storing to the Platform's executor cache. 169 // Ownership IS transferred to the caller. 170 virtual port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor( 171 const StreamExecutorConfig& config) = 0; 172 173 // Warning: this is a dangerous API and should be used with caution. 174 // 175 // Forces the platform to delete executor instances, releasing their 176 // associated device contexts. There must be no held instances of the executor 177 // and there must be no outstanding activity on the devices for this platform. 178 // 179 // This is only useful on platforms which bind a device to a single process 180 // that has obtained the device context. May return UNIMPLEMENTED on platforms 181 // that have no reason to destroy device contexts. 182 // 183 // The platform must be reinitialized after this is called. 184 virtual port::Status ForceExecutorShutdown(); 185 186 // Registers a TraceListener to listen to all StreamExecutors for this 187 // platform. 188 // Takes ownership of listener. 189 virtual void RegisterTraceListener( 190 std::unique_ptr<TraceListener> listener) = 0; 191 192 // Removes the specified TraceListener from all StreamExecutors. 193 virtual void UnregisterTraceListener(TraceListener* listener) = 0; 194 195 // Map of executor-to-executor coordinate and boolean, indicating if the first 196 // executor can access the second's memory. 197 using PeerAccessMap = std::map<std::pair<int, int>, bool>; 198 199 // Returns a matrix indicating which executors can access which other 200 // executors' memory. 201 virtual std::unique_ptr<PeerAccessMap> GetPeerAccessMap(); 202 203 // Attempts to enable all peer-to-peer access links described by the result of 204 // GetPeerAccessMap(). Note that calling this routine will force the creation 205 // of a default-argument (see StreamExecutorConfig) StreamExecutor object for 206 // each device ordinal in the system, should any not yet exist. 207 virtual port::Status EnablePeerAccess(); 208 209 protected: 210 // SE_DISALLOW_COPY_AND_ASSIGN declares a constructor, which suppresses the 211 // presence of the default constructor. This statement re-enables it, which 212 // simplifies subclassing. 213 Platform() = default; 214 215 private: 216 SE_DISALLOW_COPY_AND_ASSIGN(Platform); 217 }; 218 219 } // namespace stream_executor 220 221 #endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_H_ 222