1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_ 18 #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_ 19 20 #include <memory> 21 #include <string> 22 #include <utility> 23 #include "fl/server/kernel/kernel_factory.h" 24 #include "fl/server/kernel/optimizer_kernel.h" 25 26 namespace mindspore { 27 namespace fl { 28 namespace server { 29 namespace kernel { 30 using OptimizerKernelCreator = std::function<std::shared_ptr<OptimizerKernel>()>; 31 class OptimizerKernelFactory : public KernelFactory<std::shared_ptr<OptimizerKernel>, OptimizerKernelCreator> { 32 public: GetInstance()33 static OptimizerKernelFactory &GetInstance() { 34 static OptimizerKernelFactory instance; 35 return instance; 36 } 37 38 private: 39 OptimizerKernelFactory() = default; 40 ~OptimizerKernelFactory() override = default; 41 OptimizerKernelFactory(const OptimizerKernelFactory &) = delete; 42 OptimizerKernelFactory &operator=(const OptimizerKernelFactory &) = delete; 43 44 // Judge whether the server optimizer kernel can be created according to registered ParamsInfo. 45 bool Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) override; 46 }; 47 48 class OptimizerKernelRegister { 49 public: OptimizerKernelRegister(const std::string & name,const ParamsInfo & params_info,OptimizerKernelCreator && creator)50 OptimizerKernelRegister(const std::string &name, const ParamsInfo ¶ms_info, OptimizerKernelCreator &&creator) { 51 OptimizerKernelFactory::GetInstance().Register(name, params_info, std::move(creator)); 52 } 53 ~OptimizerKernelRegister() = default; 54 }; 55 56 // Register optimizer kernel with one template type T. 57 #define REG_OPTIMIZER_KERNEL(NAME, PARAMS_INFO, CLASS, T) \ 58 static_assert(std::is_base_of<OptimizerKernel, CLASS<T>>::value, " must be base of OptimizerKernel"); \ 59 static const OptimizerKernelRegister g_##NAME##_##T##_optimizer_kernel_reg( \ 60 #NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T>>(); }); 61 } // namespace kernel 62 } // namespace server 63 } // namespace fl 64 } // namespace mindspore 65 #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_ 66