1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h"
16
17 #include "tensorflow/c/experimental/stream_executor/stream_executor.h"
18
19 namespace stream_executor {
20 namespace test_util {
21
22 /*** Functions for creating SP_StreamExecutor ***/
Allocate(const SP_Device * const device,uint64_t size,int64_t memory_space,SP_DeviceMemoryBase * const mem)23 void Allocate(const SP_Device* const device, uint64_t size,
24 int64_t memory_space, SP_DeviceMemoryBase* const mem) {}
Deallocate(const SP_Device * const device,SP_DeviceMemoryBase * const mem)25 void Deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) {
26 }
HostMemoryAllocate(const SP_Device * const device,uint64_t size)27 void* HostMemoryAllocate(const SP_Device* const device, uint64_t size) {
28 return nullptr;
29 }
HostMemoryDeallocate(const SP_Device * const device,void * mem)30 void HostMemoryDeallocate(const SP_Device* const device, void* mem) {}
GetAllocatorStats(const SP_Device * const device,SP_AllocatorStats * const stats)31 TF_Bool GetAllocatorStats(const SP_Device* const device,
32 SP_AllocatorStats* const stats) {
33 return true;
34 }
DeviceMemoryUsage(const SP_Device * const device,int64_t * const free,int64_t * const total)35 TF_Bool DeviceMemoryUsage(const SP_Device* const device, int64_t* const free,
36 int64_t* const total) {
37 return true;
38 }
CreateStream(const SP_Device * const device,SP_Stream * stream,TF_Status * const status)39 void CreateStream(const SP_Device* const device, SP_Stream* stream,
40 TF_Status* const status) {
41 *stream = nullptr;
42 }
DestroyStream(const SP_Device * const device,SP_Stream stream)43 void DestroyStream(const SP_Device* const device, SP_Stream stream) {}
CreateStreamDependency(const SP_Device * const device,SP_Stream dependent,SP_Stream other,TF_Status * const status)44 void CreateStreamDependency(const SP_Device* const device, SP_Stream dependent,
45 SP_Stream other, TF_Status* const status) {}
GetStreamStatus(const SP_Device * const device,SP_Stream stream,TF_Status * const status)46 void GetStreamStatus(const SP_Device* const device, SP_Stream stream,
47 TF_Status* const status) {}
CreateEvent(const SP_Device * const device,SP_Event * event,TF_Status * const status)48 void CreateEvent(const SP_Device* const device, SP_Event* event,
49 TF_Status* const status) {
50 *event = nullptr;
51 }
DestroyEvent(const SP_Device * const device,SP_Event event)52 void DestroyEvent(const SP_Device* const device, SP_Event event) {}
GetEventStatus(const SP_Device * const device,SP_Event event)53 SE_EventStatus GetEventStatus(const SP_Device* const device, SP_Event event) {
54 return SE_EVENT_UNKNOWN;
55 }
RecordEvent(const SP_Device * const device,SP_Stream stream,SP_Event event,TF_Status * const status)56 void RecordEvent(const SP_Device* const device, SP_Stream stream,
57 SP_Event event, TF_Status* const status) {}
WaitForEvent(const SP_Device * const device,SP_Stream stream,SP_Event event,TF_Status * const status)58 void WaitForEvent(const SP_Device* const device, SP_Stream stream,
59 SP_Event event, TF_Status* const status) {}
CreateTimer(const SP_Device * const device,SP_Timer * timer,TF_Status * const status)60 void CreateTimer(const SP_Device* const device, SP_Timer* timer,
61 TF_Status* const status) {}
DestroyTimer(const SP_Device * const device,SP_Timer timer)62 void DestroyTimer(const SP_Device* const device, SP_Timer timer) {}
StartTimer(const SP_Device * const device,SP_Stream stream,SP_Timer timer,TF_Status * const status)63 void StartTimer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
64 TF_Status* const status) {}
StopTimer(const SP_Device * const device,SP_Stream stream,SP_Timer timer,TF_Status * const status)65 void StopTimer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
66 TF_Status* const status) {}
MemcpyDToH(const SP_Device * const device,SP_Stream stream,void * host_dst,const SP_DeviceMemoryBase * const device_src,uint64_t size,TF_Status * const status)67 void MemcpyDToH(const SP_Device* const device, SP_Stream stream, void* host_dst,
68 const SP_DeviceMemoryBase* const device_src, uint64_t size,
69 TF_Status* const status) {}
MemcpyHToD(const SP_Device * const device,SP_Stream stream,SP_DeviceMemoryBase * const device_dst,const void * host_src,uint64_t size,TF_Status * const status)70 void MemcpyHToD(const SP_Device* const device, SP_Stream stream,
71 SP_DeviceMemoryBase* const device_dst, const void* host_src,
72 uint64_t size, TF_Status* const status) {}
SyncMemcpyDToH(const SP_Device * const device,void * host_dst,const SP_DeviceMemoryBase * const device_src,uint64_t size,TF_Status * const status)73 void SyncMemcpyDToH(const SP_Device* const device, void* host_dst,
74 const SP_DeviceMemoryBase* const device_src, uint64_t size,
75 TF_Status* const status) {}
SyncMemcpyHToD(const SP_Device * const device,SP_DeviceMemoryBase * const device_dst,const void * host_src,uint64_t size,TF_Status * const status)76 void SyncMemcpyHToD(const SP_Device* const device,
77 SP_DeviceMemoryBase* const device_dst, const void* host_src,
78 uint64_t size, TF_Status* const status) {}
BlockHostForEvent(const SP_Device * const device,SP_Event event,TF_Status * const status)79 void BlockHostForEvent(const SP_Device* const device, SP_Event event,
80 TF_Status* const status) {}
SynchronizeAllActivity(const SP_Device * const device,TF_Status * const status)81 void SynchronizeAllActivity(const SP_Device* const device,
82 TF_Status* const status) {}
HostCallback(const SP_Device * const device,SP_Stream stream,SE_StatusCallbackFn const callback_fn,void * const callback_arg)83 TF_Bool HostCallback(const SP_Device* const device, SP_Stream stream,
84 SE_StatusCallbackFn const callback_fn,
85 void* const callback_arg) {
86 return true;
87 }
88
PopulateDefaultStreamExecutor(SP_StreamExecutor * se)89 void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) {
90 *se = {SP_STREAMEXECUTOR_STRUCT_SIZE};
91 se->allocate = Allocate;
92 se->deallocate = Deallocate;
93 se->host_memory_allocate = HostMemoryAllocate;
94 se->host_memory_deallocate = HostMemoryDeallocate;
95 se->get_allocator_stats = GetAllocatorStats;
96 se->device_memory_usage = DeviceMemoryUsage;
97 se->create_stream = CreateStream;
98 se->destroy_stream = DestroyStream;
99 se->create_stream_dependency = CreateStreamDependency;
100 se->get_stream_status = GetStreamStatus;
101 se->create_event = CreateEvent;
102 se->destroy_event = DestroyEvent;
103 se->get_event_status = GetEventStatus;
104 se->record_event = RecordEvent;
105 se->wait_for_event = WaitForEvent;
106 se->create_timer = CreateTimer;
107 se->destroy_timer = DestroyTimer;
108 se->start_timer = StartTimer;
109 se->stop_timer = StopTimer;
110 se->memcpy_dtoh = MemcpyDToH;
111 se->memcpy_htod = MemcpyHToD;
112 se->sync_memcpy_dtoh = SyncMemcpyDToH;
113 se->sync_memcpy_htod = SyncMemcpyHToD;
114 se->block_host_for_event = BlockHostForEvent;
115 se->synchronize_all_activity = SynchronizeAllActivity;
116 se->host_callback = HostCallback;
117 }
118
PopulateDefaultDeviceFns(SP_DeviceFns * device_fns)119 void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns) {
120 *device_fns = {SP_DEVICE_FNS_STRUCT_SIZE};
121 }
122
123 /*** Functions for creating SP_TimerFns ***/
Nanoseconds(SP_Timer timer)124 uint64_t Nanoseconds(SP_Timer timer) { return timer->timer_id; }
125
PopulateDefaultTimerFns(SP_TimerFns * timer_fns)126 void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) {
127 timer_fns->nanoseconds = Nanoseconds;
128 }
129
130 /*** Functions for creating SP_Platform ***/
CreateTimerFns(const SP_Platform * platform,SP_TimerFns * timer_fns,TF_Status * status)131 void CreateTimerFns(const SP_Platform* platform, SP_TimerFns* timer_fns,
132 TF_Status* status) {
133 TF_SetStatus(status, TF_OK, "");
134 PopulateDefaultTimerFns(timer_fns);
135 }
DestroyTimerFns(const SP_Platform * platform,SP_TimerFns * timer_fns)136 void DestroyTimerFns(const SP_Platform* platform, SP_TimerFns* timer_fns) {}
137
CreateStreamExecutor(const SP_Platform * platform,SE_CreateStreamExecutorParams * params,TF_Status * status)138 void CreateStreamExecutor(const SP_Platform* platform,
139 SE_CreateStreamExecutorParams* params,
140 TF_Status* status) {
141 TF_SetStatus(status, TF_OK, "");
142 PopulateDefaultStreamExecutor(params->stream_executor);
143 }
DestroyStreamExecutor(const SP_Platform * platform,SP_StreamExecutor * se)144 void DestroyStreamExecutor(const SP_Platform* platform, SP_StreamExecutor* se) {
145 }
GetDeviceCount(const SP_Platform * platform,int * device_count,TF_Status * status)146 void GetDeviceCount(const SP_Platform* platform, int* device_count,
147 TF_Status* status) {
148 TF_SetStatus(status, TF_OK, "");
149 *device_count = kDeviceCount;
150 }
CreateDevice(const SP_Platform * platform,SE_CreateDeviceParams * params,TF_Status * status)151 void CreateDevice(const SP_Platform* platform, SE_CreateDeviceParams* params,
152 TF_Status* status) {
153 TF_SetStatus(status, TF_OK, "");
154 params->device->struct_size = {SP_DEVICE_STRUCT_SIZE};
155 }
DestroyDevice(const SP_Platform * platform,SP_Device * device)156 void DestroyDevice(const SP_Platform* platform, SP_Device* device) {}
157
CreateDeviceFns(const SP_Platform * platform,SE_CreateDeviceFnsParams * params,TF_Status * status)158 void CreateDeviceFns(const SP_Platform* platform,
159 SE_CreateDeviceFnsParams* params, TF_Status* status) {
160 TF_SetStatus(status, TF_OK, "");
161 params->device_fns->struct_size = {SP_DEVICE_FNS_STRUCT_SIZE};
162 }
DestroyDeviceFns(const SP_Platform * platform,SP_DeviceFns * device_fns)163 void DestroyDeviceFns(const SP_Platform* platform, SP_DeviceFns* device_fns) {}
164
PopulateDefaultPlatform(SP_Platform * platform,SP_PlatformFns * platform_fns)165 void PopulateDefaultPlatform(SP_Platform* platform,
166 SP_PlatformFns* platform_fns) {
167 *platform = {SP_PLATFORM_STRUCT_SIZE};
168 platform->name = kDeviceName;
169 platform->type = kDeviceType;
170 platform_fns->get_device_count = GetDeviceCount;
171 platform_fns->create_device = CreateDevice;
172 platform_fns->destroy_device = DestroyDevice;
173 platform_fns->create_device_fns = CreateDeviceFns;
174 platform_fns->destroy_device_fns = DestroyDeviceFns;
175 platform_fns->create_stream_executor = CreateStreamExecutor;
176 platform_fns->destroy_stream_executor = DestroyStreamExecutor;
177 platform_fns->create_timer_fns = CreateTimerFns;
178 platform_fns->destroy_timer_fns = DestroyTimerFns;
179 }
180
181 /*** Functions for creating SE_PlatformRegistrationParams ***/
DestroyPlatform(SP_Platform * platform)182 void DestroyPlatform(SP_Platform* platform) {}
DestroyPlatformFns(SP_PlatformFns * platform_fns)183 void DestroyPlatformFns(SP_PlatformFns* platform_fns) {}
184
PopulateDefaultPlatformRegistrationParams(SE_PlatformRegistrationParams * const params)185 void PopulateDefaultPlatformRegistrationParams(
186 SE_PlatformRegistrationParams* const params) {
187 PopulateDefaultPlatform(params->platform, params->platform_fns);
188 params->destroy_platform = DestroyPlatform;
189 params->destroy_platform_fns = DestroyPlatformFns;
190 }
191
192 } // namespace test_util
193 } // namespace stream_executor
194