1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_ 18 #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_ 19 20 #include <memory> 21 #include <string> 22 #include <utility> 23 #include "fl/server/kernel/kernel_factory.h" 24 #include "fl/server/kernel/aggregation_kernel.h" 25 26 namespace mindspore { 27 namespace fl { 28 namespace server { 29 namespace kernel { 30 using AggregationKernelCreator = std::function<std::shared_ptr<AggregationKernel>()>; 31 class AggregationKernelFactory : public KernelFactory<std::shared_ptr<AggregationKernel>, AggregationKernelCreator> { 32 public: GetInstance()33 static AggregationKernelFactory &GetInstance() { 34 static AggregationKernelFactory instance; 35 return instance; 36 } 37 38 private: 39 AggregationKernelFactory() = default; 40 ~AggregationKernelFactory() override = default; 41 AggregationKernelFactory(const AggregationKernelFactory &) = delete; 42 AggregationKernelFactory &operator=(const AggregationKernelFactory &) = delete; 43 44 // Judge whether the server aggregation kernel can be created according to registered ParamsInfo. 45 bool Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) override; 46 }; 47 48 class AggregationKernelRegister { 49 public: AggregationKernelRegister(const std::string & name,const ParamsInfo & params_info,AggregationKernelCreator && creator)50 AggregationKernelRegister(const std::string &name, const ParamsInfo ¶ms_info, 51 AggregationKernelCreator &&creator) { 52 AggregationKernelFactory::GetInstance().Register(name, params_info, std::move(creator)); 53 } 54 ~AggregationKernelRegister() = default; 55 }; 56 57 // Register aggregation kernel with one template type T. 58 #define REG_AGGREGATION_KERNEL(NAME, PARAMS_INFO, CLASS, T) \ 59 static_assert(std::is_base_of<AggregationKernel, CLASS<T>>::value, " must be base of AggregationKernel"); \ 60 static const AggregationKernelRegister g_##NAME##_##T##_aggregation_kernel_reg( \ 61 #NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T>>(); }); 62 63 // Register aggregation kernel with two template types: T and S. 64 #define REG_AGGREGATION_KERNEL_TWO(NAME, PARAMS_INFO, CLASS, T, S) \ 65 static_assert(std::is_base_of<AggregationKernel, CLASS<T, S>>::value, " must be base of AggregationKernel"); \ 66 static const AggregationKernelRegister g_##NAME##_##T##_##S##_aggregation_kernel_reg( \ 67 #NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T, S>>(); }); 68 } // namespace kernel 69 } // namespace server 70 } // namespace fl 71 } // namespace mindspore 72 #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_ 73