• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Classes and utilities that work with StreamExecutor C API for internal use.
16 // This includes functions used for device registration and interfaces needed
17 // for testing.
18 #ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
19 #define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
20 
21 #include "tensorflow/c/experimental/stream_executor/stream_executor.h"
22 #include "tensorflow/c/tf_status_helper.h"
23 #include "tensorflow/stream_executor/executor_cache.h"
24 #include "tensorflow/stream_executor/lib/status.h"
25 #include "tensorflow/stream_executor/platform.h"
26 
27 namespace stream_executor {
28 
29 // Plugin initialization function that a device plugin
30 // must define.
31 typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const,
32                                TF_Status* const);
33 
34 // Registers StreamExecutor platform. `device_type` and `platform_name` are
35 // output parameters.
36 port::Status InitStreamExecutorPlugin(void* dso_handle,
37                                       std::string* device_type,
38                                       std::string* platform_name);
39 
40 // Allow registering a StreamExecutor plugin using a function (used for
41 // testing).
42 port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
43                                       std::string* device_type,
44                                       std::string* platform_name);
45 
46 // This file implements core stream executor base classes in terms of
47 // the C API defined in stream_executor.h. A class "CSomething" represents a
48 // "Something" that can be manipulated via calls in the C interface.
49 class CPlatform : public Platform {
50  public:
51   explicit CPlatform(SP_Platform platform,
52                      void (*destroy_platform)(SP_Platform*),
53                      SP_PlatformFns platform_fns,
54                      void (*destroy_platform_fns)(SP_PlatformFns*),
55                      SP_DeviceFns device_fns, SP_StreamExecutor stream_executor,
56                      SP_TimerFns timer_fns);
57   ~CPlatform() override;
58 
id()59   Id id() const override { return const_cast<int*>(&plugin_id_value_); }
Name()60   const std::string& Name() const override { return name_; }
VisibleDeviceCount()61   int VisibleDeviceCount() const override {
62     int visible_device_count = 0;
63     tensorflow::TF_StatusPtr c_status(TF_NewStatus());
64     platform_fns_.get_device_count(&platform_, &visible_device_count,
65                                    c_status.get());
66     if (TF_GetCode(c_status.get()) != TF_OK) {
67       LOG(ERROR) << TF_Message(c_status.get());
68       return 0;
69     }
70     return visible_device_count;
71   }
UseBfcAllocator()72   bool UseBfcAllocator() const { return platform_.use_bfc_allocator; }
ForceMemoryGrowth()73   bool ForceMemoryGrowth() const { return platform_.force_memory_growth; }
74   port::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice(
75       int ordinal) const override;
76   port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
77   port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig(
78       int ordinal, const PluginConfig& plugin_config) override;
79   port::StatusOr<StreamExecutor*> GetExecutor(
80       const StreamExecutorConfig& config) override;
81   port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
82       const StreamExecutorConfig& config) override;
83 
84   // Trace listener is not supported
RegisterTraceListener(std::unique_ptr<TraceListener> listener)85   void RegisterTraceListener(std::unique_ptr<TraceListener> listener) override {
86     LOG(FATAL) << "RegisterTraceListener is not supported by pluggable device";
87   }
UnregisterTraceListener(TraceListener * listener)88   void UnregisterTraceListener(TraceListener* listener) override {}
89 
DestroyAllExecutors()90   void DestroyAllExecutors() { executor_cache_.DestroyAllExecutors(); }
91 
92  private:
93   SP_Platform platform_;
94   void (*destroy_platform_)(SP_Platform*);
95   SP_PlatformFns platform_fns_;
96   void (*destroy_platform_fns_)(SP_PlatformFns*);
97   SP_DeviceFns device_fns_;
98   SP_StreamExecutor stream_executor_;
99   SP_TimerFns timer_fns_;
100   const std::string name_;
101   int plugin_id_value_;
102   stream_executor::ExecutorCache executor_cache_;
103 };
104 
105 class CStream : public internal::StreamInterface {
106  public:
CStream(SP_Device * device,SP_StreamExecutor * stream_executor)107   CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
108       : device_(device),
109         stream_executor_(stream_executor),
110         stream_handle_(nullptr) {}
~CStream()111   ~CStream() override { Destroy(); }
112 
Create()113   port::Status Create() {
114     tensorflow::TF_StatusPtr c_status(TF_NewStatus());
115     stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
116     port::Status s = tensorflow::StatusFromTF_Status(c_status.get());
117     return s;
118   }
119 
Destroy()120   void Destroy() {
121     if (stream_handle_ != nullptr) {
122       stream_executor_->destroy_stream(device_, stream_handle_);
123       stream_handle_ = nullptr;
124     }
125   }
126 
Handle()127   SP_Stream Handle() { return stream_handle_; }
128 
129  private:
130   SP_Device* device_;
131   SP_StreamExecutor* stream_executor_;
132   SP_Stream stream_handle_;
133 };
134 
135 class CEvent : public internal::EventInterface {
136  public:
CEvent(SP_Device * device,SP_StreamExecutor * stream_executor)137   CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
138       : device_(device),
139         stream_executor_(stream_executor),
140         event_handle_(nullptr) {}
~CEvent()141   ~CEvent() override { Destroy(); }
142 
Create()143   port::Status Create() {
144     tensorflow::TF_StatusPtr c_status(TF_NewStatus());
145     stream_executor_->create_event(device_, &event_handle_, c_status.get());
146     return tensorflow::StatusFromTF_Status(c_status.get());
147   }
148 
Record(SP_Stream stream_handle)149   port::Status Record(SP_Stream stream_handle) {
150     tensorflow::TF_StatusPtr c_status(TF_NewStatus());
151     stream_executor_->record_event(device_, stream_handle, event_handle_,
152                                    c_status.get());
153     return tensorflow::StatusFromTF_Status(c_status.get());
154   }
155 
Destroy()156   void Destroy() {
157     if (event_handle_ != nullptr) {
158       stream_executor_->destroy_event(device_, event_handle_);
159       event_handle_ = nullptr;
160     }
161   }
162 
Handle()163   SP_Event Handle() { return event_handle_; }
164 
165  private:
166   SP_Device* device_;
167   SP_StreamExecutor* stream_executor_;
168   SP_Event event_handle_;
169 };
170 
171 class CTimer : public internal::TimerInterface {
172  public:
CTimer(SP_Device * device,SP_StreamExecutor * stream_executor,SP_TimerFns * timer_fns)173   CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
174          SP_TimerFns* timer_fns)
175       : device_(device),
176         stream_executor_(stream_executor),
177         timer_handle_(nullptr),
178         timer_fns_(timer_fns) {}
~CTimer()179   ~CTimer() override { Destroy(); }
180 
Create()181   port::Status Create() {
182     tensorflow::TF_StatusPtr c_status(TF_NewStatus());
183     stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
184     return tensorflow::StatusFromTF_Status(c_status.get());
185   }
186 
Destroy()187   void Destroy() {
188     if (timer_handle_ != nullptr) {
189       stream_executor_->destroy_timer(device_, timer_handle_);
190       timer_handle_ = nullptr;
191     }
192   }
193 
Handle()194   SP_Timer Handle() { return timer_handle_; }
195 
Microseconds()196   uint64 Microseconds() const override {
197     return timer_fns_->nanoseconds(timer_handle_) / 1000;
198   }
199 
Nanoseconds()200   uint64 Nanoseconds() const override {
201     return timer_fns_->nanoseconds(timer_handle_);
202   }
203 
204  private:
205   SP_Device* device_;
206   SP_StreamExecutor* stream_executor_;
207   SP_Timer timer_handle_;
208   SP_TimerFns* timer_fns_;
209 };
210 
211 }  // namespace stream_executor
212 #endif  // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
213