1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_STROAGE_FORMAT_CONFIG_FACTORY_H_ 17 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_STROAGE_FORMAT_CONFIG_FACTORY_H_ 18 #include <cstddef> 19 #include <string> 20 #include <utility> 21 #include <vector> 22 #include <memory> 23 #include <map> 24 #include <optional> 25 26 #include "include/transform/graph_ir/types.h" 27 28 namespace mindspore::transform { 29 using GetFormatFunc = 30 std::function<std::optional<std::string>(const AnfNodePtr &, const std::shared_ptr<GeTensorDesc> &)>; 31 32 struct StorageFormatInfo { 33 std::string expand_dims_; 34 GetFormatFunc func_; 35 }; 36 37 class StorageFormatConfig { 38 public: StorageFormatConfig(std::string op_type)39 explicit StorageFormatConfig(std::string op_type) : op_type_(std::move(op_type)) {} 40 ~StorageFormatConfig() = default; 41 StorageFormatConfig &set_index_format(size_t index, const GetFormatFunc &func, const std::string &expend_dims = ""); 42 std::optional<StorageFormatInfo> GetStorageFormatInfo(size_t index); 43 44 private: 45 std::string op_type_; 46 std::map<size_t, StorageFormatInfo> storage_format_infoes_{}; 47 }; 48 49 class StorageFormatConfigRegister { 50 public: 51 static StorageFormatConfigRegister &GetInstance(); 52 StorageFormatConfig &Register(const std::string &op_type); 53 [[nodiscard]] std::optional<StorageFormatConfig> GetStorageFormatConfig(const std::string &op_type) const; 54 55 private: 56 StorageFormatConfigRegister() = default; 57 ~StorageFormatConfigRegister() = default; 58 std::map<std::string, StorageFormatConfig> storage_format_configs_; 59 }; 60 61 #define REGISTER_STORAGE_FORMAT_CONFIG_IMPL(ctr, name) \ 62 static transform::StorageFormatConfig ®ister_acl##name##ctr = \ 63 StorageFormatConfigRegister::GetInstance().Register(#name) 64 65 #define REGISTER_STORAGE_FORMAT_CONFIG(name) REGISTER_STORAGE_FORMAT_CONFIG_IMPL(__COUNTER__, name) 66 } // namespace mindspore::transform 67 68 #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_STROAGE_FORMAT_CONFIG_FACTORY_H_ 69