1 // Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information.
2
3 #pragma once
4
5 /*! \file rx-group_by.hpp
6
7 \brief Return an observable that emits grouped_observables, each of which corresponds to a unique key value and each of which emits those items from the source observable that share that key value.
8
9 \tparam KeySelector the type of the key extracting function
10 \tparam MarbleSelector the type of the element extracting function
11 \tparam BinaryPredicate the type of the key comparing function
12 \tparam DurationSelector the type of the duration observable function
13
14 \param ks a function that extracts the key for each item (optional)
15 \param ms a function that extracts the return element for each item (optional)
16 \param p a function that implements comparison of two keys (optional)
17
18 \return Observable that emits values of grouped_observable type, each of which corresponds to a unique key value and each of which emits those items from the source observable that share that key value.
19
20 \sample
21 \snippet group_by.cpp group_by full intro
22 \snippet group_by.cpp group_by full sample
23 \snippet output.txt group_by full sample
24
25 \sample
26 \snippet group_by.cpp group_by sample
27 \snippet output.txt group_by sample
28 */
29
30 #if !defined(RXCPP_OPERATORS_RX_GROUP_BY_HPP)
31 #define RXCPP_OPERATORS_RX_GROUP_BY_HPP
32
33 #include "../rx-includes.hpp"
34
35 namespace rxcpp {
36
37 namespace operators {
38
39 namespace detail {
40
41 template<class... AN>
42 struct group_by_invalid_arguments {};
43
44 template<class... AN>
45 struct group_by_invalid : public rxo::operator_base<group_by_invalid_arguments<AN...>> {
46 using type = observable<group_by_invalid_arguments<AN...>, group_by_invalid<AN...>>;
47 };
48 template<class... AN>
49 using group_by_invalid_t = typename group_by_invalid<AN...>::type;
50
51 template<class T, class Selector>
52 struct is_group_by_selector_for {
53
54 typedef rxu::decay_t<Selector> selector_type;
55 typedef T source_value_type;
56
57 struct tag_not_valid {};
58 template<class CV, class CS>
59 static auto check(int) -> decltype((*(CS*)nullptr)(*(CV*)nullptr));
60 template<class CV, class CS>
61 static tag_not_valid check(...);
62
63 typedef decltype(check<source_value_type, selector_type>(0)) type;
64 static const bool value = !std::is_same<type, tag_not_valid>::value;
65 };
66
67 template<class T, class Observable, class KeySelector, class MarbleSelector, class BinaryPredicate, class DurationSelector>
68 struct group_by_traits
69 {
70 typedef T source_value_type;
71 typedef rxu::decay_t<Observable> source_type;
72 typedef rxu::decay_t<KeySelector> key_selector_type;
73 typedef rxu::decay_t<MarbleSelector> marble_selector_type;
74 typedef rxu::decay_t<BinaryPredicate> predicate_type;
75 typedef rxu::decay_t<DurationSelector> duration_selector_type;
76
77 static_assert(is_group_by_selector_for<source_value_type, key_selector_type>::value, "group_by KeySelector must be a function with the signature key_type(source_value_type)");
78
79 typedef typename is_group_by_selector_for<source_value_type, key_selector_type>::type key_type;
80
81 static_assert(is_group_by_selector_for<source_value_type, marble_selector_type>::value, "group_by MarbleSelector must be a function with the signature marble_type(source_value_type)");
82
83 typedef typename is_group_by_selector_for<source_value_type, marble_selector_type>::type marble_type;
84
85 typedef rxsub::subject<marble_type> subject_type;
86
87 typedef std::map<key_type, typename subject_type::subscriber_type, predicate_type> key_subscriber_map_type;
88
89 typedef grouped_observable<key_type, marble_type> grouped_observable_type;
90 };
91
92 template<class T, class Observable, class KeySelector, class MarbleSelector, class BinaryPredicate, class DurationSelector>
93 struct group_by
94 {
95 typedef group_by_traits<T, Observable, KeySelector, MarbleSelector, BinaryPredicate, DurationSelector> traits_type;
96 typedef typename traits_type::key_selector_type key_selector_type;
97 typedef typename traits_type::marble_selector_type marble_selector_type;
98 typedef typename traits_type::marble_type marble_type;
99 typedef typename traits_type::predicate_type predicate_type;
100 typedef typename traits_type::duration_selector_type duration_selector_type;
101 typedef typename traits_type::subject_type subject_type;
102 typedef typename traits_type::key_type key_type;
103
104 typedef typename traits_type::key_subscriber_map_type group_map_type;
105 typedef std::vector<typename composite_subscription::weak_subscription> bindings_type;
106
107 struct group_by_state_type
108 {
group_by_state_typerxcpp::operators::detail::group_by::group_by_state_type109 group_by_state_type(composite_subscription sl, predicate_type p)
110 : source_lifetime(sl)
111 , groups(p)
112 , observers(0)
113 {}
114 composite_subscription source_lifetime;
115 rxsc::worker worker;
116 group_map_type groups;
117 std::atomic<int> observers;
118 };
119
120 template<class Subscriber>
stopsourcerxcpp::operators::detail::group_by121 static void stopsource(Subscriber&& dest, std::shared_ptr<group_by_state_type>& state) {
122 ++state->observers;
123 dest.add([state](){
124 if (!state->source_lifetime.is_subscribed()) {
125 return;
126 }
127 --state->observers;
128 if (state->observers == 0) {
129 state->source_lifetime.unsubscribe();
130 }
131 });
132 }
133
134 struct group_by_values
135 {
group_by_valuesrxcpp::operators::detail::group_by::group_by_values136 group_by_values(key_selector_type ks, marble_selector_type ms, predicate_type p, duration_selector_type ds)
137 : keySelector(std::move(ks))
138 , marbleSelector(std::move(ms))
139 , predicate(std::move(p))
140 , durationSelector(std::move(ds))
141 {
142 }
143 mutable key_selector_type keySelector;
144 mutable marble_selector_type marbleSelector;
145 mutable predicate_type predicate;
146 mutable duration_selector_type durationSelector;
147 };
148
149 group_by_values initial;
150
group_byrxcpp::operators::detail::group_by151 group_by(key_selector_type ks, marble_selector_type ms, predicate_type p, duration_selector_type ds)
152 : initial(std::move(ks), std::move(ms), std::move(p), std::move(ds))
153 {
154 }
155
156 struct group_by_observable : public rxs::source_base<marble_type>
157 {
158 mutable std::shared_ptr<group_by_state_type> state;
159 subject_type subject;
160 key_type key;
161
group_by_observablerxcpp::operators::detail::group_by::group_by_observable162 group_by_observable(std::shared_ptr<group_by_state_type> st, subject_type s, key_type k)
163 : state(std::move(st))
164 , subject(std::move(s))
165 , key(k)
166 {
167 }
168
169 template<class Subscriber>
on_subscriberxcpp::operators::detail::group_by::group_by_observable170 void on_subscribe(Subscriber&& o) const {
171 group_by::stopsource(o, state);
172 subject.get_observable().subscribe(std::forward<Subscriber>(o));
173 }
174
on_get_keyrxcpp::operators::detail::group_by::group_by_observable175 key_type on_get_key() {
176 return key;
177 }
178 };
179
180 template<class Subscriber>
181 struct group_by_observer : public group_by_values
182 {
183 typedef group_by_observer<Subscriber> this_type;
184 typedef typename traits_type::grouped_observable_type value_type;
185 typedef rxu::decay_t<Subscriber> dest_type;
186 typedef observer<T, this_type> observer_type;
187
188 dest_type dest;
189
190 mutable std::shared_ptr<group_by_state_type> state;
191
group_by_observerrxcpp::operators::detail::group_by::group_by_observer192 group_by_observer(composite_subscription l, dest_type d, group_by_values v)
193 : group_by_values(v)
194 , dest(std::move(d))
195 , state(std::make_shared<group_by_state_type>(l, group_by_values::predicate))
196 {
197 group_by::stopsource(dest, state);
198 }
on_nextrxcpp::operators::detail::group_by::group_by_observer199 void on_next(T v) const {
200 auto selectedKey = on_exception(
201 [&](){
202 return this->keySelector(v);},
203 [this](rxu::error_ptr e){on_error(e);});
204 if (selectedKey.empty()) {
205 return;
206 }
207 auto g = state->groups.find(selectedKey.get());
208 if (g == state->groups.end()) {
209 if (!dest.is_subscribed()) {
210 return;
211 }
212 auto sub = subject_type();
213 g = state->groups.insert(std::make_pair(selectedKey.get(), sub.get_subscriber())).first;
214 auto obs = make_dynamic_grouped_observable<key_type, marble_type>(group_by_observable(state, sub, selectedKey.get()));
215 auto durationObs = on_exception(
216 [&](){
217 return this->durationSelector(obs);},
218 [this](rxu::error_ptr e){on_error(e);});
219 if (durationObs.empty()) {
220 return;
221 }
222
223 dest.on_next(obs);
224 composite_subscription duration_sub;
225 auto ssub = state->source_lifetime.add(duration_sub);
226
227 auto expire_state = state;
228 auto expire_dest = g->second;
229 auto expire = [=]() {
230 auto g = expire_state->groups.find(selectedKey.get());
231 if (g != expire_state->groups.end()) {
232 expire_state->groups.erase(g);
233 expire_dest.on_completed();
234 }
235 expire_state->source_lifetime.remove(ssub);
236 };
237 auto robs = durationObs.get().take(1);
238 duration_sub.add(robs.subscribe(
239 [](const typename decltype(robs)::value_type &){},
240 [=](rxu::error_ptr) {expire();},
241 [=](){expire();}
242 ));
243 }
244 auto selectedMarble = on_exception(
245 [&](){
246 return this->marbleSelector(v);},
247 [this](rxu::error_ptr e){on_error(e);});
248 if (selectedMarble.empty()) {
249 return;
250 }
251 g->second.on_next(std::move(selectedMarble.get()));
252 }
on_errorrxcpp::operators::detail::group_by::group_by_observer253 void on_error(rxu::error_ptr e) const {
254 for(auto& g : state->groups) {
255 g.second.on_error(e);
256 }
257 dest.on_error(e);
258 }
on_completedrxcpp::operators::detail::group_by::group_by_observer259 void on_completed() const {
260 for(auto& g : state->groups) {
261 g.second.on_completed();
262 }
263 dest.on_completed();
264 }
265
makerxcpp::operators::detail::group_by::group_by_observer266 static subscriber<T, observer_type> make(dest_type d, group_by_values v) {
267 auto cs = composite_subscription();
268 return make_subscriber<T>(cs, observer_type(this_type(cs, std::move(d), std::move(v))));
269 }
270 };
271
272 template<class Subscriber>
operator ()rxcpp::operators::detail::group_by273 auto operator()(Subscriber dest) const
274 -> decltype(group_by_observer<Subscriber>::make(std::move(dest), initial)) {
275 return group_by_observer<Subscriber>::make(std::move(dest), initial);
276 }
277 };
278
279 template<class KeySelector, class MarbleSelector, class BinaryPredicate, class DurationSelector>
280 class group_by_factory
281 {
282 typedef rxu::decay_t<KeySelector> key_selector_type;
283 typedef rxu::decay_t<MarbleSelector> marble_selector_type;
284 typedef rxu::decay_t<BinaryPredicate> predicate_type;
285 typedef rxu::decay_t<DurationSelector> duration_selector_type;
286 key_selector_type keySelector;
287 marble_selector_type marbleSelector;
288 predicate_type predicate;
289 duration_selector_type durationSelector;
290 public:
group_by_factory(key_selector_type ks,marble_selector_type ms,predicate_type p,duration_selector_type ds)291 group_by_factory(key_selector_type ks, marble_selector_type ms, predicate_type p, duration_selector_type ds)
292 : keySelector(std::move(ks))
293 , marbleSelector(std::move(ms))
294 , predicate(std::move(p))
295 , durationSelector(std::move(ds))
296 {
297 }
298 template<class Observable>
299 struct group_by_factory_traits
300 {
301 typedef rxu::value_type_t<rxu::decay_t<Observable>> value_type;
302 typedef detail::group_by_traits<value_type, Observable, KeySelector, MarbleSelector, BinaryPredicate, DurationSelector> traits_type;
303 typedef detail::group_by<value_type, Observable, KeySelector, MarbleSelector, BinaryPredicate, DurationSelector> group_by_type;
304 };
305 template<class Observable>
operator ()(Observable && source)306 auto operator()(Observable&& source)
307 -> decltype(source.template lift<typename group_by_factory_traits<Observable>::traits_type::grouped_observable_type>(typename group_by_factory_traits<Observable>::group_by_type(std::move(keySelector), std::move(marbleSelector), std::move(predicate), std::move(durationSelector)))) {
308 return source.template lift<typename group_by_factory_traits<Observable>::traits_type::grouped_observable_type>(typename group_by_factory_traits<Observable>::group_by_type(std::move(keySelector), std::move(marbleSelector), std::move(predicate), std::move(durationSelector)));
309 }
310 };
311
312 }
313
314 /*! @copydoc rx-group_by.hpp
315 */
316 template<class... AN>
group_by(AN &&...an)317 auto group_by(AN&&... an)
318 -> operator_factory<group_by_tag, AN...> {
319 return operator_factory<group_by_tag, AN...>(std::make_tuple(std::forward<AN>(an)...));
320 }
321
322 }
323
324 template<>
325 struct member_overload<group_by_tag>
326 {
327 template<class Observable, class KeySelector, class MarbleSelector, class BinaryPredicate, class DurationSelector,
328 class SourceValue = rxu::value_type_t<Observable>,
329 class Traits = rxo::detail::group_by_traits<SourceValue, rxu::decay_t<Observable>, KeySelector, MarbleSelector, BinaryPredicate, DurationSelector>,
330 class GroupBy = rxo::detail::group_by<SourceValue, rxu::decay_t<Observable>, rxu::decay_t<KeySelector>, rxu::decay_t<MarbleSelector>, rxu::decay_t<BinaryPredicate>, rxu::decay_t<DurationSelector>>,
331 class Value = typename Traits::grouped_observable_type>
memberrxcpp::member_overload332 static auto member(Observable&& o, KeySelector&& ks, MarbleSelector&& ms, BinaryPredicate&& p, DurationSelector&& ds)
333 -> decltype(o.template lift<Value>(GroupBy(std::forward<KeySelector>(ks), std::forward<MarbleSelector>(ms), std::forward<BinaryPredicate>(p), std::forward<DurationSelector>(ds)))) {
334 return o.template lift<Value>(GroupBy(std::forward<KeySelector>(ks), std::forward<MarbleSelector>(ms), std::forward<BinaryPredicate>(p), std::forward<DurationSelector>(ds)));
335 }
336
337 template<class Observable, class KeySelector, class MarbleSelector, class BinaryPredicate,
338 class DurationSelector=rxu::ret<observable<int, rxs::detail::never<int>>>,
339 class SourceValue = rxu::value_type_t<Observable>,
340 class Traits = rxo::detail::group_by_traits<SourceValue, rxu::decay_t<Observable>, KeySelector, MarbleSelector, BinaryPredicate, DurationSelector>,
341 class GroupBy = rxo::detail::group_by<SourceValue, rxu::decay_t<Observable>, rxu::decay_t<KeySelector>, rxu::decay_t<MarbleSelector>, rxu::decay_t<BinaryPredicate>, rxu::decay_t<DurationSelector>>,
342 class Value = typename Traits::grouped_observable_type>
memberrxcpp::member_overload343 static auto member(Observable&& o, KeySelector&& ks, MarbleSelector&& ms, BinaryPredicate&& p)
344 -> decltype(o.template lift<Value>(GroupBy(std::forward<KeySelector>(ks), std::forward<MarbleSelector>(ms), std::forward<BinaryPredicate>(p), rxu::ret<observable<int, rxs::detail::never<int>>>()))) {
345 return o.template lift<Value>(GroupBy(std::forward<KeySelector>(ks), std::forward<MarbleSelector>(ms), std::forward<BinaryPredicate>(p), rxu::ret<observable<int, rxs::detail::never<int>>>()));
346 }
347
348 template<class Observable, class KeySelector, class MarbleSelector,
349 class BinaryPredicate=rxu::less,
350 class DurationSelector=rxu::ret<observable<int, rxs::detail::never<int>>>,
351 class SourceValue = rxu::value_type_t<Observable>,
352 class Traits = rxo::detail::group_by_traits<SourceValue, rxu::decay_t<Observable>, KeySelector, MarbleSelector, BinaryPredicate, DurationSelector>,
353 class GroupBy = rxo::detail::group_by<SourceValue, rxu::decay_t<Observable>, rxu::decay_t<KeySelector>, rxu::decay_t<MarbleSelector>, rxu::decay_t<BinaryPredicate>, rxu::decay_t<DurationSelector>>,
354 class Value = typename Traits::grouped_observable_type>
memberrxcpp::member_overload355 static auto member(Observable&& o, KeySelector&& ks, MarbleSelector&& ms)
356 -> decltype(o.template lift<Value>(GroupBy(std::forward<KeySelector>(ks), std::forward<MarbleSelector>(ms), rxu::less(), rxu::ret<observable<int, rxs::detail::never<int>>>()))) {
357 return o.template lift<Value>(GroupBy(std::forward<KeySelector>(ks), std::forward<MarbleSelector>(ms), rxu::less(), rxu::ret<observable<int, rxs::detail::never<int>>>()));
358 }
359
360
361 template<class Observable, class KeySelector,
362 class MarbleSelector=rxu::detail::take_at<0>,
363 class BinaryPredicate=rxu::less,
364 class DurationSelector=rxu::ret<observable<int, rxs::detail::never<int>>>,
365 class SourceValue = rxu::value_type_t<Observable>,
366 class Traits = rxo::detail::group_by_traits<SourceValue, rxu::decay_t<Observable>, KeySelector, MarbleSelector, BinaryPredicate, DurationSelector>,
367 class GroupBy = rxo::detail::group_by<SourceValue, rxu::decay_t<Observable>, rxu::decay_t<KeySelector>, rxu::decay_t<MarbleSelector>, rxu::decay_t<BinaryPredicate>, rxu::decay_t<DurationSelector>>,
368 class Value = typename Traits::grouped_observable_type>
memberrxcpp::member_overload369 static auto member(Observable&& o, KeySelector&& ks)
370 -> decltype(o.template lift<Value>(GroupBy(std::forward<KeySelector>(ks), rxu::detail::take_at<0>(), rxu::less(), rxu::ret<observable<int, rxs::detail::never<int>>>()))) {
371 return o.template lift<Value>(GroupBy(std::forward<KeySelector>(ks), rxu::detail::take_at<0>(), rxu::less(), rxu::ret<observable<int, rxs::detail::never<int>>>()));
372 }
373
374 template<class Observable,
375 class KeySelector=rxu::detail::take_at<0>,
376 class MarbleSelector=rxu::detail::take_at<0>,
377 class BinaryPredicate=rxu::less,
378 class DurationSelector=rxu::ret<observable<int, rxs::detail::never<int>>>,
379 class Enabled = rxu::enable_if_all_true_type_t<
380 all_observables<Observable>>,
381 class SourceValue = rxu::value_type_t<Observable>,
382 class Traits = rxo::detail::group_by_traits<SourceValue, rxu::decay_t<Observable>, KeySelector, MarbleSelector, BinaryPredicate, DurationSelector>,
383 class GroupBy = rxo::detail::group_by<SourceValue, rxu::decay_t<Observable>, rxu::decay_t<KeySelector>, rxu::decay_t<MarbleSelector>, rxu::decay_t<BinaryPredicate>, rxu::decay_t<DurationSelector>>,
384 class Value = typename Traits::grouped_observable_type>
memberrxcpp::member_overload385 static auto member(Observable&& o)
386 -> decltype(o.template lift<Value>(GroupBy(rxu::detail::take_at<0>(), rxu::detail::take_at<0>(), rxu::less(), rxu::ret<observable<int, rxs::detail::never<int>>>()))) {
387 return o.template lift<Value>(GroupBy(rxu::detail::take_at<0>(), rxu::detail::take_at<0>(), rxu::less(), rxu::ret<observable<int, rxs::detail::never<int>>>()));
388 }
389
390 template<class... AN>
memberrxcpp::member_overload391 static operators::detail::group_by_invalid_t<AN...> member(const AN&...) {
392 std::terminate();
393 return {};
394 static_assert(sizeof...(AN) == 10000, "group_by takes (optional KeySelector, optional MarbleSelector, optional BinaryKeyPredicate, optional DurationSelector), KeySelector takes (Observable::value_type) -> KeyValue, MarbleSelector takes (Observable::value_type) -> MarbleValue, BinaryKeyPredicate takes (KeyValue, KeyValue) -> bool, DurationSelector takes (Observable::value_type) -> Observable");
395 }
396
397 };
398
399 }
400
401 #endif
402
403