1 /*
2 *
3 * Copyright 2018 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 #include <condition_variable>
20
21 #include <grpcpp/channel.h>
22
23 #include "src/proto/grpc/testing/echo.grpc.pb.h"
24 #include "test/cpp/util/string_ref_helper.h"
25
26 #include <gtest/gtest.h>
27
28 namespace grpc {
29 namespace testing {
30 /* This interceptor does nothing. Just keeps a global count on the number of
31 * times it was invoked. */
32 class DummyInterceptor : public experimental::Interceptor {
33 public:
DummyInterceptor()34 DummyInterceptor() {}
35
Intercept(experimental::InterceptorBatchMethods * methods)36 virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
37 if (methods->QueryInterceptionHookPoint(
38 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
39 num_times_run_++;
40 } else if (methods->QueryInterceptionHookPoint(
41 experimental::InterceptionHookPoints::
42 POST_RECV_INITIAL_METADATA)) {
43 num_times_run_reverse_++;
44 } else if (methods->QueryInterceptionHookPoint(
45 experimental::InterceptionHookPoints::PRE_SEND_CANCEL)) {
46 num_times_cancel_++;
47 }
48 methods->Proceed();
49 }
50
Reset()51 static void Reset() {
52 num_times_run_.store(0);
53 num_times_run_reverse_.store(0);
54 num_times_cancel_.store(0);
55 }
56
GetNumTimesRun()57 static int GetNumTimesRun() {
58 EXPECT_EQ(num_times_run_.load(), num_times_run_reverse_.load());
59 return num_times_run_.load();
60 }
61
GetNumTimesCancel()62 static int GetNumTimesCancel() { return num_times_cancel_.load(); }
63
64 private:
65 static std::atomic<int> num_times_run_;
66 static std::atomic<int> num_times_run_reverse_;
67 static std::atomic<int> num_times_cancel_;
68 };
69
70 class DummyInterceptorFactory
71 : public experimental::ClientInterceptorFactoryInterface,
72 public experimental::ServerInterceptorFactoryInterface {
73 public:
CreateClientInterceptor(experimental::ClientRpcInfo *)74 virtual experimental::Interceptor* CreateClientInterceptor(
75 experimental::ClientRpcInfo* /*info*/) override {
76 return new DummyInterceptor();
77 }
78
CreateServerInterceptor(experimental::ServerRpcInfo *)79 virtual experimental::Interceptor* CreateServerInterceptor(
80 experimental::ServerRpcInfo* /*info*/) override {
81 return new DummyInterceptor();
82 }
83 };
84
85 /* This interceptor factory returns nullptr on interceptor creation */
86 class NullInterceptorFactory
87 : public experimental::ClientInterceptorFactoryInterface,
88 public experimental::ServerInterceptorFactoryInterface {
89 public:
CreateClientInterceptor(experimental::ClientRpcInfo *)90 virtual experimental::Interceptor* CreateClientInterceptor(
91 experimental::ClientRpcInfo* /*info*/) override {
92 return nullptr;
93 }
94
CreateServerInterceptor(experimental::ServerRpcInfo *)95 virtual experimental::Interceptor* CreateServerInterceptor(
96 experimental::ServerRpcInfo* /*info*/) override {
97 return nullptr;
98 }
99 };
100
101 class EchoTestServiceStreamingImpl : public EchoTestService::Service {
102 public:
~EchoTestServiceStreamingImpl()103 ~EchoTestServiceStreamingImpl() override {}
104
Echo(ServerContext * context,const EchoRequest * request,EchoResponse * response)105 Status Echo(ServerContext* context, const EchoRequest* request,
106 EchoResponse* response) override {
107 auto client_metadata = context->client_metadata();
108 for (const auto& pair : client_metadata) {
109 context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
110 }
111 response->set_message(request->message());
112 return Status::OK;
113 }
114
BidiStream(ServerContext * context,grpc::ServerReaderWriter<EchoResponse,EchoRequest> * stream)115 Status BidiStream(
116 ServerContext* context,
117 grpc::ServerReaderWriter<EchoResponse, EchoRequest>* stream) override {
118 EchoRequest req;
119 EchoResponse resp;
120 auto client_metadata = context->client_metadata();
121 for (const auto& pair : client_metadata) {
122 context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
123 }
124
125 while (stream->Read(&req)) {
126 resp.set_message(req.message());
127 EXPECT_TRUE(stream->Write(resp, grpc::WriteOptions()));
128 }
129 return Status::OK;
130 }
131
RequestStream(ServerContext * context,ServerReader<EchoRequest> * reader,EchoResponse * resp)132 Status RequestStream(ServerContext* context,
133 ServerReader<EchoRequest>* reader,
134 EchoResponse* resp) override {
135 auto client_metadata = context->client_metadata();
136 for (const auto& pair : client_metadata) {
137 context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
138 }
139
140 EchoRequest req;
141 string response_str = "";
142 while (reader->Read(&req)) {
143 response_str += req.message();
144 }
145 resp->set_message(response_str);
146 return Status::OK;
147 }
148
ResponseStream(ServerContext * context,const EchoRequest * req,ServerWriter<EchoResponse> * writer)149 Status ResponseStream(ServerContext* context, const EchoRequest* req,
150 ServerWriter<EchoResponse>* writer) override {
151 auto client_metadata = context->client_metadata();
152 for (const auto& pair : client_metadata) {
153 context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
154 }
155
156 EchoResponse resp;
157 resp.set_message(req->message());
158 for (int i = 0; i < 10; i++) {
159 EXPECT_TRUE(writer->Write(resp));
160 }
161 return Status::OK;
162 }
163 };
164
165 constexpr int kNumStreamingMessages = 10;
166
167 void MakeCall(const std::shared_ptr<Channel>& channel);
168
169 void MakeClientStreamingCall(const std::shared_ptr<Channel>& channel);
170
171 void MakeServerStreamingCall(const std::shared_ptr<Channel>& channel);
172
173 void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel);
174
175 void MakeAsyncCQCall(const std::shared_ptr<Channel>& channel);
176
177 void MakeAsyncCQClientStreamingCall(const std::shared_ptr<Channel>& channel);
178
179 void MakeAsyncCQServerStreamingCall(const std::shared_ptr<Channel>& channel);
180
181 void MakeAsyncCQBidiStreamingCall(const std::shared_ptr<Channel>& channel);
182
183 void MakeCallbackCall(const std::shared_ptr<Channel>& channel);
184
185 bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
186 const string& key, const string& value);
187
188 bool CheckMetadata(const std::multimap<std::string, std::string>& map,
189 const string& key, const string& value);
190
191 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
192 CreateDummyClientInterceptors();
193
tag(int i)194 inline void* tag(int i) { return (void*)static_cast<intptr_t>(i); }
detag(void * p)195 inline int detag(void* p) {
196 return static_cast<int>(reinterpret_cast<intptr_t>(p));
197 }
198
199 class Verifier {
200 public:
Verifier()201 Verifier() : lambda_run_(false) {}
202 // Expect sets the expected ok value for a specific tag
Expect(int i,bool expect_ok)203 Verifier& Expect(int i, bool expect_ok) {
204 return ExpectUnless(i, expect_ok, false);
205 }
206 // ExpectUnless sets the expected ok value for a specific tag
207 // unless the tag was already marked seen (as a result of ExpectMaybe)
ExpectUnless(int i,bool expect_ok,bool seen)208 Verifier& ExpectUnless(int i, bool expect_ok, bool seen) {
209 if (!seen) {
210 expectations_[tag(i)] = expect_ok;
211 }
212 return *this;
213 }
214 // ExpectMaybe sets the expected ok value for a specific tag, but does not
215 // require it to appear
216 // If it does, sets *seen to true
ExpectMaybe(int i,bool expect_ok,bool * seen)217 Verifier& ExpectMaybe(int i, bool expect_ok, bool* seen) {
218 if (!*seen) {
219 maybe_expectations_[tag(i)] = MaybeExpect{expect_ok, seen};
220 }
221 return *this;
222 }
223
224 // Next waits for 1 async tag to complete, checks its
225 // expectations, and returns the tag
Next(CompletionQueue * cq,bool ignore_ok)226 int Next(CompletionQueue* cq, bool ignore_ok) {
227 bool ok;
228 void* got_tag;
229 EXPECT_TRUE(cq->Next(&got_tag, &ok));
230 GotTag(got_tag, ok, ignore_ok);
231 return detag(got_tag);
232 }
233
234 template <typename T>
DoOnceThenAsyncNext(CompletionQueue * cq,void ** got_tag,bool * ok,T deadline,std::function<void (void)> lambda)235 CompletionQueue::NextStatus DoOnceThenAsyncNext(
236 CompletionQueue* cq, void** got_tag, bool* ok, T deadline,
237 std::function<void(void)> lambda) {
238 if (lambda_run_) {
239 return cq->AsyncNext(got_tag, ok, deadline);
240 } else {
241 lambda_run_ = true;
242 return cq->DoThenAsyncNext(lambda, got_tag, ok, deadline);
243 }
244 }
245
246 // Verify keeps calling Next until all currently set
247 // expected tags are complete
Verify(CompletionQueue * cq)248 void Verify(CompletionQueue* cq) { Verify(cq, false); }
249
250 // This version of Verify allows optionally ignoring the
251 // outcome of the expectation
Verify(CompletionQueue * cq,bool ignore_ok)252 void Verify(CompletionQueue* cq, bool ignore_ok) {
253 GPR_ASSERT(!expectations_.empty() || !maybe_expectations_.empty());
254 while (!expectations_.empty()) {
255 Next(cq, ignore_ok);
256 }
257 }
258
259 // This version of Verify stops after a certain deadline, and uses the
260 // DoThenAsyncNext API
261 // to call the lambda
Verify(CompletionQueue * cq,std::chrono::system_clock::time_point deadline,const std::function<void (void)> & lambda)262 void Verify(CompletionQueue* cq,
263 std::chrono::system_clock::time_point deadline,
264 const std::function<void(void)>& lambda) {
265 if (expectations_.empty()) {
266 bool ok;
267 void* got_tag;
268 EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
269 CompletionQueue::TIMEOUT);
270 } else {
271 while (!expectations_.empty()) {
272 bool ok;
273 void* got_tag;
274 EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
275 CompletionQueue::GOT_EVENT);
276 GotTag(got_tag, ok, false);
277 }
278 }
279 }
280
281 private:
GotTag(void * got_tag,bool ok,bool ignore_ok)282 void GotTag(void* got_tag, bool ok, bool ignore_ok) {
283 auto it = expectations_.find(got_tag);
284 if (it != expectations_.end()) {
285 if (!ignore_ok) {
286 EXPECT_EQ(it->second, ok);
287 }
288 expectations_.erase(it);
289 } else {
290 auto it2 = maybe_expectations_.find(got_tag);
291 if (it2 != maybe_expectations_.end()) {
292 if (it2->second.seen != nullptr) {
293 EXPECT_FALSE(*it2->second.seen);
294 *it2->second.seen = true;
295 }
296 if (!ignore_ok) {
297 EXPECT_EQ(it2->second.ok, ok);
298 }
299 } else {
300 gpr_log(GPR_ERROR, "Unexpected tag: %p", got_tag);
301 abort();
302 }
303 }
304 }
305
306 struct MaybeExpect {
307 bool ok;
308 bool* seen;
309 };
310
311 std::map<void*, bool> expectations_;
312 std::map<void*, MaybeExpect> maybe_expectations_;
313 bool lambda_run_;
314 };
315
316 } // namespace testing
317 } // namespace grpc
318