1 // Copyright 2021 The Tint Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <stdio.h>
16 #include <stdlib.h>
17 #include <fstream>
18 #include <sstream>
19 #include <string>
20 #include <type_traits>
21 #include <vector>
22
23 #include <thread> // NOLINT
24
25 #include "tools/src/cmd/remote-compile/compile.h"
26 #include "tools/src/cmd/remote-compile/socket.h"
27
28 namespace {
29
30 #if 0
31 #define DEBUG(msg, ...) printf(msg "\n", ##__VA_ARGS__)
32 #else
33 #define DEBUG(...)
34 #endif
35
36 /// Print the tool usage, and exit with 1.
ShowUsage()37 void ShowUsage() {
38 const char* name = "tint-remote-compile";
39 printf(R"(%s is a tool for compiling a shader on a remote machine
40
41 usage as server:
42 %s -s [-p port-number]
43
44 usage as client:
45 %s [-p port-number] [server-address] shader-file-path
46
47 [server-address] can be omitted if the TINT_REMOTE_COMPILE_ADDRESS environment
48 variable is set.
49 Alternatively, you can pass xcrun arguments so %s can be used as a
50 drop-in replacement.
51 )",
52 name, name, name, name);
53 exit(1);
54 }
55
56 /// The protocol version code. Bump each time the protocol changes
57 constexpr uint32_t kProtocolVersion = 1;
58
59 /// Supported shader source languages
60 enum SourceLanguage {
61 MSL,
62 };
63
64 /// Stream is a serialization wrapper around a socket
65 struct Stream {
66 /// The underlying socket
67 Socket* const socket;
68 /// Error state
69 std::string error;
70
71 /// Writes a uint32_t to the socket
operator <<__anon42263a490111::Stream72 Stream operator<<(uint32_t v) {
73 if (error.empty()) {
74 Write(&v, sizeof(v));
75 }
76 return *this;
77 }
78
79 /// Reads a uint32_t from the socket
operator >>__anon42263a490111::Stream80 Stream operator>>(uint32_t& v) {
81 if (error.empty()) {
82 Read(&v, sizeof(v));
83 }
84 return *this;
85 }
86
87 /// Writes a std::string to the socket
operator <<__anon42263a490111::Stream88 Stream operator<<(const std::string& v) {
89 if (error.empty()) {
90 uint32_t count = static_cast<uint32_t>(v.size());
91 *this << count;
92 if (count) {
93 Write(v.data(), count);
94 }
95 }
96 return *this;
97 }
98
99 /// Reads a std::string from the socket
operator >>__anon42263a490111::Stream100 Stream operator>>(std::string& v) {
101 uint32_t count = 0;
102 *this >> count;
103 if (count) {
104 std::vector<char> buf(count);
105 if (Read(buf.data(), count)) {
106 v = std::string(buf.data(), buf.size());
107 }
108 } else {
109 v.clear();
110 }
111 return *this;
112 }
113
114 /// Writes an enum value to the socket
115 template <typename T>
operator <<__anon42263a490111::Stream116 std::enable_if_t<std::is_enum<T>::value, Stream> operator<<(T e) {
117 return *this << static_cast<uint32_t>(e);
118 }
119
120 /// Reads an enum value from the socket
121 template <typename T>
operator >>__anon42263a490111::Stream122 std::enable_if_t<std::is_enum<T>::value, Stream> operator>>(T& e) {
123 uint32_t v;
124 *this >> v;
125 e = static_cast<T>(v);
126 return *this;
127 }
128
129 private:
Write__anon42263a490111::Stream130 bool Write(const void* data, size_t size) {
131 if (error.empty()) {
132 if (!socket->Write(data, size)) {
133 error = "Socket::Write() failed";
134 }
135 }
136 return error.empty();
137 }
138
Read__anon42263a490111::Stream139 bool Read(void* data, size_t size) {
140 auto buf = reinterpret_cast<uint8_t*>(data);
141 while (size > 0 && error.empty()) {
142 if (auto n = socket->Read(buf, size)) {
143 if (n > size) {
144 error = "Socket::Read() returned more bytes than requested";
145 return false;
146 }
147 size -= n;
148 buf += n;
149 }
150 }
151 return error.empty();
152 }
153 };
154
155 ////////////////////////////////////////////////////////////////////////////////
156 // Messages
157 ////////////////////////////////////////////////////////////////////////////////
158
159 /// Base class for all messages
160 struct Message {
161 /// The type of the message
162 enum class Type {
163 ConnectionRequest,
164 ConnectionResponse,
165 CompileRequest,
166 CompileResponse,
167 };
168
Message__anon42263a490111::Message169 explicit Message(Type ty) : type(ty) {}
170
171 const Type type;
172 };
173
174 struct ConnectionResponse : Message { // Server -> Client
ConnectionResponse__anon42263a490111::ConnectionResponse175 ConnectionResponse() : Message(Type::ConnectionResponse) {}
176
177 template <typename T>
Serialize__anon42263a490111::ConnectionResponse178 void Serialize(T&& f) {
179 f(error);
180 }
181
182 std::string error;
183 };
184
185 struct ConnectionRequest : Message { // Client -> Server
186 using Response = ConnectionResponse;
187
ConnectionRequest__anon42263a490111::ConnectionRequest188 explicit ConnectionRequest(uint32_t proto_ver = kProtocolVersion)
189 : Message(Type::ConnectionRequest), protocol_version(proto_ver) {}
190
191 template <typename T>
Serialize__anon42263a490111::ConnectionRequest192 void Serialize(T&& f) {
193 f(protocol_version);
194 }
195
196 uint32_t protocol_version;
197 };
198
199 struct CompileResponse : Message { // Server -> Client
CompileResponse__anon42263a490111::CompileResponse200 CompileResponse() : Message(Type::CompileResponse) {}
201
202 template <typename T>
Serialize__anon42263a490111::CompileResponse203 void Serialize(T&& f) {
204 f(error);
205 }
206
207 std::string error;
208 };
209
210 struct CompileRequest : Message { // Client -> Server
211 using Response = CompileResponse;
212
CompileRequest__anon42263a490111::CompileRequest213 CompileRequest() : Message(Type::CompileRequest) {}
CompileRequest__anon42263a490111::CompileRequest214 CompileRequest(SourceLanguage lang, std::string src)
215 : Message(Type::CompileRequest), language(lang), source(src) {}
216
217 template <typename T>
Serialize__anon42263a490111::CompileRequest218 void Serialize(T&& f) {
219 f(language);
220 f(source);
221 }
222
223 SourceLanguage language;
224 std::string source;
225 };
226
227 /// Writes the message `m` to the stream `s`
228 template <typename MESSAGE>
operator <<(Stream & s,const MESSAGE & m)229 std::enable_if_t<std::is_base_of<Message, MESSAGE>::value, Stream>& operator<<(
230 Stream& s,
231 const MESSAGE& m) {
232 s << m.type;
233 const_cast<MESSAGE&>(m).Serialize([&s](const auto& value) { s << value; });
234 return s;
235 }
236
237 /// Reads the message `m` from the stream `s`
238 template <typename MESSAGE>
operator >>(Stream & s,MESSAGE & m)239 std::enable_if_t<std::is_base_of<Message, MESSAGE>::value, Stream>& operator>>(
240 Stream& s,
241 MESSAGE& m) {
242 Message::Type ty;
243 s >> ty;
244 if (ty == m.type) {
245 m.Serialize([&s](auto& value) { s >> value; });
246 } else {
247 std::stringstream ss;
248 ss << "expected message type " << static_cast<int>(m.type) << ", got "
249 << static_cast<int>(ty);
250 s.error = ss.str();
251 }
252 return s;
253 }
254
255 /// Writes the request message `req` to the stream `s`, then reads and returns
256 /// the response message from the same stream.
257 template <typename REQUEST, typename RESPONSE = typename REQUEST::Response>
Send(Stream & s,const REQUEST & req)258 RESPONSE Send(Stream& s, const REQUEST& req) {
259 s << req;
260 if (s.error.empty()) {
261 RESPONSE resp;
262 s >> resp;
263 if (s.error.empty()) {
264 return resp;
265 }
266 }
267 return {};
268 }
269
270 } // namespace
271
272 bool RunServer(std::string port);
273 bool RunClient(std::string address, std::string port, std::string file);
274
main(int argc,char * argv[])275 int main(int argc, char* argv[]) {
276 bool run_server = false;
277 std::string port = "19000";
278
279 std::vector<std::string> args;
280 for (int i = 1; i < argc; i++) {
281 std::string arg = argv[i];
282 if (arg == "-s" || arg == "--server") {
283 run_server = true;
284 continue;
285 }
286 if (arg == "-p" || arg == "--port") {
287 if (i < argc - 1) {
288 i++;
289 port = argv[i];
290 } else {
291 printf("expected port number");
292 exit(1);
293 }
294 continue;
295 }
296
297 // xcrun flags are ignored so this executable can be used as a replacement
298 // for xcrun.
299 if ((arg == "-x" || arg == "-sdk") && (i < argc - 1)) {
300 i++;
301 continue;
302 }
303 if (arg == "metal") {
304 for (; i < argc; i++) {
305 if (std::string(argv[i]) == "-c") {
306 break;
307 }
308 }
309 continue;
310 }
311
312 args.emplace_back(arg);
313 }
314
315 bool success = false;
316
317 if (run_server) {
318 success = RunServer(port);
319 } else {
320 std::string address;
321 std::string file;
322 switch (args.size()) {
323 case 1:
324 if (auto* addr = getenv("TINT_REMOTE_COMPILE_ADDRESS")) {
325 address = addr;
326 }
327 file = args[0];
328 break;
329 case 2:
330 address = args[0];
331 file = args[1];
332 break;
333 }
334 if (address.empty() || file.empty()) {
335 ShowUsage();
336 }
337 success = RunClient(address, port, file);
338 }
339
340 if (!success) {
341 exit(1);
342 }
343
344 return 0;
345 }
346
RunServer(std::string port)347 bool RunServer(std::string port) {
348 auto server_socket = Socket::Listen("", port.c_str());
349 if (!server_socket) {
350 printf("Failed to listen on port %s\n", port.c_str());
351 return false;
352 }
353 printf("Listening on port %s...\n", port.c_str());
354 while (auto conn = server_socket->Accept()) {
355 std::thread([=] {
356 DEBUG("Client connected...");
357 Stream stream{conn.get()};
358
359 {
360 ConnectionRequest req;
361 stream >> req;
362 if (!stream.error.empty()) {
363 printf("%s\n", stream.error.c_str());
364 return;
365 }
366 ConnectionResponse resp;
367 if (req.protocol_version != kProtocolVersion) {
368 DEBUG("Protocol version mismatch");
369 resp.error = "Protocol version mismatch";
370 stream << resp;
371 return;
372 }
373 stream << resp;
374 }
375 DEBUG("Connection established");
376 {
377 CompileRequest req;
378 stream >> req;
379 if (!stream.error.empty()) {
380 printf("%s\n", stream.error.c_str());
381 return;
382 }
383 #ifdef TINT_ENABLE_MSL_COMPILATION_USING_METAL_API
384 if (req.language == SourceLanguage::MSL) {
385 auto result = CompileMslUsingMetalAPI(req.source);
386 CompileResponse resp;
387 if (!result.success) {
388 resp.error = result.output;
389 }
390 stream << resp;
391 return;
392 }
393 #endif
394 CompileResponse resp;
395 resp.error = "server cannot compile this type of shader";
396 stream << resp;
397 }
398 }).detach();
399 }
400 return true;
401 }
402
RunClient(std::string address,std::string port,std::string file)403 bool RunClient(std::string address, std::string port, std::string file) {
404 // Read the file
405 std::ifstream input(file, std::ios::binary);
406 if (!input) {
407 printf("Couldn't open '%s'\n", file.c_str());
408 return false;
409 }
410 std::string source((std::istreambuf_iterator<char>(input)),
411 std::istreambuf_iterator<char>());
412
413 constexpr const int timeout_ms = 10000;
414 DEBUG("Connecting to %s:%s...", address.c_str(), port.c_str());
415 auto conn = Socket::Connect(address.c_str(), port.c_str(), timeout_ms);
416 if (!conn) {
417 printf("Connection failed\n");
418 return false;
419 }
420
421 Stream stream{conn.get()};
422
423 DEBUG("Sending connection request...");
424 auto conn_resp = Send(stream, ConnectionRequest{kProtocolVersion});
425 if (!stream.error.empty()) {
426 printf("%s\n", stream.error.c_str());
427 return false;
428 }
429 if (!conn_resp.error.empty()) {
430 printf("%s\n", conn_resp.error.c_str());
431 return false;
432 }
433 DEBUG("Connection established. Requesting compile...");
434 auto comp_resp = Send(stream, CompileRequest{SourceLanguage::MSL, source});
435 if (!stream.error.empty()) {
436 printf("%s\n", stream.error.c_str());
437 return false;
438 }
439 if (!comp_resp.error.empty()) {
440 printf("%s\n", comp_resp.error.c_str());
441 return false;
442 }
443 DEBUG("Compilation successful");
444 return true;
445 }
446