• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
18 
19 #include <unordered_map>
20 #include <map>
21 #include <string>
22 
23 #include "backend/session/anf_runtime_algorithm.h"
24 #include "utils/ms_utils.h"
25 
26 namespace mindspore {
27 namespace kernel {
28 namespace tbe {
29 const std::unordered_map<std::string, TypeId> type_str_id_maps = {
30   {"float", TypeId::kNumberTypeFloat32},
31   {"float16", TypeId::kNumberTypeFloat16},
32   {"float32", TypeId::kNumberTypeFloat32},
33   {"float64", TypeId::kNumberTypeFloat64},
34   {"int", TypeId::kNumberTypeInt},
35   {"int8", TypeId::kNumberTypeInt8},
36   {"int16", TypeId::kNumberTypeInt16},
37   {"int32", TypeId::kNumberTypeInt32},
38   {"int64", TypeId::kNumberTypeInt64},
39   {"uint", TypeId::kNumberTypeUInt},
40   {"uint8", TypeId::kNumberTypeUInt8},
41   {"uint16", TypeId::kNumberTypeUInt16},
42   {"uint32", TypeId::kNumberTypeUInt32},
43   {"uint64", TypeId::kNumberTypeUInt64},
44   {"bool", TypeId::kNumberTypeBool},
45   {"int4", TypeId::kNumberTypeInt4},
46   {"complex64", TypeId::kNumberTypeComplex64},
47   {"complex128", TypeId::kNumberTypeComplex128},
48   {"", TypeId::kMetaTypeNone},
49 };
50 
51 const std::map<TypeId, std::string> type_id_str_maps = {
52   {TypeId::kNumberTypeFloat32, "float32"},
53   {TypeId::kNumberTypeFloat16, "float16"},
54   {TypeId::kNumberTypeFloat, "float32"},
55   {TypeId::kNumberTypeFloat64, "float64"},
56   {TypeId::kNumberTypeInt, "int"},
57   {TypeId::kNumberTypeInt8, "int8"},
58   {TypeId::kNumberTypeInt16, "int16"},
59   {TypeId::kNumberTypeInt32, "int32"},
60   {TypeId::kNumberTypeInt64, "int64"},
61   {TypeId::kNumberTypeUInt, "uint"},
62   {TypeId::kNumberTypeUInt8, "uint8"},
63   {TypeId::kNumberTypeUInt16, "uint16"},
64   {TypeId::kNumberTypeUInt32, "uint32"},
65   {TypeId::kNumberTypeUInt64, "uint64"},
66   {TypeId::kNumberTypeBool, "int8"},
67   {TypeId::kNumberTypeInt4, "int4"},
68   {TypeId::kNumberTypeComplex64, "complex64"},
69   {TypeId::kNumberTypeComplex128, "complex128"},
70   {TypeId::kMetaTypeNone, ""},
71 };
72 
73 const std::unordered_map<std::string, size_t> type_nbyte_maps = {
74   {"float16", sizeof(float) / 2}, {"float32", sizeof(float)},       {"float64", sizeof(float) * 2},
75   {"int8", sizeof(int) / 4},      {"int16", sizeof(int) / 2},       {"int32", sizeof(int)},
76   {"int64", sizeof(int) * 2},     {"uint8", sizeof(int) / 4},       {"uint16", sizeof(int) / 2},
77   {"uint32", sizeof(int)},        {"uint64", sizeof(int) * 2},      {"bool", sizeof(char)},
78   {"int4", sizeof(int) / 4},      {"complex64", sizeof(float) * 2}, {"complex128", sizeof(double) * 2},
79 };
80 
DtypeToTypeId(const std::string & dtypes)81 TypeId DtypeToTypeId(const std::string &dtypes) {
82   auto iter = type_str_id_maps.find(dtypes);
83   if (iter == type_str_id_maps.end()) {
84     MS_LOG(EXCEPTION) << "Illegal input device dtype: " << dtypes;
85   }
86   return iter->second;
87 }
88 
TypeIdToString(TypeId type_id)89 std::string TypeIdToString(TypeId type_id) {
90   auto iter = type_id_str_maps.find(type_id);
91   if (iter == type_id_str_maps.end()) {
92     MS_LOG(EXCEPTION) << "Illegal input dtype: " << TypeIdLabel(type_id);
93   }
94   return iter->second;
95 }
96 
GetDtypeNbyte(const std::string & dtypes)97 size_t GetDtypeNbyte(const std::string &dtypes) {
98   auto iter = type_nbyte_maps.find(dtypes);
99   if (iter == type_nbyte_maps.end()) {
100     MS_LOG(EXCEPTION) << "Illegal input dtype: " << dtypes;
101   }
102   return iter->second;
103 }
104 
GetProcessor(const AnfNodePtr & anf_node)105 std::string GetProcessor(const AnfNodePtr &anf_node) {
106   MS_EXCEPTION_IF_NULL(anf_node);
107   std::string device;
108   switch (AnfAlgo::GetProcessor(anf_node)) {
109     case Processor::AICORE:
110       device = kProcessorAiCore;
111       break;
112     default:
113       MS_LOG(INFO) << "Unknown processor type." << anf_node->fullname_with_scope();
114       break;
115   }
116   return device;
117 }
118 }  // namespace tbe
119 }  // namespace kernel
120 }  // namespace mindspore
121