1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_CONTROL_FLOW_ACTOR_SWITCH_ACTOR_H_ 18 #define MINDSPORE_LITE_SRC_CONTROL_FLOW_ACTOR_SWITCH_ACTOR_H_ 19 #include <vector> 20 #include <memory> 21 #include <string> 22 #include <unordered_map> 23 #include <set> 24 #include <utility> 25 #include "src/litert/lite_mindrt.h" 26 27 namespace mindspore::lite { 28 class LiteSwitchOpActor : public LiteOpActor { 29 public: LiteSwitchOpActor(kernel::KernelExec * kernel,lite::InnerContext * ctx)30 explicit LiteSwitchOpActor(kernel::KernelExec *kernel, lite::InnerContext *ctx) : LiteOpActor(kernel, ctx) {} ~LiteSwitchOpActor()31 ~LiteSwitchOpActor() override { 32 delete switch_type_node_; 33 for (auto &partial_node : partial_nodes_) { 34 delete partial_node; 35 } 36 }; 37 void RunOpData(OpData<Tensor> *inputs, OpContext<Tensor> *context = nullptr) override; 38 int CompileArrow(const std::unordered_map<void *, std::set<std::pair<AID, size_t>>> &receivers_map) override; GetPartialKernels()39 std::set<kernel::KernelExec *> GetPartialKernels() const override { 40 std::set<kernel::KernelExec *> ret{}; 41 for (auto &item : partial_nodes_) { 42 (void)ret.insert(item); 43 } 44 return ret; 45 } 46 47 protected: 48 int UpdateActorOutput() override; 49 int PrepareOutputData() override; 50 51 private: 52 STATUS AsyncBranchOutput(const size_t &index, OpContext<Tensor> *context); 53 void DecreaseOtherBranchInputTensor(const size_t &index); 54 int GetSwitchAndCallNode(kernel::SubGraphKernel *subgraph_kernel); 55 void AppendOutputTensors(); 56 int CompileArrowThroughSwitchCall(const std::unordered_map<void *, std::set<std::pair<AID, size_t>>> &receivers_map); 57 int CreateSwitchTypeArrow(const std::unordered_map<void *, std::set<std::pair<AID, size_t>>> &receivers_map, 58 const std::set<void *> &receiver_tensors, const Tensor *partial_in_tensor, 59 std::vector<DataArrowPtr> *branch_output_data_arrows); 60 int ModifySubgraphKernel(); 61 int SetSwitchPartialNodes(); 62 int SetSwitchLayerPartialNodes(); 63 64 // each element is a set of data arrow sent to the next target actor. 65 std::vector<std::vector<DataArrowPtr>> all_branch_output_data_arrows_; 66 67 std::vector<kernel::KernelExec *> partial_nodes_{}; 68 kernel::KernelExec *switch_type_node_ = nullptr; 69 70 // each element is a set of output data which is going to be send to the next target actor. 71 std::vector<std::vector<OpDataPtr<Tensor>>> all_branchs_output_data_; 72 }; 73 } // namespace mindspore::lite 74 #endif // MINDSPORE_LITE_SRC_CONTROL_FLOW_ACTOR_SWITCH_ACTOR_H_ 75