1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_KERNEL_SELECT_CPU_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_KERNEL_SELECT_CPU_H_ 19 20 #include <utility> 21 #include <string> 22 #include <vector> 23 24 #include "ir/anf.h" 25 #include "ir/dtype/type.h" 26 #include "utils/utils.h" 27 28 namespace mindspore { 29 namespace device { 30 namespace cpu { 31 void SetKernelInfo(const CNodePtr &apply_kernel_ptr); 32 // Indicate whether the kernel input/output number are variable. 33 bool IsDynamicParamKernel(const std::string &op_name); 34 35 class KernelAttr { 36 public: 37 using DataType = std::pair<TypeId, std::string>; KernelAttr()38 KernelAttr() : all_same_(0) {} 39 ~KernelAttr() = default; 40 41 KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) { 42 input_type_.emplace_back(ms_type, format); 43 return *this; 44 } 45 46 KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) { 47 output_type_.emplace_back(ms_type, format); 48 return *this; 49 } 50 SetAllSameAttr(bool all_same)51 KernelAttr &SetAllSameAttr(bool all_same) { 52 all_same_ = all_same; 53 return *this; 54 } 55 GetInputAttr(const size_t index)56 const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; } GetOutputAttr(const size_t index)57 const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; } GetAllSame()58 bool GetAllSame() const { return all_same_; } 59 GetInputSize()60 size_t GetInputSize() const { return input_type_.size(); } GetOutputSize()61 size_t GetOutputSize() const { return output_type_.size(); } 62 63 private: 64 std::vector<DataType> input_type_; 65 std::vector<DataType> output_type_; 66 bool all_same_; 67 }; 68 } // namespace cpu 69 } // namespace device 70 } // namespace mindspore 71 72 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_KERNEL_SELECT_CPU_H_ 73