1 // Copyright 2019 The Dawn Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "dawn/mock_webgpu.h" 16 #include "gtest/gtest.h" 17 18 #include <memory> 19 20 // Definition of a "Lambda predicate matcher" for GMock to allow checking deep structures 21 // are passed correctly by the wire. 22 23 // Helper templates to extract the argument type of a lambda. 24 template <typename T> 25 struct MatcherMethodArgument; 26 27 template <typename Lambda, typename Arg> 28 struct MatcherMethodArgument<bool (Lambda::*)(Arg) const> { 29 using Type = Arg; 30 }; 31 32 template <typename Lambda> 33 using MatcherLambdaArgument = typename MatcherMethodArgument<decltype(&Lambda::operator())>::Type; 34 35 // The matcher itself, unfortunately it isn't able to return detailed information like other 36 // matchers do. 37 template <typename Lambda, typename Arg> 38 class LambdaMatcherImpl : public testing::MatcherInterface<Arg> { 39 public: 40 explicit LambdaMatcherImpl(Lambda lambda) : mLambda(lambda) { 41 } 42 43 void DescribeTo(std::ostream* os) const override { 44 *os << "with a custom matcher"; 45 } 46 47 bool MatchAndExplain(Arg value, testing::MatchResultListener* listener) const override { 48 if (!mLambda(value)) { 49 *listener << "which doesn't satisfy the custom predicate"; 50 return false; 51 } 52 return true; 53 } 54 55 private: 56 Lambda mLambda; 57 }; 58 59 // Use the MatchesLambda as follows: 60 // 61 // EXPECT_CALL(foo, Bar(MatchesLambda([](ArgType arg) -> bool { 62 // return CheckPredicateOnArg(arg); 63 // }))); 64 template <typename Lambda> 65 inline testing::Matcher<MatcherLambdaArgument<Lambda>> MatchesLambda(Lambda lambda) { 66 return MakeMatcher(new LambdaMatcherImpl<Lambda, MatcherLambdaArgument<Lambda>>(lambda)); 67 } 68 69 class StringMessageMatcher : public testing::MatcherInterface<const char*> { 70 public: 71 explicit StringMessageMatcher() { 72 } 73 74 bool MatchAndExplain(const char* message, 75 testing::MatchResultListener* listener) const override { 76 if (message == nullptr) { 77 *listener << "missing error message"; 78 return false; 79 } 80 if (std::strlen(message) <= 1) { 81 *listener << "message is truncated"; 82 return false; 83 } 84 return true; 85 } 86 87 void DescribeTo(std::ostream* os) const override { 88 *os << "valid error message"; 89 } 90 91 void DescribeNegationTo(std::ostream* os) const override { 92 *os << "invalid error message"; 93 } 94 }; 95 96 inline testing::Matcher<const char*> ValidStringMessage() { 97 return MakeMatcher(new StringMessageMatcher()); 98 } 99 100 namespace dawn_wire { 101 class WireClient; 102 class WireServer; 103 namespace client { 104 class MemoryTransferService; 105 } // namespace client 106 namespace server { 107 class MemoryTransferService; 108 } // namespace server 109 } // namespace dawn_wire 110 111 namespace utils { 112 class TerribleCommandBuffer; 113 } 114 115 class WireTest : public testing::Test { 116 protected: 117 WireTest(); 118 ~WireTest() override; 119 120 void SetUp() override; 121 void TearDown() override; 122 123 void FlushClient(bool success = true); 124 void FlushServer(bool success = true); 125 126 void DefaultApiDeviceWasReleased(); 127 128 testing::StrictMock<MockProcTable> api; 129 WGPUDevice apiDevice; 130 WGPUQueue apiQueue; 131 WGPUDevice device; 132 WGPUQueue queue; 133 134 dawn_wire::WireServer* GetWireServer(); 135 dawn_wire::WireClient* GetWireClient(); 136 137 void DeleteServer(); 138 void DeleteClient(); 139 140 private: 141 void SetupIgnoredCallExpectations(); 142 143 virtual dawn_wire::client::MemoryTransferService* GetClientMemoryTransferService(); 144 virtual dawn_wire::server::MemoryTransferService* GetServerMemoryTransferService(); 145 146 std::unique_ptr<dawn_wire::WireServer> mWireServer; 147 std::unique_ptr<dawn_wire::WireClient> mWireClient; 148 std::unique_ptr<utils::TerribleCommandBuffer> mS2cBuf; 149 std::unique_ptr<utils::TerribleCommandBuffer> mC2sBuf; 150 }; 151