1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Classes and utilities that work with StreamExecutor C API for internal use. 16 // This includes functions used for device registration and interfaces needed 17 // for testing. 18 #ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 19 #define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 20 21 #include "tensorflow/c/experimental/stream_executor/stream_executor.h" 22 #include "tensorflow/c/tf_status_helper.h" 23 #include "tensorflow/stream_executor/executor_cache.h" 24 #include "tensorflow/stream_executor/lib/status.h" 25 #include "tensorflow/stream_executor/platform.h" 26 27 namespace stream_executor { 28 29 // Plugin initialization function that a device plugin 30 // must define. 31 typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const, 32 TF_Status* const); 33 34 // Registers StreamExecutor platform. 35 port::Status InitStreamExecutorPlugin(void* dso_handle); 36 37 // Allow registering a StreamExecutor plugin using a function (used for 38 // testing). 39 port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn); 40 41 struct TFStatusDeleter { operatorTFStatusDeleter42 void operator()(TF_Status* s) const { TF_DeleteStatus(s); } 43 }; 44 45 // This file implements core stream executor base classes in terms of 46 // the C API defined in stream_executor.h. A class "CSomething" represents a 47 // "Something" that can be manipulated via calls in the C interface. 48 class CPlatform : public Platform { 49 public: 50 explicit CPlatform(SP_Platform platform, 51 void (*destroy_platform)(SP_Platform*), 52 SP_PlatformFns platform_fns, 53 void (*destroy_platform_fns)(SP_PlatformFns*), 54 SP_DeviceFns device_fns, SP_StreamExecutor stream_executor, 55 SP_TimerFns timer_fns); 56 ~CPlatform() override; 57 id()58 Id id() const override { return const_cast<int*>(&plugin_id_value_); } Name()59 const std::string& Name() const override { return name_; } VisibleDeviceCount()60 int VisibleDeviceCount() const override { 61 return platform_.visible_device_count; 62 } 63 port::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice( 64 int ordinal) const override; 65 port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override; 66 port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig( 67 int ordinal, const PluginConfig& plugin_config) override; 68 port::StatusOr<StreamExecutor*> GetExecutor( 69 const StreamExecutorConfig& config) override; 70 port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor( 71 const StreamExecutorConfig& config) override; 72 73 // Trace listener is not supported RegisterTraceListener(std::unique_ptr<TraceListener> listener)74 void RegisterTraceListener(std::unique_ptr<TraceListener> listener) override { 75 LOG(FATAL) << "RegisterTraceListener is not supported by pluggable device"; 76 } UnregisterTraceListener(TraceListener * listener)77 void UnregisterTraceListener(TraceListener* listener) override {} 78 DestroyAllExecutors()79 void DestroyAllExecutors() { executor_cache_.DestroyAllExecutors(); } 80 81 private: 82 SP_Platform platform_; 83 void (*destroy_platform_)(SP_Platform*); 84 SP_PlatformFns platform_fns_; 85 void (*destroy_platform_fns_)(SP_PlatformFns*); 86 SP_DeviceFns device_fns_; 87 SP_StreamExecutor stream_executor_; 88 SP_TimerFns timer_fns_; 89 const std::string name_; 90 int plugin_id_value_; 91 stream_executor::ExecutorCache executor_cache_; 92 }; 93 94 class CStream : public internal::StreamInterface { 95 public: CStream(SP_Device * device,SP_StreamExecutor * stream_executor)96 CStream(SP_Device* device, SP_StreamExecutor* stream_executor) 97 : device_(device), 98 stream_executor_(stream_executor), 99 stream_handle_(nullptr) {} ~CStream()100 ~CStream() override { Destroy(); } 101 Create()102 port::Status Create() { 103 std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus()); 104 stream_executor_->create_stream(device_, &stream_handle_, c_status.get()); 105 port::Status s = tensorflow::StatusFromTF_Status(c_status.get()); 106 return s; 107 } 108 Destroy()109 void Destroy() { 110 if (stream_handle_ != nullptr) { 111 stream_executor_->destroy_stream(device_, stream_handle_); 112 stream_handle_ = nullptr; 113 } 114 } 115 Handle()116 SP_Stream Handle() { return stream_handle_; } 117 118 private: 119 SP_Device* device_; 120 SP_StreamExecutor* stream_executor_; 121 SP_Stream stream_handle_; 122 }; 123 124 class CEvent : public internal::EventInterface { 125 public: CEvent(SP_Device * device,SP_StreamExecutor * stream_executor)126 CEvent(SP_Device* device, SP_StreamExecutor* stream_executor) 127 : device_(device), 128 stream_executor_(stream_executor), 129 event_handle_(nullptr) {} ~CEvent()130 ~CEvent() override { Destroy(); } 131 Create()132 port::Status Create() { 133 std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus()); 134 stream_executor_->create_event(device_, &event_handle_, c_status.get()); 135 return tensorflow::StatusFromTF_Status(c_status.get()); 136 } 137 Record(SP_Stream stream_handle)138 port::Status Record(SP_Stream stream_handle) { 139 std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus()); 140 stream_executor_->record_event(device_, stream_handle, event_handle_, 141 c_status.get()); 142 return tensorflow::StatusFromTF_Status(c_status.get()); 143 } 144 Destroy()145 void Destroy() { 146 if (event_handle_ != nullptr) { 147 stream_executor_->destroy_event(device_, event_handle_); 148 event_handle_ = nullptr; 149 } 150 } 151 Handle()152 SP_Event Handle() { return event_handle_; } 153 154 private: 155 SP_Device* device_; 156 SP_StreamExecutor* stream_executor_; 157 SP_Event event_handle_; 158 }; 159 160 class CTimer : public internal::TimerInterface { 161 public: CTimer(SP_Device * device,SP_StreamExecutor * stream_executor,SP_TimerFns * timer_fns)162 CTimer(SP_Device* device, SP_StreamExecutor* stream_executor, 163 SP_TimerFns* timer_fns) 164 : device_(device), 165 stream_executor_(stream_executor), 166 timer_handle_(nullptr), 167 timer_fns_(timer_fns) {} ~CTimer()168 ~CTimer() override { Destroy(); } 169 Create()170 port::Status Create() { 171 std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus()); 172 stream_executor_->create_timer(device_, &timer_handle_, c_status.get()); 173 return tensorflow::StatusFromTF_Status(c_status.get()); 174 } 175 Destroy()176 void Destroy() { 177 if (timer_handle_ != nullptr) { 178 stream_executor_->destroy_timer(device_, timer_handle_); 179 timer_handle_ = nullptr; 180 } 181 } 182 Handle()183 SP_Timer Handle() { return timer_handle_; } 184 Microseconds()185 uint64 Microseconds() const override { 186 return timer_fns_->nanoseconds(timer_handle_) / 1000; 187 } 188 Nanoseconds()189 uint64 Nanoseconds() const override { 190 return timer_fns_->nanoseconds(timer_handle_); 191 } 192 193 private: 194 SP_Device* device_; 195 SP_StreamExecutor* stream_executor_; 196 SP_Timer timer_handle_; 197 SP_TimerFns* timer_fns_; 198 }; 199 200 } // namespace stream_executor 201 #endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 202