• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Defines types and declares functions for identifying and extracting
17 // information about the types of platforms and supporting libraries for which
18 // StreamExecutor implementations exist.
19 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_PLATFORM_H_
20 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_PLATFORM_H_
21 
22 #include <map>
23 
24 #include "tensorflow/compiler/xla/stream_executor/device_description.h"
25 #include "tensorflow/compiler/xla/stream_executor/device_options.h"
26 #include "tensorflow/compiler/xla/stream_executor/lib/status.h"
27 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
28 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
29 #include "tensorflow/compiler/xla/stream_executor/plugin.h"
30 #include "tensorflow/compiler/xla/stream_executor/trace_listener.h"
31 
32 namespace stream_executor {
33 
34 class StreamExecutor;
35 class DeviceDescription;
36 
37 // Describes the platform for a StreamExecutor instantiation to act upon.
38 //
39 // Implementors: if you add a value here be sure to update PlatformKindString
40 // and CheckPlatformKindIsValid.
41 enum class PlatformKind {
42   kInvalid,
43   kCuda,
44   kROCm,
45   kOpenCL,
46   kHost,
47   kMock,
48   kSize,
49 };
50 
51 // Returns true if kind represents a valid platform capable of enqueuing items
52 // on a stream, but not necessarily on an accelerator device.
53 // Returns false for kMock and any invalid PlatformKind values.
54 bool PlatformIsRunnable(PlatformKind kind);
55 
56 // Returns true if kind represents a valid platform capable of running kernels
57 // on an accelerator device. Returns false for kHost*, kMock and any invalid
58 // PlatformKind values.
59 bool PlatformIsRunnableOnDevice(PlatformKind kind);
60 
61 // Returns a printable description of a PlatformKind.
62 std::string PlatformKindString(PlatformKind kind);
63 
64 // Returns the PlatformKind corresponding to the input string; returns kInvalid
65 // in the case of no match.
66 PlatformKind PlatformKindFromString(std::string platform_string);
67 
68 // Checks that kind takes on a valid value.
69 void CheckPlatformKindIsValid(PlatformKind kind);
70 
71 // StreamExecutorConfig encapsulates the set of options for constructing a
72 // StreamExecutor for a given platform.
73 struct StreamExecutorConfig {
74   // Sets members to defaults: -1 for ordinal (must be changed), and default
75   // PluginConfig and DeviceOptions.
76   StreamExecutorConfig();
77 
78   // Simple ordinal-setting constructor.
79   explicit StreamExecutorConfig(int ordinal);
80 
81   // The GPU stream for which we are searching the executor.
82   // If this field is specified for the search, others will be ignored.
83   void* gpu_stream = nullptr;
84 
85   // The ordinal of the device to be managed by the returned StreamExecutor.
86   int ordinal;
87 
88   // The PluginConfig for the returned StreamExecutor.
89   PluginConfig plugin_config;
90 
91   // The DeviceOptions for the returned StreamExecutor.
92   DeviceOptions device_options;
93 };
94 
95 // Abstract base class for a platform registered with the MultiPlatformManager.
96 class Platform {
97  public:
98   virtual ~Platform();
99 
100   // A platform ID is a unique identifier for each registered platform type -
101   // each platform is required to expose an ID to ensure unique registration and
102   // as a target against which plugins can register.
103   //
104   // The macro below is provided to help generate a [process-unique] identifier.
105   using Id = void*;
106 
107 // Helper macro to define a plugin ID. To be used only inside plugin
108 // implementation files. Works by "reserving" an address/value (guaranteed to be
109 // unique) inside a process space.
110 #define PLATFORM_DEFINE_ID(ID_VAR_NAME) \
111   namespace {                           \
112   int plugin_id_value;                  \
113   }                                     \
114   const ::stream_executor::Platform::Id ID_VAR_NAME = &plugin_id_value;
115 
116   // Returns a key uniquely identifying this platform.
117   virtual Id id() const = 0;
118 
119   // Name of this platform.
120   virtual const std::string& Name() const = 0;
121 
122   // Returns the number of devices accessible on this platform.
123   //
124   // Note that, though these devices are visible, if there is only one userspace
125   // context allowed for the device at a time and another process is using this
126   // device, a call to ExecutorForDevice may return an error status.
127   virtual int VisibleDeviceCount() const = 0;
128 
129   // Returns true iff the platform has been initialized.
130   virtual bool Initialized() const;
131 
132   // Initializes the platform with a custom set of options. The platform must be
133   // initialized before obtaining StreamExecutor objects.  The interpretation of
134   // the platform_options argument is implementation specific.  This method may
135   // return an error if unrecognized options are provided.  If using
136   // MultiPlatformManager, this method will be called automatically by
137   // InitializePlatformWithId/InitializePlatformWithName.
138   virtual port::Status Initialize(
139       const std::map<std::string, std::string>& platform_options);
140 
141   // Returns a populated DeviceDescription for the device at the given ordinal.
142   // This should not require device initialization. Note that not all platforms
143   // may support acquiring the DeviceDescription indirectly.
144   //
145   // Alternatively callers may call GetDeviceDescription() on the StreamExecutor
146   // which returns a cached instance specific to the initialized StreamExecutor.
147   virtual port::StatusOr<std::unique_ptr<DeviceDescription>>
148   DescriptionForDevice(int ordinal) const = 0;
149 
150   // Returns a device with the given ordinal on this platform with a default
151   // plugin configuration or, if none can be found with the given ordinal or
152   // there is an error in opening a context to communicate with the device, an
153   // error status is returned.
154   //
155   // Ownership of the executor is NOT transferred to the caller --
156   // the Platform owns the executors in a singleton-like fashion.
157   virtual port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) = 0;
158 
159   // Returns a device or error, as above, with the specified plugins.
160   //
161   // Ownership of the executor is NOT transferred to the caller.
162   virtual port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig(
163       int ordinal, const PluginConfig& plugin_config) = 0;
164 
165   // Returns a device constructed with the options specified in "config".
166   // Ownership of the executor is NOT transferred to the caller.
167   virtual port::StatusOr<StreamExecutor*> GetExecutor(
168       const StreamExecutorConfig& config) = 0;
169 
170   // Returns a device constructed with the options specified in "config" without
171   // looking in or storing to the Platform's executor cache.
172   // Ownership IS transferred to the caller.
173   virtual port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
174       const StreamExecutorConfig& config) = 0;
175 
176   // Warning: this is a dangerous API and should be used with caution.
177   //
178   // Forces the platform to delete executor instances, releasing their
179   // associated device contexts. There must be no held instances of the executor
180   // and there must be no outstanding activity on the devices for this platform.
181   //
182   // This is only useful on platforms which bind a device to a single process
183   // that has obtained the device context. May return UNIMPLEMENTED on platforms
184   // that have no reason to destroy device contexts.
185   //
186   // The platform must be reinitialized after this is called.
187   virtual port::Status ForceExecutorShutdown();
188 
189   // Registers a TraceListener to listen to all StreamExecutors for this
190   // platform.
191   // Takes ownership of listener.
192   virtual void RegisterTraceListener(
193       std::unique_ptr<TraceListener> listener) = 0;
194 
195   // Removes the specified TraceListener from all StreamExecutors.
196   virtual void UnregisterTraceListener(TraceListener* listener) = 0;
197 
198   // Map of executor-to-executor coordinate and boolean, indicating if the first
199   // executor can access the second's memory.
200   using PeerAccessMap = std::map<std::pair<int, int>, bool>;
201 
202   // Returns a matrix indicating which executors can access which other
203   // executors' memory.
204   virtual std::unique_ptr<PeerAccessMap> GetPeerAccessMap();
205 
206   // Attempts to enable all peer-to-peer access links described by the result of
207   // GetPeerAccessMap(). Note that calling this routine will force the creation
208   // of a default-argument (see StreamExecutorConfig) StreamExecutor object for
209   // each device ordinal in the system, should any not yet exist.
210   virtual port::Status EnablePeerAccess();
211 
212  protected:
213   // SE_DISALLOW_COPY_AND_ASSIGN declares a constructor, which suppresses the
214   // presence of the default constructor. This statement re-enables it, which
215   // simplifies subclassing.
216   Platform() = default;
217 
218  private:
219   SE_DISALLOW_COPY_AND_ASSIGN(Platform);
220 };
221 
222 }  // namespace stream_executor
223 
224 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_PLATFORM_H_
225