1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "cast/common/public/cast_socket.h"
6
7 #include "cast/common/channel/message_framer.h"
8 #include "cast/common/channel/proto/cast_channel.pb.h"
9 #include "util/osp_logging.h"
10
11 namespace openscreen {
12 namespace cast {
13
14 using ::cast::channel::CastMessage;
15 using message_serialization::DeserializeResult;
16
CastSocket(std::unique_ptr<TlsConnection> connection,Client * client)17 CastSocket::CastSocket(std::unique_ptr<TlsConnection> connection,
18 Client* client)
19 : connection_(std::move(connection)),
20 client_(client),
21 socket_id_(g_next_socket_id_++) {
22 OSP_DCHECK(client);
23 connection_->SetClient(this);
24 }
25
~CastSocket()26 CastSocket::~CastSocket() {
27 connection_->SetClient(nullptr);
28 }
29
Send(const CastMessage & message)30 Error CastSocket::Send(const CastMessage& message) {
31 if (state_ == State::kError) {
32 return Error::Code::kSocketClosedFailure;
33 }
34
35 const ErrorOr<std::vector<uint8_t>> out =
36 message_serialization::Serialize(message);
37 if (!out) {
38 return out.error();
39 }
40
41 if (!connection_->Send(out.value().data(), out.value().size())) {
42 return Error::Code::kAgain;
43 }
44 return Error::Code::kNone;
45 }
46
SetClient(Client * client)47 void CastSocket::SetClient(Client* client) {
48 OSP_DCHECK(client);
49 client_ = client;
50 }
51
GetSanitizedIpAddress()52 std::array<uint8_t, 2> CastSocket::GetSanitizedIpAddress() {
53 IPEndpoint remote = connection_->GetRemoteEndpoint();
54 std::array<uint8_t, 2> result;
55 uint8_t bytes[16];
56 if (remote.address.IsV4()) {
57 remote.address.CopyToV4(bytes);
58 result[0] = bytes[2];
59 result[1] = bytes[3];
60 } else {
61 remote.address.CopyToV6(bytes);
62 result[0] = bytes[14];
63 result[1] = bytes[15];
64 }
65 return result;
66 }
67
OnError(TlsConnection * connection,Error error)68 void CastSocket::OnError(TlsConnection* connection, Error error) {
69 state_ = State::kError;
70 client_->OnError(this, error);
71 }
72
OnRead(TlsConnection * connection,std::vector<uint8_t> block)73 void CastSocket::OnRead(TlsConnection* connection, std::vector<uint8_t> block) {
74 read_buffer_.insert(read_buffer_.end(), block.begin(), block.end());
75 // NOTE: Read as many messages as possible out of |read_buffer_| since we only
76 // get one callback opportunity for this.
77 do {
78 ErrorOr<DeserializeResult> message_or_error =
79 message_serialization::TryDeserialize(
80 absl::Span<uint8_t>(&read_buffer_[0], read_buffer_.size()));
81 if (!message_or_error) {
82 return;
83 }
84 read_buffer_.erase(read_buffer_.begin(),
85 read_buffer_.begin() + message_or_error.value().length);
86 client_->OnMessage(this, std::move(message_or_error.value().message));
87 } while (!read_buffer_.empty());
88 }
89
90 int CastSocket::g_next_socket_id_ = 1;
91
92 } // namespace cast
93 } // namespace openscreen
94