• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_MAILBOX_H
18 #define MINDSPORE_MAILBOX_H
19 #include <list>
20 #include <memory>
21 #include <mutex>
22 #include <condition_variable>
23 #include <functional>
24 #include <utility>
25 #include "actor/msg.h"
26 #include "thread/hqueue.h"
27 
28 namespace mindspore {
29 class MailBox {
30  public:
31   virtual ~MailBox() = default;
32   virtual int EnqueueMessage(std::unique_ptr<MessageBase> msg) = 0;
33   virtual std::list<std::unique_ptr<MessageBase>> *GetMsgs() = 0;
34   virtual std::unique_ptr<MessageBase> GetMsg() = 0;
SetNotifyHook(std::unique_ptr<std::function<void ()>> && hook)35   inline void SetNotifyHook(std::unique_ptr<std::function<void()>> &&hook) { notifyHook = std::move(hook); }
TakeAllMsgsEachTime()36   inline bool TakeAllMsgsEachTime() { return takeAllMsgsEachTime; }
SwapMailBox(std::list<std::unique_ptr<MessageBase>> ** box1,std::list<std::unique_ptr<MessageBase>> ** box2)37   void SwapMailBox(std::list<std::unique_ptr<MessageBase>> **box1, std::list<std::unique_ptr<MessageBase>> **box2) {
38     std::list<std::unique_ptr<MessageBase>> *tmp = *box1;
39     *box1 = *box2;
40     *box2 = tmp;
41   }
42 
43  protected:
44   // if this flag is true, GetMsgs() should be invoked to take all enqueued msgs each time, otherwise we can only get
45   // one msg by GetMsg() each time.
46   bool takeAllMsgsEachTime = true;
47   std::unique_ptr<std::function<void()>> notifyHook;
48 };
49 
50 class BlockingMailBox : public MailBox {
51  public:
BlockingMailBox()52   BlockingMailBox() : mailbox1(), mailbox2(), enqueMailBox(&mailbox1), dequeMailBox(&mailbox2) {}
~BlockingMailBox()53   virtual ~BlockingMailBox() {
54     mailbox1.clear();
55     mailbox2.clear();
56   }
57   int EnqueueMessage(std::unique_ptr<MessageBase> msg) override;
58   std::list<std::unique_ptr<MessageBase>> *GetMsgs() override;
GetMsg()59   std::unique_ptr<MessageBase> GetMsg() override { return nullptr; }
60 
61  private:
62   std::list<std::unique_ptr<MessageBase>> mailbox1;
63   std::list<std::unique_ptr<MessageBase>> mailbox2;
64   std::list<std::unique_ptr<MessageBase>> *enqueMailBox;
65   std::list<std::unique_ptr<MessageBase>> *dequeMailBox;
66   std::mutex lock;
67   std::condition_variable cond;
68 };
69 
70 class NonblockingMailBox : public MailBox {
71  public:
NonblockingMailBox()72   NonblockingMailBox() : mailbox1(), mailbox2(), enqueMailBox(&mailbox1), dequeMailBox(&mailbox2) {}
~NonblockingMailBox()73   virtual ~NonblockingMailBox() {
74     mailbox1.clear();
75     mailbox2.clear();
76   }
77   int EnqueueMessage(std::unique_ptr<MessageBase> msg) override;
78   std::list<std::unique_ptr<MessageBase>> *GetMsgs() override;
GetMsg()79   std::unique_ptr<MessageBase> GetMsg() override { return nullptr; }
80 
81  private:
82   std::list<std::unique_ptr<MessageBase>> mailbox1;
83   std::list<std::unique_ptr<MessageBase>> mailbox2;
84   std::list<std::unique_ptr<MessageBase>> *enqueMailBox;
85   std::list<std::unique_ptr<MessageBase>> *dequeMailBox;
86   std::mutex lock;
87   bool released_ = true;
88 };
89 
90 class HQueMailBox : public MailBox {
91  public:
HQueMailBox()92   HQueMailBox() { takeAllMsgsEachTime = false; }
Init()93   inline bool Init() { return mailbox.Init(MAX_MSG_QUE_SIZE); }
94   int EnqueueMessage(std::unique_ptr<MessageBase> msg) override;
GetMsgs()95   std::list<std::unique_ptr<MessageBase>> *GetMsgs() override { return nullptr; }
96   std::unique_ptr<MessageBase> GetMsg() override;
97 
98  private:
99   HQueue<MessageBase> mailbox;
100   static const int32_t MAX_MSG_QUE_SIZE = 4096;
101 };
102 
103 }  // namespace mindspore
104 
105 #endif  // MINDSPORE_MAILBOX_H
106