1 // Copyright 2021 The Dawn Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "utils/WireHelper.h" 16 17 #include "common/Assert.h" 18 #include "common/Log.h" 19 #include "common/SystemUtils.h" 20 #include "dawn/dawn_proc.h" 21 #include "dawn_native/DawnNative.h" 22 #include "dawn_wire/WireClient.h" 23 #include "dawn_wire/WireServer.h" 24 #include "utils/TerribleCommandBuffer.h" 25 26 #include <algorithm> 27 #include <cstring> 28 #include <fstream> 29 #include <iomanip> 30 #include <set> 31 #include <sstream> 32 33 namespace utils { 34 35 namespace { 36 37 class WireServerTraceLayer : public dawn_wire::CommandHandler { 38 public: WireServerTraceLayer(const char * dir,dawn_wire::CommandHandler * handler)39 WireServerTraceLayer(const char* dir, dawn_wire::CommandHandler* handler) 40 : dawn_wire::CommandHandler(), mDir(dir), mHandler(handler) { 41 const char* sep = GetPathSeparator(); 42 if (mDir.size() > 0 && mDir.back() != *sep) { 43 mDir += sep; 44 } 45 } 46 BeginWireTrace(const char * name)47 void BeginWireTrace(const char* name) { 48 std::string filename = name; 49 // Replace slashes in gtest names with underscores so everything is in one 50 // directory. 51 std::replace(filename.begin(), filename.end(), '/', '_'); 52 std::replace(filename.begin(), filename.end(), '\\', '_'); 53 54 // Prepend the filename with the directory. 55 filename = mDir + filename; 56 57 ASSERT(!mFile.is_open()); 58 mFile.open(filename, 59 std::ios_base::out | std::ios_base::binary | std::ios_base::trunc); 60 61 // Write the initial 8 bytes. This means the fuzzer should never inject an 62 // error. 63 const uint64_t injectedErrorIndex = 0xFFFF'FFFF'FFFF'FFFF; 64 mFile.write(reinterpret_cast<const char*>(&injectedErrorIndex), 65 sizeof(injectedErrorIndex)); 66 } 67 HandleCommands(const volatile char * commands,size_t size)68 const volatile char* HandleCommands(const volatile char* commands, 69 size_t size) override { 70 if (mFile.is_open()) { 71 mFile.write(const_cast<const char*>(commands), size); 72 } 73 return mHandler->HandleCommands(commands, size); 74 } 75 76 private: 77 std::string mDir; 78 dawn_wire::CommandHandler* mHandler; 79 std::ofstream mFile; 80 }; 81 82 class WireHelperDirect : public WireHelper { 83 public: WireHelperDirect()84 WireHelperDirect() { 85 dawnProcSetProcs(&dawn_native::GetProcs()); 86 } 87 RegisterDevice(WGPUDevice backendDevice)88 std::pair<wgpu::Device, WGPUDevice> RegisterDevice(WGPUDevice backendDevice) override { 89 ASSERT(backendDevice != nullptr); 90 return std::make_pair(wgpu::Device::Acquire(backendDevice), backendDevice); 91 } 92 BeginWireTrace(const char * name)93 void BeginWireTrace(const char* name) override { 94 } 95 FlushClient()96 bool FlushClient() override { 97 return true; 98 } 99 FlushServer()100 bool FlushServer() override { 101 return true; 102 } 103 }; 104 105 class WireHelperProxy : public WireHelper { 106 public: WireHelperProxy(const char * wireTraceDir)107 explicit WireHelperProxy(const char* wireTraceDir) { 108 mC2sBuf = std::make_unique<utils::TerribleCommandBuffer>(); 109 mS2cBuf = std::make_unique<utils::TerribleCommandBuffer>(); 110 111 dawn_wire::WireServerDescriptor serverDesc = {}; 112 serverDesc.procs = &dawn_native::GetProcs(); 113 serverDesc.serializer = mS2cBuf.get(); 114 115 mWireServer.reset(new dawn_wire::WireServer(serverDesc)); 116 mC2sBuf->SetHandler(mWireServer.get()); 117 118 if (wireTraceDir != nullptr && strlen(wireTraceDir) > 0) { 119 mWireServerTraceLayer.reset( 120 new WireServerTraceLayer(wireTraceDir, mWireServer.get())); 121 mC2sBuf->SetHandler(mWireServerTraceLayer.get()); 122 } 123 124 dawn_wire::WireClientDescriptor clientDesc = {}; 125 clientDesc.serializer = mC2sBuf.get(); 126 127 mWireClient.reset(new dawn_wire::WireClient(clientDesc)); 128 mS2cBuf->SetHandler(mWireClient.get()); 129 dawnProcSetProcs(&dawn_wire::client::GetProcs()); 130 } 131 RegisterDevice(WGPUDevice backendDevice)132 std::pair<wgpu::Device, WGPUDevice> RegisterDevice(WGPUDevice backendDevice) override { 133 ASSERT(backendDevice != nullptr); 134 135 auto reservation = mWireClient->ReserveDevice(); 136 mWireServer->InjectDevice(backendDevice, reservation.id, reservation.generation); 137 dawn_native::GetProcs().deviceRelease(backendDevice); 138 139 return std::make_pair(wgpu::Device::Acquire(reservation.device), backendDevice); 140 } 141 BeginWireTrace(const char * name)142 void BeginWireTrace(const char* name) override { 143 if (mWireServerTraceLayer) { 144 return mWireServerTraceLayer->BeginWireTrace(name); 145 } 146 } 147 FlushClient()148 bool FlushClient() override { 149 return mC2sBuf->Flush(); 150 } 151 FlushServer()152 bool FlushServer() override { 153 return mS2cBuf->Flush(); 154 } 155 156 private: 157 std::unique_ptr<utils::TerribleCommandBuffer> mC2sBuf; 158 std::unique_ptr<utils::TerribleCommandBuffer> mS2cBuf; 159 std::unique_ptr<WireServerTraceLayer> mWireServerTraceLayer; 160 std::unique_ptr<dawn_wire::WireServer> mWireServer; 161 std::unique_ptr<dawn_wire::WireClient> mWireClient; 162 }; 163 164 } // anonymous namespace 165 CreateWireHelper(bool useWire,const char * wireTraceDir)166 std::unique_ptr<WireHelper> CreateWireHelper(bool useWire, const char* wireTraceDir) { 167 if (useWire) { 168 return std::unique_ptr<WireHelper>(new WireHelperProxy(wireTraceDir)); 169 } else { 170 return std::unique_ptr<WireHelper>(new WireHelperDirect()); 171 } 172 } 173 ~WireHelper()174 WireHelper::~WireHelper() { 175 dawnProcSetProcs(nullptr); 176 } 177 178 } // namespace utils 179