• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_MAILBOX_H
18 #define MINDSPORE_MAILBOX_H
19 #include <list>
20 #include <memory>
21 #include <mutex>
22 #include <condition_variable>
23 #include <functional>
24 #include <utility>
25 #include "actor/msg.h"
26 #include "thread/hqueue.h"
27 
28 namespace mindspore {
29 class MailBox {
30  public:
31   virtual ~MailBox() = default;
32   virtual int EnqueueMessage(std::unique_ptr<MessageBase> msg) = 0;
33   virtual std::list<std::unique_ptr<MessageBase>> *GetMsgs() = 0;
34   virtual std::unique_ptr<MessageBase> GetMsg() = 0;
SetNotifyHook(std::unique_ptr<std::function<void ()>> && hook)35   inline void SetNotifyHook(std::unique_ptr<std::function<void()>> &&hook) { notifyHook = std::move(hook); }
TakeAllMsgsEachTime()36   inline bool TakeAllMsgsEachTime() const { return takeAllMsgsEachTime; }
37 
38  protected:
39   // if this flag is true, GetMsgs() should be invoked to take all enqueued msgs each time, otherwise we can only get
40   // one msg by GetMsg() each time.
41   bool takeAllMsgsEachTime = true;
42   std::unique_ptr<std::function<void()>> notifyHook;
43 };
44 
45 class BlockingMailBox : public MailBox {
46  public:
BlockingMailBox()47   BlockingMailBox() : mailbox1(), mailbox2(), enqueMailBox(&mailbox1), dequeMailBox(&mailbox2) {}
~BlockingMailBox()48   virtual ~BlockingMailBox() {
49     mailbox1.clear();
50     mailbox2.clear();
51   }
52   int EnqueueMessage(std::unique_ptr<MessageBase> msg) override;
53   std::list<std::unique_ptr<MessageBase>> *GetMsgs() override;
GetMsg()54   std::unique_ptr<MessageBase> GetMsg() override { return nullptr; }
55 
56  private:
57   std::list<std::unique_ptr<MessageBase>> mailbox1;
58   std::list<std::unique_ptr<MessageBase>> mailbox2;
59   std::list<std::unique_ptr<MessageBase>> *enqueMailBox;
60   std::list<std::unique_ptr<MessageBase>> *dequeMailBox;
61   std::mutex lock;
62   std::condition_variable cond;
63 };
64 
65 class NonblockingMailBox : public MailBox {
66  public:
NonblockingMailBox()67   NonblockingMailBox() : mailbox1(), mailbox2(), enqueMailBox(&mailbox1), dequeMailBox(&mailbox2) {}
~NonblockingMailBox()68   virtual ~NonblockingMailBox() {
69     mailbox1.clear();
70     mailbox2.clear();
71   }
72   int EnqueueMessage(std::unique_ptr<MessageBase> msg) override;
73   std::list<std::unique_ptr<MessageBase>> *GetMsgs() override;
GetMsg()74   std::unique_ptr<MessageBase> GetMsg() override { return nullptr; }
75 
76  private:
77   std::list<std::unique_ptr<MessageBase>> mailbox1;
78   std::list<std::unique_ptr<MessageBase>> mailbox2;
79   std::list<std::unique_ptr<MessageBase>> *enqueMailBox;
80   std::list<std::unique_ptr<MessageBase>> *dequeMailBox;
81   std::mutex lock;
82   bool released_ = true;
83 };
84 
85 class HQueMailBox : public MailBox {
86  public:
HQueMailBox()87   HQueMailBox() { takeAllMsgsEachTime = false; }
Init()88   inline bool Init() { return mailbox.Init(MAX_MSG_QUE_SIZE); }
89   int EnqueueMessage(std::unique_ptr<MessageBase> msg) override;
GetMsgs()90   std::list<std::unique_ptr<MessageBase>> *GetMsgs() override { return nullptr; }
91   std::unique_ptr<MessageBase> GetMsg() override;
92 
93  private:
94   HQueue<MessageBase> mailbox;
95   static const int32_t MAX_MSG_QUE_SIZE = 4096;
96 };
97 }  // namespace mindspore
98 
99 #endif  // MINDSPORE_MAILBOX_H
100