• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_PLATFORM_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_PLATFORM_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/synchronization/mutex.h"
23 #include "tensorflow/core/platform/thread_annotations.h"
24 #include "tensorflow/stream_executor/executor_cache.h"
25 #include "tensorflow/stream_executor/lib/statusor.h"
26 #include "tensorflow/stream_executor/multi_platform_manager.h"
27 #include "tensorflow/stream_executor/platform.h"
28 #include "tensorflow/stream_executor/platform/port.h"
29 #include "tensorflow/stream_executor/stream_executor_internal.h"
30 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
31 #include "tensorflow/stream_executor/trace_listener.h"
32 
33 namespace stream_executor {
34 namespace gpu {
35 
36 // Opaque and unique identifier for the ROCM platform plugin.
37 // This is needed so that plugins can refer to/identify this platform without
38 // instantiating a ROCmPlatform object.
39 extern const Platform::Id kROCmPlatformId;
40 
41 // ROCm-specific platform plugin, registered as a singleton value via module
42 // initializer.
43 class ROCmPlatform : public Platform {
44  public:
45   ROCmPlatform();
46   ~ROCmPlatform() override;
47 
48   // ROCmPlatform-specific functionality
49   // Returns the number of distinct buses / NUMA nodes on the machine.
50   int BusCount();
51 
52   // Returns the bus/NUMA node for the specified device ordinal.
53   int DeviceToBus(int device_ordinal);
54 
55   // Returns the lowest-ordinal-number StreamExecutor on the specified bus.
56   port::StatusOr<StreamExecutor*> FirstExecutorForBus(int bus_ordinal);
57 
58   // Platform interface implementation:
59   // Returns the same value as kROCmPlatform above.
60   Platform::Id id() const override;
61 
62   // Returns -1 as a sentinel on internal failure (and logs the error).
63   int VisibleDeviceCount() const override;
64 
65   const string& Name() const override;
66 
67   port::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice(
68       int ordinal) const override;
69 
70   port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
71 
72   port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig(
73       int ordinal, const PluginConfig& config) override;
74 
75   port::StatusOr<StreamExecutor*> GetExecutor(
76       const StreamExecutorConfig& config) override;
77 
78   port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
79       const StreamExecutorConfig& config) override;
80 
81   void RegisterTraceListener(std::unique_ptr<TraceListener> listener) override;
82 
83   void UnregisterTraceListener(TraceListener* listener) override;
84 
85  private:
86   // Determines the number of NUMA nodes and the assignment of executor to each.
87   void InspectNumaNodes();
88 
89   // This platform's name.
90   string name_;
91 
92   // mutex that guards internal state.
93   mutable absl::Mutex mu_;
94 
95   // Cache of created executors.
96   ExecutorCache executor_cache_;
97 
98   // The smallest NUMA node value for any device managed by this machine
99   // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus
100   // ordinals. The NUMA node space occupied by GPUs is assumed to be dense./
101   int min_numa_node_;
102 
103   // Larger than the NUMA node value for any device managed by this machine
104   // manager.
105   int limit_numa_node_;
106 
107   SE_DISALLOW_COPY_AND_ASSIGN(ROCmPlatform);
108 };
109 
110 }  // namespace gpu
111 }  // namespace stream_executor
112 
113 #endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_PLATFORM_H_
114