1 // Copyright 2014 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef EXTENSIONS_BROWSER_API_SOCKET_SOCKET_API_H_ 6 #define EXTENSIONS_BROWSER_API_SOCKET_SOCKET_API_H_ 7 8 #include <string> 9 10 #include "base/gtest_prod_util.h" 11 #include "base/memory/ref_counted.h" 12 #include "extensions/browser/api/api_resource_manager.h" 13 #include "extensions/browser/api/async_api_function.h" 14 #include "extensions/browser/extension_function.h" 15 #include "extensions/common/api/socket.h" 16 #include "net/base/address_list.h" 17 #include "net/dns/host_resolver.h" 18 #include "net/socket/tcp_client_socket.h" 19 20 namespace content { 21 class BrowserContext; 22 class ResourceContext; 23 } 24 25 namespace net { 26 class IOBuffer; 27 class URLRequestContextGetter; 28 class SSLClientSocket; 29 } 30 31 namespace extensions { 32 class TLSSocket; 33 class Socket; 34 35 // A simple interface to ApiResourceManager<Socket> or derived class. The goal 36 // of this interface is to allow Socket API functions to use distinct instances 37 // of ApiResourceManager<> depending on the type of socket (old version in 38 // "socket" namespace vs new version in "socket.xxx" namespaces). 39 class SocketResourceManagerInterface { 40 public: ~SocketResourceManagerInterface()41 virtual ~SocketResourceManagerInterface() {} 42 43 virtual bool SetBrowserContext(content::BrowserContext* context) = 0; 44 virtual int Add(Socket* socket) = 0; 45 virtual Socket* Get(const std::string& extension_id, int api_resource_id) = 0; 46 virtual void Remove(const std::string& extension_id, int api_resource_id) = 0; 47 virtual void Replace(const std::string& extension_id, 48 int api_resource_id, 49 Socket* socket) = 0; 50 virtual base::hash_set<int>* GetResourceIds( 51 const std::string& extension_id) = 0; 52 }; 53 54 // Implementation of SocketResourceManagerInterface using an 55 // ApiResourceManager<T> instance (where T derives from Socket). 56 template <typename T> 57 class SocketResourceManager : public SocketResourceManagerInterface { 58 public: SocketResourceManager()59 SocketResourceManager() : manager_(NULL) {} 60 SetBrowserContext(content::BrowserContext * context)61 virtual bool SetBrowserContext(content::BrowserContext* context) OVERRIDE { 62 manager_ = ApiResourceManager<T>::Get(context); 63 DCHECK(manager_) 64 << "There is no socket manager. " 65 "If this assertion is failing during a test, then it is likely that " 66 "TestExtensionSystem is failing to provide an instance of " 67 "ApiResourceManager<Socket>."; 68 return manager_ != NULL; 69 } 70 Add(Socket * socket)71 virtual int Add(Socket* socket) OVERRIDE { 72 // Note: Cast needed here, because "T" may be a subclass of "Socket". 73 return manager_->Add(static_cast<T*>(socket)); 74 } 75 Get(const std::string & extension_id,int api_resource_id)76 virtual Socket* Get(const std::string& extension_id, 77 int api_resource_id) OVERRIDE { 78 return manager_->Get(extension_id, api_resource_id); 79 } 80 Replace(const std::string & extension_id,int api_resource_id,Socket * socket)81 virtual void Replace(const std::string& extension_id, 82 int api_resource_id, 83 Socket* socket) OVERRIDE { 84 manager_->Replace(extension_id, api_resource_id, static_cast<T*>(socket)); 85 } 86 Remove(const std::string & extension_id,int api_resource_id)87 virtual void Remove(const std::string& extension_id, 88 int api_resource_id) OVERRIDE { 89 manager_->Remove(extension_id, api_resource_id); 90 } 91 GetResourceIds(const std::string & extension_id)92 virtual base::hash_set<int>* GetResourceIds(const std::string& extension_id) 93 OVERRIDE { 94 return manager_->GetResourceIds(extension_id); 95 } 96 97 private: 98 ApiResourceManager<T>* manager_; 99 }; 100 101 class SocketAsyncApiFunction : public AsyncApiFunction { 102 public: 103 SocketAsyncApiFunction(); 104 105 protected: 106 virtual ~SocketAsyncApiFunction(); 107 108 // AsyncApiFunction: 109 virtual bool PrePrepare() OVERRIDE; 110 virtual bool Respond() OVERRIDE; 111 112 virtual scoped_ptr<SocketResourceManagerInterface> 113 CreateSocketResourceManager(); 114 115 int AddSocket(Socket* socket); 116 Socket* GetSocket(int api_resource_id); 117 void ReplaceSocket(int api_resource_id, Socket* socket); 118 void RemoveSocket(int api_resource_id); 119 base::hash_set<int>* GetSocketIds(); 120 121 private: 122 scoped_ptr<SocketResourceManagerInterface> manager_; 123 }; 124 125 class SocketExtensionWithDnsLookupFunction : public SocketAsyncApiFunction { 126 protected: 127 SocketExtensionWithDnsLookupFunction(); 128 virtual ~SocketExtensionWithDnsLookupFunction(); 129 130 // AsyncApiFunction: 131 virtual bool PrePrepare() OVERRIDE; 132 133 void StartDnsLookup(const std::string& hostname); 134 virtual void AfterDnsLookup(int lookup_result) = 0; 135 136 std::string resolved_address_; 137 138 private: 139 void OnDnsLookup(int resolve_result); 140 141 // Weak pointer to the resource context. 142 content::ResourceContext* resource_context_; 143 144 scoped_ptr<net::HostResolver::RequestHandle> request_handle_; 145 scoped_ptr<net::AddressList> addresses_; 146 }; 147 148 class SocketCreateFunction : public SocketAsyncApiFunction { 149 public: 150 DECLARE_EXTENSION_FUNCTION("socket.create", SOCKET_CREATE) 151 152 SocketCreateFunction(); 153 154 protected: 155 virtual ~SocketCreateFunction(); 156 157 // AsyncApiFunction: 158 virtual bool Prepare() OVERRIDE; 159 virtual void Work() OVERRIDE; 160 161 private: 162 FRIEND_TEST_ALL_PREFIXES(SocketUnitTest, Create); 163 enum SocketType { kSocketTypeInvalid = -1, kSocketTypeTCP, kSocketTypeUDP }; 164 165 scoped_ptr<core_api::socket::Create::Params> params_; 166 SocketType socket_type_; 167 }; 168 169 class SocketDestroyFunction : public SocketAsyncApiFunction { 170 public: 171 DECLARE_EXTENSION_FUNCTION("socket.destroy", SOCKET_DESTROY) 172 173 protected: ~SocketDestroyFunction()174 virtual ~SocketDestroyFunction() {} 175 176 // AsyncApiFunction: 177 virtual bool Prepare() OVERRIDE; 178 virtual void Work() OVERRIDE; 179 180 private: 181 int socket_id_; 182 }; 183 184 class SocketConnectFunction : public SocketExtensionWithDnsLookupFunction { 185 public: 186 DECLARE_EXTENSION_FUNCTION("socket.connect", SOCKET_CONNECT) 187 188 SocketConnectFunction(); 189 190 protected: 191 virtual ~SocketConnectFunction(); 192 193 // AsyncApiFunction: 194 virtual bool Prepare() OVERRIDE; 195 virtual void AsyncWorkStart() OVERRIDE; 196 197 // SocketExtensionWithDnsLookupFunction: 198 virtual void AfterDnsLookup(int lookup_result) OVERRIDE; 199 200 private: 201 void StartConnect(); 202 void OnConnect(int result); 203 204 int socket_id_; 205 std::string hostname_; 206 int port_; 207 Socket* socket_; 208 }; 209 210 class SocketDisconnectFunction : public SocketAsyncApiFunction { 211 public: 212 DECLARE_EXTENSION_FUNCTION("socket.disconnect", SOCKET_DISCONNECT) 213 214 protected: ~SocketDisconnectFunction()215 virtual ~SocketDisconnectFunction() {} 216 217 // AsyncApiFunction: 218 virtual bool Prepare() OVERRIDE; 219 virtual void Work() OVERRIDE; 220 221 private: 222 int socket_id_; 223 }; 224 225 class SocketBindFunction : public SocketAsyncApiFunction { 226 public: 227 DECLARE_EXTENSION_FUNCTION("socket.bind", SOCKET_BIND) 228 229 protected: ~SocketBindFunction()230 virtual ~SocketBindFunction() {} 231 232 // AsyncApiFunction: 233 virtual bool Prepare() OVERRIDE; 234 virtual void Work() OVERRIDE; 235 236 private: 237 int socket_id_; 238 std::string address_; 239 int port_; 240 }; 241 242 class SocketListenFunction : public SocketAsyncApiFunction { 243 public: 244 DECLARE_EXTENSION_FUNCTION("socket.listen", SOCKET_LISTEN) 245 246 SocketListenFunction(); 247 248 protected: 249 virtual ~SocketListenFunction(); 250 251 // AsyncApiFunction: 252 virtual bool Prepare() OVERRIDE; 253 virtual void Work() OVERRIDE; 254 255 private: 256 scoped_ptr<core_api::socket::Listen::Params> params_; 257 }; 258 259 class SocketAcceptFunction : public SocketAsyncApiFunction { 260 public: 261 DECLARE_EXTENSION_FUNCTION("socket.accept", SOCKET_ACCEPT) 262 263 SocketAcceptFunction(); 264 265 protected: 266 virtual ~SocketAcceptFunction(); 267 268 // AsyncApiFunction: 269 virtual bool Prepare() OVERRIDE; 270 virtual void AsyncWorkStart() OVERRIDE; 271 272 private: 273 void OnAccept(int result_code, net::TCPClientSocket* socket); 274 275 scoped_ptr<core_api::socket::Accept::Params> params_; 276 }; 277 278 class SocketReadFunction : public SocketAsyncApiFunction { 279 public: 280 DECLARE_EXTENSION_FUNCTION("socket.read", SOCKET_READ) 281 282 SocketReadFunction(); 283 284 protected: 285 virtual ~SocketReadFunction(); 286 287 // AsyncApiFunction: 288 virtual bool Prepare() OVERRIDE; 289 virtual void AsyncWorkStart() OVERRIDE; 290 void OnCompleted(int result, scoped_refptr<net::IOBuffer> io_buffer); 291 292 private: 293 scoped_ptr<core_api::socket::Read::Params> params_; 294 }; 295 296 class SocketWriteFunction : public SocketAsyncApiFunction { 297 public: 298 DECLARE_EXTENSION_FUNCTION("socket.write", SOCKET_WRITE) 299 300 SocketWriteFunction(); 301 302 protected: 303 virtual ~SocketWriteFunction(); 304 305 // AsyncApiFunction: 306 virtual bool Prepare() OVERRIDE; 307 virtual void AsyncWorkStart() OVERRIDE; 308 void OnCompleted(int result); 309 310 private: 311 int socket_id_; 312 scoped_refptr<net::IOBuffer> io_buffer_; 313 size_t io_buffer_size_; 314 }; 315 316 class SocketRecvFromFunction : public SocketAsyncApiFunction { 317 public: 318 DECLARE_EXTENSION_FUNCTION("socket.recvFrom", SOCKET_RECVFROM) 319 320 SocketRecvFromFunction(); 321 322 protected: 323 virtual ~SocketRecvFromFunction(); 324 325 // AsyncApiFunction 326 virtual bool Prepare() OVERRIDE; 327 virtual void AsyncWorkStart() OVERRIDE; 328 void OnCompleted(int result, 329 scoped_refptr<net::IOBuffer> io_buffer, 330 const std::string& address, 331 int port); 332 333 private: 334 scoped_ptr<core_api::socket::RecvFrom::Params> params_; 335 }; 336 337 class SocketSendToFunction : public SocketExtensionWithDnsLookupFunction { 338 public: 339 DECLARE_EXTENSION_FUNCTION("socket.sendTo", SOCKET_SENDTO) 340 341 SocketSendToFunction(); 342 343 protected: 344 virtual ~SocketSendToFunction(); 345 346 // AsyncApiFunction: 347 virtual bool Prepare() OVERRIDE; 348 virtual void AsyncWorkStart() OVERRIDE; 349 void OnCompleted(int result); 350 351 // SocketExtensionWithDnsLookupFunction: 352 virtual void AfterDnsLookup(int lookup_result) OVERRIDE; 353 354 private: 355 void StartSendTo(); 356 357 int socket_id_; 358 scoped_refptr<net::IOBuffer> io_buffer_; 359 size_t io_buffer_size_; 360 std::string hostname_; 361 int port_; 362 Socket* socket_; 363 }; 364 365 class SocketSetKeepAliveFunction : public SocketAsyncApiFunction { 366 public: 367 DECLARE_EXTENSION_FUNCTION("socket.setKeepAlive", SOCKET_SETKEEPALIVE) 368 369 SocketSetKeepAliveFunction(); 370 371 protected: 372 virtual ~SocketSetKeepAliveFunction(); 373 374 // AsyncApiFunction: 375 virtual bool Prepare() OVERRIDE; 376 virtual void Work() OVERRIDE; 377 378 private: 379 scoped_ptr<core_api::socket::SetKeepAlive::Params> params_; 380 }; 381 382 class SocketSetNoDelayFunction : public SocketAsyncApiFunction { 383 public: 384 DECLARE_EXTENSION_FUNCTION("socket.setNoDelay", SOCKET_SETNODELAY) 385 386 SocketSetNoDelayFunction(); 387 388 protected: 389 virtual ~SocketSetNoDelayFunction(); 390 391 // AsyncApiFunction: 392 virtual bool Prepare() OVERRIDE; 393 virtual void Work() OVERRIDE; 394 395 private: 396 scoped_ptr<core_api::socket::SetNoDelay::Params> params_; 397 }; 398 399 class SocketGetInfoFunction : public SocketAsyncApiFunction { 400 public: 401 DECLARE_EXTENSION_FUNCTION("socket.getInfo", SOCKET_GETINFO) 402 403 SocketGetInfoFunction(); 404 405 protected: 406 virtual ~SocketGetInfoFunction(); 407 408 // AsyncApiFunction: 409 virtual bool Prepare() OVERRIDE; 410 virtual void Work() OVERRIDE; 411 412 private: 413 scoped_ptr<core_api::socket::GetInfo::Params> params_; 414 }; 415 416 class SocketGetNetworkListFunction : public AsyncExtensionFunction { 417 public: 418 DECLARE_EXTENSION_FUNCTION("socket.getNetworkList", SOCKET_GETNETWORKLIST) 419 420 protected: ~SocketGetNetworkListFunction()421 virtual ~SocketGetNetworkListFunction() {} 422 virtual bool RunAsync() OVERRIDE; 423 424 private: 425 void GetNetworkListOnFileThread(); 426 void HandleGetNetworkListError(); 427 void SendResponseOnUIThread(const net::NetworkInterfaceList& interface_list); 428 }; 429 430 class SocketJoinGroupFunction : public SocketAsyncApiFunction { 431 public: 432 DECLARE_EXTENSION_FUNCTION("socket.joinGroup", SOCKET_MULTICAST_JOIN_GROUP) 433 434 SocketJoinGroupFunction(); 435 436 protected: 437 virtual ~SocketJoinGroupFunction(); 438 439 // AsyncApiFunction 440 virtual bool Prepare() OVERRIDE; 441 virtual void Work() OVERRIDE; 442 443 private: 444 scoped_ptr<core_api::socket::JoinGroup::Params> params_; 445 }; 446 447 class SocketLeaveGroupFunction : public SocketAsyncApiFunction { 448 public: 449 DECLARE_EXTENSION_FUNCTION("socket.leaveGroup", SOCKET_MULTICAST_LEAVE_GROUP) 450 451 SocketLeaveGroupFunction(); 452 453 protected: 454 virtual ~SocketLeaveGroupFunction(); 455 456 // AsyncApiFunction 457 virtual bool Prepare() OVERRIDE; 458 virtual void Work() OVERRIDE; 459 460 private: 461 scoped_ptr<core_api::socket::LeaveGroup::Params> params_; 462 }; 463 464 class SocketSetMulticastTimeToLiveFunction : public SocketAsyncApiFunction { 465 public: 466 DECLARE_EXTENSION_FUNCTION("socket.setMulticastTimeToLive", 467 SOCKET_MULTICAST_SET_TIME_TO_LIVE) 468 469 SocketSetMulticastTimeToLiveFunction(); 470 471 protected: 472 virtual ~SocketSetMulticastTimeToLiveFunction(); 473 474 // AsyncApiFunction 475 virtual bool Prepare() OVERRIDE; 476 virtual void Work() OVERRIDE; 477 478 private: 479 scoped_ptr<core_api::socket::SetMulticastTimeToLive::Params> params_; 480 }; 481 482 class SocketSetMulticastLoopbackModeFunction : public SocketAsyncApiFunction { 483 public: 484 DECLARE_EXTENSION_FUNCTION("socket.setMulticastLoopbackMode", 485 SOCKET_MULTICAST_SET_LOOPBACK_MODE) 486 487 SocketSetMulticastLoopbackModeFunction(); 488 489 protected: 490 virtual ~SocketSetMulticastLoopbackModeFunction(); 491 492 // AsyncApiFunction 493 virtual bool Prepare() OVERRIDE; 494 virtual void Work() OVERRIDE; 495 496 private: 497 scoped_ptr<core_api::socket::SetMulticastLoopbackMode::Params> params_; 498 }; 499 500 class SocketGetJoinedGroupsFunction : public SocketAsyncApiFunction { 501 public: 502 DECLARE_EXTENSION_FUNCTION("socket.getJoinedGroups", 503 SOCKET_MULTICAST_GET_JOINED_GROUPS) 504 505 SocketGetJoinedGroupsFunction(); 506 507 protected: 508 virtual ~SocketGetJoinedGroupsFunction(); 509 510 // AsyncApiFunction 511 virtual bool Prepare() OVERRIDE; 512 virtual void Work() OVERRIDE; 513 514 private: 515 scoped_ptr<core_api::socket::GetJoinedGroups::Params> params_; 516 }; 517 518 class SocketSecureFunction : public SocketAsyncApiFunction { 519 public: 520 DECLARE_EXTENSION_FUNCTION("socket.secure", SOCKET_SECURE); 521 SocketSecureFunction(); 522 523 protected: 524 virtual ~SocketSecureFunction(); 525 526 // AsyncApiFunction 527 virtual bool Prepare() OVERRIDE; 528 virtual void AsyncWorkStart() OVERRIDE; 529 530 private: 531 // Callback from TLSSocket::UpgradeSocketToTLS(). 532 void TlsConnectDone(scoped_ptr<TLSSocket> socket, int result); 533 534 scoped_ptr<core_api::socket::Secure::Params> params_; 535 scoped_refptr<net::URLRequestContextGetter> url_request_getter_; 536 537 DISALLOW_COPY_AND_ASSIGN(SocketSecureFunction); 538 }; 539 540 } // namespace extensions 541 542 #endif // EXTENSIONS_BROWSER_API_SOCKET_SOCKET_API_H_ 543