• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_TARGETS_H_
17 #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_TARGETS_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "mlir/IR/Operation.h"  // from @llvm-project
27 
28 namespace mlir {
29 namespace TFL {
30 namespace tac {
31 
32 // Device attribute string on the TFL dialect.
33 constexpr char kDevice[] = "tac.device";
34 
35 // Inference type.
36 constexpr char kInferenceType[] = "tac.inference_type";
37 
38 // TODO(renjieliu): Add more inference types.
39 enum InferenceType {
40   UNKNOWN = 0,
41   FLOAT = 1,
42   QUANTIZED_INT8 = 2,
43   QUANTIZED_UINT8 = 3,
44   HYBRID = 4
45 };
46 
GetInferenceTypeEnum(llvm::StringRef inference_type_str)47 inline InferenceType GetInferenceTypeEnum(llvm::StringRef inference_type_str) {
48   if (inference_type_str == "FLOAT") {
49     return FLOAT;
50   } else if (inference_type_str == "QUANTIZED_INT8") {
51     return QUANTIZED_INT8;
52   } else if (inference_type_str == "QUANTIZED_UINT8") {
53     return QUANTIZED_UINT8;
54   } else if (inference_type_str == "HYBRID") {
55     return HYBRID;
56   } else {
57     return UNKNOWN;
58   }
59 }
60 
GetInferenceString(InferenceType inference_type)61 inline std::string GetInferenceString(InferenceType inference_type) {
62   if (inference_type == FLOAT) {
63     return "FLOAT";
64   } else if (inference_type == QUANTIZED_INT8) {
65     return "QUANTIZED_INT8";
66   } else if (inference_type == QUANTIZED_UINT8) {
67     return "QUANTIZED_UINT8";
68   } else if (inference_type == HYBRID) {
69     return "HYBRID";
70   } else {
71     return "UNKNOWN";
72   }
73 }
74 
75 // Returns canonical representation for hardware name (All uppercase).
76 // TODO(b/177376459): Remove this in favor of the string defined by hardwares
77 // MyHardware::kId.
GetCanonicalHardwareName(const std::string & hardware_name)78 inline std::string GetCanonicalHardwareName(const std::string& hardware_name) {
79   std::string name = hardware_name;
80   std::transform(
81       name.begin(), name.end(), name.begin(),
82       [](unsigned char c) -> unsigned char { return std::toupper(c); });
83   return name;
84 }
85 
86 // Get the target annotation form the op.
GetTargetAnnotation(Operation * op)87 inline llvm::Optional<std::string> GetTargetAnnotation(Operation* op) {
88   auto device = op->getAttrOfType<StringAttr>(kDevice);
89   if (device == nullptr || device.getValue().empty()) return llvm::None;
90 
91   return GetCanonicalHardwareName(device.getValue().str());
92 }
93 
94 // Get inference type attribute from the operation if available.
GetInferenceTypeAnnotation(Operation * op)95 inline llvm::Optional<InferenceType> GetInferenceTypeAnnotation(Operation* op) {
96   auto inference_type = op->getAttrOfType<StringAttr>(kInferenceType);
97   if (inference_type == nullptr) return llvm::None;
98 
99   llvm::StringRef device_name_str = inference_type.getValue();
100   return GetInferenceTypeEnum(device_name_str);
101 }
102 
103 // InferenceDeviceType is a combination of the hardware with inference type.
104 struct InferenceDeviceType {
105   std::string hardware;
106   InferenceType inference_type;
107 
108   bool operator==(const InferenceDeviceType& other) const {
109     return (hardware == other.hardware) &&
110            (inference_type == other.inference_type);
111   }
112 
113   bool operator!=(const InferenceDeviceType& other) const {
114     return !(*this == other);
115   }
116 
117   struct inference_device_type_hash {
operatorInferenceDeviceType::inference_device_type_hash118     size_t operator()(const InferenceDeviceType& p) const {
119       auto hash1 = std::hash<std::string>{}(p.hardware);
120       auto hash2 = std::hash<InferenceType>{}(p.inference_type);
121       return hash1 ^ hash2;
122     }
123   };
124 };
125 
126 // Get InferenceDeviceType attribute from the operation if available.
GetInferenceDeviceTypeForOp(Operation * op)127 inline llvm::Optional<InferenceDeviceType> GetInferenceDeviceTypeForOp(
128     Operation* op) {
129   auto hardware = GetTargetAnnotation(op);
130   if (!hardware.hasValue()) return llvm::None;
131 
132   auto inference_type = GetInferenceTypeAnnotation(op);
133   if (!inference_type.hasValue()) return llvm::None;
134 
135   InferenceDeviceType inference_device_type;
136   inference_device_type.hardware = hardware.getValue();
137   inference_device_type.inference_type = inference_type.getValue();
138   return inference_device_type;
139 }
140 
141 }  // namespace tac
142 }  // namespace TFL
143 }  // namespace mlir
144 
145 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_TARGETS_H_
146