1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/ascend/ge_types_convert.h"
18 #include "graph/utils/type_utils.h"
19
20 namespace {
21 constexpr auto kInvalidFormat = "RESERVED";
22 }
23 namespace mindspore {
24 namespace device {
25 namespace ascend {
GetGeDataType(TypeId type_id)26 ge::proto::DataType GeTypesConvert::GetGeDataType(TypeId type_id) {
27 static const std::map<TypeId, ge::proto::DataType> data_type_map = {
28 {TypeId::kTypeUnknown, ge::proto::DT_UNDEFINED}, {TypeId::kNumberTypeFloat32, ge::proto::DT_FLOAT},
29 {TypeId::kNumberTypeFloat16, ge::proto::DT_FLOAT16}, {TypeId::kNumberTypeInt8, ge::proto::DT_INT8},
30 {TypeId::kNumberTypeUInt8, ge::proto::DT_UINT8}, {TypeId::kNumberTypeInt16, ge::proto::DT_INT16},
31 {TypeId::kNumberTypeUInt16, ge::proto::DT_UINT16}, {TypeId::kNumberTypeInt32, ge::proto::DT_INT32},
32 {TypeId::kNumberTypeInt64, ge::proto::DT_INT64}, {TypeId::kNumberTypeUInt32, ge::proto::DT_UINT32},
33 {TypeId::kNumberTypeUInt64, ge::proto::DT_UINT64}, {TypeId::kNumberTypeBool, ge::proto::DT_BOOL},
34 {TypeId::kNumberTypeFloat64, ge::proto::DT_DOUBLE}, {TypeId::kObjectTypeString, ge::proto::DT_STRING},
35 };
36 MS_LOG(INFO) << "Vm origin type_id:" << type_id << ": " << TypeIdLabel(type_id);
37 auto iter = data_type_map.find(type_id);
38 if (iter == data_type_map.end()) {
39 MS_LOG(EXCEPTION) << "MindSpore data type:" << TypeIdLabel(type_id) << " can't been found in GE.";
40 }
41 return iter->second;
42 }
43
TransTypeIdToGeDataType(TypeId type_id)44 ge::DataType GeTypesConvert::TransTypeIdToGeDataType(TypeId type_id) {
45 static const std::map<TypeId, ge::DataType> data_type_map = {
46 {TypeId::kNumberTypeFloat, ge::DataType::DT_FLOAT}, {TypeId::kNumberTypeFloat32, ge::DataType::DT_FLOAT},
47 {TypeId::kNumberTypeFloat16, ge::DataType::DT_FLOAT16}, {TypeId::kNumberTypeInt8, ge::DataType::DT_INT8},
48 {TypeId::kNumberTypeInt16, ge::DataType::DT_INT16}, {TypeId::kNumberTypeUInt16, ge::DataType::DT_UINT16},
49 {TypeId::kNumberTypeUInt8, ge::DataType::DT_UINT8}, {TypeId::kNumberTypeInt32, ge::DataType::DT_INT32},
50 {TypeId::kNumberTypeInt, ge::DataType::DT_INT32}, {TypeId::kNumberTypeInt64, ge::DataType::DT_INT64},
51 {TypeId::kNumberTypeUInt32, ge::DataType::DT_UINT32}, {TypeId::kNumberTypeUInt, ge::DataType::DT_UINT32},
52 {TypeId::kNumberTypeUInt64, ge::DataType::DT_UINT64}, {TypeId::kNumberTypeBool, ge::DataType::DT_BOOL},
53 {TypeId::kNumberTypeInt64, ge::DataType::DT_DOUBLE}, {TypeId::kTypeUnknown, ge::DataType::DT_UNDEFINED}};
54 auto iter = data_type_map.find(type_id);
55 if (iter == data_type_map.end()) {
56 MS_LOG(EXCEPTION) << "Invalid data type:" << type_id << ": " << TypeIdLabel(type_id);
57 }
58 return iter->second;
59 }
60
GetGeFormat(const std::string & format,size_t shape_size)61 ge::Format GeTypesConvert::GetGeFormat(const std::string &format, size_t shape_size) {
62 static constexpr size_t k4dSize = 4;
63 static const std::map<std::string, ge::Format> format_map = {
64 // default format: nchw, fractal_nz?
65 {kOpFormat_DEFAULT, ge::Format::FORMAT_NCHW},
66 {kOpFormat_NC1KHKWHWC0, ge::Format::FORMAT_NC1KHKWHWC0},
67 {kOpFormat_ND, ge::Format::FORMAT_ND},
68 {kOpFormat_NCHW, ge::Format::FORMAT_NCHW},
69 {kOpFormat_NHWC, ge::Format::FORMAT_NHWC},
70 {kOpFormat_HWCN, ge::Format::FORMAT_HWCN},
71 {kOpFormat_NC1HWC0, ge::Format::FORMAT_NC1HWC0},
72 {kOpFormat_FRAC_Z, ge::Format::FORMAT_FRACTAL_Z},
73 {kOpFormat_FRAC_NZ, ge::Format::FORMAT_FRACTAL_NZ},
74 {kOpFormat_C1HWNCoC0, ge::Format::FORMAT_C1HWNCoC0},
75 {kOpFormat_NC1HWC0_C04, ge::Format::FORMAT_NC1HWC0_C04},
76 {kOpFormat_FRACTAL_Z_C04, ge::Format::FORMAT_FRACTAL_Z_C04},
77 {kOpFormat_NDHWC, ge::Format::FORMAT_NDHWC},
78 {kOpFormat_NCDHW, ge::Format::FORMAT_NCDHW},
79 {kOpFormat_DHWNC, ge::Format::FORMAT_DHWNC},
80 {kOpFormat_DHWCN, ge::Format::FORMAT_DHWCN},
81 {kOpFormat_NDC1HWC0, ge::Format::FORMAT_NDC1HWC0},
82 {kOpFormat_FRACTAL_Z_3D, ge::Format::FORMAT_FRACTAL_Z_3D},
83 {kOpFormat_FRACTAL_ZN_LSTM, ge::Format::FORMAT_FRACTAL_ZN_LSTM}};
84 MS_LOG(INFO) << "GetGeFormat format:" << format << " shape_size:" << shape_size;
85 if (format == kOpFormat_DEFAULT) {
86 return shape_size == k4dSize ? ge::Format::FORMAT_NCHW : ge::Format::FORMAT_ND;
87 }
88 auto iter = format_map.find(format);
89 if (iter == format_map.end()) {
90 MS_LOG(EXCEPTION) << "Invalid format:" << format;
91 }
92 return iter->second;
93 }
94
GetGeTilingFormat(ge::Format ge_format)95 std::string GeTypesConvert::GetGeTilingFormat(ge::Format ge_format) {
96 auto format_str = ge::TypeUtils::FormatToSerialString(ge_format);
97 if (format_str == kInvalidFormat) {
98 MS_LOG(EXCEPTION) << "Not support format:" << ge_format;
99 }
100 return format_str;
101 }
102 } // namespace ascend
103 } // namespace device
104 } // namespace mindspore
105