• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Classes and utilities that work with StreamExecutor C API for internal use.
16 // This includes functions used for device registration and interfaces needed
17 // for testing.
18 #ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
19 #define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
20 
21 #include "tensorflow/c/experimental/stream_executor/stream_executor.h"
22 #include "tensorflow/c/tf_status_helper.h"
23 #include "tensorflow/stream_executor/executor_cache.h"
24 #include "tensorflow/stream_executor/lib/status.h"
25 #include "tensorflow/stream_executor/platform.h"
26 
27 namespace stream_executor {
28 
29 // Plugin initialization function that a device plugin
30 // must define.
31 typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const,
32                                TF_Status* const);
33 
34 // Registers StreamExecutor platform.
35 port::Status InitStreamExecutorPlugin(void* dso_handle);
36 
37 // Allow registering a StreamExecutor plugin using a function (used for
38 // testing).
39 port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn);
40 
41 struct TFStatusDeleter {
operatorTFStatusDeleter42   void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
43 };
44 
45 // This file implements core stream executor base classes in terms of
46 // the C API defined in stream_executor.h. A class "CSomething" represents a
47 // "Something" that can be manipulated via calls in the C interface.
48 class CPlatform : public Platform {
49  public:
50   explicit CPlatform(SP_Platform platform,
51                      void (*destroy_platform)(SP_Platform*),
52                      SP_PlatformFns platform_fns,
53                      void (*destroy_platform_fns)(SP_PlatformFns*),
54                      SP_DeviceFns device_fns, SP_StreamExecutor stream_executor,
55                      SP_TimerFns timer_fns);
56   ~CPlatform() override;
57 
id()58   Id id() const override { return const_cast<int*>(&plugin_id_value_); }
Name()59   const std::string& Name() const override { return name_; }
VisibleDeviceCount()60   int VisibleDeviceCount() const override {
61     return platform_.visible_device_count;
62   }
63   port::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice(
64       int ordinal) const override;
65   port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
66   port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig(
67       int ordinal, const PluginConfig& plugin_config) override;
68   port::StatusOr<StreamExecutor*> GetExecutor(
69       const StreamExecutorConfig& config) override;
70   port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
71       const StreamExecutorConfig& config) override;
72 
73   // Trace listener is not supported
RegisterTraceListener(std::unique_ptr<TraceListener> listener)74   void RegisterTraceListener(std::unique_ptr<TraceListener> listener) override {
75     LOG(FATAL) << "RegisterTraceListener is not supported by pluggable device";
76   }
UnregisterTraceListener(TraceListener * listener)77   void UnregisterTraceListener(TraceListener* listener) override {}
78 
DestroyAllExecutors()79   void DestroyAllExecutors() { executor_cache_.DestroyAllExecutors(); }
80 
81  private:
82   SP_Platform platform_;
83   void (*destroy_platform_)(SP_Platform*);
84   SP_PlatformFns platform_fns_;
85   void (*destroy_platform_fns_)(SP_PlatformFns*);
86   SP_DeviceFns device_fns_;
87   SP_StreamExecutor stream_executor_;
88   SP_TimerFns timer_fns_;
89   const std::string name_;
90   int plugin_id_value_;
91   stream_executor::ExecutorCache executor_cache_;
92 };
93 
94 class CStream : public internal::StreamInterface {
95  public:
CStream(SP_Device * device,SP_StreamExecutor * stream_executor)96   CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
97       : device_(device),
98         stream_executor_(stream_executor),
99         stream_handle_(nullptr) {}
~CStream()100   ~CStream() override { Destroy(); }
101 
Create()102   port::Status Create() {
103     std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
104     stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
105     port::Status s = tensorflow::StatusFromTF_Status(c_status.get());
106     return s;
107   }
108 
Destroy()109   void Destroy() {
110     if (stream_handle_ != nullptr) {
111       stream_executor_->destroy_stream(device_, stream_handle_);
112       stream_handle_ = nullptr;
113     }
114   }
115 
Handle()116   SP_Stream Handle() { return stream_handle_; }
117 
118  private:
119   SP_Device* device_;
120   SP_StreamExecutor* stream_executor_;
121   SP_Stream stream_handle_;
122 };
123 
124 class CEvent : public internal::EventInterface {
125  public:
CEvent(SP_Device * device,SP_StreamExecutor * stream_executor)126   CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
127       : device_(device),
128         stream_executor_(stream_executor),
129         event_handle_(nullptr) {}
~CEvent()130   ~CEvent() override { Destroy(); }
131 
Create()132   port::Status Create() {
133     std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
134     stream_executor_->create_event(device_, &event_handle_, c_status.get());
135     return tensorflow::StatusFromTF_Status(c_status.get());
136   }
137 
Record(SP_Stream stream_handle)138   port::Status Record(SP_Stream stream_handle) {
139     std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
140     stream_executor_->record_event(device_, stream_handle, event_handle_,
141                                    c_status.get());
142     return tensorflow::StatusFromTF_Status(c_status.get());
143   }
144 
Destroy()145   void Destroy() {
146     if (event_handle_ != nullptr) {
147       stream_executor_->destroy_event(device_, event_handle_);
148       event_handle_ = nullptr;
149     }
150   }
151 
Handle()152   SP_Event Handle() { return event_handle_; }
153 
154  private:
155   SP_Device* device_;
156   SP_StreamExecutor* stream_executor_;
157   SP_Event event_handle_;
158 };
159 
160 class CTimer : public internal::TimerInterface {
161  public:
CTimer(SP_Device * device,SP_StreamExecutor * stream_executor,SP_TimerFns * timer_fns)162   CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
163          SP_TimerFns* timer_fns)
164       : device_(device),
165         stream_executor_(stream_executor),
166         timer_handle_(nullptr),
167         timer_fns_(timer_fns) {}
~CTimer()168   ~CTimer() override { Destroy(); }
169 
Create()170   port::Status Create() {
171     std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
172     stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
173     return tensorflow::StatusFromTF_Status(c_status.get());
174   }
175 
Destroy()176   void Destroy() {
177     if (timer_handle_ != nullptr) {
178       stream_executor_->destroy_timer(device_, timer_handle_);
179       timer_handle_ = nullptr;
180     }
181   }
182 
Handle()183   SP_Timer Handle() { return timer_handle_; }
184 
Microseconds()185   uint64 Microseconds() const override {
186     return timer_fns_->nanoseconds(timer_handle_) / 1000;
187   }
188 
Nanoseconds()189   uint64 Nanoseconds() const override {
190     return timer_fns_->nanoseconds(timer_handle_);
191   }
192 
193  private:
194   SP_Device* device_;
195   SP_StreamExecutor* stream_executor_;
196   SP_Timer timer_handle_;
197   SP_TimerFns* timer_fns_;
198 };
199 
200 }  // namespace stream_executor
201 #endif  // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
202