• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "transform/graph_ir/storage_format_config_factory.h"
18 
19 #include <utility>
20 
21 namespace mindspore::transform {
set_index_format(size_t index,const GetFormatFunc & func,const std::string & expend_dims)22 StorageFormatConfig &StorageFormatConfig::set_index_format(size_t index, const GetFormatFunc &func,
23                                                            const std::string &expend_dims) {
24   StorageFormatInfo info;
25   info.expand_dims_ = expend_dims;
26   info.func_ = func;
27   auto ret = storage_format_infoes_.emplace(index + 1, info);
28   if (!ret.second) {
29     MS_LOG(ERROR) << "Set index format op type: " << op_type_ << ", index: " << index
30                   << ", expand_dims: " << expend_dims << " failed.";
31   }
32   return *this;
33 }
34 
GetStorageFormatInfo(size_t index)35 std::optional<StorageFormatInfo> StorageFormatConfig::GetStorageFormatInfo(size_t index) {
36   auto iter = storage_format_infoes_.find(index);
37   if (iter == storage_format_infoes_.end()) {
38     return std::nullopt;
39   }
40   return iter->second;
41 }
42 
GetInstance()43 StorageFormatConfigRegister &StorageFormatConfigRegister::GetInstance() {
44   static StorageFormatConfigRegister inst;
45   return inst;
46 }
47 
Register(const std::string & op_type)48 StorageFormatConfig &StorageFormatConfigRegister::Register(const std::string &op_type) {
49   auto iter = storage_format_configs_.find(op_type);
50   if (iter != storage_format_configs_.end()) {
51     return iter->second;
52   }
53   auto ret = storage_format_configs_.emplace(op_type, StorageFormatConfig(op_type));
54   if (!ret.second) {
55     MS_LOG(ERROR) << "Reg op failed: " << op_type;
56   }
57   return ret.first->second;
58 }
59 
GetStorageFormatConfig(const std::string & op_type) const60 std::optional<StorageFormatConfig> StorageFormatConfigRegister::GetStorageFormatConfig(
61   const std::string &op_type) const {
62   auto iter = storage_format_configs_.find(op_type);
63   if (iter == storage_format_configs_.end()) {
64     return std::nullopt;
65   }
66   return iter->second;
67 }
68 }  // namespace mindspore::transform
69