• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_C_API_SRC_DYNAMIC_OP_INFO_H_
18 #define MINDSPORE_CCSRC_C_API_SRC_DYNAMIC_OP_INFO_H_
19 
20 #include <vector>
21 #include <utility>
22 #include <string>
23 #include <map>
24 #include <memory>
25 #include "base/base.h"
26 #include "include/c_api/ms/node.h"
27 #include "c_api/src/common.h"
28 
29 struct InnerOpInfo {
30   std::string op_name;
31   std::vector<ValuePtr> input_values{};
32   std::vector<ShapeVector> input_shapes{};
33   std::vector<DataTypeC> input_dtypes{};
34   std::vector<ShapeVector> output_shapes{};
35   std::vector<DataTypeC> output_dtypes{};
36   std::vector<std::pair<std::string, ValuePtr>> attrs{};
37 
InnerOpInfoInnerOpInfo38   InnerOpInfo(const char *op_type, const std::vector<ValuePtr> &inputs, const std::vector<ShapeVector> &out_shapes,
39               const std::vector<DataTypeC> &out_dtypes,
40               const std::vector<std::pair<std::string, ValuePtr>> &attrs_pair) {
41     op_name = op_type;
42     for (auto input : inputs) {
43       MS_EXCEPTION_IF_NULL(input);
44       if (input->isa<TensorImpl>()) {
45         auto in_tensor = input->cast<TensorPtr>();
46         (void)input_shapes.emplace_back(in_tensor->shape());
47         (void)input_dtypes.emplace_back(DataTypeC(in_tensor->data_type_c()));
48       } else {
49         (void)input_values.emplace_back(input);
50       }
51     }
52     output_shapes = out_shapes;
53     output_dtypes = out_dtypes;
54     attrs = attrs_pair;
55   }
56 
57   bool operator==(const InnerOpInfo &op_info) const {
58     return op_name == op_info.op_name && input_values == op_info.input_values && input_shapes == op_info.input_shapes &&
59            output_shapes == op_info.output_shapes && input_dtypes == op_info.input_dtypes &&
60            output_dtypes == op_info.output_dtypes && attrs == op_info.attrs;
61   }
62 };
63 
64 template <>
65 struct std::hash<std::vector<ValuePtr>> {
66   size_t operator()(const std::vector<ValuePtr> &value_ptr_vec) const {
67     size_t res = 17;
68     for (const auto &value_ptr : value_ptr_vec) {
69       res = res * 31 + std::hash<ValuePtr>()(value_ptr);
70     }
71     return res;
72   }
73 };
74 
75 template <>
76 struct std::hash<std::vector<int64_t>> {
77   size_t operator()(const std::vector<int64_t> &value_ptr_vec) const {
78     size_t res = 17;
79     for (const auto &value_ptr : value_ptr_vec) {
80       res = res * 31 + std::hash<int64_t>()(value_ptr);
81     }
82     return res;
83   }
84 };
85 
86 template <>
87 struct std::hash<std::vector<ShapeVector>> {
88   size_t operator()(const std::vector<ShapeVector> &shape_vec) const {
89     size_t res = 17;
90     for (const auto &shape : shape_vec) {
91       res = res * 31 + std::hash<ShapeVector>()(shape);
92     }
93     return res;
94   }
95 };
96 
97 template <>
98 struct std::hash<std::vector<DataTypeC>> {
99   size_t operator()(const std::vector<DataTypeC> &dtype_vec) const {
100     size_t res = 17;
101     for (const auto &dtype : dtype_vec) {
102       res = res * 31 + std::hash<DataTypeC>()(dtype);
103     }
104     return res;
105   }
106 };
107 
108 template <>
109 struct std::hash<std::vector<std::pair<std::string, ValuePtr>>> {
110   size_t operator()(const std::vector<std::pair<std::string, ValuePtr>> &attrs_vec) const {
111     size_t res = 17;
112     for (const auto &attr : attrs_vec) {
113       res = res * 31 + std::hash<std::string>()(attr.first);
114       res = res * 31 + std::hash<ValuePtr>()(attr.second);
115     }
116     return res;
117   }
118 };
119 
120 template <>
121 struct std::hash<InnerOpInfo> {
122   size_t operator()(const InnerOpInfo &op_info) const {
123     size_t res = 17;
124     res = res * 31 + std::hash<std::string>()(op_info.op_name);
125     res = res * 31 + std::hash<std::vector<ValuePtr>>()(op_info.input_values);
126     res = res * 31 + std::hash<std::vector<ShapeVector>>()(op_info.input_shapes);
127     res = res * 31 + std::hash<std::vector<ShapeVector>>()(op_info.output_shapes);
128     res = res * 31 + std::hash<std::vector<DataTypeC>>()(op_info.input_dtypes);
129     res = res * 31 + std::hash<std::vector<DataTypeC>>()(op_info.output_dtypes);
130     res = res * 31 + std::hash<std::vector<std::pair<std::string, ValuePtr>>>()(op_info.attrs);
131     return res;
132   }
133 };
134 #endif  // MINDSPORE_CCSRC_C_API_SRC_DYNAMIC_OP_INFO_H_
135