1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/host/host_platform.h"
17
18 #include <thread>
19
20 #include "absl/memory/memory.h"
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/stream_executor/host/host_gpu_executor.h"
23 #include "tensorflow/stream_executor/host/host_platform_id.h"
24 #include "tensorflow/stream_executor/lib/error.h"
25 #include "tensorflow/stream_executor/lib/initialize.h"
26 #include "tensorflow/stream_executor/lib/status.h"
27 #include "tensorflow/stream_executor/lib/status_macros.h"
28
29 namespace stream_executor {
30 namespace host {
31
HostPlatform()32 HostPlatform::HostPlatform() : name_("Host") {}
33
~HostPlatform()34 HostPlatform::~HostPlatform() {}
35
id() const36 Platform::Id HostPlatform::id() const { return kHostPlatformId; }
37
VisibleDeviceCount() const38 int HostPlatform::VisibleDeviceCount() const {
39 return std::thread::hardware_concurrency();
40 }
41
Name() const42 const std::string& HostPlatform::Name() const { return name_; }
43
44 port::StatusOr<std::unique_ptr<DeviceDescription>>
DescriptionForDevice(int ordinal) const45 HostPlatform::DescriptionForDevice(int ordinal) const {
46 return HostExecutor::CreateDeviceDescription(ordinal);
47 }
48
ExecutorForDevice(int ordinal)49 port::StatusOr<StreamExecutor*> HostPlatform::ExecutorForDevice(int ordinal) {
50 StreamExecutorConfig config;
51 config.ordinal = ordinal;
52 config.plugin_config = PluginConfig();
53 config.device_options = DeviceOptions::Default();
54 return GetExecutor(config);
55 }
56
ExecutorForDeviceWithPluginConfig(int device_ordinal,const PluginConfig & plugin_config)57 port::StatusOr<StreamExecutor*> HostPlatform::ExecutorForDeviceWithPluginConfig(
58 int device_ordinal, const PluginConfig& plugin_config) {
59 StreamExecutorConfig config;
60 config.ordinal = device_ordinal;
61 config.plugin_config = plugin_config;
62 config.device_options = DeviceOptions::Default();
63 return GetExecutor(config);
64 }
65
GetExecutor(const StreamExecutorConfig & config)66 port::StatusOr<StreamExecutor*> HostPlatform::GetExecutor(
67 const StreamExecutorConfig& config) {
68 return executor_cache_.GetOrCreate(
69 config, [&]() { return GetUncachedExecutor(config); });
70 }
71
72 port::StatusOr<std::unique_ptr<StreamExecutor>>
GetUncachedExecutor(const StreamExecutorConfig & config)73 HostPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
74 auto executor = absl::make_unique<StreamExecutor>(
75 this, absl::make_unique<HostExecutor>(config.plugin_config),
76 config.ordinal);
77 auto init_status = executor->Init(config.device_options);
78 if (!init_status.ok()) {
79 return port::Status(
80 port::error::INTERNAL,
81 absl::StrFormat(
82 "failed initializing StreamExecutor for device ordinal %d: %s",
83 config.ordinal, init_status.ToString().c_str()));
84 }
85
86 return std::move(executor);
87 }
88
RegisterTraceListener(std::unique_ptr<TraceListener> listener)89 void HostPlatform::RegisterTraceListener(
90 std::unique_ptr<TraceListener> listener) {
91 LOG(FATAL) << "not yet implemented: register host trace listener";
92 }
93
UnregisterTraceListener(TraceListener * listener)94 void HostPlatform::UnregisterTraceListener(TraceListener* listener) {
95 LOG(FATAL) << "not yet implemented: unregister host trace listener";
96 }
97
InitializeHostPlatform()98 static void InitializeHostPlatform() {
99 std::unique_ptr<Platform> platform(new host::HostPlatform);
100 SE_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform)));
101 }
102
103 } // namespace host
104 } // namespace stream_executor
105
106 REGISTER_MODULE_INITIALIZER(host_platform,
107 stream_executor::host::InitializeHostPlatform());
108
109 // Note that module initialization sequencing is not supported in the
110 // open-source project, so this will be a no-op there.
111 REGISTER_MODULE_INITIALIZER_SEQUENCE(host_platform, multi_platform_manager);
112 REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
113 host_platform);
114