• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef MOJO_PUBLIC_CPP_BINDINGS_BINDING_SET_H_
6 #define MOJO_PUBLIC_CPP_BINDINGS_BINDING_SET_H_
7 
8 #include <string>
9 #include <utility>
10 
11 #include "base/bind.h"
12 #include "base/callback.h"
13 #include "base/macros.h"
14 #include "base/memory/ptr_util.h"
15 #include "mojo/public/cpp/bindings/binding.h"
16 #include "mojo/public/cpp/bindings/connection_error_callback.h"
17 #include "mojo/public/cpp/bindings/interface_ptr.h"
18 #include "mojo/public/cpp/bindings/interface_request.h"
19 #include "mojo/public/cpp/bindings/message.h"
20 
21 namespace mojo {
22 
23 template <typename BindingType>
24 struct BindingSetTraits;
25 
26 template <typename Interface, typename ImplRefTraits>
27 struct BindingSetTraits<Binding<Interface, ImplRefTraits>> {
28   using ProxyType = InterfacePtr<Interface>;
29   using RequestType = InterfaceRequest<Interface>;
30   using BindingType = Binding<Interface, ImplRefTraits>;
31   using ImplPointerType = typename BindingType::ImplPointerType;
32 
33   static RequestType MakeRequest(ProxyType* proxy) {
34     return mojo::MakeRequest(proxy);
35   }
36 };
37 
38 using BindingId = size_t;
39 
40 template <typename ContextType>
41 struct BindingSetContextTraits {
42   using Type = ContextType;
43 
44   static constexpr bool SupportsContext() { return true; }
45 };
46 
47 template <>
48 struct BindingSetContextTraits<void> {
49   // NOTE: This choice of Type only matters insofar as it affects the size of
50   // the |context_| field of a BindingSetBase::Entry with void context. The
51   // context value is never used in this case.
52   using Type = bool;
53 
54   static constexpr bool SupportsContext() { return false; }
55 };
56 
57 // Generic definition used for BindingSet and AssociatedBindingSet to own a
58 // collection of bindings which point to the same implementation.
59 //
60 // If |ContextType| is non-void, then every added binding must include a context
61 // value of that type, and |dispatch_context()| will return that value during
62 // the extent of any message dispatch targeting that specific binding.
63 template <typename Interface, typename BindingType, typename ContextType>
64 class BindingSetBase {
65  public:
66   using ContextTraits = BindingSetContextTraits<ContextType>;
67   using Context = typename ContextTraits::Type;
68   using PreDispatchCallback = base::Callback<void(const Context&)>;
69   using Traits = BindingSetTraits<BindingType>;
70   using ProxyType = typename Traits::ProxyType;
71   using RequestType = typename Traits::RequestType;
72   using ImplPointerType = typename Traits::ImplPointerType;
73 
74   BindingSetBase() : weak_ptr_factory_(this) {}
75 
76   void set_connection_error_handler(base::RepeatingClosure error_handler) {
77     error_handler_ = std::move(error_handler);
78     error_with_reason_handler_.Reset();
79   }
80 
81   void set_connection_error_with_reason_handler(
82       RepeatingConnectionErrorWithReasonCallback error_handler) {
83     error_with_reason_handler_ = std::move(error_handler);
84     error_handler_.Reset();
85   }
86 
87   // Sets a callback to be invoked immediately before dispatching any message or
88   // error received by any of the bindings in the set. This may only be used
89   // with a non-void |ContextType|.
90   void set_pre_dispatch_handler(const PreDispatchCallback& handler) {
91     static_assert(ContextTraits::SupportsContext(),
92                   "Pre-dispatch handler usage requires non-void context type.");
93     pre_dispatch_handler_ = handler;
94   }
95 
96   // Adds a new binding to the set which binds |request| to |impl| with no
97   // additional context.
98   BindingId AddBinding(ImplPointerType impl, RequestType request) {
99     static_assert(!ContextTraits::SupportsContext(),
100                   "Context value required for non-void context type.");
101     return AddBindingImpl(std::move(impl), std::move(request), false);
102   }
103 
104   // Adds a new binding associated with |context|.
105   BindingId AddBinding(ImplPointerType impl,
106                        RequestType request,
107                        Context context) {
108     static_assert(ContextTraits::SupportsContext(),
109                   "Context value unsupported for void context type.");
110     return AddBindingImpl(std::move(impl), std::move(request),
111                           std::move(context));
112   }
113 
114   // Removes a binding from the set. Note that this is safe to call even if the
115   // binding corresponding to |id| has already been removed.
116   //
117   // Returns |true| if the binding was removed and |false| if it didn't exist.
118   bool RemoveBinding(BindingId id) {
119     auto it = bindings_.find(id);
120     if (it == bindings_.end())
121       return false;
122     bindings_.erase(it);
123     return true;
124   }
125 
126   // Swaps the interface implementation with a different one, to allow tests
127   // to modify behavior.
128   //
129   // Returns the existing interface implementation to the caller.
130   ImplPointerType SwapImplForTesting(BindingId id, ImplPointerType new_impl) {
131     auto it = bindings_.find(id);
132     if (it == bindings_.end())
133       return nullptr;
134 
135     return it->second->SwapImplForTesting(new_impl);
136   }
137 
138   void CloseAllBindings() { bindings_.clear(); }
139 
140   bool empty() const { return bindings_.empty(); }
141 
142   size_t size() const { return bindings_.size(); }
143 
144   // Implementations may call this when processing a dispatched message or
145   // error. During the extent of message or error dispatch, this will return the
146   // context associated with the specific binding which received the message or
147   // error. Use AddBinding() to associated a context with a specific binding.
148   const Context& dispatch_context() const {
149     static_assert(ContextTraits::SupportsContext(),
150                   "dispatch_context() requires non-void context type.");
151     DCHECK(dispatch_context_);
152     return *dispatch_context_;
153   }
154 
155   // Implementations may call this when processing a dispatched message or
156   // error. During the extent of message or error dispatch, this will return the
157   // BindingId of the specific binding which received the message or error.
158   BindingId dispatch_binding() const {
159     DCHECK(dispatch_context_);
160     return dispatch_binding_;
161   }
162 
163   // Reports the currently dispatching Message as bad and closes the binding the
164   // message was received from. Note that this is only legal to call from
165   // directly within the stack frame of a message dispatch. If you need to do
166   // asynchronous work before you can determine the legitimacy of a message, use
167   // GetBadMessageCallback() and retain its result until you're ready to invoke
168   // or discard it.
169   void ReportBadMessage(const std::string& error) {
170     GetBadMessageCallback().Run(error);
171   }
172 
173   // Acquires a callback which may be run to report the currently dispatching
174   // Message as bad and close the binding the message was received from. Note
175   // that this is only legal to call from directly within the stack frame of a
176   // message dispatch, but the returned callback may be called exactly once any
177   // time thereafter as long as the binding set itself hasn't been destroyed yet
178   // to report the message as bad. This may only be called once per message.
179   // The returned callback must be called on the BindingSet's own sequence.
180   ReportBadMessageCallback GetBadMessageCallback() {
181     DCHECK(dispatch_context_);
182     return base::BindOnce(
183         [](ReportBadMessageCallback error_callback,
184            base::WeakPtr<BindingSetBase> binding_set, BindingId binding_id,
185            const std::string& error) {
186           std::move(error_callback).Run(error);
187           if (binding_set)
188             binding_set->RemoveBinding(binding_id);
189         },
190         mojo::GetBadMessageCallback(), weak_ptr_factory_.GetWeakPtr(),
191         dispatch_binding());
192   }
193 
194   void FlushForTesting() {
195     DCHECK(!is_flushing_);
196     is_flushing_ = true;
197     for (auto& binding : bindings_)
198       if (binding.second)
199         binding.second->FlushForTesting();
200     is_flushing_ = false;
201     // Clean up any bindings that were destroyed.
202     for (auto it = bindings_.begin(); it != bindings_.end();) {
203       if (!it->second)
204         it = bindings_.erase(it);
205       else
206         ++it;
207     }
208   }
209 
210  private:
211   friend class Entry;
212 
213   class Entry {
214    public:
215     Entry(ImplPointerType impl,
216           RequestType request,
217           BindingSetBase* binding_set,
218           BindingId binding_id,
219           Context context)
220         : binding_(std::move(impl), std::move(request)),
221           binding_set_(binding_set),
222           binding_id_(binding_id),
223           context_(std::move(context)) {
224       binding_.AddFilter(std::make_unique<DispatchFilter>(this));
225       binding_.set_connection_error_with_reason_handler(
226           base::BindOnce(&Entry::OnConnectionError, base::Unretained(this)));
227     }
228 
229     void FlushForTesting() { binding_.FlushForTesting(); }
230 
231     ImplPointerType SwapImplForTesting(ImplPointerType new_impl) {
232       return binding_.SwapImplForTesting(new_impl);
233     }
234 
235    private:
236     class DispatchFilter : public MessageReceiver {
237      public:
238       explicit DispatchFilter(Entry* entry) : entry_(entry) {}
239       ~DispatchFilter() override {}
240 
241      private:
242       // MessageReceiver:
243       bool Accept(Message* message) override {
244         entry_->WillDispatch();
245         return true;
246       }
247 
248       Entry* entry_;
249 
250       DISALLOW_COPY_AND_ASSIGN(DispatchFilter);
251     };
252 
253     void WillDispatch() {
254       binding_set_->SetDispatchContext(&context_, binding_id_);
255     }
256 
257     void OnConnectionError(uint32_t custom_reason,
258                            const std::string& description) {
259       WillDispatch();
260       binding_set_->OnConnectionError(binding_id_, custom_reason, description);
261     }
262 
263     BindingType binding_;
264     BindingSetBase* const binding_set_;
265     const BindingId binding_id_;
266     Context const context_;
267 
268     DISALLOW_COPY_AND_ASSIGN(Entry);
269   };
270 
271   void SetDispatchContext(const Context* context, BindingId binding_id) {
272     dispatch_context_ = context;
273     dispatch_binding_ = binding_id;
274     if (!pre_dispatch_handler_.is_null())
275       pre_dispatch_handler_.Run(*context);
276   }
277 
278   BindingId AddBindingImpl(ImplPointerType impl,
279                            RequestType request,
280                            Context context) {
281     BindingId id = next_binding_id_++;
282     DCHECK_GE(next_binding_id_, 0u);
283     auto entry = std::make_unique<Entry>(std::move(impl), std::move(request),
284                                          this, id, std::move(context));
285     bindings_.insert(std::make_pair(id, std::move(entry)));
286     return id;
287   }
288 
289   void OnConnectionError(BindingId id,
290                          uint32_t custom_reason,
291                          const std::string& description) {
292     auto it = bindings_.find(id);
293     DCHECK(it != bindings_.end());
294 
295     // We keep the Entry alive throughout error dispatch.
296     std::unique_ptr<Entry> entry = std::move(it->second);
297     if (!is_flushing_)
298       bindings_.erase(it);
299 
300     if (error_handler_) {
301       error_handler_.Run();
302     } else if (error_with_reason_handler_) {
303       error_with_reason_handler_.Run(custom_reason, description);
304     }
305   }
306 
307   base::RepeatingClosure error_handler_;
308   RepeatingConnectionErrorWithReasonCallback error_with_reason_handler_;
309   PreDispatchCallback pre_dispatch_handler_;
310   BindingId next_binding_id_ = 0;
311   std::map<BindingId, std::unique_ptr<Entry>> bindings_;
312   bool is_flushing_ = false;
313   const Context* dispatch_context_ = nullptr;
314   BindingId dispatch_binding_;
315   base::WeakPtrFactory<BindingSetBase> weak_ptr_factory_;
316 
317   DISALLOW_COPY_AND_ASSIGN(BindingSetBase);
318 };
319 
320 template <typename Interface, typename ContextType = void>
321 using BindingSet = BindingSetBase<Interface, Binding<Interface>, ContextType>;
322 
323 }  // namespace mojo
324 
325 #endif  // MOJO_PUBLIC_CPP_BINDINGS_BINDING_SET_H_
326