1 // Copyright 2014 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef MOJO_PUBLIC_CPP_BINDINGS_BINDING_SET_H_ 6 #define MOJO_PUBLIC_CPP_BINDINGS_BINDING_SET_H_ 7 8 #include <string> 9 #include <utility> 10 11 #include "base/bind.h" 12 #include "base/callback.h" 13 #include "base/macros.h" 14 #include "base/memory/ptr_util.h" 15 #include "mojo/public/cpp/bindings/binding.h" 16 #include "mojo/public/cpp/bindings/connection_error_callback.h" 17 #include "mojo/public/cpp/bindings/interface_ptr.h" 18 #include "mojo/public/cpp/bindings/interface_request.h" 19 #include "mojo/public/cpp/bindings/message.h" 20 21 namespace mojo { 22 23 template <typename BindingType> 24 struct BindingSetTraits; 25 26 template <typename Interface, typename ImplRefTraits> 27 struct BindingSetTraits<Binding<Interface, ImplRefTraits>> { 28 using ProxyType = InterfacePtr<Interface>; 29 using RequestType = InterfaceRequest<Interface>; 30 using BindingType = Binding<Interface, ImplRefTraits>; 31 using ImplPointerType = typename BindingType::ImplPointerType; 32 33 static RequestType MakeRequest(ProxyType* proxy) { 34 return mojo::MakeRequest(proxy); 35 } 36 }; 37 38 using BindingId = size_t; 39 40 template <typename ContextType> 41 struct BindingSetContextTraits { 42 using Type = ContextType; 43 44 static constexpr bool SupportsContext() { return true; } 45 }; 46 47 template <> 48 struct BindingSetContextTraits<void> { 49 // NOTE: This choice of Type only matters insofar as it affects the size of 50 // the |context_| field of a BindingSetBase::Entry with void context. The 51 // context value is never used in this case. 52 using Type = bool; 53 54 static constexpr bool SupportsContext() { return false; } 55 }; 56 57 // Generic definition used for BindingSet and AssociatedBindingSet to own a 58 // collection of bindings which point to the same implementation. 59 // 60 // If |ContextType| is non-void, then every added binding must include a context 61 // value of that type, and |dispatch_context()| will return that value during 62 // the extent of any message dispatch targeting that specific binding. 63 template <typename Interface, typename BindingType, typename ContextType> 64 class BindingSetBase { 65 public: 66 using ContextTraits = BindingSetContextTraits<ContextType>; 67 using Context = typename ContextTraits::Type; 68 using PreDispatchCallback = base::Callback<void(const Context&)>; 69 using Traits = BindingSetTraits<BindingType>; 70 using ProxyType = typename Traits::ProxyType; 71 using RequestType = typename Traits::RequestType; 72 using ImplPointerType = typename Traits::ImplPointerType; 73 74 BindingSetBase() : weak_ptr_factory_(this) {} 75 76 void set_connection_error_handler(base::RepeatingClosure error_handler) { 77 error_handler_ = std::move(error_handler); 78 error_with_reason_handler_.Reset(); 79 } 80 81 void set_connection_error_with_reason_handler( 82 RepeatingConnectionErrorWithReasonCallback error_handler) { 83 error_with_reason_handler_ = std::move(error_handler); 84 error_handler_.Reset(); 85 } 86 87 // Sets a callback to be invoked immediately before dispatching any message or 88 // error received by any of the bindings in the set. This may only be used 89 // with a non-void |ContextType|. 90 void set_pre_dispatch_handler(const PreDispatchCallback& handler) { 91 static_assert(ContextTraits::SupportsContext(), 92 "Pre-dispatch handler usage requires non-void context type."); 93 pre_dispatch_handler_ = handler; 94 } 95 96 // Adds a new binding to the set which binds |request| to |impl| with no 97 // additional context. 98 BindingId AddBinding(ImplPointerType impl, RequestType request) { 99 static_assert(!ContextTraits::SupportsContext(), 100 "Context value required for non-void context type."); 101 return AddBindingImpl(std::move(impl), std::move(request), false); 102 } 103 104 // Adds a new binding associated with |context|. 105 BindingId AddBinding(ImplPointerType impl, 106 RequestType request, 107 Context context) { 108 static_assert(ContextTraits::SupportsContext(), 109 "Context value unsupported for void context type."); 110 return AddBindingImpl(std::move(impl), std::move(request), 111 std::move(context)); 112 } 113 114 // Removes a binding from the set. Note that this is safe to call even if the 115 // binding corresponding to |id| has already been removed. 116 // 117 // Returns |true| if the binding was removed and |false| if it didn't exist. 118 bool RemoveBinding(BindingId id) { 119 auto it = bindings_.find(id); 120 if (it == bindings_.end()) 121 return false; 122 bindings_.erase(it); 123 return true; 124 } 125 126 // Swaps the interface implementation with a different one, to allow tests 127 // to modify behavior. 128 // 129 // Returns the existing interface implementation to the caller. 130 ImplPointerType SwapImplForTesting(BindingId id, ImplPointerType new_impl) { 131 auto it = bindings_.find(id); 132 if (it == bindings_.end()) 133 return nullptr; 134 135 return it->second->SwapImplForTesting(new_impl); 136 } 137 138 void CloseAllBindings() { bindings_.clear(); } 139 140 bool empty() const { return bindings_.empty(); } 141 142 size_t size() const { return bindings_.size(); } 143 144 // Implementations may call this when processing a dispatched message or 145 // error. During the extent of message or error dispatch, this will return the 146 // context associated with the specific binding which received the message or 147 // error. Use AddBinding() to associated a context with a specific binding. 148 const Context& dispatch_context() const { 149 static_assert(ContextTraits::SupportsContext(), 150 "dispatch_context() requires non-void context type."); 151 DCHECK(dispatch_context_); 152 return *dispatch_context_; 153 } 154 155 // Implementations may call this when processing a dispatched message or 156 // error. During the extent of message or error dispatch, this will return the 157 // BindingId of the specific binding which received the message or error. 158 BindingId dispatch_binding() const { 159 DCHECK(dispatch_context_); 160 return dispatch_binding_; 161 } 162 163 // Reports the currently dispatching Message as bad and closes the binding the 164 // message was received from. Note that this is only legal to call from 165 // directly within the stack frame of a message dispatch. If you need to do 166 // asynchronous work before you can determine the legitimacy of a message, use 167 // GetBadMessageCallback() and retain its result until you're ready to invoke 168 // or discard it. 169 void ReportBadMessage(const std::string& error) { 170 GetBadMessageCallback().Run(error); 171 } 172 173 // Acquires a callback which may be run to report the currently dispatching 174 // Message as bad and close the binding the message was received from. Note 175 // that this is only legal to call from directly within the stack frame of a 176 // message dispatch, but the returned callback may be called exactly once any 177 // time thereafter as long as the binding set itself hasn't been destroyed yet 178 // to report the message as bad. This may only be called once per message. 179 // The returned callback must be called on the BindingSet's own sequence. 180 ReportBadMessageCallback GetBadMessageCallback() { 181 DCHECK(dispatch_context_); 182 return base::BindOnce( 183 [](ReportBadMessageCallback error_callback, 184 base::WeakPtr<BindingSetBase> binding_set, BindingId binding_id, 185 const std::string& error) { 186 std::move(error_callback).Run(error); 187 if (binding_set) 188 binding_set->RemoveBinding(binding_id); 189 }, 190 mojo::GetBadMessageCallback(), weak_ptr_factory_.GetWeakPtr(), 191 dispatch_binding()); 192 } 193 194 void FlushForTesting() { 195 DCHECK(!is_flushing_); 196 is_flushing_ = true; 197 for (auto& binding : bindings_) 198 if (binding.second) 199 binding.second->FlushForTesting(); 200 is_flushing_ = false; 201 // Clean up any bindings that were destroyed. 202 for (auto it = bindings_.begin(); it != bindings_.end();) { 203 if (!it->second) 204 it = bindings_.erase(it); 205 else 206 ++it; 207 } 208 } 209 210 private: 211 friend class Entry; 212 213 class Entry { 214 public: 215 Entry(ImplPointerType impl, 216 RequestType request, 217 BindingSetBase* binding_set, 218 BindingId binding_id, 219 Context context) 220 : binding_(std::move(impl), std::move(request)), 221 binding_set_(binding_set), 222 binding_id_(binding_id), 223 context_(std::move(context)) { 224 binding_.AddFilter(std::make_unique<DispatchFilter>(this)); 225 binding_.set_connection_error_with_reason_handler( 226 base::BindOnce(&Entry::OnConnectionError, base::Unretained(this))); 227 } 228 229 void FlushForTesting() { binding_.FlushForTesting(); } 230 231 ImplPointerType SwapImplForTesting(ImplPointerType new_impl) { 232 return binding_.SwapImplForTesting(new_impl); 233 } 234 235 private: 236 class DispatchFilter : public MessageReceiver { 237 public: 238 explicit DispatchFilter(Entry* entry) : entry_(entry) {} 239 ~DispatchFilter() override {} 240 241 private: 242 // MessageReceiver: 243 bool Accept(Message* message) override { 244 entry_->WillDispatch(); 245 return true; 246 } 247 248 Entry* entry_; 249 250 DISALLOW_COPY_AND_ASSIGN(DispatchFilter); 251 }; 252 253 void WillDispatch() { 254 binding_set_->SetDispatchContext(&context_, binding_id_); 255 } 256 257 void OnConnectionError(uint32_t custom_reason, 258 const std::string& description) { 259 WillDispatch(); 260 binding_set_->OnConnectionError(binding_id_, custom_reason, description); 261 } 262 263 BindingType binding_; 264 BindingSetBase* const binding_set_; 265 const BindingId binding_id_; 266 Context const context_; 267 268 DISALLOW_COPY_AND_ASSIGN(Entry); 269 }; 270 271 void SetDispatchContext(const Context* context, BindingId binding_id) { 272 dispatch_context_ = context; 273 dispatch_binding_ = binding_id; 274 if (!pre_dispatch_handler_.is_null()) 275 pre_dispatch_handler_.Run(*context); 276 } 277 278 BindingId AddBindingImpl(ImplPointerType impl, 279 RequestType request, 280 Context context) { 281 BindingId id = next_binding_id_++; 282 DCHECK_GE(next_binding_id_, 0u); 283 auto entry = std::make_unique<Entry>(std::move(impl), std::move(request), 284 this, id, std::move(context)); 285 bindings_.insert(std::make_pair(id, std::move(entry))); 286 return id; 287 } 288 289 void OnConnectionError(BindingId id, 290 uint32_t custom_reason, 291 const std::string& description) { 292 auto it = bindings_.find(id); 293 DCHECK(it != bindings_.end()); 294 295 // We keep the Entry alive throughout error dispatch. 296 std::unique_ptr<Entry> entry = std::move(it->second); 297 if (!is_flushing_) 298 bindings_.erase(it); 299 300 if (error_handler_) { 301 error_handler_.Run(); 302 } else if (error_with_reason_handler_) { 303 error_with_reason_handler_.Run(custom_reason, description); 304 } 305 } 306 307 base::RepeatingClosure error_handler_; 308 RepeatingConnectionErrorWithReasonCallback error_with_reason_handler_; 309 PreDispatchCallback pre_dispatch_handler_; 310 BindingId next_binding_id_ = 0; 311 std::map<BindingId, std::unique_ptr<Entry>> bindings_; 312 bool is_flushing_ = false; 313 const Context* dispatch_context_ = nullptr; 314 BindingId dispatch_binding_; 315 base::WeakPtrFactory<BindingSetBase> weak_ptr_factory_; 316 317 DISALLOW_COPY_AND_ASSIGN(BindingSetBase); 318 }; 319 320 template <typename Interface, typename ContextType = void> 321 using BindingSet = BindingSetBase<Interface, Binding<Interface>, ContextType>; 322 323 } // namespace mojo 324 325 #endif // MOJO_PUBLIC_CPP_BINDINGS_BINDING_SET_H_ 326