1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "google_apis/gcm/engine/connection_handler_impl.h"
6
7 #include "base/message_loop/message_loop.h"
8 #include "google/protobuf/io/coded_stream.h"
9 #include "google_apis/gcm/base/mcs_util.h"
10 #include "google_apis/gcm/base/socket_stream.h"
11 #include "google_apis/gcm/protocol/mcs.pb.h"
12 #include "net/base/net_errors.h"
13 #include "net/socket/stream_socket.h"
14
15 using namespace google::protobuf::io;
16
17 namespace gcm {
18
19 namespace {
20
21 // # of bytes a MCS version packet consumes.
22 const int kVersionPacketLen = 1;
23 // # of bytes a tag packet consumes.
24 const int kTagPacketLen = 1;
25 // Max # of bytes a length packet consumes.
26 const int kSizePacketLenMin = 1;
27 const int kSizePacketLenMax = 2;
28
29 // The current MCS protocol version.
30 // TODO(zea): bump to 41 once the server supports it.
31 const int kMCSVersion = 38;
32
33 } // namespace
34
ConnectionHandlerImpl(base::TimeDelta read_timeout,const ProtoReceivedCallback & read_callback,const ProtoSentCallback & write_callback,const ConnectionChangedCallback & connection_callback)35 ConnectionHandlerImpl::ConnectionHandlerImpl(
36 base::TimeDelta read_timeout,
37 const ProtoReceivedCallback& read_callback,
38 const ProtoSentCallback& write_callback,
39 const ConnectionChangedCallback& connection_callback)
40 : read_timeout_(read_timeout),
41 handshake_complete_(false),
42 message_tag_(0),
43 message_size_(0),
44 read_callback_(read_callback),
45 write_callback_(write_callback),
46 connection_callback_(connection_callback),
47 weak_ptr_factory_(this) {
48 }
49
~ConnectionHandlerImpl()50 ConnectionHandlerImpl::~ConnectionHandlerImpl() {
51 }
52
Init(const mcs_proto::LoginRequest & login_request,scoped_ptr<net::StreamSocket> socket)53 void ConnectionHandlerImpl::Init(
54 const mcs_proto::LoginRequest& login_request,
55 scoped_ptr<net::StreamSocket> socket) {
56 DCHECK(!read_callback_.is_null());
57 DCHECK(!write_callback_.is_null());
58 DCHECK(!connection_callback_.is_null());
59
60 // Invalidate any previously outstanding reads.
61 weak_ptr_factory_.InvalidateWeakPtrs();
62
63 handshake_complete_ = false;
64 message_tag_ = 0;
65 message_size_ = 0;
66 socket_ = socket.Pass();
67 input_stream_.reset(new SocketInputStream(socket_.get()));
68 output_stream_.reset(new SocketOutputStream(socket_.get()));
69
70 Login(login_request);
71 }
72
CanSendMessage() const73 bool ConnectionHandlerImpl::CanSendMessage() const {
74 return handshake_complete_ && output_stream_.get() &&
75 output_stream_->GetState() == SocketOutputStream::EMPTY;
76 }
77
SendMessage(const google::protobuf::MessageLite & message)78 void ConnectionHandlerImpl::SendMessage(
79 const google::protobuf::MessageLite& message) {
80 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
81 DCHECK(handshake_complete_);
82
83 {
84 CodedOutputStream coded_output_stream(output_stream_.get());
85 DVLOG(1) << "Writing proto of size " << message.ByteSize();
86 int tag = GetMCSProtoTag(message);
87 DCHECK_NE(tag, -1);
88 coded_output_stream.WriteRaw(&tag, 1);
89 coded_output_stream.WriteVarint32(message.ByteSize());
90 message.SerializeToCodedStream(&coded_output_stream);
91 }
92
93 if (output_stream_->Flush(
94 base::Bind(&ConnectionHandlerImpl::OnMessageSent,
95 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
96 OnMessageSent();
97 }
98 }
99
Login(const google::protobuf::MessageLite & login_request)100 void ConnectionHandlerImpl::Login(
101 const google::protobuf::MessageLite& login_request) {
102 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
103
104 const char version_byte[1] = {kMCSVersion};
105 const char login_request_tag[1] = {kLoginRequestTag};
106 {
107 CodedOutputStream coded_output_stream(output_stream_.get());
108 coded_output_stream.WriteRaw(version_byte, 1);
109 coded_output_stream.WriteRaw(login_request_tag, 1);
110 coded_output_stream.WriteVarint32(login_request.ByteSize());
111 login_request.SerializeToCodedStream(&coded_output_stream);
112 }
113
114 if (output_stream_->Flush(
115 base::Bind(&ConnectionHandlerImpl::OnMessageSent,
116 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
117 base::MessageLoop::current()->PostTask(
118 FROM_HERE,
119 base::Bind(&ConnectionHandlerImpl::OnMessageSent,
120 weak_ptr_factory_.GetWeakPtr()));
121 }
122
123 read_timeout_timer_.Start(FROM_HERE,
124 read_timeout_,
125 base::Bind(&ConnectionHandlerImpl::OnTimeout,
126 weak_ptr_factory_.GetWeakPtr()));
127 WaitForData(MCS_VERSION_TAG_AND_SIZE);
128 }
129
OnMessageSent()130 void ConnectionHandlerImpl::OnMessageSent() {
131 if (!output_stream_.get()) {
132 // The connection has already been closed. Just return.
133 DCHECK(!input_stream_.get());
134 DCHECK(!read_timeout_timer_.IsRunning());
135 return;
136 }
137
138 if (output_stream_->GetState() != SocketOutputStream::EMPTY) {
139 int last_error = output_stream_->last_error();
140 CloseConnection();
141 // If the socket stream had an error, plumb it up, else plumb up FAILED.
142 if (last_error == net::OK)
143 last_error = net::ERR_FAILED;
144 connection_callback_.Run(last_error);
145 return;
146 }
147
148 write_callback_.Run();
149 }
150
GetNextMessage()151 void ConnectionHandlerImpl::GetNextMessage() {
152 DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() ||
153 SocketInputStream::READY == input_stream_->GetState());
154 message_tag_ = 0;
155 message_size_ = 0;
156
157 WaitForData(MCS_TAG_AND_SIZE);
158 }
159
WaitForData(ProcessingState state)160 void ConnectionHandlerImpl::WaitForData(ProcessingState state) {
161 DVLOG(1) << "Waiting for MCS data: state == " << state;
162
163 if (!input_stream_) {
164 // The connection has already been closed. Just return.
165 DCHECK(!output_stream_.get());
166 DCHECK(!read_timeout_timer_.IsRunning());
167 return;
168 }
169
170 if (input_stream_->GetState() != SocketInputStream::EMPTY &&
171 input_stream_->GetState() != SocketInputStream::READY) {
172 // An error occurred.
173 int last_error = output_stream_->last_error();
174 CloseConnection();
175 // If the socket stream had an error, plumb it up, else plumb up FAILED.
176 if (last_error == net::OK)
177 last_error = net::ERR_FAILED;
178 connection_callback_.Run(last_error);
179 return;
180 }
181
182 // Used to determine whether a Socket::Read is necessary.
183 int min_bytes_needed = 0;
184 // Used to limit the size of the Socket::Read.
185 int max_bytes_needed = 0;
186
187 switch(state) {
188 case MCS_VERSION_TAG_AND_SIZE:
189 min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin;
190 max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax;
191 break;
192 case MCS_TAG_AND_SIZE:
193 min_bytes_needed = kTagPacketLen + kSizePacketLenMin;
194 max_bytes_needed = kTagPacketLen + kSizePacketLenMax;
195 break;
196 case MCS_FULL_SIZE:
197 // If in this state, the minimum size packet length must already have been
198 // insufficient, so set both to the max length.
199 min_bytes_needed = kSizePacketLenMax;
200 max_bytes_needed = kSizePacketLenMax;
201 break;
202 case MCS_PROTO_BYTES:
203 read_timeout_timer_.Reset();
204 // No variability in the message size, set both to the same.
205 min_bytes_needed = message_size_;
206 max_bytes_needed = message_size_;
207 break;
208 default:
209 NOTREACHED();
210 }
211 DCHECK_GE(max_bytes_needed, min_bytes_needed);
212
213 int byte_count = input_stream_->UnreadByteCount();
214 if (min_bytes_needed - byte_count > 0 &&
215 input_stream_->Refresh(
216 base::Bind(&ConnectionHandlerImpl::WaitForData,
217 weak_ptr_factory_.GetWeakPtr(),
218 state),
219 max_bytes_needed - byte_count) == net::ERR_IO_PENDING) {
220 return;
221 }
222
223 // Check for refresh errors.
224 if (input_stream_->GetState() != SocketInputStream::READY) {
225 // An error occurred.
226 int last_error = output_stream_->last_error();
227 CloseConnection();
228 // If the socket stream had an error, plumb it up, else plumb up FAILED.
229 if (last_error == net::OK)
230 last_error = net::ERR_FAILED;
231 connection_callback_.Run(last_error);
232 return;
233 }
234
235 // Received enough bytes, process them.
236 DVLOG(1) << "Processing MCS data: state == " << state;
237 switch(state) {
238 case MCS_VERSION_TAG_AND_SIZE:
239 OnGotVersion();
240 break;
241 case MCS_TAG_AND_SIZE:
242 OnGotMessageTag();
243 break;
244 case MCS_FULL_SIZE:
245 OnGotMessageSize();
246 break;
247 case MCS_PROTO_BYTES:
248 OnGotMessageBytes();
249 break;
250 default:
251 NOTREACHED();
252 }
253 }
254
OnGotVersion()255 void ConnectionHandlerImpl::OnGotVersion() {
256 uint8 version = 0;
257 {
258 CodedInputStream coded_input_stream(input_stream_.get());
259 coded_input_stream.ReadRaw(&version, 1);
260 }
261 if (version < kMCSVersion) {
262 LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version);
263 connection_callback_.Run(net::ERR_FAILED);
264 return;
265 }
266
267 input_stream_->RebuildBuffer();
268
269 // Process the LoginResponse message tag.
270 OnGotMessageTag();
271 }
272
OnGotMessageTag()273 void ConnectionHandlerImpl::OnGotMessageTag() {
274 if (input_stream_->GetState() != SocketInputStream::READY) {
275 LOG(ERROR) << "Failed to receive protobuf tag.";
276 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
277 return;
278 }
279
280 {
281 CodedInputStream coded_input_stream(input_stream_.get());
282 coded_input_stream.ReadRaw(&message_tag_, 1);
283 }
284
285 DVLOG(1) << "Received proto of type "
286 << static_cast<unsigned int>(message_tag_);
287
288 if (!read_timeout_timer_.IsRunning()) {
289 read_timeout_timer_.Start(FROM_HERE,
290 read_timeout_,
291 base::Bind(&ConnectionHandlerImpl::OnTimeout,
292 weak_ptr_factory_.GetWeakPtr()));
293 }
294 OnGotMessageSize();
295 }
296
OnGotMessageSize()297 void ConnectionHandlerImpl::OnGotMessageSize() {
298 if (input_stream_->GetState() != SocketInputStream::READY) {
299 LOG(ERROR) << "Failed to receive message size.";
300 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
301 return;
302 }
303
304 bool need_another_byte = false;
305 int prev_byte_count = input_stream_->ByteCount();
306 {
307 CodedInputStream coded_input_stream(input_stream_.get());
308 if (!coded_input_stream.ReadVarint32(&message_size_))
309 need_another_byte = true;
310 }
311
312 if (need_another_byte) {
313 DVLOG(1) << "Expecting another message size byte.";
314 if (prev_byte_count >= kSizePacketLenMax) {
315 // Already had enough bytes, something else went wrong.
316 LOG(ERROR) << "Failed to process message size.";
317 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
318 return;
319 }
320 // Back up by the amount read (should always be 1 byte).
321 int bytes_read = prev_byte_count - input_stream_->ByteCount();
322 DCHECK_EQ(bytes_read, 1);
323 input_stream_->BackUp(bytes_read);
324 WaitForData(MCS_FULL_SIZE);
325 return;
326 }
327
328 DVLOG(1) << "Proto size: " << message_size_;
329
330 if (message_size_ > 0)
331 WaitForData(MCS_PROTO_BYTES);
332 else
333 OnGotMessageBytes();
334 }
335
OnGotMessageBytes()336 void ConnectionHandlerImpl::OnGotMessageBytes() {
337 read_timeout_timer_.Stop();
338 scoped_ptr<google::protobuf::MessageLite> protobuf(
339 BuildProtobufFromTag(message_tag_));
340 // Messages with no content are valid; just use the default protobuf for
341 // that tag.
342 if (protobuf.get() && message_size_ == 0) {
343 base::MessageLoop::current()->PostTask(
344 FROM_HERE,
345 base::Bind(&ConnectionHandlerImpl::GetNextMessage,
346 weak_ptr_factory_.GetWeakPtr()));
347 read_callback_.Run(protobuf.Pass());
348 return;
349 }
350
351 if (!protobuf.get() ||
352 input_stream_->GetState() != SocketInputStream::READY) {
353 LOG(ERROR) << "Failed to extract protobuf bytes of type "
354 << static_cast<unsigned int>(message_tag_);
355 protobuf.reset(); // Return a null pointer to denote an error.
356 read_callback_.Run(protobuf.Pass());
357 return;
358 }
359
360 {
361 CodedInputStream coded_input_stream(input_stream_.get());
362 if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) {
363 NOTREACHED() << "Unable to parse GCM message of type "
364 << static_cast<unsigned int>(message_tag_);
365 protobuf.reset(); // Return a null pointer to denote an error.
366 read_callback_.Run(protobuf.Pass());
367 return;
368 }
369 }
370
371 input_stream_->RebuildBuffer();
372 base::MessageLoop::current()->PostTask(
373 FROM_HERE,
374 base::Bind(&ConnectionHandlerImpl::GetNextMessage,
375 weak_ptr_factory_.GetWeakPtr()));
376 if (message_tag_ == kLoginResponseTag) {
377 if (handshake_complete_) {
378 LOG(ERROR) << "Unexpected login response.";
379 } else {
380 handshake_complete_ = true;
381 DVLOG(1) << "GCM Handshake complete.";
382 }
383 }
384 read_callback_.Run(protobuf.Pass());
385 }
386
OnTimeout()387 void ConnectionHandlerImpl::OnTimeout() {
388 LOG(ERROR) << "Timed out waiting for GCM Protocol buffer.";
389 CloseConnection();
390 connection_callback_.Run(net::ERR_TIMED_OUT);
391 }
392
CloseConnection()393 void ConnectionHandlerImpl::CloseConnection() {
394 DVLOG(1) << "Closing connection.";
395 read_callback_.Reset();
396 write_callback_.Reset();
397 read_timeout_timer_.Stop();
398 socket_->Disconnect();
399 input_stream_.reset();
400 output_stream_.reset();
401 weak_ptr_factory_.InvalidateWeakPtrs();
402 }
403
404 } // namespace gcm
405