1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "google_apis/gcm/engine/connection_handler_impl.h"
6
7 #include "base/message_loop/message_loop.h"
8 #include "google/protobuf/io/coded_stream.h"
9 #include "google_apis/gcm/base/mcs_util.h"
10 #include "google_apis/gcm/base/socket_stream.h"
11 #include "google_apis/gcm/protocol/mcs.pb.h"
12 #include "net/base/net_errors.h"
13 #include "net/socket/stream_socket.h"
14
15 using namespace google::protobuf::io;
16
17 namespace gcm {
18
19 namespace {
20
21 // # of bytes a MCS version packet consumes.
22 const int kVersionPacketLen = 1;
23 // # of bytes a tag packet consumes.
24 const int kTagPacketLen = 1;
25 // Max # of bytes a length packet consumes.
26 const int kSizePacketLenMin = 1;
27 const int kSizePacketLenMax = 2;
28
29 // The current MCS protocol version.
30 const int kMCSVersion = 41;
31
32 } // namespace
33
ConnectionHandlerImpl(base::TimeDelta read_timeout,const ProtoReceivedCallback & read_callback,const ProtoSentCallback & write_callback,const ConnectionChangedCallback & connection_callback)34 ConnectionHandlerImpl::ConnectionHandlerImpl(
35 base::TimeDelta read_timeout,
36 const ProtoReceivedCallback& read_callback,
37 const ProtoSentCallback& write_callback,
38 const ConnectionChangedCallback& connection_callback)
39 : read_timeout_(read_timeout),
40 socket_(NULL),
41 handshake_complete_(false),
42 message_tag_(0),
43 message_size_(0),
44 read_callback_(read_callback),
45 write_callback_(write_callback),
46 connection_callback_(connection_callback),
47 weak_ptr_factory_(this) {
48 }
49
~ConnectionHandlerImpl()50 ConnectionHandlerImpl::~ConnectionHandlerImpl() {
51 }
52
Init(const mcs_proto::LoginRequest & login_request,net::StreamSocket * socket)53 void ConnectionHandlerImpl::Init(
54 const mcs_proto::LoginRequest& login_request,
55 net::StreamSocket* socket) {
56 DCHECK(!read_callback_.is_null());
57 DCHECK(!write_callback_.is_null());
58 DCHECK(!connection_callback_.is_null());
59
60 // Invalidate any previously outstanding reads.
61 weak_ptr_factory_.InvalidateWeakPtrs();
62
63 handshake_complete_ = false;
64 message_tag_ = 0;
65 message_size_ = 0;
66 socket_ = socket;
67 input_stream_.reset(new SocketInputStream(socket_));
68 output_stream_.reset(new SocketOutputStream(socket_));
69
70 Login(login_request);
71 }
72
Reset()73 void ConnectionHandlerImpl::Reset() {
74 CloseConnection();
75 }
76
CanSendMessage() const77 bool ConnectionHandlerImpl::CanSendMessage() const {
78 return handshake_complete_ && output_stream_.get() &&
79 output_stream_->GetState() == SocketOutputStream::EMPTY;
80 }
81
SendMessage(const google::protobuf::MessageLite & message)82 void ConnectionHandlerImpl::SendMessage(
83 const google::protobuf::MessageLite& message) {
84 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
85 DCHECK(handshake_complete_);
86
87 {
88 CodedOutputStream coded_output_stream(output_stream_.get());
89 DVLOG(1) << "Writing proto of size " << message.ByteSize();
90 int tag = GetMCSProtoTag(message);
91 DCHECK_NE(tag, -1);
92 coded_output_stream.WriteRaw(&tag, 1);
93 coded_output_stream.WriteVarint32(message.ByteSize());
94 message.SerializeToCodedStream(&coded_output_stream);
95 }
96
97 if (output_stream_->Flush(
98 base::Bind(&ConnectionHandlerImpl::OnMessageSent,
99 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
100 OnMessageSent();
101 }
102 }
103
Login(const google::protobuf::MessageLite & login_request)104 void ConnectionHandlerImpl::Login(
105 const google::protobuf::MessageLite& login_request) {
106 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
107
108 const char version_byte[1] = {kMCSVersion};
109 const char login_request_tag[1] = {kLoginRequestTag};
110 {
111 CodedOutputStream coded_output_stream(output_stream_.get());
112 coded_output_stream.WriteRaw(version_byte, 1);
113 coded_output_stream.WriteRaw(login_request_tag, 1);
114 coded_output_stream.WriteVarint32(login_request.ByteSize());
115 login_request.SerializeToCodedStream(&coded_output_stream);
116 }
117
118 if (output_stream_->Flush(
119 base::Bind(&ConnectionHandlerImpl::OnMessageSent,
120 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
121 base::MessageLoop::current()->PostTask(
122 FROM_HERE,
123 base::Bind(&ConnectionHandlerImpl::OnMessageSent,
124 weak_ptr_factory_.GetWeakPtr()));
125 }
126
127 read_timeout_timer_.Start(FROM_HERE,
128 read_timeout_,
129 base::Bind(&ConnectionHandlerImpl::OnTimeout,
130 weak_ptr_factory_.GetWeakPtr()));
131 WaitForData(MCS_VERSION_TAG_AND_SIZE);
132 }
133
OnMessageSent()134 void ConnectionHandlerImpl::OnMessageSent() {
135 if (!output_stream_.get()) {
136 // The connection has already been closed. Just return.
137 DCHECK(!input_stream_.get());
138 DCHECK(!read_timeout_timer_.IsRunning());
139 return;
140 }
141
142 if (output_stream_->GetState() != SocketOutputStream::EMPTY) {
143 int last_error = output_stream_->last_error();
144 CloseConnection();
145 // If the socket stream had an error, plumb it up, else plumb up FAILED.
146 if (last_error == net::OK)
147 last_error = net::ERR_FAILED;
148 connection_callback_.Run(last_error);
149 return;
150 }
151
152 write_callback_.Run();
153 }
154
GetNextMessage()155 void ConnectionHandlerImpl::GetNextMessage() {
156 DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() ||
157 SocketInputStream::READY == input_stream_->GetState());
158 message_tag_ = 0;
159 message_size_ = 0;
160
161 WaitForData(MCS_TAG_AND_SIZE);
162 }
163
WaitForData(ProcessingState state)164 void ConnectionHandlerImpl::WaitForData(ProcessingState state) {
165 DVLOG(1) << "Waiting for MCS data: state == " << state;
166
167 if (!input_stream_) {
168 // The connection has already been closed. Just return.
169 DCHECK(!output_stream_.get());
170 DCHECK(!read_timeout_timer_.IsRunning());
171 return;
172 }
173
174 if (input_stream_->GetState() != SocketInputStream::EMPTY &&
175 input_stream_->GetState() != SocketInputStream::READY) {
176 // An error occurred.
177 int last_error = output_stream_->last_error();
178 CloseConnection();
179 // If the socket stream had an error, plumb it up, else plumb up FAILED.
180 if (last_error == net::OK)
181 last_error = net::ERR_FAILED;
182 connection_callback_.Run(last_error);
183 return;
184 }
185
186 // Used to determine whether a Socket::Read is necessary.
187 size_t min_bytes_needed = 0;
188 // Used to limit the size of the Socket::Read.
189 size_t max_bytes_needed = 0;
190
191 switch(state) {
192 case MCS_VERSION_TAG_AND_SIZE:
193 min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin;
194 max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax;
195 break;
196 case MCS_TAG_AND_SIZE:
197 min_bytes_needed = kTagPacketLen + kSizePacketLenMin;
198 max_bytes_needed = kTagPacketLen + kSizePacketLenMax;
199 break;
200 case MCS_FULL_SIZE:
201 // If in this state, the minimum size packet length must already have been
202 // insufficient, so set both to the max length.
203 min_bytes_needed = kSizePacketLenMax;
204 max_bytes_needed = kSizePacketLenMax;
205 break;
206 case MCS_PROTO_BYTES:
207 read_timeout_timer_.Reset();
208 // No variability in the message size, set both to the same.
209 min_bytes_needed = message_size_;
210 max_bytes_needed = message_size_;
211 break;
212 default:
213 NOTREACHED();
214 }
215 DCHECK_GE(max_bytes_needed, min_bytes_needed);
216
217 size_t unread_byte_count = input_stream_->UnreadByteCount();
218 if (min_bytes_needed > unread_byte_count &&
219 input_stream_->Refresh(
220 base::Bind(&ConnectionHandlerImpl::WaitForData,
221 weak_ptr_factory_.GetWeakPtr(),
222 state),
223 max_bytes_needed - unread_byte_count) == net::ERR_IO_PENDING) {
224 return;
225 }
226
227 // Check for refresh errors.
228 if (input_stream_->GetState() != SocketInputStream::READY) {
229 // An error occurred.
230 int last_error = input_stream_->last_error();
231 CloseConnection();
232 // If the socket stream had an error, plumb it up, else plumb up FAILED.
233 if (last_error == net::OK)
234 last_error = net::ERR_FAILED;
235 connection_callback_.Run(last_error);
236 return;
237 }
238
239 // Check whether read is complete, or needs to be continued (
240 // SocketInputStream::Refresh can finish without reading all the data).
241 if (input_stream_->UnreadByteCount() < min_bytes_needed) {
242 DVLOG(1) << "Socket read finished prematurely. Waiting for "
243 << min_bytes_needed - input_stream_->UnreadByteCount()
244 << " more bytes.";
245 base::MessageLoop::current()->PostTask(
246 FROM_HERE,
247 base::Bind(&ConnectionHandlerImpl::WaitForData,
248 weak_ptr_factory_.GetWeakPtr(),
249 MCS_PROTO_BYTES));
250 return;
251 }
252
253 // Received enough bytes, process them.
254 DVLOG(1) << "Processing MCS data: state == " << state;
255 switch(state) {
256 case MCS_VERSION_TAG_AND_SIZE:
257 OnGotVersion();
258 break;
259 case MCS_TAG_AND_SIZE:
260 OnGotMessageTag();
261 break;
262 case MCS_FULL_SIZE:
263 OnGotMessageSize();
264 break;
265 case MCS_PROTO_BYTES:
266 OnGotMessageBytes();
267 break;
268 default:
269 NOTREACHED();
270 }
271 }
272
OnGotVersion()273 void ConnectionHandlerImpl::OnGotVersion() {
274 uint8 version = 0;
275 {
276 CodedInputStream coded_input_stream(input_stream_.get());
277 coded_input_stream.ReadRaw(&version, 1);
278 }
279 // TODO(zea): remove this when the server is ready.
280 if (version < kMCSVersion && version != 38) {
281 LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version);
282 connection_callback_.Run(net::ERR_FAILED);
283 return;
284 }
285
286 input_stream_->RebuildBuffer();
287
288 // Process the LoginResponse message tag.
289 OnGotMessageTag();
290 }
291
OnGotMessageTag()292 void ConnectionHandlerImpl::OnGotMessageTag() {
293 if (input_stream_->GetState() != SocketInputStream::READY) {
294 LOG(ERROR) << "Failed to receive protobuf tag.";
295 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
296 return;
297 }
298
299 {
300 CodedInputStream coded_input_stream(input_stream_.get());
301 coded_input_stream.ReadRaw(&message_tag_, 1);
302 }
303
304 DVLOG(1) << "Received proto of type "
305 << static_cast<unsigned int>(message_tag_);
306
307 if (!read_timeout_timer_.IsRunning()) {
308 read_timeout_timer_.Start(FROM_HERE,
309 read_timeout_,
310 base::Bind(&ConnectionHandlerImpl::OnTimeout,
311 weak_ptr_factory_.GetWeakPtr()));
312 }
313 OnGotMessageSize();
314 }
315
OnGotMessageSize()316 void ConnectionHandlerImpl::OnGotMessageSize() {
317 if (input_stream_->GetState() != SocketInputStream::READY) {
318 LOG(ERROR) << "Failed to receive message size.";
319 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
320 return;
321 }
322
323 bool need_another_byte = false;
324 int prev_byte_count = input_stream_->ByteCount();
325 {
326 CodedInputStream coded_input_stream(input_stream_.get());
327 if (!coded_input_stream.ReadVarint32(&message_size_))
328 need_another_byte = true;
329 }
330
331 if (need_another_byte) {
332 DVLOG(1) << "Expecting another message size byte.";
333 if (prev_byte_count >= kSizePacketLenMax) {
334 // Already had enough bytes, something else went wrong.
335 LOG(ERROR) << "Failed to process message size.";
336 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
337 return;
338 }
339 // Back up by the amount read (should always be 1 byte).
340 int bytes_read = prev_byte_count - input_stream_->ByteCount();
341 DCHECK_EQ(bytes_read, 1);
342 input_stream_->BackUp(bytes_read);
343 WaitForData(MCS_FULL_SIZE);
344 return;
345 }
346
347 DVLOG(1) << "Proto size: " << message_size_;
348
349 if (message_size_ > 0)
350 WaitForData(MCS_PROTO_BYTES);
351 else
352 OnGotMessageBytes();
353 }
354
OnGotMessageBytes()355 void ConnectionHandlerImpl::OnGotMessageBytes() {
356 read_timeout_timer_.Stop();
357 scoped_ptr<google::protobuf::MessageLite> protobuf(
358 BuildProtobufFromTag(message_tag_));
359 // Messages with no content are valid; just use the default protobuf for
360 // that tag.
361 if (protobuf.get() && message_size_ == 0) {
362 base::MessageLoop::current()->PostTask(
363 FROM_HERE,
364 base::Bind(&ConnectionHandlerImpl::GetNextMessage,
365 weak_ptr_factory_.GetWeakPtr()));
366 read_callback_.Run(protobuf.Pass());
367 return;
368 }
369
370 if (!protobuf.get() ||
371 input_stream_->GetState() != SocketInputStream::READY) {
372 LOG(ERROR) << "Failed to extract protobuf bytes of type "
373 << static_cast<unsigned int>(message_tag_);
374 // Reset the connection.
375 connection_callback_.Run(net::ERR_FAILED);
376 return;
377 }
378
379 {
380 CodedInputStream coded_input_stream(input_stream_.get());
381 if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) {
382 LOG(ERROR) << "Unable to parse GCM message of type "
383 << static_cast<unsigned int>(message_tag_);
384 // Reset the connection.
385 connection_callback_.Run(net::ERR_FAILED);
386 return;
387 }
388 }
389
390 input_stream_->RebuildBuffer();
391 base::MessageLoop::current()->PostTask(
392 FROM_HERE,
393 base::Bind(&ConnectionHandlerImpl::GetNextMessage,
394 weak_ptr_factory_.GetWeakPtr()));
395 if (message_tag_ == kLoginResponseTag) {
396 if (handshake_complete_) {
397 LOG(ERROR) << "Unexpected login response.";
398 } else {
399 handshake_complete_ = true;
400 DVLOG(1) << "GCM Handshake complete.";
401 connection_callback_.Run(net::OK);
402 }
403 }
404 read_callback_.Run(protobuf.Pass());
405 }
406
OnTimeout()407 void ConnectionHandlerImpl::OnTimeout() {
408 LOG(ERROR) << "Timed out waiting for GCM Protocol buffer.";
409 CloseConnection();
410 connection_callback_.Run(net::ERR_TIMED_OUT);
411 }
412
CloseConnection()413 void ConnectionHandlerImpl::CloseConnection() {
414 DVLOG(1) << "Closing connection.";
415 read_timeout_timer_.Stop();
416 if (socket_)
417 socket_->Disconnect();
418 socket_ = NULL;
419 handshake_complete_ = false;
420 message_tag_ = 0;
421 message_size_ = 0;
422 input_stream_.reset();
423 output_stream_.reset();
424 weak_ptr_factory_.InvalidateWeakPtrs();
425 }
426
427 } // namespace gcm
428