• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "remoting/host/gnubby_auth_handler_posix.h"
6 
7 #include <unistd.h>
8 #include <utility>
9 
10 #include "base/bind.h"
11 #include "base/file_util.h"
12 #include "base/json/json_reader.h"
13 #include "base/json/json_writer.h"
14 #include "base/lazy_instance.h"
15 #include "base/stl_util.h"
16 #include "base/values.h"
17 #include "net/socket/unix_domain_socket_posix.h"
18 #include "remoting/base/logging.h"
19 #include "remoting/host/gnubby_socket.h"
20 #include "remoting/proto/control.pb.h"
21 #include "remoting/protocol/client_stub.h"
22 
23 namespace remoting {
24 
25 namespace {
26 
27 const char kConnectionId[] = "connectionId";
28 const char kControlMessage[] = "control";
29 const char kControlOption[] = "option";
30 const char kDataMessage[] = "data";
31 const char kDataPayload[] = "data";
32 const char kErrorMessage[] = "error";
33 const char kGnubbyAuthMessage[] = "gnubby-auth";
34 const char kGnubbyAuthV1[] = "auth-v1";
35 const char kMessageType[] = "type";
36 
37 // The name of the socket to listen for gnubby requests on.
38 base::LazyInstance<base::FilePath>::Leaky g_gnubby_socket_name =
39     LAZY_INSTANCE_INITIALIZER;
40 
41 // STL predicate to match by a StreamListenSocket pointer.
42 class CompareSocket {
43  public:
CompareSocket(net::StreamListenSocket * socket)44   explicit CompareSocket(net::StreamListenSocket* socket) : socket_(socket) {}
45 
operator ()(const std::pair<int,GnubbySocket * > element) const46   bool operator()(const std::pair<int, GnubbySocket*> element) const {
47     return element.second->IsSocket(socket_);
48   }
49 
50  private:
51   net::StreamListenSocket* socket_;
52 };
53 
54 // Socket authentication function that only allows connections from callers with
55 // the current uid.
MatchUid(uid_t user_id,gid_t)56 bool MatchUid(uid_t user_id, gid_t) {
57   bool allowed = user_id == getuid();
58   if (!allowed)
59     HOST_LOG << "Refused socket connection from uid " << user_id;
60   return allowed;
61 }
62 
63 // Returns the command code (the first byte of the data) if it exists, or -1 if
64 // the data is empty.
GetCommandCode(const std::string & data)65 unsigned int GetCommandCode(const std::string& data) {
66   return data.empty() ? -1 : static_cast<unsigned int>(data[0]);
67 }
68 
69 // Creates a string of byte data from a ListValue of numbers. Returns true if
70 // all of the list elements are numbers.
ConvertListValueToString(base::ListValue * bytes,std::string * out)71 bool ConvertListValueToString(base::ListValue* bytes, std::string* out) {
72   out->clear();
73 
74   unsigned int byte_count = bytes->GetSize();
75   if (byte_count != 0) {
76     out->reserve(byte_count);
77     for (unsigned int i = 0; i < byte_count; i++) {
78       int value;
79       if (!bytes->GetInteger(i, &value))
80         return false;
81       out->push_back(static_cast<char>(value));
82     }
83   }
84   return true;
85 }
86 
87 }  // namespace
88 
GnubbyAuthHandlerPosix(protocol::ClientStub * client_stub)89 GnubbyAuthHandlerPosix::GnubbyAuthHandlerPosix(
90     protocol::ClientStub* client_stub)
91     : client_stub_(client_stub), last_connection_id_(0) {
92   DCHECK(client_stub_);
93 }
94 
~GnubbyAuthHandlerPosix()95 GnubbyAuthHandlerPosix::~GnubbyAuthHandlerPosix() {
96   STLDeleteValues(&active_sockets_);
97 }
98 
99 // static
Create(protocol::ClientStub * client_stub)100 scoped_ptr<GnubbyAuthHandler> GnubbyAuthHandler::Create(
101     protocol::ClientStub* client_stub) {
102   return scoped_ptr<GnubbyAuthHandler>(new GnubbyAuthHandlerPosix(client_stub));
103 }
104 
105 // static
SetGnubbySocketName(const base::FilePath & gnubby_socket_name)106 void GnubbyAuthHandler::SetGnubbySocketName(
107     const base::FilePath& gnubby_socket_name) {
108   g_gnubby_socket_name.Get() = gnubby_socket_name;
109 }
110 
DeliverClientMessage(const std::string & message)111 void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string& message) {
112   DCHECK(CalledOnValidThread());
113 
114   scoped_ptr<base::Value> value(base::JSONReader::Read(message));
115   base::DictionaryValue* client_message;
116   if (value && value->GetAsDictionary(&client_message)) {
117     std::string type;
118     if (!client_message->GetString(kMessageType, &type)) {
119       LOG(ERROR) << "Invalid gnubby-auth message";
120       return;
121     }
122 
123     if (type == kControlMessage) {
124       std::string option;
125       if (client_message->GetString(kControlOption, &option) &&
126           option == kGnubbyAuthV1) {
127         CreateAuthorizationSocket();
128       } else {
129         LOG(ERROR) << "Invalid gnubby-auth control option";
130       }
131     } else if (type == kDataMessage) {
132       ActiveSockets::iterator iter = GetSocketForMessage(client_message);
133       if (iter != active_sockets_.end()) {
134         base::ListValue* bytes;
135         std::string response;
136         if (client_message->GetList(kDataPayload, &bytes) &&
137             ConvertListValueToString(bytes, &response)) {
138           HOST_LOG << "Sending gnubby response: " << GetCommandCode(response);
139           iter->second->SendResponse(response);
140         } else {
141           LOG(ERROR) << "Invalid gnubby data";
142           SendErrorAndCloseActiveSocket(iter);
143         }
144       } else {
145         LOG(ERROR) << "Unknown gnubby-auth data connection";
146       }
147     } else if (type == kErrorMessage) {
148       ActiveSockets::iterator iter = GetSocketForMessage(client_message);
149       if (iter != active_sockets_.end()) {
150         HOST_LOG << "Sending gnubby error";
151         SendErrorAndCloseActiveSocket(iter);
152       } else {
153         LOG(ERROR) << "Unknown gnubby-auth error connection";
154       }
155     } else {
156       LOG(ERROR) << "Unknown gnubby-auth message type: " << type;
157     }
158   }
159 }
160 
DeliverHostDataMessage(int connection_id,const std::string & data) const161 void GnubbyAuthHandlerPosix::DeliverHostDataMessage(
162     int connection_id,
163     const std::string& data) const {
164   DCHECK(CalledOnValidThread());
165 
166   base::DictionaryValue request;
167   request.SetString(kMessageType, kDataMessage);
168   request.SetInteger(kConnectionId, connection_id);
169 
170   base::ListValue* bytes = new base::ListValue();
171   for (std::string::const_iterator i = data.begin(); i != data.end(); ++i) {
172     bytes->AppendInteger(static_cast<unsigned char>(*i));
173   }
174   request.Set(kDataPayload, bytes);
175 
176   std::string request_json;
177   if (!base::JSONWriter::Write(&request, &request_json)) {
178     LOG(ERROR) << "Failed to create request json";
179     return;
180   }
181 
182   protocol::ExtensionMessage message;
183   message.set_type(kGnubbyAuthMessage);
184   message.set_data(request_json);
185 
186   client_stub_->DeliverHostMessage(message);
187 }
188 
HasActiveSocketForTesting(net::StreamListenSocket * socket) const189 bool GnubbyAuthHandlerPosix::HasActiveSocketForTesting(
190     net::StreamListenSocket* socket) const {
191   return std::find_if(active_sockets_.begin(),
192                       active_sockets_.end(),
193                       CompareSocket(socket)) != active_sockets_.end();
194 }
195 
GetConnectionIdForTesting(net::StreamListenSocket * socket) const196 int GnubbyAuthHandlerPosix::GetConnectionIdForTesting(
197     net::StreamListenSocket* socket) const {
198   ActiveSockets::const_iterator iter = std::find_if(
199       active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
200   return iter->first;
201 }
202 
GetGnubbySocketForTesting(net::StreamListenSocket * socket) const203 GnubbySocket* GnubbyAuthHandlerPosix::GetGnubbySocketForTesting(
204     net::StreamListenSocket* socket) const {
205   ActiveSockets::const_iterator iter = std::find_if(
206       active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
207   return iter->second;
208 }
209 
DidAccept(net::StreamListenSocket * server,scoped_ptr<net::StreamListenSocket> socket)210 void GnubbyAuthHandlerPosix::DidAccept(
211     net::StreamListenSocket* server,
212     scoped_ptr<net::StreamListenSocket> socket) {
213   DCHECK(CalledOnValidThread());
214 
215   int connection_id = ++last_connection_id_;
216   active_sockets_[connection_id] =
217       new GnubbySocket(socket.Pass(),
218                        base::Bind(&GnubbyAuthHandlerPosix::RequestTimedOut,
219                                   base::Unretained(this),
220                                   connection_id));
221 }
222 
DidRead(net::StreamListenSocket * socket,const char * data,int len)223 void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket* socket,
224                                      const char* data,
225                                      int len) {
226   DCHECK(CalledOnValidThread());
227 
228   ActiveSockets::iterator iter = std::find_if(
229       active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
230   if (iter != active_sockets_.end()) {
231     GnubbySocket* gnubby_socket = iter->second;
232     gnubby_socket->AddRequestData(data, len);
233     if (gnubby_socket->IsRequestTooLarge()) {
234       SendErrorAndCloseActiveSocket(iter);
235     } else if (gnubby_socket->IsRequestComplete()) {
236       std::string request_data;
237       gnubby_socket->GetAndClearRequestData(&request_data);
238       ProcessGnubbyRequest(iter->first, request_data);
239     }
240   } else {
241     LOG(ERROR) << "Received data for unknown connection";
242   }
243 }
244 
DidClose(net::StreamListenSocket * socket)245 void GnubbyAuthHandlerPosix::DidClose(net::StreamListenSocket* socket) {
246   DCHECK(CalledOnValidThread());
247 
248   ActiveSockets::iterator iter = std::find_if(
249       active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
250   if (iter != active_sockets_.end()) {
251     delete iter->second;
252     active_sockets_.erase(iter);
253   }
254 }
255 
CreateAuthorizationSocket()256 void GnubbyAuthHandlerPosix::CreateAuthorizationSocket() {
257   DCHECK(CalledOnValidThread());
258 
259   if (!g_gnubby_socket_name.Get().empty()) {
260     // If the file already exists, a socket in use error is returned.
261     base::DeleteFile(g_gnubby_socket_name.Get(), false);
262 
263     HOST_LOG << "Listening for gnubby requests on "
264              << g_gnubby_socket_name.Get().value();
265 
266     auth_socket_ = net::UnixDomainSocket::CreateAndListen(
267         g_gnubby_socket_name.Get().value(), this, base::Bind(MatchUid));
268     if (!auth_socket_.get()) {
269       LOG(ERROR) << "Failed to open socket for gnubby requests";
270     }
271   } else {
272     HOST_LOG << "No gnubby socket name specified";
273   }
274 }
275 
ProcessGnubbyRequest(int connection_id,const std::string & request_data)276 void GnubbyAuthHandlerPosix::ProcessGnubbyRequest(
277     int connection_id,
278     const std::string& request_data) {
279   HOST_LOG << "Received gnubby request: " << GetCommandCode(request_data);
280   DeliverHostDataMessage(connection_id, request_data);
281 }
282 
283 GnubbyAuthHandlerPosix::ActiveSockets::iterator
GetSocketForMessage(base::DictionaryValue * message)284 GnubbyAuthHandlerPosix::GetSocketForMessage(base::DictionaryValue* message) {
285   int connection_id;
286   if (message->GetInteger(kConnectionId, &connection_id)) {
287     return active_sockets_.find(connection_id);
288   }
289   return active_sockets_.end();
290 }
291 
SendErrorAndCloseActiveSocket(const ActiveSockets::iterator & iter)292 void GnubbyAuthHandlerPosix::SendErrorAndCloseActiveSocket(
293     const ActiveSockets::iterator& iter) {
294   iter->second->SendSshError();
295 
296   delete iter->second;
297   active_sockets_.erase(iter);
298 }
299 
RequestTimedOut(int connection_id)300 void GnubbyAuthHandlerPosix::RequestTimedOut(int connection_id) {
301   HOST_LOG << "Gnubby request timed out";
302   ActiveSockets::iterator iter = active_sockets_.find(connection_id);
303   if (iter != active_sockets_.end())
304     SendErrorAndCloseActiveSocket(iter);
305 }
306 
307 }  // namespace remoting
308