• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_CONTROL_FLOW_ACTOR_SWITCH_ACTOR_H_
18 #define MINDSPORE_LITE_SRC_CONTROL_FLOW_ACTOR_SWITCH_ACTOR_H_
19 #include <vector>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <set>
24 #include <utility>
25 #include "src/litert/lite_mindrt.h"
26 
27 namespace mindspore::lite {
28 class LiteSwitchOpActor : public LiteOpActor {
29  public:
LiteSwitchOpActor(kernel::KernelExec * kernel,lite::InnerContext * ctx)30   explicit LiteSwitchOpActor(kernel::KernelExec *kernel, lite::InnerContext *ctx) : LiteOpActor(kernel, ctx) {}
~LiteSwitchOpActor()31   ~LiteSwitchOpActor() override {
32     delete switch_type_node_;
33     for (auto &partial_node : partial_nodes_) {
34       delete partial_node;
35     }
36   };
37   void RunOpData(OpData<Tensor> *inputs, OpContext<Tensor> *context = nullptr) override;
38   int CompileArrow(const std::unordered_map<void *, std::set<std::pair<AID, size_t>>> &receivers_map) override;
GetPartialKernels()39   std::set<kernel::KernelExec *> GetPartialKernels() const override {
40     std::set<kernel::KernelExec *> ret{};
41     for (auto &item : partial_nodes_) {
42       (void)ret.insert(item);
43     }
44     return ret;
45   }
46 
47  protected:
48   int UpdateActorOutput() override;
49   int PrepareOutputData() override;
50 
51  private:
52   STATUS AsyncBranchOutput(const size_t &index, OpContext<Tensor> *context);
53   void DecreaseOtherBranchInputTensor(const size_t &index);
54   int GetSwitchAndCallNode(kernel::SubGraphKernel *subgraph_kernel);
55   void AppendOutputTensors();
56   int CompileArrowThroughSwitchCall(const std::unordered_map<void *, std::set<std::pair<AID, size_t>>> &receivers_map);
57   int CreateSwitchTypeArrow(const std::unordered_map<void *, std::set<std::pair<AID, size_t>>> &receivers_map,
58                             const std::set<void *> &receiver_tensors, const Tensor *partial_in_tensor,
59                             std::vector<DataArrowPtr> *branch_output_data_arrows);
60   int ModifySubgraphKernel();
61   int SetSwitchPartialNodes();
62   int SetSwitchLayerPartialNodes();
63 
64   // each element is a set of data arrow sent to the next target actor.
65   std::vector<std::vector<DataArrowPtr>> all_branch_output_data_arrows_;
66 
67   std::vector<kernel::KernelExec *> partial_nodes_{};
68   kernel::KernelExec *switch_type_node_ = nullptr;
69 
70   // each element is a set of output data which is going to be send to the next target actor.
71   std::vector<std::vector<OpDataPtr<Tensor>>> all_branchs_output_data_;
72 };
73 }  // namespace mindspore::lite
74 #endif  // MINDSPORE_LITE_SRC_CONTROL_FLOW_ACTOR_SWITCH_ACTOR_H_
75