• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "google_apis/gcm/engine/connection_handler_impl.h"
6 
7 #include "base/message_loop/message_loop.h"
8 #include "google/protobuf/io/coded_stream.h"
9 #include "google_apis/gcm/base/mcs_util.h"
10 #include "google_apis/gcm/base/socket_stream.h"
11 #include "google_apis/gcm/protocol/mcs.pb.h"
12 #include "net/base/net_errors.h"
13 #include "net/socket/stream_socket.h"
14 
15 using namespace google::protobuf::io;
16 
17 namespace gcm {
18 
19 namespace {
20 
21 // # of bytes a MCS version packet consumes.
22 const int kVersionPacketLen = 1;
23 // # of bytes a tag packet consumes.
24 const int kTagPacketLen = 1;
25 // Max # of bytes a length packet consumes.
26 const int kSizePacketLenMin = 1;
27 const int kSizePacketLenMax = 2;
28 
29 // The current MCS protocol version.
30 const int kMCSVersion = 41;
31 
32 }  // namespace
33 
ConnectionHandlerImpl(base::TimeDelta read_timeout,const ProtoReceivedCallback & read_callback,const ProtoSentCallback & write_callback,const ConnectionChangedCallback & connection_callback)34 ConnectionHandlerImpl::ConnectionHandlerImpl(
35     base::TimeDelta read_timeout,
36     const ProtoReceivedCallback& read_callback,
37     const ProtoSentCallback& write_callback,
38     const ConnectionChangedCallback& connection_callback)
39     : read_timeout_(read_timeout),
40       socket_(NULL),
41       handshake_complete_(false),
42       message_tag_(0),
43       message_size_(0),
44       read_callback_(read_callback),
45       write_callback_(write_callback),
46       connection_callback_(connection_callback),
47       weak_ptr_factory_(this) {
48 }
49 
~ConnectionHandlerImpl()50 ConnectionHandlerImpl::~ConnectionHandlerImpl() {
51 }
52 
Init(const mcs_proto::LoginRequest & login_request,net::StreamSocket * socket)53 void ConnectionHandlerImpl::Init(
54     const mcs_proto::LoginRequest& login_request,
55     net::StreamSocket* socket) {
56   DCHECK(!read_callback_.is_null());
57   DCHECK(!write_callback_.is_null());
58   DCHECK(!connection_callback_.is_null());
59 
60   // Invalidate any previously outstanding reads.
61   weak_ptr_factory_.InvalidateWeakPtrs();
62 
63   handshake_complete_ = false;
64   message_tag_ = 0;
65   message_size_ = 0;
66   socket_ = socket;
67   input_stream_.reset(new SocketInputStream(socket_));
68   output_stream_.reset(new SocketOutputStream(socket_));
69 
70   Login(login_request);
71 }
72 
Reset()73 void ConnectionHandlerImpl::Reset() {
74   CloseConnection();
75 }
76 
CanSendMessage() const77 bool ConnectionHandlerImpl::CanSendMessage() const {
78   return handshake_complete_ && output_stream_.get() &&
79       output_stream_->GetState() == SocketOutputStream::EMPTY;
80 }
81 
SendMessage(const google::protobuf::MessageLite & message)82 void ConnectionHandlerImpl::SendMessage(
83     const google::protobuf::MessageLite& message) {
84   DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
85   DCHECK(handshake_complete_);
86 
87   {
88     CodedOutputStream coded_output_stream(output_stream_.get());
89     DVLOG(1) << "Writing proto of size " << message.ByteSize();
90     int tag = GetMCSProtoTag(message);
91     DCHECK_NE(tag, -1);
92     coded_output_stream.WriteRaw(&tag, 1);
93     coded_output_stream.WriteVarint32(message.ByteSize());
94     message.SerializeToCodedStream(&coded_output_stream);
95   }
96 
97   if (output_stream_->Flush(
98           base::Bind(&ConnectionHandlerImpl::OnMessageSent,
99                      weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
100     OnMessageSent();
101   }
102 }
103 
Login(const google::protobuf::MessageLite & login_request)104 void ConnectionHandlerImpl::Login(
105     const google::protobuf::MessageLite& login_request) {
106   DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
107 
108   const char version_byte[1] = {kMCSVersion};
109   const char login_request_tag[1] = {kLoginRequestTag};
110   {
111     CodedOutputStream coded_output_stream(output_stream_.get());
112     coded_output_stream.WriteRaw(version_byte, 1);
113     coded_output_stream.WriteRaw(login_request_tag, 1);
114     coded_output_stream.WriteVarint32(login_request.ByteSize());
115     login_request.SerializeToCodedStream(&coded_output_stream);
116   }
117 
118   if (output_stream_->Flush(
119           base::Bind(&ConnectionHandlerImpl::OnMessageSent,
120                      weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
121     base::MessageLoop::current()->PostTask(
122         FROM_HERE,
123         base::Bind(&ConnectionHandlerImpl::OnMessageSent,
124                    weak_ptr_factory_.GetWeakPtr()));
125   }
126 
127   read_timeout_timer_.Start(FROM_HERE,
128                             read_timeout_,
129                             base::Bind(&ConnectionHandlerImpl::OnTimeout,
130                                        weak_ptr_factory_.GetWeakPtr()));
131   WaitForData(MCS_VERSION_TAG_AND_SIZE);
132 }
133 
OnMessageSent()134 void ConnectionHandlerImpl::OnMessageSent() {
135   if (!output_stream_.get()) {
136     // The connection has already been closed. Just return.
137     DCHECK(!input_stream_.get());
138     DCHECK(!read_timeout_timer_.IsRunning());
139     return;
140   }
141 
142   if (output_stream_->GetState() != SocketOutputStream::EMPTY) {
143     int last_error = output_stream_->last_error();
144     CloseConnection();
145     // If the socket stream had an error, plumb it up, else plumb up FAILED.
146     if (last_error == net::OK)
147       last_error = net::ERR_FAILED;
148     connection_callback_.Run(last_error);
149     return;
150   }
151 
152   write_callback_.Run();
153 }
154 
GetNextMessage()155 void ConnectionHandlerImpl::GetNextMessage() {
156   DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() ||
157          SocketInputStream::READY == input_stream_->GetState());
158   message_tag_ = 0;
159   message_size_ = 0;
160 
161   WaitForData(MCS_TAG_AND_SIZE);
162 }
163 
WaitForData(ProcessingState state)164 void ConnectionHandlerImpl::WaitForData(ProcessingState state) {
165   DVLOG(1) << "Waiting for MCS data: state == " << state;
166 
167   if (!input_stream_) {
168     // The connection has already been closed. Just return.
169     DCHECK(!output_stream_.get());
170     DCHECK(!read_timeout_timer_.IsRunning());
171     return;
172   }
173 
174   if (input_stream_->GetState() != SocketInputStream::EMPTY &&
175       input_stream_->GetState() != SocketInputStream::READY) {
176     // An error occurred.
177     int last_error = output_stream_->last_error();
178     CloseConnection();
179     // If the socket stream had an error, plumb it up, else plumb up FAILED.
180     if (last_error == net::OK)
181       last_error = net::ERR_FAILED;
182     connection_callback_.Run(last_error);
183     return;
184   }
185 
186   // Used to determine whether a Socket::Read is necessary.
187   size_t min_bytes_needed = 0;
188   // Used to limit the size of the Socket::Read.
189   size_t max_bytes_needed = 0;
190 
191   switch(state) {
192     case MCS_VERSION_TAG_AND_SIZE:
193       min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin;
194       max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax;
195       break;
196     case MCS_TAG_AND_SIZE:
197       min_bytes_needed = kTagPacketLen + kSizePacketLenMin;
198       max_bytes_needed = kTagPacketLen + kSizePacketLenMax;
199       break;
200     case MCS_FULL_SIZE:
201       // If in this state, the minimum size packet length must already have been
202       // insufficient, so set both to the max length.
203       min_bytes_needed = kSizePacketLenMax;
204       max_bytes_needed = kSizePacketLenMax;
205       break;
206     case MCS_PROTO_BYTES:
207       read_timeout_timer_.Reset();
208       // No variability in the message size, set both to the same.
209       min_bytes_needed = message_size_;
210       max_bytes_needed = message_size_;
211       break;
212     default:
213       NOTREACHED();
214   }
215   DCHECK_GE(max_bytes_needed, min_bytes_needed);
216 
217   size_t unread_byte_count = input_stream_->UnreadByteCount();
218   if (min_bytes_needed > unread_byte_count &&
219       input_stream_->Refresh(
220           base::Bind(&ConnectionHandlerImpl::WaitForData,
221                      weak_ptr_factory_.GetWeakPtr(),
222                      state),
223           max_bytes_needed - unread_byte_count) == net::ERR_IO_PENDING) {
224     return;
225   }
226 
227   // Check for refresh errors.
228   if (input_stream_->GetState() != SocketInputStream::READY) {
229     // An error occurred.
230     int last_error = input_stream_->last_error();
231     CloseConnection();
232     // If the socket stream had an error, plumb it up, else plumb up FAILED.
233     if (last_error == net::OK)
234       last_error = net::ERR_FAILED;
235     connection_callback_.Run(last_error);
236     return;
237   }
238 
239   // Check whether read is complete, or needs to be continued (
240   // SocketInputStream::Refresh can finish without reading all the data).
241   if (input_stream_->UnreadByteCount() < min_bytes_needed) {
242     DVLOG(1) << "Socket read finished prematurely. Waiting for "
243              << min_bytes_needed - input_stream_->UnreadByteCount()
244              << " more bytes.";
245     base::MessageLoop::current()->PostTask(
246         FROM_HERE,
247         base::Bind(&ConnectionHandlerImpl::WaitForData,
248                    weak_ptr_factory_.GetWeakPtr(),
249                    MCS_PROTO_BYTES));
250     return;
251   }
252 
253   // Received enough bytes, process them.
254   DVLOG(1) << "Processing MCS data: state == " << state;
255   switch(state) {
256     case MCS_VERSION_TAG_AND_SIZE:
257       OnGotVersion();
258       break;
259     case MCS_TAG_AND_SIZE:
260       OnGotMessageTag();
261       break;
262     case MCS_FULL_SIZE:
263       OnGotMessageSize();
264       break;
265     case MCS_PROTO_BYTES:
266       OnGotMessageBytes();
267       break;
268     default:
269       NOTREACHED();
270   }
271 }
272 
OnGotVersion()273 void ConnectionHandlerImpl::OnGotVersion() {
274   uint8 version = 0;
275   {
276     CodedInputStream coded_input_stream(input_stream_.get());
277     coded_input_stream.ReadRaw(&version, 1);
278   }
279   // TODO(zea): remove this when the server is ready.
280   if (version < kMCSVersion && version != 38) {
281     LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version);
282     connection_callback_.Run(net::ERR_FAILED);
283     return;
284   }
285 
286   input_stream_->RebuildBuffer();
287 
288   // Process the LoginResponse message tag.
289   OnGotMessageTag();
290 }
291 
OnGotMessageTag()292 void ConnectionHandlerImpl::OnGotMessageTag() {
293   if (input_stream_->GetState() != SocketInputStream::READY) {
294     LOG(ERROR) << "Failed to receive protobuf tag.";
295     read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
296     return;
297   }
298 
299   {
300     CodedInputStream coded_input_stream(input_stream_.get());
301     coded_input_stream.ReadRaw(&message_tag_, 1);
302   }
303 
304   DVLOG(1) << "Received proto of type "
305            << static_cast<unsigned int>(message_tag_);
306 
307   if (!read_timeout_timer_.IsRunning()) {
308     read_timeout_timer_.Start(FROM_HERE,
309                               read_timeout_,
310                               base::Bind(&ConnectionHandlerImpl::OnTimeout,
311                                          weak_ptr_factory_.GetWeakPtr()));
312   }
313   OnGotMessageSize();
314 }
315 
OnGotMessageSize()316 void ConnectionHandlerImpl::OnGotMessageSize() {
317   if (input_stream_->GetState() != SocketInputStream::READY) {
318     LOG(ERROR) << "Failed to receive message size.";
319     read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
320     return;
321   }
322 
323   bool need_another_byte = false;
324   int prev_byte_count = input_stream_->ByteCount();
325   {
326     CodedInputStream coded_input_stream(input_stream_.get());
327     if (!coded_input_stream.ReadVarint32(&message_size_))
328       need_another_byte = true;
329   }
330 
331   if (need_another_byte) {
332     DVLOG(1) << "Expecting another message size byte.";
333     if (prev_byte_count >= kSizePacketLenMax) {
334       // Already had enough bytes, something else went wrong.
335       LOG(ERROR) << "Failed to process message size.";
336       read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
337       return;
338     }
339     // Back up by the amount read (should always be 1 byte).
340     int bytes_read = prev_byte_count - input_stream_->ByteCount();
341     DCHECK_EQ(bytes_read, 1);
342     input_stream_->BackUp(bytes_read);
343     WaitForData(MCS_FULL_SIZE);
344     return;
345   }
346 
347   DVLOG(1) << "Proto size: " << message_size_;
348 
349   if (message_size_ > 0)
350     WaitForData(MCS_PROTO_BYTES);
351   else
352     OnGotMessageBytes();
353 }
354 
OnGotMessageBytes()355 void ConnectionHandlerImpl::OnGotMessageBytes() {
356   read_timeout_timer_.Stop();
357   scoped_ptr<google::protobuf::MessageLite> protobuf(
358       BuildProtobufFromTag(message_tag_));
359   // Messages with no content are valid; just use the default protobuf for
360   // that tag.
361   if (protobuf.get() && message_size_ == 0) {
362     base::MessageLoop::current()->PostTask(
363         FROM_HERE,
364         base::Bind(&ConnectionHandlerImpl::GetNextMessage,
365                    weak_ptr_factory_.GetWeakPtr()));
366     read_callback_.Run(protobuf.Pass());
367     return;
368   }
369 
370   if (!protobuf.get() ||
371       input_stream_->GetState() != SocketInputStream::READY) {
372     LOG(ERROR) << "Failed to extract protobuf bytes of type "
373                << static_cast<unsigned int>(message_tag_);
374     // Reset the connection.
375     connection_callback_.Run(net::ERR_FAILED);
376     return;
377   }
378 
379   {
380     CodedInputStream coded_input_stream(input_stream_.get());
381     if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) {
382       LOG(ERROR) << "Unable to parse GCM message of type "
383                  << static_cast<unsigned int>(message_tag_);
384       // Reset the connection.
385       connection_callback_.Run(net::ERR_FAILED);
386       return;
387     }
388   }
389 
390   input_stream_->RebuildBuffer();
391   base::MessageLoop::current()->PostTask(
392       FROM_HERE,
393       base::Bind(&ConnectionHandlerImpl::GetNextMessage,
394                  weak_ptr_factory_.GetWeakPtr()));
395   if (message_tag_ == kLoginResponseTag) {
396     if (handshake_complete_) {
397       LOG(ERROR) << "Unexpected login response.";
398     } else {
399       handshake_complete_ = true;
400       DVLOG(1) << "GCM Handshake complete.";
401       connection_callback_.Run(net::OK);
402     }
403   }
404   read_callback_.Run(protobuf.Pass());
405 }
406 
OnTimeout()407 void ConnectionHandlerImpl::OnTimeout() {
408   LOG(ERROR) << "Timed out waiting for GCM Protocol buffer.";
409   CloseConnection();
410   connection_callback_.Run(net::ERR_TIMED_OUT);
411 }
412 
CloseConnection()413 void ConnectionHandlerImpl::CloseConnection() {
414   DVLOG(1) << "Closing connection.";
415   read_timeout_timer_.Stop();
416   if (socket_)
417     socket_->Disconnect();
418   socket_ = NULL;
419   handshake_complete_ = false;
420   message_tag_ = 0;
421   message_size_ = 0;
422   input_stream_.reset();
423   output_stream_.reset();
424   weak_ptr_factory_.InvalidateWeakPtrs();
425 }
426 
427 }  // namespace gcm
428