• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_STROAGE_FORMAT_CONFIG_FACTORY_H_
17 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_STROAGE_FORMAT_CONFIG_FACTORY_H_
18 #include <cstddef>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 #include <memory>
23 #include <map>
24 #include <optional>
25 
26 #include "include/transform/graph_ir/types.h"
27 
28 namespace mindspore::transform {
29 using GetFormatFunc =
30   std::function<std::optional<std::string>(const AnfNodePtr &, const std::shared_ptr<GeTensorDesc> &)>;
31 
32 struct StorageFormatInfo {
33   std::string expand_dims_;
34   GetFormatFunc func_;
35 };
36 
37 class StorageFormatConfig {
38  public:
StorageFormatConfig(std::string op_type)39   explicit StorageFormatConfig(std::string op_type) : op_type_(std::move(op_type)) {}
40   ~StorageFormatConfig() = default;
41   StorageFormatConfig &set_index_format(size_t index, const GetFormatFunc &func, const std::string &expend_dims = "");
42   std::optional<StorageFormatInfo> GetStorageFormatInfo(size_t index);
43 
44  private:
45   std::string op_type_;
46   std::map<size_t, StorageFormatInfo> storage_format_infoes_{};
47 };
48 
49 class StorageFormatConfigRegister {
50  public:
51   static StorageFormatConfigRegister &GetInstance();
52   StorageFormatConfig &Register(const std::string &op_type);
53   [[nodiscard]] std::optional<StorageFormatConfig> GetStorageFormatConfig(const std::string &op_type) const;
54 
55  private:
56   StorageFormatConfigRegister() = default;
57   ~StorageFormatConfigRegister() = default;
58   std::map<std::string, StorageFormatConfig> storage_format_configs_;
59 };
60 
61 #define REGISTER_STORAGE_FORMAT_CONFIG_IMPL(ctr, name)             \
62   static transform::StorageFormatConfig &register_acl##name##ctr = \
63     StorageFormatConfigRegister::GetInstance().Register(#name)
64 
65 #define REGISTER_STORAGE_FORMAT_CONFIG(name) REGISTER_STORAGE_FORMAT_CONFIG_IMPL(__COUNTER__, name)
66 }  // namespace mindspore::transform
67 
68 #endif  // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_STROAGE_FORMAT_CONFIG_FACTORY_H_
69