• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Classes and utilities that work with StreamExecutor C API for internal use.
16 // This includes functions used for device registration and interfaces needed
17 // for testing.
18 #ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
19 #define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
20 
21 #include "tensorflow/c/experimental/stream_executor/stream_executor.h"
22 #include "tensorflow/c/tf_status_helper.h"
23 #include "tensorflow/stream_executor/executor_cache.h"
24 #include "tensorflow/stream_executor/lib/status.h"
25 #include "tensorflow/stream_executor/platform.h"
26 
27 namespace stream_executor {
28 
29 // Plugin initialization function that a device plugin
30 // must define.
31 typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const,
32                                TF_Status* const);
33 
34 // Registers StreamExecutor platform. `device_type` and `platform_name` are
35 // output parameters.
36 port::Status InitStreamExecutorPlugin(void* dso_handle,
37                                       std::string* device_type,
38                                       std::string* platform_name);
39 
40 // Allow registering a StreamExecutor plugin using a function (used for
41 // testing).
42 port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
43                                       std::string* device_type,
44                                       std::string* platform_name);
45 
46 struct TFStatusDeleter {
operatorTFStatusDeleter47   void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
48 };
49 
50 // This file implements core stream executor base classes in terms of
51 // the C API defined in stream_executor.h. A class "CSomething" represents a
52 // "Something" that can be manipulated via calls in the C interface.
53 class CPlatform : public Platform {
54  public:
55   explicit CPlatform(SP_Platform platform,
56                      void (*destroy_platform)(SP_Platform*),
57                      SP_PlatformFns platform_fns,
58                      void (*destroy_platform_fns)(SP_PlatformFns*),
59                      SP_DeviceFns device_fns, SP_StreamExecutor stream_executor,
60                      SP_TimerFns timer_fns);
61   ~CPlatform() override;
62 
id()63   Id id() const override { return const_cast<int*>(&plugin_id_value_); }
Name()64   const std::string& Name() const override { return name_; }
VisibleDeviceCount()65   int VisibleDeviceCount() const override {
66     int visible_device_count = 0;
67     std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
68     platform_fns_.get_device_count(&platform_, &visible_device_count,
69                                    c_status.get());
70     if (TF_GetCode(c_status.get()) != TF_OK) {
71       LOG(ERROR) << TF_Message(c_status.get());
72       return 0;
73     }
74     return visible_device_count;
75   }
UseBfcAllocator()76   bool UseBfcAllocator() const { return platform_.use_bfc_allocator; }
77   port::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice(
78       int ordinal) const override;
79   port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
80   port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig(
81       int ordinal, const PluginConfig& plugin_config) override;
82   port::StatusOr<StreamExecutor*> GetExecutor(
83       const StreamExecutorConfig& config) override;
84   port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
85       const StreamExecutorConfig& config) override;
86 
87   // Trace listener is not supported
RegisterTraceListener(std::unique_ptr<TraceListener> listener)88   void RegisterTraceListener(std::unique_ptr<TraceListener> listener) override {
89     LOG(FATAL) << "RegisterTraceListener is not supported by pluggable device";
90   }
UnregisterTraceListener(TraceListener * listener)91   void UnregisterTraceListener(TraceListener* listener) override {}
92 
DestroyAllExecutors()93   void DestroyAllExecutors() { executor_cache_.DestroyAllExecutors(); }
94 
95  private:
96   SP_Platform platform_;
97   void (*destroy_platform_)(SP_Platform*);
98   SP_PlatformFns platform_fns_;
99   void (*destroy_platform_fns_)(SP_PlatformFns*);
100   SP_DeviceFns device_fns_;
101   SP_StreamExecutor stream_executor_;
102   SP_TimerFns timer_fns_;
103   const std::string name_;
104   int plugin_id_value_;
105   stream_executor::ExecutorCache executor_cache_;
106 };
107 
108 class CStream : public internal::StreamInterface {
109  public:
CStream(SP_Device * device,SP_StreamExecutor * stream_executor)110   CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
111       : device_(device),
112         stream_executor_(stream_executor),
113         stream_handle_(nullptr) {}
~CStream()114   ~CStream() override { Destroy(); }
115 
Create()116   port::Status Create() {
117     std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
118     stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
119     port::Status s = tensorflow::StatusFromTF_Status(c_status.get());
120     return s;
121   }
122 
Destroy()123   void Destroy() {
124     if (stream_handle_ != nullptr) {
125       stream_executor_->destroy_stream(device_, stream_handle_);
126       stream_handle_ = nullptr;
127     }
128   }
129 
Handle()130   SP_Stream Handle() { return stream_handle_; }
131 
132  private:
133   SP_Device* device_;
134   SP_StreamExecutor* stream_executor_;
135   SP_Stream stream_handle_;
136 };
137 
138 class CEvent : public internal::EventInterface {
139  public:
CEvent(SP_Device * device,SP_StreamExecutor * stream_executor)140   CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
141       : device_(device),
142         stream_executor_(stream_executor),
143         event_handle_(nullptr) {}
~CEvent()144   ~CEvent() override { Destroy(); }
145 
Create()146   port::Status Create() {
147     std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
148     stream_executor_->create_event(device_, &event_handle_, c_status.get());
149     return tensorflow::StatusFromTF_Status(c_status.get());
150   }
151 
Record(SP_Stream stream_handle)152   port::Status Record(SP_Stream stream_handle) {
153     std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
154     stream_executor_->record_event(device_, stream_handle, event_handle_,
155                                    c_status.get());
156     return tensorflow::StatusFromTF_Status(c_status.get());
157   }
158 
Destroy()159   void Destroy() {
160     if (event_handle_ != nullptr) {
161       stream_executor_->destroy_event(device_, event_handle_);
162       event_handle_ = nullptr;
163     }
164   }
165 
Handle()166   SP_Event Handle() { return event_handle_; }
167 
168  private:
169   SP_Device* device_;
170   SP_StreamExecutor* stream_executor_;
171   SP_Event event_handle_;
172 };
173 
174 class CTimer : public internal::TimerInterface {
175  public:
CTimer(SP_Device * device,SP_StreamExecutor * stream_executor,SP_TimerFns * timer_fns)176   CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
177          SP_TimerFns* timer_fns)
178       : device_(device),
179         stream_executor_(stream_executor),
180         timer_handle_(nullptr),
181         timer_fns_(timer_fns) {}
~CTimer()182   ~CTimer() override { Destroy(); }
183 
Create()184   port::Status Create() {
185     std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
186     stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
187     return tensorflow::StatusFromTF_Status(c_status.get());
188   }
189 
Destroy()190   void Destroy() {
191     if (timer_handle_ != nullptr) {
192       stream_executor_->destroy_timer(device_, timer_handle_);
193       timer_handle_ = nullptr;
194     }
195   }
196 
Handle()197   SP_Timer Handle() { return timer_handle_; }
198 
Microseconds()199   uint64 Microseconds() const override {
200     return timer_fns_->nanoseconds(timer_handle_) / 1000;
201   }
202 
Nanoseconds()203   uint64 Nanoseconds() const override {
204     return timer_fns_->nanoseconds(timer_handle_);
205   }
206 
207  private:
208   SP_Device* device_;
209   SP_StreamExecutor* stream_executor_;
210   SP_Timer timer_handle_;
211   SP_TimerFns* timer_fns_;
212 };
213 
214 }  // namespace stream_executor
215 #endif  // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
216