• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "utils/WireHelper.h"
16 
17 #include "common/Assert.h"
18 #include "common/Log.h"
19 #include "common/SystemUtils.h"
20 #include "dawn/dawn_proc.h"
21 #include "dawn_native/DawnNative.h"
22 #include "dawn_wire/WireClient.h"
23 #include "dawn_wire/WireServer.h"
24 #include "utils/TerribleCommandBuffer.h"
25 
26 #include <algorithm>
27 #include <cstring>
28 #include <fstream>
29 #include <iomanip>
30 #include <set>
31 #include <sstream>
32 
33 namespace utils {
34 
35     namespace {
36 
37         class WireServerTraceLayer : public dawn_wire::CommandHandler {
38           public:
WireServerTraceLayer(const char * dir,dawn_wire::CommandHandler * handler)39             WireServerTraceLayer(const char* dir, dawn_wire::CommandHandler* handler)
40                 : dawn_wire::CommandHandler(), mDir(dir), mHandler(handler) {
41                 const char* sep = GetPathSeparator();
42                 if (mDir.size() > 0 && mDir.back() != *sep) {
43                     mDir += sep;
44                 }
45             }
46 
BeginWireTrace(const char * name)47             void BeginWireTrace(const char* name) {
48                 std::string filename = name;
49                 // Replace slashes in gtest names with underscores so everything is in one
50                 // directory.
51                 std::replace(filename.begin(), filename.end(), '/', '_');
52                 std::replace(filename.begin(), filename.end(), '\\', '_');
53 
54                 // Prepend the filename with the directory.
55                 filename = mDir + filename;
56 
57                 ASSERT(!mFile.is_open());
58                 mFile.open(filename,
59                            std::ios_base::out | std::ios_base::binary | std::ios_base::trunc);
60 
61                 // Write the initial 8 bytes. This means the fuzzer should never inject an
62                 // error.
63                 const uint64_t injectedErrorIndex = 0xFFFF'FFFF'FFFF'FFFF;
64                 mFile.write(reinterpret_cast<const char*>(&injectedErrorIndex),
65                             sizeof(injectedErrorIndex));
66             }
67 
HandleCommands(const volatile char * commands,size_t size)68             const volatile char* HandleCommands(const volatile char* commands,
69                                                 size_t size) override {
70                 if (mFile.is_open()) {
71                     mFile.write(const_cast<const char*>(commands), size);
72                 }
73                 return mHandler->HandleCommands(commands, size);
74             }
75 
76           private:
77             std::string mDir;
78             dawn_wire::CommandHandler* mHandler;
79             std::ofstream mFile;
80         };
81 
82         class WireHelperDirect : public WireHelper {
83           public:
WireHelperDirect()84             WireHelperDirect() {
85                 dawnProcSetProcs(&dawn_native::GetProcs());
86             }
87 
RegisterDevice(WGPUDevice backendDevice)88             std::pair<wgpu::Device, WGPUDevice> RegisterDevice(WGPUDevice backendDevice) override {
89                 ASSERT(backendDevice != nullptr);
90                 return std::make_pair(wgpu::Device::Acquire(backendDevice), backendDevice);
91             }
92 
BeginWireTrace(const char * name)93             void BeginWireTrace(const char* name) override {
94             }
95 
FlushClient()96             bool FlushClient() override {
97                 return true;
98             }
99 
FlushServer()100             bool FlushServer() override {
101                 return true;
102             }
103         };
104 
105         class WireHelperProxy : public WireHelper {
106           public:
WireHelperProxy(const char * wireTraceDir)107             explicit WireHelperProxy(const char* wireTraceDir) {
108                 mC2sBuf = std::make_unique<utils::TerribleCommandBuffer>();
109                 mS2cBuf = std::make_unique<utils::TerribleCommandBuffer>();
110 
111                 dawn_wire::WireServerDescriptor serverDesc = {};
112                 serverDesc.procs = &dawn_native::GetProcs();
113                 serverDesc.serializer = mS2cBuf.get();
114 
115                 mWireServer.reset(new dawn_wire::WireServer(serverDesc));
116                 mC2sBuf->SetHandler(mWireServer.get());
117 
118                 if (wireTraceDir != nullptr && strlen(wireTraceDir) > 0) {
119                     mWireServerTraceLayer.reset(
120                         new WireServerTraceLayer(wireTraceDir, mWireServer.get()));
121                     mC2sBuf->SetHandler(mWireServerTraceLayer.get());
122                 }
123 
124                 dawn_wire::WireClientDescriptor clientDesc = {};
125                 clientDesc.serializer = mC2sBuf.get();
126 
127                 mWireClient.reset(new dawn_wire::WireClient(clientDesc));
128                 mS2cBuf->SetHandler(mWireClient.get());
129                 dawnProcSetProcs(&dawn_wire::client::GetProcs());
130             }
131 
RegisterDevice(WGPUDevice backendDevice)132             std::pair<wgpu::Device, WGPUDevice> RegisterDevice(WGPUDevice backendDevice) override {
133                 ASSERT(backendDevice != nullptr);
134 
135                 auto reservation = mWireClient->ReserveDevice();
136                 mWireServer->InjectDevice(backendDevice, reservation.id, reservation.generation);
137                 dawn_native::GetProcs().deviceRelease(backendDevice);
138 
139                 return std::make_pair(wgpu::Device::Acquire(reservation.device), backendDevice);
140             }
141 
BeginWireTrace(const char * name)142             void BeginWireTrace(const char* name) override {
143                 if (mWireServerTraceLayer) {
144                     return mWireServerTraceLayer->BeginWireTrace(name);
145                 }
146             }
147 
FlushClient()148             bool FlushClient() override {
149                 return mC2sBuf->Flush();
150             }
151 
FlushServer()152             bool FlushServer() override {
153                 return mS2cBuf->Flush();
154             }
155 
156           private:
157             std::unique_ptr<utils::TerribleCommandBuffer> mC2sBuf;
158             std::unique_ptr<utils::TerribleCommandBuffer> mS2cBuf;
159             std::unique_ptr<WireServerTraceLayer> mWireServerTraceLayer;
160             std::unique_ptr<dawn_wire::WireServer> mWireServer;
161             std::unique_ptr<dawn_wire::WireClient> mWireClient;
162         };
163 
164     }  // anonymous namespace
165 
CreateWireHelper(bool useWire,const char * wireTraceDir)166     std::unique_ptr<WireHelper> CreateWireHelper(bool useWire, const char* wireTraceDir) {
167         if (useWire) {
168             return std::unique_ptr<WireHelper>(new WireHelperProxy(wireTraceDir));
169         } else {
170             return std::unique_ptr<WireHelper>(new WireHelperDirect());
171         }
172     }
173 
~WireHelper()174     WireHelper::~WireHelper() {
175         dawnProcSetProcs(nullptr);
176     }
177 
178 }  // namespace utils
179