• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <stdio.h>
16 #include <stdlib.h>
17 #include <fstream>
18 #include <sstream>
19 #include <string>
20 #include <type_traits>
21 #include <vector>
22 
23 #include <thread>  // NOLINT
24 
25 #include "tools/src/cmd/remote-compile/compile.h"
26 #include "tools/src/cmd/remote-compile/socket.h"
27 
28 namespace {
29 
30 #if 0
31 #define DEBUG(msg, ...) printf(msg "\n", ##__VA_ARGS__)
32 #else
33 #define DEBUG(...)
34 #endif
35 
36 /// Print the tool usage, and exit with 1.
ShowUsage()37 void ShowUsage() {
38   const char* name = "tint-remote-compile";
39   printf(R"(%s is a tool for compiling a shader on a remote machine
40 
41 usage as server:
42   %s -s [-p port-number]
43 
44 usage as client:
45   %s [-p port-number] [server-address] shader-file-path
46 
47   [server-address] can be omitted if the TINT_REMOTE_COMPILE_ADDRESS environment
48   variable is set.
49   Alternatively, you can pass xcrun arguments so %s can be used as a
50   drop-in replacement.
51 )",
52          name, name, name, name);
53   exit(1);
54 }
55 
56 /// The protocol version code. Bump each time the protocol changes
57 constexpr uint32_t kProtocolVersion = 1;
58 
59 /// Supported shader source languages
60 enum SourceLanguage {
61   MSL,
62 };
63 
64 /// Stream is a serialization wrapper around a socket
65 struct Stream {
66   /// The underlying socket
67   Socket* const socket;
68   /// Error state
69   std::string error;
70 
71   /// Writes a uint32_t to the socket
operator <<__anon42263a490111::Stream72   Stream operator<<(uint32_t v) {
73     if (error.empty()) {
74       Write(&v, sizeof(v));
75     }
76     return *this;
77   }
78 
79   /// Reads a uint32_t from the socket
operator >>__anon42263a490111::Stream80   Stream operator>>(uint32_t& v) {
81     if (error.empty()) {
82       Read(&v, sizeof(v));
83     }
84     return *this;
85   }
86 
87   /// Writes a std::string to the socket
operator <<__anon42263a490111::Stream88   Stream operator<<(const std::string& v) {
89     if (error.empty()) {
90       uint32_t count = static_cast<uint32_t>(v.size());
91       *this << count;
92       if (count) {
93         Write(v.data(), count);
94       }
95     }
96     return *this;
97   }
98 
99   /// Reads a std::string from the socket
operator >>__anon42263a490111::Stream100   Stream operator>>(std::string& v) {
101     uint32_t count = 0;
102     *this >> count;
103     if (count) {
104       std::vector<char> buf(count);
105       if (Read(buf.data(), count)) {
106         v = std::string(buf.data(), buf.size());
107       }
108     } else {
109       v.clear();
110     }
111     return *this;
112   }
113 
114   /// Writes an enum value to the socket
115   template <typename T>
operator <<__anon42263a490111::Stream116   std::enable_if_t<std::is_enum<T>::value, Stream> operator<<(T e) {
117     return *this << static_cast<uint32_t>(e);
118   }
119 
120   /// Reads an enum value from the socket
121   template <typename T>
operator >>__anon42263a490111::Stream122   std::enable_if_t<std::is_enum<T>::value, Stream> operator>>(T& e) {
123     uint32_t v;
124     *this >> v;
125     e = static_cast<T>(v);
126     return *this;
127   }
128 
129  private:
Write__anon42263a490111::Stream130   bool Write(const void* data, size_t size) {
131     if (error.empty()) {
132       if (!socket->Write(data, size)) {
133         error = "Socket::Write() failed";
134       }
135     }
136     return error.empty();
137   }
138 
Read__anon42263a490111::Stream139   bool Read(void* data, size_t size) {
140     auto buf = reinterpret_cast<uint8_t*>(data);
141     while (size > 0 && error.empty()) {
142       if (auto n = socket->Read(buf, size)) {
143         if (n > size) {
144           error = "Socket::Read() returned more bytes than requested";
145           return false;
146         }
147         size -= n;
148         buf += n;
149       }
150     }
151     return error.empty();
152   }
153 };
154 
155 ////////////////////////////////////////////////////////////////////////////////
156 // Messages
157 ////////////////////////////////////////////////////////////////////////////////
158 
159 /// Base class for all messages
160 struct Message {
161   /// The type of the message
162   enum class Type {
163     ConnectionRequest,
164     ConnectionResponse,
165     CompileRequest,
166     CompileResponse,
167   };
168 
Message__anon42263a490111::Message169   explicit Message(Type ty) : type(ty) {}
170 
171   const Type type;
172 };
173 
174 struct ConnectionResponse : Message {  // Server -> Client
ConnectionResponse__anon42263a490111::ConnectionResponse175   ConnectionResponse() : Message(Type::ConnectionResponse) {}
176 
177   template <typename T>
Serialize__anon42263a490111::ConnectionResponse178   void Serialize(T&& f) {
179     f(error);
180   }
181 
182   std::string error;
183 };
184 
185 struct ConnectionRequest : Message {  // Client -> Server
186   using Response = ConnectionResponse;
187 
ConnectionRequest__anon42263a490111::ConnectionRequest188   explicit ConnectionRequest(uint32_t proto_ver = kProtocolVersion)
189       : Message(Type::ConnectionRequest), protocol_version(proto_ver) {}
190 
191   template <typename T>
Serialize__anon42263a490111::ConnectionRequest192   void Serialize(T&& f) {
193     f(protocol_version);
194   }
195 
196   uint32_t protocol_version;
197 };
198 
199 struct CompileResponse : Message {  //  Server -> Client
CompileResponse__anon42263a490111::CompileResponse200   CompileResponse() : Message(Type::CompileResponse) {}
201 
202   template <typename T>
Serialize__anon42263a490111::CompileResponse203   void Serialize(T&& f) {
204     f(error);
205   }
206 
207   std::string error;
208 };
209 
210 struct CompileRequest : Message {  // Client -> Server
211   using Response = CompileResponse;
212 
CompileRequest__anon42263a490111::CompileRequest213   CompileRequest() : Message(Type::CompileRequest) {}
CompileRequest__anon42263a490111::CompileRequest214   CompileRequest(SourceLanguage lang, std::string src)
215       : Message(Type::CompileRequest), language(lang), source(src) {}
216 
217   template <typename T>
Serialize__anon42263a490111::CompileRequest218   void Serialize(T&& f) {
219     f(language);
220     f(source);
221   }
222 
223   SourceLanguage language;
224   std::string source;
225 };
226 
227 /// Writes the message `m` to the stream `s`
228 template <typename MESSAGE>
operator <<(Stream & s,const MESSAGE & m)229 std::enable_if_t<std::is_base_of<Message, MESSAGE>::value, Stream>& operator<<(
230     Stream& s,
231     const MESSAGE& m) {
232   s << m.type;
233   const_cast<MESSAGE&>(m).Serialize([&s](const auto& value) { s << value; });
234   return s;
235 }
236 
237 /// Reads the message `m` from the stream `s`
238 template <typename MESSAGE>
operator >>(Stream & s,MESSAGE & m)239 std::enable_if_t<std::is_base_of<Message, MESSAGE>::value, Stream>& operator>>(
240     Stream& s,
241     MESSAGE& m) {
242   Message::Type ty;
243   s >> ty;
244   if (ty == m.type) {
245     m.Serialize([&s](auto& value) { s >> value; });
246   } else {
247     std::stringstream ss;
248     ss << "expected message type " << static_cast<int>(m.type) << ", got "
249        << static_cast<int>(ty);
250     s.error = ss.str();
251   }
252   return s;
253 }
254 
255 /// Writes the request message `req` to the stream `s`, then reads and returns
256 /// the response message from the same stream.
257 template <typename REQUEST, typename RESPONSE = typename REQUEST::Response>
Send(Stream & s,const REQUEST & req)258 RESPONSE Send(Stream& s, const REQUEST& req) {
259   s << req;
260   if (s.error.empty()) {
261     RESPONSE resp;
262     s >> resp;
263     if (s.error.empty()) {
264       return resp;
265     }
266   }
267   return {};
268 }
269 
270 }  // namespace
271 
272 bool RunServer(std::string port);
273 bool RunClient(std::string address, std::string port, std::string file);
274 
main(int argc,char * argv[])275 int main(int argc, char* argv[]) {
276   bool run_server = false;
277   std::string port = "19000";
278 
279   std::vector<std::string> args;
280   for (int i = 1; i < argc; i++) {
281     std::string arg = argv[i];
282     if (arg == "-s" || arg == "--server") {
283       run_server = true;
284       continue;
285     }
286     if (arg == "-p" || arg == "--port") {
287       if (i < argc - 1) {
288         i++;
289         port = argv[i];
290       } else {
291         printf("expected port number");
292         exit(1);
293       }
294       continue;
295     }
296 
297     // xcrun flags are ignored so this executable can be used as a replacement
298     // for xcrun.
299     if ((arg == "-x" || arg == "-sdk") && (i < argc - 1)) {
300       i++;
301       continue;
302     }
303     if (arg == "metal") {
304       for (; i < argc; i++) {
305         if (std::string(argv[i]) == "-c") {
306           break;
307         }
308       }
309       continue;
310     }
311 
312     args.emplace_back(arg);
313   }
314 
315   bool success = false;
316 
317   if (run_server) {
318     success = RunServer(port);
319   } else {
320     std::string address;
321     std::string file;
322     switch (args.size()) {
323       case 1:
324         if (auto* addr = getenv("TINT_REMOTE_COMPILE_ADDRESS")) {
325           address = addr;
326         }
327         file = args[0];
328         break;
329       case 2:
330         address = args[0];
331         file = args[1];
332         break;
333     }
334     if (address.empty() || file.empty()) {
335       ShowUsage();
336     }
337     success = RunClient(address, port, file);
338   }
339 
340   if (!success) {
341     exit(1);
342   }
343 
344   return 0;
345 }
346 
RunServer(std::string port)347 bool RunServer(std::string port) {
348   auto server_socket = Socket::Listen("", port.c_str());
349   if (!server_socket) {
350     printf("Failed to listen on port %s\n", port.c_str());
351     return false;
352   }
353   printf("Listening on port %s...\n", port.c_str());
354   while (auto conn = server_socket->Accept()) {
355     std::thread([=] {
356       DEBUG("Client connected...");
357       Stream stream{conn.get()};
358 
359       {
360         ConnectionRequest req;
361         stream >> req;
362         if (!stream.error.empty()) {
363           printf("%s\n", stream.error.c_str());
364           return;
365         }
366         ConnectionResponse resp;
367         if (req.protocol_version != kProtocolVersion) {
368           DEBUG("Protocol version mismatch");
369           resp.error = "Protocol version mismatch";
370           stream << resp;
371           return;
372         }
373         stream << resp;
374       }
375       DEBUG("Connection established");
376       {
377         CompileRequest req;
378         stream >> req;
379         if (!stream.error.empty()) {
380           printf("%s\n", stream.error.c_str());
381           return;
382         }
383 #ifdef TINT_ENABLE_MSL_COMPILATION_USING_METAL_API
384         if (req.language == SourceLanguage::MSL) {
385           auto result = CompileMslUsingMetalAPI(req.source);
386           CompileResponse resp;
387           if (!result.success) {
388             resp.error = result.output;
389           }
390           stream << resp;
391           return;
392         }
393 #endif
394         CompileResponse resp;
395         resp.error = "server cannot compile this type of shader";
396         stream << resp;
397       }
398     }).detach();
399   }
400   return true;
401 }
402 
RunClient(std::string address,std::string port,std::string file)403 bool RunClient(std::string address, std::string port, std::string file) {
404   // Read the file
405   std::ifstream input(file, std::ios::binary);
406   if (!input) {
407     printf("Couldn't open '%s'\n", file.c_str());
408     return false;
409   }
410   std::string source((std::istreambuf_iterator<char>(input)),
411                      std::istreambuf_iterator<char>());
412 
413   constexpr const int timeout_ms = 10000;
414   DEBUG("Connecting to %s:%s...", address.c_str(), port.c_str());
415   auto conn = Socket::Connect(address.c_str(), port.c_str(), timeout_ms);
416   if (!conn) {
417     printf("Connection failed\n");
418     return false;
419   }
420 
421   Stream stream{conn.get()};
422 
423   DEBUG("Sending connection request...");
424   auto conn_resp = Send(stream, ConnectionRequest{kProtocolVersion});
425   if (!stream.error.empty()) {
426     printf("%s\n", stream.error.c_str());
427     return false;
428   }
429   if (!conn_resp.error.empty()) {
430     printf("%s\n", conn_resp.error.c_str());
431     return false;
432   }
433   DEBUG("Connection established. Requesting compile...");
434   auto comp_resp = Send(stream, CompileRequest{SourceLanguage::MSL, source});
435   if (!stream.error.empty()) {
436     printf("%s\n", stream.error.c_str());
437     return false;
438   }
439   if (!comp_resp.error.empty()) {
440     printf("%s\n", comp_resp.error.c_str());
441     return false;
442   }
443   DEBUG("Compilation successful");
444   return true;
445 }
446