• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/ascend/ge_types_convert.h"
18 #include "graph/utils/type_utils.h"
19 
20 namespace {
21 constexpr auto kInvalidFormat = "RESERVED";
22 }
23 namespace mindspore {
24 namespace device {
25 namespace ascend {
GetGeDataType(TypeId type_id)26 ge::proto::DataType GeTypesConvert::GetGeDataType(TypeId type_id) {
27   static const std::map<TypeId, ge::proto::DataType> data_type_map = {
28     {TypeId::kTypeUnknown, ge::proto::DT_UNDEFINED},     {TypeId::kNumberTypeFloat32, ge::proto::DT_FLOAT},
29     {TypeId::kNumberTypeFloat16, ge::proto::DT_FLOAT16}, {TypeId::kNumberTypeInt8, ge::proto::DT_INT8},
30     {TypeId::kNumberTypeUInt8, ge::proto::DT_UINT8},     {TypeId::kNumberTypeInt16, ge::proto::DT_INT16},
31     {TypeId::kNumberTypeUInt16, ge::proto::DT_UINT16},   {TypeId::kNumberTypeInt32, ge::proto::DT_INT32},
32     {TypeId::kNumberTypeInt64, ge::proto::DT_INT64},     {TypeId::kNumberTypeUInt32, ge::proto::DT_UINT32},
33     {TypeId::kNumberTypeUInt64, ge::proto::DT_UINT64},   {TypeId::kNumberTypeBool, ge::proto::DT_BOOL},
34     {TypeId::kNumberTypeFloat64, ge::proto::DT_DOUBLE},  {TypeId::kObjectTypeString, ge::proto::DT_STRING},
35   };
36   MS_LOG(INFO) << "Vm origin type_id:" << type_id << ": " << TypeIdLabel(type_id);
37   auto iter = data_type_map.find(type_id);
38   if (iter == data_type_map.end()) {
39     MS_LOG(EXCEPTION) << "MindSpore data type:" << TypeIdLabel(type_id) << " can't been found in GE.";
40   }
41   return iter->second;
42 }
43 
TransTypeIdToGeDataType(TypeId type_id)44 ge::DataType GeTypesConvert::TransTypeIdToGeDataType(TypeId type_id) {
45   static const std::map<TypeId, ge::DataType> data_type_map = {
46     {TypeId::kNumberTypeFloat, ge::DataType::DT_FLOAT},     {TypeId::kNumberTypeFloat32, ge::DataType::DT_FLOAT},
47     {TypeId::kNumberTypeFloat16, ge::DataType::DT_FLOAT16}, {TypeId::kNumberTypeInt8, ge::DataType::DT_INT8},
48     {TypeId::kNumberTypeInt16, ge::DataType::DT_INT16},     {TypeId::kNumberTypeUInt16, ge::DataType::DT_UINT16},
49     {TypeId::kNumberTypeUInt8, ge::DataType::DT_UINT8},     {TypeId::kNumberTypeInt32, ge::DataType::DT_INT32},
50     {TypeId::kNumberTypeInt, ge::DataType::DT_INT32},       {TypeId::kNumberTypeInt64, ge::DataType::DT_INT64},
51     {TypeId::kNumberTypeUInt32, ge::DataType::DT_UINT32},   {TypeId::kNumberTypeUInt, ge::DataType::DT_UINT32},
52     {TypeId::kNumberTypeUInt64, ge::DataType::DT_UINT64},   {TypeId::kNumberTypeBool, ge::DataType::DT_BOOL},
53     {TypeId::kNumberTypeInt64, ge::DataType::DT_DOUBLE},    {TypeId::kTypeUnknown, ge::DataType::DT_UNDEFINED}};
54   auto iter = data_type_map.find(type_id);
55   if (iter == data_type_map.end()) {
56     MS_LOG(EXCEPTION) << "Invalid data type:" << type_id << ": " << TypeIdLabel(type_id);
57   }
58   return iter->second;
59 }
60 
GetGeFormat(const std::string & format,size_t shape_size)61 ge::Format GeTypesConvert::GetGeFormat(const std::string &format, size_t shape_size) {
62   static constexpr size_t k4dSize = 4;
63   static const std::map<std::string, ge::Format> format_map = {
64     // default format: nchw, fractal_nz?
65     {kOpFormat_DEFAULT, ge::Format::FORMAT_NCHW},
66     {kOpFormat_NC1KHKWHWC0, ge::Format::FORMAT_NC1KHKWHWC0},
67     {kOpFormat_ND, ge::Format::FORMAT_ND},
68     {kOpFormat_NCHW, ge::Format::FORMAT_NCHW},
69     {kOpFormat_NHWC, ge::Format::FORMAT_NHWC},
70     {kOpFormat_HWCN, ge::Format::FORMAT_HWCN},
71     {kOpFormat_NC1HWC0, ge::Format::FORMAT_NC1HWC0},
72     {kOpFormat_FRAC_Z, ge::Format::FORMAT_FRACTAL_Z},
73     {kOpFormat_FRAC_NZ, ge::Format::FORMAT_FRACTAL_NZ},
74     {kOpFormat_C1HWNCoC0, ge::Format::FORMAT_C1HWNCoC0},
75     {kOpFormat_NC1HWC0_C04, ge::Format::FORMAT_NC1HWC0_C04},
76     {kOpFormat_FRACTAL_Z_C04, ge::Format::FORMAT_FRACTAL_Z_C04},
77     {kOpFormat_NDHWC, ge::Format::FORMAT_NDHWC},
78     {kOpFormat_NCDHW, ge::Format::FORMAT_NCDHW},
79     {kOpFormat_DHWNC, ge::Format::FORMAT_DHWNC},
80     {kOpFormat_DHWCN, ge::Format::FORMAT_DHWCN},
81     {kOpFormat_NDC1HWC0, ge::Format::FORMAT_NDC1HWC0},
82     {kOpFormat_FRACTAL_Z_3D, ge::Format::FORMAT_FRACTAL_Z_3D},
83     {kOpFormat_FRACTAL_ZN_LSTM, ge::Format::FORMAT_FRACTAL_ZN_LSTM}};
84   MS_LOG(INFO) << "GetGeFormat format:" << format << " shape_size:" << shape_size;
85   if (format == kOpFormat_DEFAULT) {
86     return shape_size == k4dSize ? ge::Format::FORMAT_NCHW : ge::Format::FORMAT_ND;
87   }
88   auto iter = format_map.find(format);
89   if (iter == format_map.end()) {
90     MS_LOG(EXCEPTION) << "Invalid format:" << format;
91   }
92   return iter->second;
93 }
94 
GetGeTilingFormat(ge::Format ge_format)95 std::string GeTypesConvert::GetGeTilingFormat(ge::Format ge_format) {
96   auto format_str = ge::TypeUtils::FormatToSerialString(ge_format);
97   if (format_str == kInvalidFormat) {
98     MS_LOG(EXCEPTION) << "Not support format:" << ge_format;
99   }
100   return format_str;
101 }
102 }  // namespace ascend
103 }  // namespace device
104 }  // namespace mindspore
105