• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_KERNEL_FACTORY_H_
18 #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_KERNEL_FACTORY_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include <utility>
24 #include <unordered_map>
25 #include "fl/server/common.h"
26 #include "fl/server/kernel/params_info.h"
27 
28 namespace mindspore {
29 namespace fl {
30 namespace server {
31 namespace kernel {
32 // KernelFactory is used to select and build kernels in server. It's the base class of OptimizerKernelFactory
33 // and AggregationKernelFactory.
34 
35 // Unlike normal MindSpore operator kernels, the server defines multiple types of kernels. For example: Aggregation
36 // Kernel, Optimizer Kernel, Forward Kernel, etc. So we define KernelFactory as a template class for register of all
37 // types of kernels.
38 
39 // Because most information we need to create a server kernel is in func_graph passed by the front end, we create a
40 // server kernel based on a cnode.
41 
42 // Typename K refers to the shared_ptr of the kernel type.
43 // Typename C refers to the creator function of the kernel.
44 template <typename K, typename C>
45 class KernelFactory {
46  public:
47   KernelFactory() = default;
48   virtual ~KernelFactory() = default;
49 
GetInstance()50   static KernelFactory &GetInstance() {
51     static KernelFactory instance;
52     return instance;
53   }
54 
55   // Kernels are registered by parameter information and its creator(constructor).
Register(const std::string & name,const ParamsInfo & params_info,C && creator)56   void Register(const std::string &name, const ParamsInfo &params_info, C &&creator) {
57     name_to_creator_map_[name].push_back(std::make_pair(params_info, creator));
58   }
59 
60   // The kernels in server are created from func_graph's kernel_node passed by the front end.
Create(const std::string & name,const CNodePtr & kernel_node)61   K Create(const std::string &name, const CNodePtr &kernel_node) {
62     if (name_to_creator_map_.count(name) == 0) {
63       MS_LOG(ERROR) << "Creating kernel failed: " << name << " is not registered.";
64     }
65     for (const auto &name_type_creator : name_to_creator_map_[name]) {
66       const ParamsInfo &params_info = name_type_creator.first;
67       const C &creator = name_type_creator.second;
68       if (Matched(params_info, kernel_node)) {
69         auto kernel = creator();
70         kernel->set_params_info(params_info);
71         return kernel;
72       }
73     }
74     return nullptr;
75   }
76 
77  private:
78   KernelFactory(const KernelFactory &) = delete;
79   KernelFactory &operator=(const KernelFactory &) = delete;
80 
81   // Judge whether the server kernel can be created according to registered ParamsInfo.
Matched(const ParamsInfo & params_info,const CNodePtr & kernel_node)82   virtual bool Matched(const ParamsInfo &params_info, const CNodePtr &kernel_node) { return true; }
83 
84   // Generally, a server kernel can correspond to several ParamsInfo which is registered by the method 'Register' in
85   // server kernel's *.cc files.
86   std::unordered_map<std::string, std::vector<std::pair<ParamsInfo, C>>> name_to_creator_map_;
87 };
88 }  // namespace kernel
89 }  // namespace server
90 }  // namespace fl
91 }  // namespace mindspore
92 #endif  // MINDSPORE_CCSRC_FL_SERVER_KERNEL_KERNEL_FACTORY_H_
93