1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Implements a verifier for AMDGPU HSA metadata.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
15 #include "llvm/Support/AMDGPUMetadata.h"
16
17 namespace llvm {
18 namespace AMDGPU {
19 namespace HSAMD {
20 namespace V3 {
21
verifyScalar(msgpack::DocNode & Node,msgpack::Type SKind,function_ref<bool (msgpack::DocNode &)> verifyValue)22 bool MetadataVerifier::verifyScalar(
23 msgpack::DocNode &Node, msgpack::Type SKind,
24 function_ref<bool(msgpack::DocNode &)> verifyValue) {
25 if (!Node.isScalar())
26 return false;
27 if (Node.getKind() != SKind) {
28 if (Strict)
29 return false;
30 // If we are not strict, we interpret string values as "implicitly typed"
31 // and attempt to coerce them to the expected type here.
32 if (Node.getKind() != msgpack::Type::String)
33 return false;
34 StringRef StringValue = Node.getString();
35 Node.fromString(StringValue);
36 if (Node.getKind() != SKind)
37 return false;
38 }
39 if (verifyValue)
40 return verifyValue(Node);
41 return true;
42 }
43
verifyInteger(msgpack::DocNode & Node)44 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
45 if (!verifyScalar(Node, msgpack::Type::UInt))
46 if (!verifyScalar(Node, msgpack::Type::Int))
47 return false;
48 return true;
49 }
50
verifyArray(msgpack::DocNode & Node,function_ref<bool (msgpack::DocNode &)> verifyNode,Optional<size_t> Size)51 bool MetadataVerifier::verifyArray(
52 msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
53 Optional<size_t> Size) {
54 if (!Node.isArray())
55 return false;
56 auto &Array = Node.getArray();
57 if (Size && Array.size() != *Size)
58 return false;
59 for (auto &Item : Array)
60 if (!verifyNode(Item))
61 return false;
62
63 return true;
64 }
65
verifyEntry(msgpack::MapDocNode & MapNode,StringRef Key,bool Required,function_ref<bool (msgpack::DocNode &)> verifyNode)66 bool MetadataVerifier::verifyEntry(
67 msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
68 function_ref<bool(msgpack::DocNode &)> verifyNode) {
69 auto Entry = MapNode.find(Key);
70 if (Entry == MapNode.end())
71 return !Required;
72 return verifyNode(Entry->second);
73 }
74
verifyScalarEntry(msgpack::MapDocNode & MapNode,StringRef Key,bool Required,msgpack::Type SKind,function_ref<bool (msgpack::DocNode &)> verifyValue)75 bool MetadataVerifier::verifyScalarEntry(
76 msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
77 msgpack::Type SKind,
78 function_ref<bool(msgpack::DocNode &)> verifyValue) {
79 return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
80 return verifyScalar(Node, SKind, verifyValue);
81 });
82 }
83
verifyIntegerEntry(msgpack::MapDocNode & MapNode,StringRef Key,bool Required)84 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
85 StringRef Key, bool Required) {
86 return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
87 return verifyInteger(Node);
88 });
89 }
90
verifyKernelArgs(msgpack::DocNode & Node)91 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
92 if (!Node.isMap())
93 return false;
94 auto &ArgsMap = Node.getMap();
95
96 if (!verifyScalarEntry(ArgsMap, ".name", false,
97 msgpack::Type::String))
98 return false;
99 if (!verifyScalarEntry(ArgsMap, ".type_name", false,
100 msgpack::Type::String))
101 return false;
102 if (!verifyIntegerEntry(ArgsMap, ".size", true))
103 return false;
104 if (!verifyIntegerEntry(ArgsMap, ".offset", true))
105 return false;
106 if (!verifyScalarEntry(ArgsMap, ".value_kind", true,
107 msgpack::Type::String,
108 [](msgpack::DocNode &SNode) {
109 return StringSwitch<bool>(SNode.getString())
110 .Case("by_value", true)
111 .Case("global_buffer", true)
112 .Case("dynamic_shared_pointer", true)
113 .Case("sampler", true)
114 .Case("image", true)
115 .Case("pipe", true)
116 .Case("queue", true)
117 .Case("hidden_global_offset_x", true)
118 .Case("hidden_global_offset_y", true)
119 .Case("hidden_global_offset_z", true)
120 .Case("hidden_none", true)
121 .Case("hidden_printf_buffer", true)
122 .Case("hidden_hostcall_buffer", true)
123 .Case("hidden_default_queue", true)
124 .Case("hidden_completion_action", true)
125 .Case("hidden_multigrid_sync_arg", true)
126 .Default(false);
127 }))
128 return false;
129 if (!verifyScalarEntry(ArgsMap, ".value_type", true,
130 msgpack::Type::String,
131 [](msgpack::DocNode &SNode) {
132 return StringSwitch<bool>(SNode.getString())
133 .Case("struct", true)
134 .Case("i8", true)
135 .Case("u8", true)
136 .Case("i16", true)
137 .Case("u16", true)
138 .Case("f16", true)
139 .Case("i32", true)
140 .Case("u32", true)
141 .Case("f32", true)
142 .Case("i64", true)
143 .Case("u64", true)
144 .Case("f64", true)
145 .Default(false);
146 }))
147 return false;
148 if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
149 return false;
150 if (!verifyScalarEntry(ArgsMap, ".address_space", false,
151 msgpack::Type::String,
152 [](msgpack::DocNode &SNode) {
153 return StringSwitch<bool>(SNode.getString())
154 .Case("private", true)
155 .Case("global", true)
156 .Case("constant", true)
157 .Case("local", true)
158 .Case("generic", true)
159 .Case("region", true)
160 .Default(false);
161 }))
162 return false;
163 if (!verifyScalarEntry(ArgsMap, ".access", false,
164 msgpack::Type::String,
165 [](msgpack::DocNode &SNode) {
166 return StringSwitch<bool>(SNode.getString())
167 .Case("read_only", true)
168 .Case("write_only", true)
169 .Case("read_write", true)
170 .Default(false);
171 }))
172 return false;
173 if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
174 msgpack::Type::String,
175 [](msgpack::DocNode &SNode) {
176 return StringSwitch<bool>(SNode.getString())
177 .Case("read_only", true)
178 .Case("write_only", true)
179 .Case("read_write", true)
180 .Default(false);
181 }))
182 return false;
183 if (!verifyScalarEntry(ArgsMap, ".is_const", false,
184 msgpack::Type::Boolean))
185 return false;
186 if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
187 msgpack::Type::Boolean))
188 return false;
189 if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
190 msgpack::Type::Boolean))
191 return false;
192 if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
193 msgpack::Type::Boolean))
194 return false;
195
196 return true;
197 }
198
verifyKernel(msgpack::DocNode & Node)199 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
200 if (!Node.isMap())
201 return false;
202 auto &KernelMap = Node.getMap();
203
204 if (!verifyScalarEntry(KernelMap, ".name", true,
205 msgpack::Type::String))
206 return false;
207 if (!verifyScalarEntry(KernelMap, ".symbol", true,
208 msgpack::Type::String))
209 return false;
210 if (!verifyScalarEntry(KernelMap, ".language", false,
211 msgpack::Type::String,
212 [](msgpack::DocNode &SNode) {
213 return StringSwitch<bool>(SNode.getString())
214 .Case("OpenCL C", true)
215 .Case("OpenCL C++", true)
216 .Case("HCC", true)
217 .Case("HIP", true)
218 .Case("OpenMP", true)
219 .Case("Assembler", true)
220 .Default(false);
221 }))
222 return false;
223 if (!verifyEntry(
224 KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
225 return verifyArray(
226 Node,
227 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
228 }))
229 return false;
230 if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
231 return verifyArray(Node, [this](msgpack::DocNode &Node) {
232 return verifyKernelArgs(Node);
233 });
234 }))
235 return false;
236 if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
237 [this](msgpack::DocNode &Node) {
238 return verifyArray(Node,
239 [this](msgpack::DocNode &Node) {
240 return verifyInteger(Node);
241 },
242 3);
243 }))
244 return false;
245 if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
246 [this](msgpack::DocNode &Node) {
247 return verifyArray(Node,
248 [this](msgpack::DocNode &Node) {
249 return verifyInteger(Node);
250 },
251 3);
252 }))
253 return false;
254 if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
255 msgpack::Type::String))
256 return false;
257 if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
258 msgpack::Type::String))
259 return false;
260 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
261 return false;
262 if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
263 return false;
264 if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
265 return false;
266 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
267 return false;
268 if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
269 return false;
270 if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
271 return false;
272 if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
273 return false;
274 if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
275 return false;
276 if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
277 return false;
278 if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
279 return false;
280
281 return true;
282 }
283
verify(msgpack::DocNode & HSAMetadataRoot)284 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
285 if (!HSAMetadataRoot.isMap())
286 return false;
287 auto &RootMap = HSAMetadataRoot.getMap();
288
289 if (!verifyEntry(
290 RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
291 return verifyArray(
292 Node,
293 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
294 }))
295 return false;
296 if (!verifyEntry(
297 RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
298 return verifyArray(Node, [this](msgpack::DocNode &Node) {
299 return verifyScalar(Node, msgpack::Type::String);
300 });
301 }))
302 return false;
303 if (!verifyEntry(RootMap, "amdhsa.kernels", true,
304 [this](msgpack::DocNode &Node) {
305 return verifyArray(Node, [this](msgpack::DocNode &Node) {
306 return verifyKernel(Node);
307 });
308 }))
309 return false;
310
311 return true;
312 }
313
314 } // end namespace V3
315 } // end namespace HSAMD
316 } // end namespace AMDGPU
317 } // end namespace llvm
318