• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h"
16 
17 #include "tensorflow/c/experimental/stream_executor/stream_executor.h"
18 
19 namespace stream_executor {
20 namespace test_util {
21 
22 /*** Functions for creating SP_StreamExecutor ***/
Allocate(const SP_Device * const device,uint64_t size,int64_t memory_space,SP_DeviceMemoryBase * const mem)23 void Allocate(const SP_Device* const device, uint64_t size,
24               int64_t memory_space, SP_DeviceMemoryBase* const mem) {}
Deallocate(const SP_Device * const device,SP_DeviceMemoryBase * const mem)25 void Deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) {
26 }
HostMemoryAllocate(const SP_Device * const device,uint64_t size)27 void* HostMemoryAllocate(const SP_Device* const device, uint64_t size) {
28   return nullptr;
29 }
HostMemoryDeallocate(const SP_Device * const device,void * mem)30 void HostMemoryDeallocate(const SP_Device* const device, void* mem) {}
GetAllocatorStats(const SP_Device * const device,SP_AllocatorStats * const stats)31 TF_Bool GetAllocatorStats(const SP_Device* const device,
32                           SP_AllocatorStats* const stats) {
33   return true;
34 }
DeviceMemoryUsage(const SP_Device * const device,int64_t * const free,int64_t * const total)35 TF_Bool DeviceMemoryUsage(const SP_Device* const device, int64_t* const free,
36                           int64_t* const total) {
37   return true;
38 }
CreateStream(const SP_Device * const device,SP_Stream * stream,TF_Status * const status)39 void CreateStream(const SP_Device* const device, SP_Stream* stream,
40                   TF_Status* const status) {
41   *stream = nullptr;
42 }
DestroyStream(const SP_Device * const device,SP_Stream stream)43 void DestroyStream(const SP_Device* const device, SP_Stream stream) {}
CreateStreamDependency(const SP_Device * const device,SP_Stream dependent,SP_Stream other,TF_Status * const status)44 void CreateStreamDependency(const SP_Device* const device, SP_Stream dependent,
45                             SP_Stream other, TF_Status* const status) {}
GetStreamStatus(const SP_Device * const device,SP_Stream stream,TF_Status * const status)46 void GetStreamStatus(const SP_Device* const device, SP_Stream stream,
47                      TF_Status* const status) {}
CreateEvent(const SP_Device * const device,SP_Event * event,TF_Status * const status)48 void CreateEvent(const SP_Device* const device, SP_Event* event,
49                  TF_Status* const status) {
50   *event = nullptr;
51 }
DestroyEvent(const SP_Device * const device,SP_Event event)52 void DestroyEvent(const SP_Device* const device, SP_Event event) {}
GetEventStatus(const SP_Device * const device,SP_Event event)53 SE_EventStatus GetEventStatus(const SP_Device* const device, SP_Event event) {
54   return SE_EVENT_UNKNOWN;
55 }
RecordEvent(const SP_Device * const device,SP_Stream stream,SP_Event event,TF_Status * const status)56 void RecordEvent(const SP_Device* const device, SP_Stream stream,
57                  SP_Event event, TF_Status* const status) {}
WaitForEvent(const SP_Device * const device,SP_Stream stream,SP_Event event,TF_Status * const status)58 void WaitForEvent(const SP_Device* const device, SP_Stream stream,
59                   SP_Event event, TF_Status* const status) {}
CreateTimer(const SP_Device * const device,SP_Timer * timer,TF_Status * const status)60 void CreateTimer(const SP_Device* const device, SP_Timer* timer,
61                  TF_Status* const status) {}
DestroyTimer(const SP_Device * const device,SP_Timer timer)62 void DestroyTimer(const SP_Device* const device, SP_Timer timer) {}
StartTimer(const SP_Device * const device,SP_Stream stream,SP_Timer timer,TF_Status * const status)63 void StartTimer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
64                 TF_Status* const status) {}
StopTimer(const SP_Device * const device,SP_Stream stream,SP_Timer timer,TF_Status * const status)65 void StopTimer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
66                TF_Status* const status) {}
MemcpyDToH(const SP_Device * const device,SP_Stream stream,void * host_dst,const SP_DeviceMemoryBase * const device_src,uint64_t size,TF_Status * const status)67 void MemcpyDToH(const SP_Device* const device, SP_Stream stream, void* host_dst,
68                 const SP_DeviceMemoryBase* const device_src, uint64_t size,
69                 TF_Status* const status) {}
MemcpyHToD(const SP_Device * const device,SP_Stream stream,SP_DeviceMemoryBase * const device_dst,const void * host_src,uint64_t size,TF_Status * const status)70 void MemcpyHToD(const SP_Device* const device, SP_Stream stream,
71                 SP_DeviceMemoryBase* const device_dst, const void* host_src,
72                 uint64_t size, TF_Status* const status) {}
SyncMemcpyDToH(const SP_Device * const device,void * host_dst,const SP_DeviceMemoryBase * const device_src,uint64_t size,TF_Status * const status)73 void SyncMemcpyDToH(const SP_Device* const device, void* host_dst,
74                     const SP_DeviceMemoryBase* const device_src, uint64_t size,
75                     TF_Status* const status) {}
SyncMemcpyHToD(const SP_Device * const device,SP_DeviceMemoryBase * const device_dst,const void * host_src,uint64_t size,TF_Status * const status)76 void SyncMemcpyHToD(const SP_Device* const device,
77                     SP_DeviceMemoryBase* const device_dst, const void* host_src,
78                     uint64_t size, TF_Status* const status) {}
BlockHostForEvent(const SP_Device * const device,SP_Event event,TF_Status * const status)79 void BlockHostForEvent(const SP_Device* const device, SP_Event event,
80                        TF_Status* const status) {}
SynchronizeAllActivity(const SP_Device * const device,TF_Status * const status)81 void SynchronizeAllActivity(const SP_Device* const device,
82                             TF_Status* const status) {}
HostCallback(const SP_Device * const device,SP_Stream stream,SE_StatusCallbackFn const callback_fn,void * const callback_arg)83 TF_Bool HostCallback(const SP_Device* const device, SP_Stream stream,
84                      SE_StatusCallbackFn const callback_fn,
85                      void* const callback_arg) {
86   return true;
87 }
88 
PopulateDefaultStreamExecutor(SP_StreamExecutor * se)89 void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) {
90   *se = {SP_STREAMEXECUTOR_STRUCT_SIZE};
91   se->allocate = Allocate;
92   se->deallocate = Deallocate;
93   se->host_memory_allocate = HostMemoryAllocate;
94   se->host_memory_deallocate = HostMemoryDeallocate;
95   se->get_allocator_stats = GetAllocatorStats;
96   se->device_memory_usage = DeviceMemoryUsage;
97   se->create_stream = CreateStream;
98   se->destroy_stream = DestroyStream;
99   se->create_stream_dependency = CreateStreamDependency;
100   se->get_stream_status = GetStreamStatus;
101   se->create_event = CreateEvent;
102   se->destroy_event = DestroyEvent;
103   se->get_event_status = GetEventStatus;
104   se->record_event = RecordEvent;
105   se->wait_for_event = WaitForEvent;
106   se->create_timer = CreateTimer;
107   se->destroy_timer = DestroyTimer;
108   se->start_timer = StartTimer;
109   se->stop_timer = StopTimer;
110   se->memcpy_dtoh = MemcpyDToH;
111   se->memcpy_htod = MemcpyHToD;
112   se->sync_memcpy_dtoh = SyncMemcpyDToH;
113   se->sync_memcpy_htod = SyncMemcpyHToD;
114   se->block_host_for_event = BlockHostForEvent;
115   se->synchronize_all_activity = SynchronizeAllActivity;
116   se->host_callback = HostCallback;
117 }
118 
PopulateDefaultDeviceFns(SP_DeviceFns * device_fns)119 void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns) {
120   *device_fns = {SP_DEVICE_FNS_STRUCT_SIZE};
121 }
122 
123 /*** Functions for creating SP_TimerFns ***/
Nanoseconds(SP_Timer timer)124 uint64_t Nanoseconds(SP_Timer timer) { return timer->timer_id; }
125 
PopulateDefaultTimerFns(SP_TimerFns * timer_fns)126 void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) {
127   timer_fns->nanoseconds = Nanoseconds;
128 }
129 
130 /*** Functions for creating SP_Platform ***/
CreateTimerFns(const SP_Platform * platform,SP_TimerFns * timer_fns,TF_Status * status)131 void CreateTimerFns(const SP_Platform* platform, SP_TimerFns* timer_fns,
132                     TF_Status* status) {
133   TF_SetStatus(status, TF_OK, "");
134   PopulateDefaultTimerFns(timer_fns);
135 }
DestroyTimerFns(const SP_Platform * platform,SP_TimerFns * timer_fns)136 void DestroyTimerFns(const SP_Platform* platform, SP_TimerFns* timer_fns) {}
137 
CreateStreamExecutor(const SP_Platform * platform,SE_CreateStreamExecutorParams * params,TF_Status * status)138 void CreateStreamExecutor(const SP_Platform* platform,
139                           SE_CreateStreamExecutorParams* params,
140                           TF_Status* status) {
141   TF_SetStatus(status, TF_OK, "");
142   PopulateDefaultStreamExecutor(params->stream_executor);
143 }
DestroyStreamExecutor(const SP_Platform * platform,SP_StreamExecutor * se)144 void DestroyStreamExecutor(const SP_Platform* platform, SP_StreamExecutor* se) {
145 }
GetDeviceCount(const SP_Platform * platform,int * device_count,TF_Status * status)146 void GetDeviceCount(const SP_Platform* platform, int* device_count,
147                     TF_Status* status) {
148   TF_SetStatus(status, TF_OK, "");
149   *device_count = kDeviceCount;
150 }
CreateDevice(const SP_Platform * platform,SE_CreateDeviceParams * params,TF_Status * status)151 void CreateDevice(const SP_Platform* platform, SE_CreateDeviceParams* params,
152                   TF_Status* status) {
153   TF_SetStatus(status, TF_OK, "");
154   params->device->struct_size = {SP_DEVICE_STRUCT_SIZE};
155 }
DestroyDevice(const SP_Platform * platform,SP_Device * device)156 void DestroyDevice(const SP_Platform* platform, SP_Device* device) {}
157 
CreateDeviceFns(const SP_Platform * platform,SE_CreateDeviceFnsParams * params,TF_Status * status)158 void CreateDeviceFns(const SP_Platform* platform,
159                      SE_CreateDeviceFnsParams* params, TF_Status* status) {
160   TF_SetStatus(status, TF_OK, "");
161   params->device_fns->struct_size = {SP_DEVICE_FNS_STRUCT_SIZE};
162 }
DestroyDeviceFns(const SP_Platform * platform,SP_DeviceFns * device_fns)163 void DestroyDeviceFns(const SP_Platform* platform, SP_DeviceFns* device_fns) {}
164 
PopulateDefaultPlatform(SP_Platform * platform,SP_PlatformFns * platform_fns)165 void PopulateDefaultPlatform(SP_Platform* platform,
166                              SP_PlatformFns* platform_fns) {
167   *platform = {SP_PLATFORM_STRUCT_SIZE};
168   platform->name = kDeviceName;
169   platform->type = kDeviceType;
170   platform_fns->get_device_count = GetDeviceCount;
171   platform_fns->create_device = CreateDevice;
172   platform_fns->destroy_device = DestroyDevice;
173   platform_fns->create_device_fns = CreateDeviceFns;
174   platform_fns->destroy_device_fns = DestroyDeviceFns;
175   platform_fns->create_stream_executor = CreateStreamExecutor;
176   platform_fns->destroy_stream_executor = DestroyStreamExecutor;
177   platform_fns->create_timer_fns = CreateTimerFns;
178   platform_fns->destroy_timer_fns = DestroyTimerFns;
179 }
180 
181 /*** Functions for creating SE_PlatformRegistrationParams ***/
DestroyPlatform(SP_Platform * platform)182 void DestroyPlatform(SP_Platform* platform) {}
DestroyPlatformFns(SP_PlatformFns * platform_fns)183 void DestroyPlatformFns(SP_PlatformFns* platform_fns) {}
184 
PopulateDefaultPlatformRegistrationParams(SE_PlatformRegistrationParams * const params)185 void PopulateDefaultPlatformRegistrationParams(
186     SE_PlatformRegistrationParams* const params) {
187   PopulateDefaultPlatform(params->platform, params->platform_fns);
188   params->destroy_platform = DestroyPlatform;
189   params->destroy_platform_fns = DestroyPlatformFns;
190 }
191 
192 }  // namespace test_util
193 }  // namespace stream_executor
194