• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "cast/common/channel/connection_namespace_handler.h"
6 
7 #include <algorithm>
8 #include <string>
9 #include <type_traits>
10 #include <utility>
11 
12 #include "absl/types/optional.h"
13 #include "cast/common/channel/message_util.h"
14 #include "cast/common/channel/proto/cast_channel.pb.h"
15 #include "cast/common/channel/virtual_connection.h"
16 #include "cast/common/channel/virtual_connection_router.h"
17 #include "cast/common/public/cast_socket.h"
18 #include "util/json/json_serialization.h"
19 #include "util/json/json_value.h"
20 #include "util/osp_logging.h"
21 
22 namespace openscreen {
23 namespace cast {
24 
25 using ::cast::channel::CastMessage;
26 using ::cast::channel::CastMessage_PayloadType;
27 
28 namespace {
29 
IsValidProtocolVersion(int version)30 bool IsValidProtocolVersion(int version) {
31   return ::cast::channel::CastMessage_ProtocolVersion_IsValid(version);
32 }
33 
FindMaxProtocolVersion(const Json::Value * version,const Json::Value * version_list)34 absl::optional<int> FindMaxProtocolVersion(const Json::Value* version,
35                                            const Json::Value* version_list) {
36   using ArrayIndex = Json::Value::ArrayIndex;
37   static_assert(std::is_integral<ArrayIndex>::value,
38                 "Assuming ArrayIndex is integral");
39   absl::optional<int> max_version;
40   if (version_list && version_list->isArray()) {
41     max_version = ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0;
42     for (auto it = version_list->begin(), end = version_list->end(); it != end;
43          ++it) {
44       if (it->isInt()) {
45         int version_int = it->asInt();
46         if (IsValidProtocolVersion(version_int) && version_int > *max_version) {
47           max_version = version_int;
48         }
49       }
50     }
51   }
52   if (version && version->isInt()) {
53     int version_int = version->asInt();
54     if (IsValidProtocolVersion(version_int)) {
55       if (!max_version) {
56         max_version = ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0;
57       }
58       if (version_int > max_version) {
59         max_version = version_int;
60       }
61     }
62   }
63   return max_version;
64 }
65 
GetCloseReason(const Json::Value & parsed_message)66 VirtualConnection::CloseReason GetCloseReason(
67     const Json::Value& parsed_message) {
68   VirtualConnection::CloseReason reason =
69       VirtualConnection::CloseReason::kClosedByPeer;
70   absl::optional<int> reason_code = MaybeGetInt(
71       parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyReasonCode));
72   if (reason_code) {
73     int code = reason_code.value();
74     if (code >= VirtualConnection::CloseReason::kFirstReason &&
75         code <= VirtualConnection::CloseReason::kLastReason) {
76       reason = static_cast<VirtualConnection::CloseReason>(code);
77     }
78   }
79   return reason;
80 }
81 
82 }  // namespace
83 
ConnectionNamespaceHandler(VirtualConnectionRouter * vc_router,VirtualConnectionPolicy * vc_policy)84 ConnectionNamespaceHandler::ConnectionNamespaceHandler(
85     VirtualConnectionRouter* vc_router,
86     VirtualConnectionPolicy* vc_policy)
87     : vc_router_(vc_router), vc_policy_(vc_policy) {
88   OSP_DCHECK(vc_router_);
89   OSP_DCHECK(vc_policy_);
90 
91   vc_router_->set_connection_namespace_handler(this);
92 }
93 
~ConnectionNamespaceHandler()94 ConnectionNamespaceHandler::~ConnectionNamespaceHandler() {
95   vc_router_->set_connection_namespace_handler(nullptr);
96 }
97 
OpenRemoteConnection(VirtualConnection conn,RemoteConnectionResultCallback result_callback)98 void ConnectionNamespaceHandler::OpenRemoteConnection(
99     VirtualConnection conn,
100     RemoteConnectionResultCallback result_callback) {
101   OSP_DCHECK(!vc_router_->GetConnectionData(conn));
102   OSP_DCHECK(std::none_of(
103       pending_remote_requests_.begin(), pending_remote_requests_.end(),
104       [&](const PendingRequest& request) { return request.conn == conn; }));
105   pending_remote_requests_.push_back({conn, std::move(result_callback)});
106 
107   SendConnect(std::move(conn));
108 }
109 
CloseRemoteConnection(VirtualConnection conn)110 void ConnectionNamespaceHandler::CloseRemoteConnection(VirtualConnection conn) {
111   if (RemoveConnection(conn, VirtualConnection::kClosedBySelf)) {
112     SendClose(std::move(conn));
113   }
114 }
115 
OnMessage(VirtualConnectionRouter * router,CastSocket * socket,CastMessage message)116 void ConnectionNamespaceHandler::OnMessage(VirtualConnectionRouter* router,
117                                            CastSocket* socket,
118                                            CastMessage message) {
119   if (message.destination_id() == kBroadcastId ||
120       message.source_id() == kBroadcastId ||
121       message.payload_type() !=
122           CastMessage_PayloadType::CastMessage_PayloadType_STRING) {
123     return;
124   }
125 
126   ErrorOr<Json::Value> result = json::Parse(message.payload_utf8());
127   if (result.is_error()) {
128     return;
129   }
130 
131   Json::Value& value = result.value();
132   if (!value.isObject()) {
133     return;
134   }
135 
136   absl::optional<absl::string_view> type =
137       MaybeGetString(value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType));
138   if (!type) {
139     // TODO(btolsch): Some of these paths should have error reporting.  One
140     // possibility is to pass errors back through |router| so higher-level code
141     // can decide whether to show an error to the user, stop talking to a
142     // particular device, etc.
143     return;
144   }
145 
146   absl::string_view type_str = type.value();
147   if (type_str == kMessageTypeConnect) {
148     HandleConnect(socket, std::move(message), std::move(value));
149   } else if (type_str == kMessageTypeClose) {
150     HandleClose(socket, std::move(message), std::move(value));
151   } else if (type_str == kMessageTypeConnected) {
152     HandleConnectedResponse(socket, std::move(message), std::move(value));
153   } else {
154     // NOTE: Unknown message type so ignore it.
155     // TODO(btolsch): Should be included in future error reporting.
156   }
157 }
158 
HandleConnect(CastSocket * socket,CastMessage message,Json::Value parsed_message)159 void ConnectionNamespaceHandler::HandleConnect(CastSocket* socket,
160                                                CastMessage message,
161                                                Json::Value parsed_message) {
162   if (message.destination_id() == kBroadcastId ||
163       message.source_id() == kBroadcastId) {
164     return;
165   }
166 
167   VirtualConnection virtual_conn{std::move(message.destination_id()),
168                                  std::move(message.source_id()),
169                                  ToCastSocketId(socket)};
170   if (!vc_policy_->IsConnectionAllowed(virtual_conn)) {
171     SendClose(std::move(virtual_conn));
172     return;
173   }
174 
175   absl::optional<int> maybe_conn_type = MaybeGetInt(
176       parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyConnType));
177   VirtualConnection::Type conn_type = VirtualConnection::Type::kStrong;
178   if (maybe_conn_type) {
179     int int_type = maybe_conn_type.value();
180     if (int_type < static_cast<int>(VirtualConnection::Type::kMinValue) ||
181         int_type > static_cast<int>(VirtualConnection::Type::kMaxValue)) {
182       SendClose(std::move(virtual_conn));
183       return;
184     }
185     conn_type = static_cast<VirtualConnection::Type>(int_type);
186   }
187 
188   VirtualConnection::AssociatedData data;
189 
190   data.type = conn_type;
191 
192   absl::optional<absl::string_view> user_agent = MaybeGetString(
193       parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyUserAgent));
194   if (user_agent) {
195     data.user_agent = std::string(user_agent.value());
196   }
197 
198   const Json::Value* sender_info_value = parsed_message.find(
199       JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeySenderInfo));
200   if (!sender_info_value || !sender_info_value->isObject()) {
201     // TODO(btolsch): Should this be guessed from user agent?
202     OSP_DVLOG << "No sender info from protocol.";
203   }
204 
205   const Json::Value* version_value = parsed_message.find(
206       JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersion));
207   const Json::Value* version_list_value = parsed_message.find(
208       JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersionList));
209   absl::optional<int> negotiated_version =
210       FindMaxProtocolVersion(version_value, version_list_value);
211   if (negotiated_version) {
212     data.max_protocol_version = static_cast<VirtualConnection::ProtocolVersion>(
213         negotiated_version.value());
214   } else {
215     data.max_protocol_version = VirtualConnection::ProtocolVersion::kV2_1_0;
216   }
217 
218   if (socket) {
219     data.ip_fragment = socket->GetSanitizedIpAddress();
220   } else {
221     data.ip_fragment = {};
222   }
223 
224   OSP_DVLOG << "Connection opened: " << virtual_conn.local_id << ", "
225             << virtual_conn.peer_id << ", " << virtual_conn.socket_id;
226 
227   // NOTE: Only send a response for senders that actually sent a version.  This
228   // maintains compatibility with older senders that don't send a version and
229   // don't expect a response.
230   if (negotiated_version) {
231     SendConnectedResponse(virtual_conn, negotiated_version.value());
232   }
233 
234   vc_router_->AddConnection(std::move(virtual_conn), std::move(data));
235 }
236 
HandleClose(CastSocket * socket,CastMessage message,Json::Value parsed_message)237 void ConnectionNamespaceHandler::HandleClose(CastSocket* socket,
238                                              CastMessage message,
239                                              Json::Value parsed_message) {
240   const VirtualConnection conn{std::move(*message.mutable_destination_id()),
241                                std::move(*message.mutable_source_id()),
242                                ToCastSocketId(socket)};
243   const auto reason = GetCloseReason(parsed_message);
244   if (RemoveConnection(conn, reason)) {
245     OSP_DVLOG << "Connection closed (reason: " << reason
246               << "): " << conn.local_id << ", " << conn.peer_id << ", "
247               << conn.socket_id;
248   }
249 }
250 
HandleConnectedResponse(CastSocket * socket,CastMessage message,Json::Value parsed_message)251 void ConnectionNamespaceHandler::HandleConnectedResponse(
252     CastSocket* socket,
253     CastMessage message,
254     Json::Value parsed_message) {
255   const VirtualConnection conn{std::move(message.destination_id()),
256                                std::move(message.source_id()),
257                                ToCastSocketId(socket)};
258   const auto it = std::find_if(
259       pending_remote_requests_.begin(), pending_remote_requests_.end(),
260       [&](const PendingRequest& request) { return request.conn == conn; });
261   if (it == pending_remote_requests_.end()) {
262     return;
263   }
264 
265   vc_router_->AddConnection(conn,
266                             {VirtualConnection::Type::kStrong,
267                              {},
268                              {},
269                              VirtualConnection::ProtocolVersion::kV2_1_3});
270 
271   const auto callback = std::move(it->result_callback);
272   pending_remote_requests_.erase(it);
273   callback(true);
274 }
275 
SendConnect(VirtualConnection virtual_conn)276 void ConnectionNamespaceHandler::SendConnect(VirtualConnection virtual_conn) {
277   ::cast::channel::CastMessage message =
278       MakeConnectMessage(virtual_conn.local_id, virtual_conn.peer_id);
279   vc_router_->Send(std::move(virtual_conn), std::move(message));
280 }
281 
SendClose(VirtualConnection virtual_conn)282 void ConnectionNamespaceHandler::SendClose(VirtualConnection virtual_conn) {
283   ::cast::channel::CastMessage message =
284       MakeCloseMessage(virtual_conn.local_id, virtual_conn.peer_id);
285   vc_router_->Send(std::move(virtual_conn), std::move(message));
286 }
287 
SendConnectedResponse(const VirtualConnection & virtual_conn,int max_protocol_version)288 void ConnectionNamespaceHandler::SendConnectedResponse(
289     const VirtualConnection& virtual_conn,
290     int max_protocol_version) {
291   Json::Value connected_message(Json::ValueType::objectValue);
292   connected_message[kMessageKeyType] = kMessageTypeConnected;
293   connected_message[kMessageKeyProtocolVersion] =
294       static_cast<int>(max_protocol_version);
295 
296   ErrorOr<std::string> result = json::Stringify(connected_message);
297   if (result.is_error()) {
298     return;
299   }
300 
301   vc_router_->Send(
302       virtual_conn,
303       MakeSimpleUTF8Message(kConnectionNamespace, std::move(result.value())));
304 }
305 
RemoveConnection(const VirtualConnection & conn,VirtualConnection::CloseReason reason)306 bool ConnectionNamespaceHandler::RemoveConnection(
307     const VirtualConnection& conn,
308     VirtualConnection::CloseReason reason) {
309   bool found_connection = false;
310   if (vc_router_->GetConnectionData(conn)) {
311     vc_router_->RemoveConnection(conn, reason);
312     found_connection = true;
313   }
314 
315   // Cancel pending remote request, if any.
316   const auto it = std::find_if(
317       pending_remote_requests_.begin(), pending_remote_requests_.end(),
318       [&](const PendingRequest& request) { return request.conn == conn; });
319   if (it != pending_remote_requests_.end()) {
320     const auto callback = std::move(it->result_callback);
321     pending_remote_requests_.erase(it);
322     callback(false);
323     found_connection = true;
324   }
325 
326   return found_connection;
327 }
328 
329 }  // namespace cast
330 }  // namespace openscreen
331