1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Classes and utilities that work with StreamExecutor C API for internal use. 16 // This includes functions used for device registration and interfaces needed 17 // for testing. 18 #ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 19 #define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 20 21 #include "tensorflow/c/experimental/stream_executor/stream_executor.h" 22 #include "tensorflow/c/tf_status_helper.h" 23 #include "tensorflow/stream_executor/executor_cache.h" 24 #include "tensorflow/stream_executor/lib/status.h" 25 #include "tensorflow/stream_executor/platform.h" 26 27 namespace stream_executor { 28 29 // Plugin initialization function that a device plugin 30 // must define. 31 typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const, 32 TF_Status* const); 33 34 // Registers StreamExecutor platform. `device_type` and `platform_name` are 35 // output parameters. 36 port::Status InitStreamExecutorPlugin(void* dso_handle, 37 std::string* device_type, 38 std::string* platform_name); 39 40 // Allow registering a StreamExecutor plugin using a function (used for 41 // testing). 42 port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn, 43 std::string* device_type, 44 std::string* platform_name); 45 46 // This file implements core stream executor base classes in terms of 47 // the C API defined in stream_executor.h. A class "CSomething" represents a 48 // "Something" that can be manipulated via calls in the C interface. 49 class CPlatform : public Platform { 50 public: 51 explicit CPlatform(SP_Platform platform, 52 void (*destroy_platform)(SP_Platform*), 53 SP_PlatformFns platform_fns, 54 void (*destroy_platform_fns)(SP_PlatformFns*), 55 SP_DeviceFns device_fns, SP_StreamExecutor stream_executor, 56 SP_TimerFns timer_fns); 57 ~CPlatform() override; 58 id()59 Id id() const override { return const_cast<int*>(&plugin_id_value_); } Name()60 const std::string& Name() const override { return name_; } VisibleDeviceCount()61 int VisibleDeviceCount() const override { 62 int visible_device_count = 0; 63 tensorflow::TF_StatusPtr c_status(TF_NewStatus()); 64 platform_fns_.get_device_count(&platform_, &visible_device_count, 65 c_status.get()); 66 if (TF_GetCode(c_status.get()) != TF_OK) { 67 LOG(ERROR) << TF_Message(c_status.get()); 68 return 0; 69 } 70 return visible_device_count; 71 } UseBfcAllocator()72 bool UseBfcAllocator() const { return platform_.use_bfc_allocator; } ForceMemoryGrowth()73 bool ForceMemoryGrowth() const { return platform_.force_memory_growth; } 74 port::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice( 75 int ordinal) const override; 76 port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override; 77 port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig( 78 int ordinal, const PluginConfig& plugin_config) override; 79 port::StatusOr<StreamExecutor*> GetExecutor( 80 const StreamExecutorConfig& config) override; 81 port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor( 82 const StreamExecutorConfig& config) override; 83 84 // Trace listener is not supported RegisterTraceListener(std::unique_ptr<TraceListener> listener)85 void RegisterTraceListener(std::unique_ptr<TraceListener> listener) override { 86 LOG(FATAL) << "RegisterTraceListener is not supported by pluggable device"; 87 } UnregisterTraceListener(TraceListener * listener)88 void UnregisterTraceListener(TraceListener* listener) override {} 89 DestroyAllExecutors()90 void DestroyAllExecutors() { executor_cache_.DestroyAllExecutors(); } 91 92 private: 93 SP_Platform platform_; 94 void (*destroy_platform_)(SP_Platform*); 95 SP_PlatformFns platform_fns_; 96 void (*destroy_platform_fns_)(SP_PlatformFns*); 97 SP_DeviceFns device_fns_; 98 SP_StreamExecutor stream_executor_; 99 SP_TimerFns timer_fns_; 100 const std::string name_; 101 int plugin_id_value_; 102 stream_executor::ExecutorCache executor_cache_; 103 }; 104 105 class CStream : public internal::StreamInterface { 106 public: CStream(SP_Device * device,SP_StreamExecutor * stream_executor)107 CStream(SP_Device* device, SP_StreamExecutor* stream_executor) 108 : device_(device), 109 stream_executor_(stream_executor), 110 stream_handle_(nullptr) {} ~CStream()111 ~CStream() override { Destroy(); } 112 Create()113 port::Status Create() { 114 tensorflow::TF_StatusPtr c_status(TF_NewStatus()); 115 stream_executor_->create_stream(device_, &stream_handle_, c_status.get()); 116 port::Status s = tensorflow::StatusFromTF_Status(c_status.get()); 117 return s; 118 } 119 Destroy()120 void Destroy() { 121 if (stream_handle_ != nullptr) { 122 stream_executor_->destroy_stream(device_, stream_handle_); 123 stream_handle_ = nullptr; 124 } 125 } 126 Handle()127 SP_Stream Handle() { return stream_handle_; } 128 129 private: 130 SP_Device* device_; 131 SP_StreamExecutor* stream_executor_; 132 SP_Stream stream_handle_; 133 }; 134 135 class CEvent : public internal::EventInterface { 136 public: CEvent(SP_Device * device,SP_StreamExecutor * stream_executor)137 CEvent(SP_Device* device, SP_StreamExecutor* stream_executor) 138 : device_(device), 139 stream_executor_(stream_executor), 140 event_handle_(nullptr) {} ~CEvent()141 ~CEvent() override { Destroy(); } 142 Create()143 port::Status Create() { 144 tensorflow::TF_StatusPtr c_status(TF_NewStatus()); 145 stream_executor_->create_event(device_, &event_handle_, c_status.get()); 146 return tensorflow::StatusFromTF_Status(c_status.get()); 147 } 148 Record(SP_Stream stream_handle)149 port::Status Record(SP_Stream stream_handle) { 150 tensorflow::TF_StatusPtr c_status(TF_NewStatus()); 151 stream_executor_->record_event(device_, stream_handle, event_handle_, 152 c_status.get()); 153 return tensorflow::StatusFromTF_Status(c_status.get()); 154 } 155 Destroy()156 void Destroy() { 157 if (event_handle_ != nullptr) { 158 stream_executor_->destroy_event(device_, event_handle_); 159 event_handle_ = nullptr; 160 } 161 } 162 Handle()163 SP_Event Handle() { return event_handle_; } 164 165 private: 166 SP_Device* device_; 167 SP_StreamExecutor* stream_executor_; 168 SP_Event event_handle_; 169 }; 170 171 class CTimer : public internal::TimerInterface { 172 public: CTimer(SP_Device * device,SP_StreamExecutor * stream_executor,SP_TimerFns * timer_fns)173 CTimer(SP_Device* device, SP_StreamExecutor* stream_executor, 174 SP_TimerFns* timer_fns) 175 : device_(device), 176 stream_executor_(stream_executor), 177 timer_handle_(nullptr), 178 timer_fns_(timer_fns) {} ~CTimer()179 ~CTimer() override { Destroy(); } 180 Create()181 port::Status Create() { 182 tensorflow::TF_StatusPtr c_status(TF_NewStatus()); 183 stream_executor_->create_timer(device_, &timer_handle_, c_status.get()); 184 return tensorflow::StatusFromTF_Status(c_status.get()); 185 } 186 Destroy()187 void Destroy() { 188 if (timer_handle_ != nullptr) { 189 stream_executor_->destroy_timer(device_, timer_handle_); 190 timer_handle_ = nullptr; 191 } 192 } 193 Handle()194 SP_Timer Handle() { return timer_handle_; } 195 Microseconds()196 uint64 Microseconds() const override { 197 return timer_fns_->nanoseconds(timer_handle_) / 1000; 198 } 199 Nanoseconds()200 uint64 Nanoseconds() const override { 201 return timer_fns_->nanoseconds(timer_handle_); 202 } 203 204 private: 205 SP_Device* device_; 206 SP_StreamExecutor* stream_executor_; 207 SP_Timer timer_handle_; 208 SP_TimerFns* timer_fns_; 209 }; 210 211 } // namespace stream_executor 212 #endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 213