• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "google_apis/gcm/engine/connection_handler_impl.h"
6 
7 #include "base/message_loop/message_loop.h"
8 #include "google/protobuf/io/coded_stream.h"
9 #include "google_apis/gcm/base/mcs_util.h"
10 #include "google_apis/gcm/base/socket_stream.h"
11 #include "google_apis/gcm/protocol/mcs.pb.h"
12 #include "net/base/net_errors.h"
13 #include "net/socket/stream_socket.h"
14 
15 using namespace google::protobuf::io;
16 
17 namespace gcm {
18 
19 namespace {
20 
21 // # of bytes a MCS version packet consumes.
22 const int kVersionPacketLen = 1;
23 // # of bytes a tag packet consumes.
24 const int kTagPacketLen = 1;
25 // Max # of bytes a length packet consumes.
26 const int kSizePacketLenMin = 1;
27 const int kSizePacketLenMax = 2;
28 
29 // The current MCS protocol version.
30 // TODO(zea): bump to 41 once the server supports it.
31 const int kMCSVersion = 38;
32 
33 }  // namespace
34 
ConnectionHandlerImpl(base::TimeDelta read_timeout,const ProtoReceivedCallback & read_callback,const ProtoSentCallback & write_callback,const ConnectionChangedCallback & connection_callback)35 ConnectionHandlerImpl::ConnectionHandlerImpl(
36     base::TimeDelta read_timeout,
37     const ProtoReceivedCallback& read_callback,
38     const ProtoSentCallback& write_callback,
39     const ConnectionChangedCallback& connection_callback)
40     : read_timeout_(read_timeout),
41       handshake_complete_(false),
42       message_tag_(0),
43       message_size_(0),
44       read_callback_(read_callback),
45       write_callback_(write_callback),
46       connection_callback_(connection_callback),
47       weak_ptr_factory_(this) {
48 }
49 
~ConnectionHandlerImpl()50 ConnectionHandlerImpl::~ConnectionHandlerImpl() {
51 }
52 
Init(const mcs_proto::LoginRequest & login_request,scoped_ptr<net::StreamSocket> socket)53 void ConnectionHandlerImpl::Init(
54     const mcs_proto::LoginRequest& login_request,
55     scoped_ptr<net::StreamSocket> socket) {
56   DCHECK(!read_callback_.is_null());
57   DCHECK(!write_callback_.is_null());
58   DCHECK(!connection_callback_.is_null());
59 
60   // Invalidate any previously outstanding reads.
61   weak_ptr_factory_.InvalidateWeakPtrs();
62 
63   handshake_complete_ = false;
64   message_tag_ = 0;
65   message_size_ = 0;
66   socket_ = socket.Pass();
67   input_stream_.reset(new SocketInputStream(socket_.get()));
68   output_stream_.reset(new SocketOutputStream(socket_.get()));
69 
70   Login(login_request);
71 }
72 
CanSendMessage() const73 bool ConnectionHandlerImpl::CanSendMessage() const {
74   return handshake_complete_ && output_stream_.get() &&
75       output_stream_->GetState() == SocketOutputStream::EMPTY;
76 }
77 
SendMessage(const google::protobuf::MessageLite & message)78 void ConnectionHandlerImpl::SendMessage(
79     const google::protobuf::MessageLite& message) {
80   DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
81   DCHECK(handshake_complete_);
82 
83   {
84     CodedOutputStream coded_output_stream(output_stream_.get());
85     DVLOG(1) << "Writing proto of size " << message.ByteSize();
86     int tag = GetMCSProtoTag(message);
87     DCHECK_NE(tag, -1);
88     coded_output_stream.WriteRaw(&tag, 1);
89     coded_output_stream.WriteVarint32(message.ByteSize());
90     message.SerializeToCodedStream(&coded_output_stream);
91   }
92 
93   if (output_stream_->Flush(
94           base::Bind(&ConnectionHandlerImpl::OnMessageSent,
95                      weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
96     OnMessageSent();
97   }
98 }
99 
Login(const google::protobuf::MessageLite & login_request)100 void ConnectionHandlerImpl::Login(
101     const google::protobuf::MessageLite& login_request) {
102   DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
103 
104   const char version_byte[1] = {kMCSVersion};
105   const char login_request_tag[1] = {kLoginRequestTag};
106   {
107     CodedOutputStream coded_output_stream(output_stream_.get());
108     coded_output_stream.WriteRaw(version_byte, 1);
109     coded_output_stream.WriteRaw(login_request_tag, 1);
110     coded_output_stream.WriteVarint32(login_request.ByteSize());
111     login_request.SerializeToCodedStream(&coded_output_stream);
112   }
113 
114   if (output_stream_->Flush(
115           base::Bind(&ConnectionHandlerImpl::OnMessageSent,
116                      weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
117     base::MessageLoop::current()->PostTask(
118         FROM_HERE,
119         base::Bind(&ConnectionHandlerImpl::OnMessageSent,
120                    weak_ptr_factory_.GetWeakPtr()));
121   }
122 
123   read_timeout_timer_.Start(FROM_HERE,
124                             read_timeout_,
125                             base::Bind(&ConnectionHandlerImpl::OnTimeout,
126                                        weak_ptr_factory_.GetWeakPtr()));
127   WaitForData(MCS_VERSION_TAG_AND_SIZE);
128 }
129 
OnMessageSent()130 void ConnectionHandlerImpl::OnMessageSent() {
131   if (!output_stream_.get()) {
132     // The connection has already been closed. Just return.
133     DCHECK(!input_stream_.get());
134     DCHECK(!read_timeout_timer_.IsRunning());
135     return;
136   }
137 
138   if (output_stream_->GetState() != SocketOutputStream::EMPTY) {
139     int last_error = output_stream_->last_error();
140     CloseConnection();
141     // If the socket stream had an error, plumb it up, else plumb up FAILED.
142     if (last_error == net::OK)
143       last_error = net::ERR_FAILED;
144     connection_callback_.Run(last_error);
145     return;
146   }
147 
148   write_callback_.Run();
149 }
150 
GetNextMessage()151 void ConnectionHandlerImpl::GetNextMessage() {
152   DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() ||
153          SocketInputStream::READY == input_stream_->GetState());
154   message_tag_ = 0;
155   message_size_ = 0;
156 
157   WaitForData(MCS_TAG_AND_SIZE);
158 }
159 
WaitForData(ProcessingState state)160 void ConnectionHandlerImpl::WaitForData(ProcessingState state) {
161   DVLOG(1) << "Waiting for MCS data: state == " << state;
162 
163   if (!input_stream_) {
164     // The connection has already been closed. Just return.
165     DCHECK(!output_stream_.get());
166     DCHECK(!read_timeout_timer_.IsRunning());
167     return;
168   }
169 
170   if (input_stream_->GetState() != SocketInputStream::EMPTY &&
171       input_stream_->GetState() != SocketInputStream::READY) {
172     // An error occurred.
173     int last_error = output_stream_->last_error();
174     CloseConnection();
175     // If the socket stream had an error, plumb it up, else plumb up FAILED.
176     if (last_error == net::OK)
177       last_error = net::ERR_FAILED;
178     connection_callback_.Run(last_error);
179     return;
180   }
181 
182   // Used to determine whether a Socket::Read is necessary.
183   int min_bytes_needed = 0;
184   // Used to limit the size of the Socket::Read.
185   int max_bytes_needed = 0;
186 
187   switch(state) {
188     case MCS_VERSION_TAG_AND_SIZE:
189       min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin;
190       max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax;
191       break;
192     case MCS_TAG_AND_SIZE:
193       min_bytes_needed = kTagPacketLen + kSizePacketLenMin;
194       max_bytes_needed = kTagPacketLen + kSizePacketLenMax;
195       break;
196     case MCS_FULL_SIZE:
197       // If in this state, the minimum size packet length must already have been
198       // insufficient, so set both to the max length.
199       min_bytes_needed = kSizePacketLenMax;
200       max_bytes_needed = kSizePacketLenMax;
201       break;
202     case MCS_PROTO_BYTES:
203       read_timeout_timer_.Reset();
204       // No variability in the message size, set both to the same.
205       min_bytes_needed = message_size_;
206       max_bytes_needed = message_size_;
207       break;
208     default:
209       NOTREACHED();
210   }
211   DCHECK_GE(max_bytes_needed, min_bytes_needed);
212 
213   int byte_count = input_stream_->UnreadByteCount();
214   if (min_bytes_needed - byte_count > 0 &&
215       input_stream_->Refresh(
216           base::Bind(&ConnectionHandlerImpl::WaitForData,
217                      weak_ptr_factory_.GetWeakPtr(),
218                      state),
219           max_bytes_needed - byte_count) == net::ERR_IO_PENDING) {
220     return;
221   }
222 
223   // Check for refresh errors.
224   if (input_stream_->GetState() != SocketInputStream::READY) {
225     // An error occurred.
226     int last_error = output_stream_->last_error();
227     CloseConnection();
228     // If the socket stream had an error, plumb it up, else plumb up FAILED.
229     if (last_error == net::OK)
230       last_error = net::ERR_FAILED;
231     connection_callback_.Run(last_error);
232     return;
233   }
234 
235   // Received enough bytes, process them.
236   DVLOG(1) << "Processing MCS data: state == " << state;
237   switch(state) {
238     case MCS_VERSION_TAG_AND_SIZE:
239       OnGotVersion();
240       break;
241     case MCS_TAG_AND_SIZE:
242       OnGotMessageTag();
243       break;
244     case MCS_FULL_SIZE:
245       OnGotMessageSize();
246       break;
247     case MCS_PROTO_BYTES:
248       OnGotMessageBytes();
249       break;
250     default:
251       NOTREACHED();
252   }
253 }
254 
OnGotVersion()255 void ConnectionHandlerImpl::OnGotVersion() {
256   uint8 version = 0;
257   {
258     CodedInputStream coded_input_stream(input_stream_.get());
259     coded_input_stream.ReadRaw(&version, 1);
260   }
261   if (version < kMCSVersion) {
262     LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version);
263     connection_callback_.Run(net::ERR_FAILED);
264     return;
265   }
266 
267   input_stream_->RebuildBuffer();
268 
269   // Process the LoginResponse message tag.
270   OnGotMessageTag();
271 }
272 
OnGotMessageTag()273 void ConnectionHandlerImpl::OnGotMessageTag() {
274   if (input_stream_->GetState() != SocketInputStream::READY) {
275     LOG(ERROR) << "Failed to receive protobuf tag.";
276     read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
277     return;
278   }
279 
280   {
281     CodedInputStream coded_input_stream(input_stream_.get());
282     coded_input_stream.ReadRaw(&message_tag_, 1);
283   }
284 
285   DVLOG(1) << "Received proto of type "
286            << static_cast<unsigned int>(message_tag_);
287 
288   if (!read_timeout_timer_.IsRunning()) {
289     read_timeout_timer_.Start(FROM_HERE,
290                               read_timeout_,
291                               base::Bind(&ConnectionHandlerImpl::OnTimeout,
292                                          weak_ptr_factory_.GetWeakPtr()));
293   }
294   OnGotMessageSize();
295 }
296 
OnGotMessageSize()297 void ConnectionHandlerImpl::OnGotMessageSize() {
298   if (input_stream_->GetState() != SocketInputStream::READY) {
299     LOG(ERROR) << "Failed to receive message size.";
300     read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
301     return;
302   }
303 
304   bool need_another_byte = false;
305   int prev_byte_count = input_stream_->ByteCount();
306   {
307     CodedInputStream coded_input_stream(input_stream_.get());
308     if (!coded_input_stream.ReadVarint32(&message_size_))
309       need_another_byte = true;
310   }
311 
312   if (need_another_byte) {
313     DVLOG(1) << "Expecting another message size byte.";
314     if (prev_byte_count >= kSizePacketLenMax) {
315       // Already had enough bytes, something else went wrong.
316       LOG(ERROR) << "Failed to process message size.";
317       read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
318       return;
319     }
320     // Back up by the amount read (should always be 1 byte).
321     int bytes_read = prev_byte_count - input_stream_->ByteCount();
322     DCHECK_EQ(bytes_read, 1);
323     input_stream_->BackUp(bytes_read);
324     WaitForData(MCS_FULL_SIZE);
325     return;
326   }
327 
328   DVLOG(1) << "Proto size: " << message_size_;
329 
330   if (message_size_ > 0)
331     WaitForData(MCS_PROTO_BYTES);
332   else
333     OnGotMessageBytes();
334 }
335 
OnGotMessageBytes()336 void ConnectionHandlerImpl::OnGotMessageBytes() {
337   read_timeout_timer_.Stop();
338   scoped_ptr<google::protobuf::MessageLite> protobuf(
339       BuildProtobufFromTag(message_tag_));
340   // Messages with no content are valid; just use the default protobuf for
341   // that tag.
342   if (protobuf.get() && message_size_ == 0) {
343     base::MessageLoop::current()->PostTask(
344         FROM_HERE,
345         base::Bind(&ConnectionHandlerImpl::GetNextMessage,
346                    weak_ptr_factory_.GetWeakPtr()));
347     read_callback_.Run(protobuf.Pass());
348     return;
349   }
350 
351   if (!protobuf.get() ||
352       input_stream_->GetState() != SocketInputStream::READY) {
353     LOG(ERROR) << "Failed to extract protobuf bytes of type "
354                << static_cast<unsigned int>(message_tag_);
355     protobuf.reset();  // Return a null pointer to denote an error.
356     read_callback_.Run(protobuf.Pass());
357     return;
358   }
359 
360   {
361     CodedInputStream coded_input_stream(input_stream_.get());
362     if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) {
363       NOTREACHED() << "Unable to parse GCM message of type "
364                    << static_cast<unsigned int>(message_tag_);
365       protobuf.reset();  // Return a null pointer to denote an error.
366       read_callback_.Run(protobuf.Pass());
367       return;
368     }
369   }
370 
371   input_stream_->RebuildBuffer();
372   base::MessageLoop::current()->PostTask(
373       FROM_HERE,
374       base::Bind(&ConnectionHandlerImpl::GetNextMessage,
375                  weak_ptr_factory_.GetWeakPtr()));
376   if (message_tag_ == kLoginResponseTag) {
377     if (handshake_complete_) {
378       LOG(ERROR) << "Unexpected login response.";
379     } else {
380       handshake_complete_ = true;
381       DVLOG(1) << "GCM Handshake complete.";
382     }
383   }
384   read_callback_.Run(protobuf.Pass());
385 }
386 
OnTimeout()387 void ConnectionHandlerImpl::OnTimeout() {
388   LOG(ERROR) << "Timed out waiting for GCM Protocol buffer.";
389   CloseConnection();
390   connection_callback_.Run(net::ERR_TIMED_OUT);
391 }
392 
CloseConnection()393 void ConnectionHandlerImpl::CloseConnection() {
394   DVLOG(1) << "Closing connection.";
395   read_callback_.Reset();
396   write_callback_.Reset();
397   read_timeout_timer_.Stop();
398   socket_->Disconnect();
399   input_stream_.reset();
400   output_stream_.reset();
401   weak_ptr_factory_.InvalidateWeakPtrs();
402 }
403 
404 }  // namespace gcm
405