1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "cast/common/channel/connection_namespace_handler.h"
6
7 #include <algorithm>
8 #include <string>
9 #include <type_traits>
10 #include <utility>
11
12 #include "absl/types/optional.h"
13 #include "cast/common/channel/message_util.h"
14 #include "cast/common/channel/proto/cast_channel.pb.h"
15 #include "cast/common/channel/virtual_connection.h"
16 #include "cast/common/channel/virtual_connection_router.h"
17 #include "cast/common/public/cast_socket.h"
18 #include "util/json/json_serialization.h"
19 #include "util/json/json_value.h"
20 #include "util/osp_logging.h"
21
22 namespace openscreen {
23 namespace cast {
24
25 using ::cast::channel::CastMessage;
26 using ::cast::channel::CastMessage_PayloadType;
27
28 namespace {
29
IsValidProtocolVersion(int version)30 bool IsValidProtocolVersion(int version) {
31 return ::cast::channel::CastMessage_ProtocolVersion_IsValid(version);
32 }
33
FindMaxProtocolVersion(const Json::Value * version,const Json::Value * version_list)34 absl::optional<int> FindMaxProtocolVersion(const Json::Value* version,
35 const Json::Value* version_list) {
36 using ArrayIndex = Json::Value::ArrayIndex;
37 static_assert(std::is_integral<ArrayIndex>::value,
38 "Assuming ArrayIndex is integral");
39 absl::optional<int> max_version;
40 if (version_list && version_list->isArray()) {
41 max_version = ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0;
42 for (auto it = version_list->begin(), end = version_list->end(); it != end;
43 ++it) {
44 if (it->isInt()) {
45 int version_int = it->asInt();
46 if (IsValidProtocolVersion(version_int) && version_int > *max_version) {
47 max_version = version_int;
48 }
49 }
50 }
51 }
52 if (version && version->isInt()) {
53 int version_int = version->asInt();
54 if (IsValidProtocolVersion(version_int)) {
55 if (!max_version) {
56 max_version = ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0;
57 }
58 if (version_int > max_version) {
59 max_version = version_int;
60 }
61 }
62 }
63 return max_version;
64 }
65
GetCloseReason(const Json::Value & parsed_message)66 VirtualConnection::CloseReason GetCloseReason(
67 const Json::Value& parsed_message) {
68 VirtualConnection::CloseReason reason =
69 VirtualConnection::CloseReason::kClosedByPeer;
70 absl::optional<int> reason_code = MaybeGetInt(
71 parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyReasonCode));
72 if (reason_code) {
73 int code = reason_code.value();
74 if (code >= VirtualConnection::CloseReason::kFirstReason &&
75 code <= VirtualConnection::CloseReason::kLastReason) {
76 reason = static_cast<VirtualConnection::CloseReason>(code);
77 }
78 }
79 return reason;
80 }
81
82 } // namespace
83
ConnectionNamespaceHandler(VirtualConnectionRouter * vc_router,VirtualConnectionPolicy * vc_policy)84 ConnectionNamespaceHandler::ConnectionNamespaceHandler(
85 VirtualConnectionRouter* vc_router,
86 VirtualConnectionPolicy* vc_policy)
87 : vc_router_(vc_router), vc_policy_(vc_policy) {
88 OSP_DCHECK(vc_router_);
89 OSP_DCHECK(vc_policy_);
90
91 vc_router_->set_connection_namespace_handler(this);
92 }
93
~ConnectionNamespaceHandler()94 ConnectionNamespaceHandler::~ConnectionNamespaceHandler() {
95 vc_router_->set_connection_namespace_handler(nullptr);
96 }
97
OpenRemoteConnection(VirtualConnection conn,RemoteConnectionResultCallback result_callback)98 void ConnectionNamespaceHandler::OpenRemoteConnection(
99 VirtualConnection conn,
100 RemoteConnectionResultCallback result_callback) {
101 OSP_DCHECK(!vc_router_->GetConnectionData(conn));
102 OSP_DCHECK(std::none_of(
103 pending_remote_requests_.begin(), pending_remote_requests_.end(),
104 [&](const PendingRequest& request) { return request.conn == conn; }));
105 pending_remote_requests_.push_back({conn, std::move(result_callback)});
106
107 SendConnect(std::move(conn));
108 }
109
CloseRemoteConnection(VirtualConnection conn)110 void ConnectionNamespaceHandler::CloseRemoteConnection(VirtualConnection conn) {
111 if (RemoveConnection(conn, VirtualConnection::kClosedBySelf)) {
112 SendClose(std::move(conn));
113 }
114 }
115
OnMessage(VirtualConnectionRouter * router,CastSocket * socket,CastMessage message)116 void ConnectionNamespaceHandler::OnMessage(VirtualConnectionRouter* router,
117 CastSocket* socket,
118 CastMessage message) {
119 if (message.destination_id() == kBroadcastId ||
120 message.source_id() == kBroadcastId ||
121 message.payload_type() !=
122 CastMessage_PayloadType::CastMessage_PayloadType_STRING) {
123 return;
124 }
125
126 ErrorOr<Json::Value> result = json::Parse(message.payload_utf8());
127 if (result.is_error()) {
128 return;
129 }
130
131 Json::Value& value = result.value();
132 if (!value.isObject()) {
133 return;
134 }
135
136 absl::optional<absl::string_view> type =
137 MaybeGetString(value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType));
138 if (!type) {
139 // TODO(btolsch): Some of these paths should have error reporting. One
140 // possibility is to pass errors back through |router| so higher-level code
141 // can decide whether to show an error to the user, stop talking to a
142 // particular device, etc.
143 return;
144 }
145
146 absl::string_view type_str = type.value();
147 if (type_str == kMessageTypeConnect) {
148 HandleConnect(socket, std::move(message), std::move(value));
149 } else if (type_str == kMessageTypeClose) {
150 HandleClose(socket, std::move(message), std::move(value));
151 } else if (type_str == kMessageTypeConnected) {
152 HandleConnectedResponse(socket, std::move(message), std::move(value));
153 } else {
154 // NOTE: Unknown message type so ignore it.
155 // TODO(btolsch): Should be included in future error reporting.
156 }
157 }
158
HandleConnect(CastSocket * socket,CastMessage message,Json::Value parsed_message)159 void ConnectionNamespaceHandler::HandleConnect(CastSocket* socket,
160 CastMessage message,
161 Json::Value parsed_message) {
162 if (message.destination_id() == kBroadcastId ||
163 message.source_id() == kBroadcastId) {
164 return;
165 }
166
167 VirtualConnection virtual_conn{std::move(message.destination_id()),
168 std::move(message.source_id()),
169 ToCastSocketId(socket)};
170 if (!vc_policy_->IsConnectionAllowed(virtual_conn)) {
171 SendClose(std::move(virtual_conn));
172 return;
173 }
174
175 absl::optional<int> maybe_conn_type = MaybeGetInt(
176 parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyConnType));
177 VirtualConnection::Type conn_type = VirtualConnection::Type::kStrong;
178 if (maybe_conn_type) {
179 int int_type = maybe_conn_type.value();
180 if (int_type < static_cast<int>(VirtualConnection::Type::kMinValue) ||
181 int_type > static_cast<int>(VirtualConnection::Type::kMaxValue)) {
182 SendClose(std::move(virtual_conn));
183 return;
184 }
185 conn_type = static_cast<VirtualConnection::Type>(int_type);
186 }
187
188 VirtualConnection::AssociatedData data;
189
190 data.type = conn_type;
191
192 absl::optional<absl::string_view> user_agent = MaybeGetString(
193 parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyUserAgent));
194 if (user_agent) {
195 data.user_agent = std::string(user_agent.value());
196 }
197
198 const Json::Value* sender_info_value = parsed_message.find(
199 JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeySenderInfo));
200 if (!sender_info_value || !sender_info_value->isObject()) {
201 // TODO(btolsch): Should this be guessed from user agent?
202 OSP_DVLOG << "No sender info from protocol.";
203 }
204
205 const Json::Value* version_value = parsed_message.find(
206 JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersion));
207 const Json::Value* version_list_value = parsed_message.find(
208 JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersionList));
209 absl::optional<int> negotiated_version =
210 FindMaxProtocolVersion(version_value, version_list_value);
211 if (negotiated_version) {
212 data.max_protocol_version = static_cast<VirtualConnection::ProtocolVersion>(
213 negotiated_version.value());
214 } else {
215 data.max_protocol_version = VirtualConnection::ProtocolVersion::kV2_1_0;
216 }
217
218 if (socket) {
219 data.ip_fragment = socket->GetSanitizedIpAddress();
220 } else {
221 data.ip_fragment = {};
222 }
223
224 OSP_DVLOG << "Connection opened: " << virtual_conn.local_id << ", "
225 << virtual_conn.peer_id << ", " << virtual_conn.socket_id;
226
227 // NOTE: Only send a response for senders that actually sent a version. This
228 // maintains compatibility with older senders that don't send a version and
229 // don't expect a response.
230 if (negotiated_version) {
231 SendConnectedResponse(virtual_conn, negotiated_version.value());
232 }
233
234 vc_router_->AddConnection(std::move(virtual_conn), std::move(data));
235 }
236
HandleClose(CastSocket * socket,CastMessage message,Json::Value parsed_message)237 void ConnectionNamespaceHandler::HandleClose(CastSocket* socket,
238 CastMessage message,
239 Json::Value parsed_message) {
240 const VirtualConnection conn{std::move(*message.mutable_destination_id()),
241 std::move(*message.mutable_source_id()),
242 ToCastSocketId(socket)};
243 const auto reason = GetCloseReason(parsed_message);
244 if (RemoveConnection(conn, reason)) {
245 OSP_DVLOG << "Connection closed (reason: " << reason
246 << "): " << conn.local_id << ", " << conn.peer_id << ", "
247 << conn.socket_id;
248 }
249 }
250
HandleConnectedResponse(CastSocket * socket,CastMessage message,Json::Value parsed_message)251 void ConnectionNamespaceHandler::HandleConnectedResponse(
252 CastSocket* socket,
253 CastMessage message,
254 Json::Value parsed_message) {
255 const VirtualConnection conn{std::move(message.destination_id()),
256 std::move(message.source_id()),
257 ToCastSocketId(socket)};
258 const auto it = std::find_if(
259 pending_remote_requests_.begin(), pending_remote_requests_.end(),
260 [&](const PendingRequest& request) { return request.conn == conn; });
261 if (it == pending_remote_requests_.end()) {
262 return;
263 }
264
265 vc_router_->AddConnection(conn,
266 {VirtualConnection::Type::kStrong,
267 {},
268 {},
269 VirtualConnection::ProtocolVersion::kV2_1_3});
270
271 const auto callback = std::move(it->result_callback);
272 pending_remote_requests_.erase(it);
273 callback(true);
274 }
275
SendConnect(VirtualConnection virtual_conn)276 void ConnectionNamespaceHandler::SendConnect(VirtualConnection virtual_conn) {
277 ::cast::channel::CastMessage message =
278 MakeConnectMessage(virtual_conn.local_id, virtual_conn.peer_id);
279 vc_router_->Send(std::move(virtual_conn), std::move(message));
280 }
281
SendClose(VirtualConnection virtual_conn)282 void ConnectionNamespaceHandler::SendClose(VirtualConnection virtual_conn) {
283 ::cast::channel::CastMessage message =
284 MakeCloseMessage(virtual_conn.local_id, virtual_conn.peer_id);
285 vc_router_->Send(std::move(virtual_conn), std::move(message));
286 }
287
SendConnectedResponse(const VirtualConnection & virtual_conn,int max_protocol_version)288 void ConnectionNamespaceHandler::SendConnectedResponse(
289 const VirtualConnection& virtual_conn,
290 int max_protocol_version) {
291 Json::Value connected_message(Json::ValueType::objectValue);
292 connected_message[kMessageKeyType] = kMessageTypeConnected;
293 connected_message[kMessageKeyProtocolVersion] =
294 static_cast<int>(max_protocol_version);
295
296 ErrorOr<std::string> result = json::Stringify(connected_message);
297 if (result.is_error()) {
298 return;
299 }
300
301 vc_router_->Send(
302 virtual_conn,
303 MakeSimpleUTF8Message(kConnectionNamespace, std::move(result.value())));
304 }
305
RemoveConnection(const VirtualConnection & conn,VirtualConnection::CloseReason reason)306 bool ConnectionNamespaceHandler::RemoveConnection(
307 const VirtualConnection& conn,
308 VirtualConnection::CloseReason reason) {
309 bool found_connection = false;
310 if (vc_router_->GetConnectionData(conn)) {
311 vc_router_->RemoveConnection(conn, reason);
312 found_connection = true;
313 }
314
315 // Cancel pending remote request, if any.
316 const auto it = std::find_if(
317 pending_remote_requests_.begin(), pending_remote_requests_.end(),
318 [&](const PendingRequest& request) { return request.conn == conn; });
319 if (it != pending_remote_requests_.end()) {
320 const auto callback = std::move(it->result_callback);
321 pending_remote_requests_.erase(it);
322 callback(false);
323 found_connection = true;
324 }
325
326 return found_connection;
327 }
328
329 } // namespace cast
330 } // namespace openscreen
331