1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PS_OPTIMIZER_INFO_BUILDER_H_ 18 #define MINDSPORE_CCSRC_PS_OPTIMIZER_INFO_BUILDER_H_ 19 20 #include <vector> 21 #include <memory> 22 #include <string> 23 #include "backend/kernel_compiler/kernel.h" 24 #include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" 25 #include "ps/optimizer_info.h" 26 27 namespace mindspore { 28 namespace ps { 29 using mindspore::kernel::KernelMod; 30 using mindspore::kernel::ps::PServerKernel; 31 class OptimizerInfoBuilder { 32 public: OptimizerInfoBuilder(size_t worker_num)33 explicit OptimizerInfoBuilder(size_t worker_num) : worker_num_(worker_num) {} 34 virtual ~OptimizerInfoBuilder() = default; 35 36 OptimizerInfo *Build(const std::shared_ptr<PServerKernel> &pserver_kernel, const WeightPtr &weight, const Keys &keys, 37 const Values &values, const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num, 38 bool sharded); 39 40 virtual OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, 41 const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num, 42 const std::shared_ptr<PServerKernel> &pserver_kernel, bool sharded) = 0; 43 44 virtual void BuildWorkspaces(OptimizerInfo *info, const std::vector<size_t> &ws_sizes, size_t worker_num); BuildOutputs(OptimizerInfo * info,size_t worker_num)45 virtual void BuildOutputs(OptimizerInfo *info, size_t worker_num) {} 46 47 protected: 48 template <typename T> 49 AddressPtr GenInputAddrPtr(const std::string &optim_type, const std::string &input_name, void *ps_data, 50 const Lengths &lens, const InputsShapePtr &inputs_shape = nullptr); 51 52 size_t worker_num_; 53 }; 54 55 class MomentumOptimInfoBuilder : public OptimizerInfoBuilder { 56 public: MomentumOptimInfoBuilder(size_t worker_num)57 explicit MomentumOptimInfoBuilder(size_t worker_num) : OptimizerInfoBuilder(worker_num) {} 58 ~MomentumOptimInfoBuilder() = default; 59 OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, 60 const InputsShapePtr &inputs_shape, size_t worker_num, 61 const std::shared_ptr<PServerKernel> &pserver_kernel, bool sharded) override; 62 }; 63 64 class SparseAdamOptimInfoBuilder : public OptimizerInfoBuilder { 65 public: SparseAdamOptimInfoBuilder(size_t worker_num)66 explicit SparseAdamOptimInfoBuilder(size_t worker_num) : OptimizerInfoBuilder(worker_num) {} 67 ~SparseAdamOptimInfoBuilder() = default; 68 OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, 69 const InputsShapePtr &inputs_shape, size_t worker_num, 70 const std::shared_ptr<PServerKernel> &pserver_kernel, bool sharded) override; 71 }; 72 73 class SparseFtrlOptimInfoBuilder : public OptimizerInfoBuilder { 74 public: SparseFtrlOptimInfoBuilder(size_t worker_num)75 explicit SparseFtrlOptimInfoBuilder(size_t worker_num) : OptimizerInfoBuilder(worker_num) {} 76 ~SparseFtrlOptimInfoBuilder() = default; 77 OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, 78 const InputsShapePtr &inputs_shape, size_t worker_num, 79 const std::shared_ptr<PServerKernel> &pserver_kernel, bool sharded) override; 80 }; 81 } // namespace ps 82 } // namespace mindspore 83 #endif // MINDSPORE_CCSRC_PS_OPTIMIZER_INFO_BUILDER_H_ 84