1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_TARGETS_H_
17 #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_TARGETS_H_
18
19 #include <string>
20 #include <vector>
21
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "mlir/IR/Operation.h" // from @llvm-project
27
28 namespace mlir {
29 namespace TFL {
30 namespace tac {
31
32 // Device attribute string on the TFL dialect.
33 constexpr char kDevice[] = "tac.device";
34
35 // Inference type.
36 constexpr char kInferenceType[] = "tac.inference_type";
37
38 // TODO(renjieliu): Add more inference types.
39 enum InferenceType {
40 UNKNOWN = 0,
41 FLOAT = 1,
42 QUANTIZED_INT8 = 2,
43 QUANTIZED_UINT8 = 3,
44 HYBRID = 4
45 };
46
GetInferenceTypeEnum(llvm::StringRef inference_type_str)47 inline InferenceType GetInferenceTypeEnum(llvm::StringRef inference_type_str) {
48 if (inference_type_str == "FLOAT") {
49 return FLOAT;
50 } else if (inference_type_str == "QUANTIZED_INT8") {
51 return QUANTIZED_INT8;
52 } else if (inference_type_str == "QUANTIZED_UINT8") {
53 return QUANTIZED_UINT8;
54 } else if (inference_type_str == "HYBRID") {
55 return HYBRID;
56 } else {
57 return UNKNOWN;
58 }
59 }
60
GetInferenceString(InferenceType inference_type)61 inline std::string GetInferenceString(InferenceType inference_type) {
62 if (inference_type == FLOAT) {
63 return "FLOAT";
64 } else if (inference_type == QUANTIZED_INT8) {
65 return "QUANTIZED_INT8";
66 } else if (inference_type == QUANTIZED_UINT8) {
67 return "QUANTIZED_UINT8";
68 } else if (inference_type == HYBRID) {
69 return "HYBRID";
70 } else {
71 return "UNKNOWN";
72 }
73 }
74
75 // Returns canonical representation for hardware name (All uppercase).
76 // TODO(b/177376459): Remove this in favor of the string defined by hardwares
77 // MyHardware::kId.
GetCanonicalHardwareName(const std::string & hardware_name)78 inline std::string GetCanonicalHardwareName(const std::string& hardware_name) {
79 std::string name = hardware_name;
80 std::transform(
81 name.begin(), name.end(), name.begin(),
82 [](unsigned char c) -> unsigned char { return std::toupper(c); });
83 return name;
84 }
85
86 // Get the target annotation form the op.
GetTargetAnnotation(Operation * op)87 inline llvm::Optional<std::string> GetTargetAnnotation(Operation* op) {
88 auto device = op->getAttrOfType<StringAttr>(kDevice);
89 if (device == nullptr || device.getValue().empty()) return llvm::None;
90
91 return GetCanonicalHardwareName(device.getValue().str());
92 }
93
94 // Get inference type attribute from the operation if available.
GetInferenceTypeAnnotation(Operation * op)95 inline llvm::Optional<InferenceType> GetInferenceTypeAnnotation(Operation* op) {
96 auto inference_type = op->getAttrOfType<StringAttr>(kInferenceType);
97 if (inference_type == nullptr) return llvm::None;
98
99 llvm::StringRef device_name_str = inference_type.getValue();
100 return GetInferenceTypeEnum(device_name_str);
101 }
102
103 // InferenceDeviceType is a combination of the hardware with inference type.
104 struct InferenceDeviceType {
105 std::string hardware;
106 InferenceType inference_type;
107
108 bool operator==(const InferenceDeviceType& other) const {
109 return (hardware == other.hardware) &&
110 (inference_type == other.inference_type);
111 }
112
113 bool operator!=(const InferenceDeviceType& other) const {
114 return !(*this == other);
115 }
116
117 struct inference_device_type_hash {
operatorInferenceDeviceType::inference_device_type_hash118 size_t operator()(const InferenceDeviceType& p) const {
119 auto hash1 = std::hash<std::string>{}(p.hardware);
120 auto hash2 = std::hash<InferenceType>{}(p.inference_type);
121 return hash1 ^ hash2;
122 }
123 };
124 };
125
126 // Get InferenceDeviceType attribute from the operation if available.
GetInferenceDeviceTypeForOp(Operation * op)127 inline llvm::Optional<InferenceDeviceType> GetInferenceDeviceTypeForOp(
128 Operation* op) {
129 auto hardware = GetTargetAnnotation(op);
130 if (!hardware.hasValue()) return llvm::None;
131
132 auto inference_type = GetInferenceTypeAnnotation(op);
133 if (!inference_type.hasValue()) return llvm::None;
134
135 InferenceDeviceType inference_device_type;
136 inference_device_type.hardware = hardware.getValue();
137 inference_device_type.inference_type = inference_type.getValue();
138 return inference_device_type;
139 }
140
141 } // namespace tac
142 } // namespace TFL
143 } // namespace mlir
144
145 #endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_TARGETS_H_
146