1 /*===--------------------------------------------------------------------------
2 * ATMI (Asynchronous Task and Memory Interface)
3 *
4 * This file is distributed under the MIT License. See LICENSE.txt for details.
5 *===------------------------------------------------------------------------*/
6 #include "atmi_interop_hsa.h"
7 #include "internal.h"
8
9 using core::atl_is_atmi_initialized;
10
atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place,const char * symbol,void ** var_addr,unsigned int * var_size)11 atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place,
12 const char *symbol,
13 void **var_addr,
14 unsigned int *var_size) {
15 /*
16 // Typical usage:
17 void *var_addr;
18 size_t var_size;
19 atmi_interop_hsa_get_symbol_addr(gpu_place, "symbol_name", &var_addr,
20 &var_size);
21 atmi_memcpy(signal, host_add, var_addr, var_size);
22 */
23
24 if (!atl_is_atmi_initialized())
25 return ATMI_STATUS_ERROR;
26 atmi_machine_t *machine = atmi_machine_get_info();
27 if (!symbol || !var_addr || !var_size || !machine)
28 return ATMI_STATUS_ERROR;
29 if (place.dev_id < 0 ||
30 place.dev_id >= machine->device_count_by_type[place.dev_type])
31 return ATMI_STATUS_ERROR;
32
33 // get the symbol info
34 std::string symbolStr = std::string(symbol);
35 if (SymbolInfoTable[place.dev_id].find(symbolStr) !=
36 SymbolInfoTable[place.dev_id].end()) {
37 atl_symbol_info_t info = SymbolInfoTable[place.dev_id][symbolStr];
38 *var_addr = reinterpret_cast<void *>(info.addr);
39 *var_size = info.size;
40 return ATMI_STATUS_SUCCESS;
41 } else {
42 *var_addr = NULL;
43 *var_size = 0;
44 return ATMI_STATUS_ERROR;
45 }
46 }
47
atmi_interop_hsa_get_kernel_info(atmi_mem_place_t place,const char * kernel_name,hsa_executable_symbol_info_t kernel_info,uint32_t * value)48 atmi_status_t atmi_interop_hsa_get_kernel_info(
49 atmi_mem_place_t place, const char *kernel_name,
50 hsa_executable_symbol_info_t kernel_info, uint32_t *value) {
51 /*
52 // Typical usage:
53 uint32_t value;
54 atmi_interop_hsa_get_kernel_addr(gpu_place, "kernel_name",
55 HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE,
56 &val);
57 */
58
59 if (!atl_is_atmi_initialized())
60 return ATMI_STATUS_ERROR;
61 atmi_machine_t *machine = atmi_machine_get_info();
62 if (!kernel_name || !value || !machine)
63 return ATMI_STATUS_ERROR;
64 if (place.dev_id < 0 ||
65 place.dev_id >= machine->device_count_by_type[place.dev_type])
66 return ATMI_STATUS_ERROR;
67
68 atmi_status_t status = ATMI_STATUS_SUCCESS;
69 // get the kernel info
70 std::string kernelStr = std::string(kernel_name);
71 if (KernelInfoTable[place.dev_id].find(kernelStr) !=
72 KernelInfoTable[place.dev_id].end()) {
73 atl_kernel_info_t info = KernelInfoTable[place.dev_id][kernelStr];
74 switch (kernel_info) {
75 case HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE:
76 *value = info.group_segment_size;
77 break;
78 case HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_PRIVATE_SEGMENT_SIZE:
79 *value = info.private_segment_size;
80 break;
81 case HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE:
82 // return the size for non-implicit args
83 *value = info.kernel_segment_size - sizeof(atmi_implicit_args_t);
84 break;
85 default:
86 *value = 0;
87 status = ATMI_STATUS_ERROR;
88 break;
89 }
90 } else {
91 *value = 0;
92 status = ATMI_STATUS_ERROR;
93 }
94
95 return status;
96 }
97