• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
18 #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
19 
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include "fl/server/kernel/kernel_factory.h"
24 #include "fl/server/kernel/aggregation_kernel.h"
25 
26 namespace mindspore {
27 namespace fl {
28 namespace server {
29 namespace kernel {
30 using AggregationKernelCreator = std::function<std::shared_ptr<AggregationKernel>()>;
31 class AggregationKernelFactory : public KernelFactory<std::shared_ptr<AggregationKernel>, AggregationKernelCreator> {
32  public:
GetInstance()33   static AggregationKernelFactory &GetInstance() {
34     static AggregationKernelFactory instance;
35     return instance;
36   }
37 
38  private:
39   AggregationKernelFactory() = default;
40   ~AggregationKernelFactory() override = default;
41   AggregationKernelFactory(const AggregationKernelFactory &) = delete;
42   AggregationKernelFactory &operator=(const AggregationKernelFactory &) = delete;
43 
44   // Judge whether the server aggregation kernel can be created according to registered ParamsInfo.
45   bool Matched(const ParamsInfo &params_info, const CNodePtr &kernel_node) override;
46 };
47 
48 class AggregationKernelRegister {
49  public:
AggregationKernelRegister(const std::string & name,const ParamsInfo & params_info,AggregationKernelCreator && creator)50   AggregationKernelRegister(const std::string &name, const ParamsInfo &params_info,
51                             AggregationKernelCreator &&creator) {
52     AggregationKernelFactory::GetInstance().Register(name, params_info, std::move(creator));
53   }
54   ~AggregationKernelRegister() = default;
55 };
56 
57 // Register aggregation kernel with one template type T.
58 #define REG_AGGREGATION_KERNEL(NAME, PARAMS_INFO, CLASS, T)                                                 \
59   static_assert(std::is_base_of<AggregationKernel, CLASS<T>>::value, " must be base of AggregationKernel"); \
60   static const AggregationKernelRegister g_##NAME##_##T##_aggregation_kernel_reg(                           \
61     #NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T>>(); });
62 
63 // Register aggregation kernel with two template types: T and S.
64 #define REG_AGGREGATION_KERNEL_TWO(NAME, PARAMS_INFO, CLASS, T, S)                                             \
65   static_assert(std::is_base_of<AggregationKernel, CLASS<T, S>>::value, " must be base of AggregationKernel"); \
66   static const AggregationKernelRegister g_##NAME##_##T##_##S##_aggregation_kernel_reg(                        \
67     #NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T, S>>(); });
68 }  // namespace kernel
69 }  // namespace server
70 }  // namespace fl
71 }  // namespace mindspore
72 #endif  // MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
73