1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef _DNS_DNSTLSSOCKET_H 18 #define _DNS_DNSTLSSOCKET_H 19 20 #include <future> 21 #include <mutex> 22 #include <openssl/ssl.h> 23 24 #include <android-base/thread_annotations.h> 25 #include <android-base/unique_fd.h> 26 #include <netdutils/Slice.h> 27 #include <netdutils/Status.h> 28 29 #include "dns/DnsTlsServer.h" 30 #include "dns/IDnsTlsSocket.h" 31 #include "dns/LockedQueue.h" 32 33 namespace android { 34 namespace net { 35 36 class IDnsTlsSocketObserver; 37 class DnsTlsSessionCache; 38 39 using netdutils::Slice; 40 41 // A class for managing a TLS socket that sends and receives messages in 42 // [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format). 43 // This class is not aware of query-response pairing or anything else about DNS. 44 // For the observer: 45 // This class is not re-entrant: the observer is not permitted to wait for a call to query() 46 // or the destructor in a callback. Doing so will result in deadlocks. 47 // This class may call the observer at any time after initialize(), until the destructor 48 // returns (but not after). 49 class DnsTlsSocket : public IDnsTlsSocket { 50 public: DnsTlsSocket(const DnsTlsServer & server,unsigned mark,IDnsTlsSocketObserver * _Nonnull observer,DnsTlsSessionCache * _Nonnull cache)51 DnsTlsSocket(const DnsTlsServer& server, unsigned mark, 52 IDnsTlsSocketObserver* _Nonnull observer, 53 DnsTlsSessionCache* _Nonnull cache) : 54 mMark(mark), mServer(server), mObserver(observer), mCache(cache) {} 55 ~DnsTlsSocket(); 56 57 // Creates the SSL context for this session and connect. Returns false on failure. 58 // This method should be called after construction and before use of a DnsTlsSocket. 59 // Only call this method once per DnsTlsSocket. 60 bool initialize() EXCLUDES(mLock); 61 62 // Send a query on the provided SSL socket. |query| contains 63 // the body of a query, not including the ID header. This function will typically return before 64 // the query is actually sent. If this function fails, DnsTlsSocketObserver will be 65 // notified that the socket is closed. 66 // Note that success here indicates successful sending, not receipt of a response. 67 // Thread-safe. 68 bool query(uint16_t id, const Slice query) override EXCLUDES(mLock); 69 70 private: 71 // Lock to be held by the SSL event loop thread. This is not normally in contention. 72 std::mutex mLock; 73 74 // Forwards queries and receives responses. Blocks until the idle timeout. 75 void loop() EXCLUDES(mLock); 76 std::unique_ptr<std::thread> mLoopThread GUARDED_BY(mLock); 77 78 // On success, sets mSslFd to a socket connected to mAddr (the 79 // connection will likely be in progress if mProtocol is IPPROTO_TCP). 80 // On error, returns the errno. 81 netdutils::Status tcpConnect() REQUIRES(mLock); 82 83 // Connect an SSL session on the provided socket. If connection fails, closing the 84 // socket remains the caller's responsibility. 85 bssl::UniquePtr<SSL> sslConnect(int fd) REQUIRES(mLock); 86 87 // Disconnect the SSL session and close the socket. 88 void sslDisconnect() REQUIRES(mLock); 89 90 // Writes a buffer to the socket. 91 bool sslWrite(const Slice buffer) REQUIRES(mLock); 92 93 // Reads exactly the specified number of bytes from the socket, or fails. 94 // Returns SSL_ERROR_NONE on success. 95 // If |wait| is true, then this function always blocks. Otherwise, it 96 // will return SSL_ERROR_WANT_READ if there is no data from the server to read. 97 int sslRead(const Slice buffer, bool wait) REQUIRES(mLock); 98 99 bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock); 100 bool readResponse() REQUIRES(mLock); 101 102 // Similar to query(), this function uses incrementEventFd to send a message to the 103 // loop thread. However, instead of incrementing the counter by one (indicating a 104 // new query), it wraps the counter to negative, which we use to indicate a shutdown 105 // request. 106 void requestLoopShutdown() EXCLUDES(mLock); 107 108 // This function sends a message to the loop thread by incrementing mEventFd. 109 bool incrementEventFd(int64_t count) EXCLUDES(mLock); 110 111 // Queue of pending queries. query() pushes items onto the queue and notifies 112 // the loop thread by incrementing mEventFd. loop() reads items off the queue. 113 LockedQueue<std::vector<uint8_t>> mQueue; 114 115 // eventfd socket used for notifying the SSL thread when queries are ready to send. 116 // This socket acts similarly to an atomic counter, incremented by query() and cleared 117 // by loop(). We have to use a socket because the SSL thread needs to wait in poll() 118 // for input from either a remote server or a query thread. Since eventfd does not have 119 // EOF, we indicate a close request by setting the counter to a negative number. 120 // This file descriptor is opened by initialize(), and closed implicitly after 121 // destruction. 122 base::unique_fd mEventFd; 123 124 // SSL Socket fields. 125 bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock); 126 base::unique_fd mSslFd GUARDED_BY(mLock); 127 bssl::UniquePtr<SSL> mSsl GUARDED_BY(mLock); 128 static constexpr std::chrono::seconds kIdleTimeout = std::chrono::seconds(20); 129 130 const unsigned mMark; // Socket mark 131 const DnsTlsServer mServer; 132 IDnsTlsSocketObserver* _Nonnull const mObserver; 133 DnsTlsSessionCache* _Nonnull const mCache; 134 }; 135 136 } // end of namespace net 137 } // end of namespace android 138 139 #endif // _DNS_DNSTLSSOCKET_H 140