1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "transform/graph_ir/storage_format_config_factory.h"
18
19 #include <utility>
20
21 namespace mindspore::transform {
set_index_format(size_t index,const GetFormatFunc & func,const std::string & expend_dims)22 StorageFormatConfig &StorageFormatConfig::set_index_format(size_t index, const GetFormatFunc &func,
23 const std::string &expend_dims) {
24 StorageFormatInfo info;
25 info.expand_dims_ = expend_dims;
26 info.func_ = func;
27 auto ret = storage_format_infoes_.emplace(index + 1, info);
28 if (!ret.second) {
29 MS_LOG(ERROR) << "Set index format op type: " << op_type_ << ", index: " << index
30 << ", expand_dims: " << expend_dims << " failed.";
31 }
32 return *this;
33 }
34
GetStorageFormatInfo(size_t index)35 std::optional<StorageFormatInfo> StorageFormatConfig::GetStorageFormatInfo(size_t index) {
36 auto iter = storage_format_infoes_.find(index);
37 if (iter == storage_format_infoes_.end()) {
38 return std::nullopt;
39 }
40 return iter->second;
41 }
42
GetInstance()43 StorageFormatConfigRegister &StorageFormatConfigRegister::GetInstance() {
44 static StorageFormatConfigRegister inst;
45 return inst;
46 }
47
Register(const std::string & op_type)48 StorageFormatConfig &StorageFormatConfigRegister::Register(const std::string &op_type) {
49 auto iter = storage_format_configs_.find(op_type);
50 if (iter != storage_format_configs_.end()) {
51 return iter->second;
52 }
53 auto ret = storage_format_configs_.emplace(op_type, StorageFormatConfig(op_type));
54 if (!ret.second) {
55 MS_LOG(ERROR) << "Reg op failed: " << op_type;
56 }
57 return ret.first->second;
58 }
59
GetStorageFormatConfig(const std::string & op_type) const60 std::optional<StorageFormatConfig> StorageFormatConfigRegister::GetStorageFormatConfig(
61 const std::string &op_type) const {
62 auto iter = storage_format_configs_.find(op_type);
63 if (iter == storage_format_configs_.end()) {
64 return std::nullopt;
65 }
66 return iter->second;
67 }
68 } // namespace mindspore::transform
69