1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/object_reader.h"
17
18 #include <cstdint>
19 #include <optional>
20 #include <string>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/delegates/gpu/common/model.h"
26 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
27 #include "tensorflow/lite/delegates/gpu/common/status.h"
28 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
29 #include "tensorflow/lite/delegates/utils.h"
30 #include "tensorflow/lite/kernels/kernel_util.h"
31
32 namespace tflite {
33 namespace gpu {
34
ReadNonConstantTensor(TfLiteContext * context,absl::flat_hash_map<int,Value * > * tensor_to_value,absl::flat_hash_map<int,int> * quant_conversion_map,GraphFloat32 * graph,uint32_t tensor_idx,Value ** value)35 absl::Status ObjectReader::ReadNonConstantTensor(
36 TfLiteContext* context, absl::flat_hash_map<int, Value*>* tensor_to_value,
37 absl::flat_hash_map<int, int>* quant_conversion_map, GraphFloat32* graph,
38 uint32_t tensor_idx, Value** value) {
39 if (tensor_idx >= context->tensors_size) {
40 return absl::OutOfRangeError(
41 absl::StrCat("ReadNonConstTensor: input tensor index: ", tensor_idx));
42 }
43
44 if (tensor_to_value->find(tensor_idx) == tensor_to_value->end()) {
45 TfLiteTensor* tflite_tensor = &context->tensors[tensor_idx];
46 if (tflite::IsConstantTensor(tflite_tensor)) {
47 return absl::InvalidArgumentError(absl::StrCat(
48 "ReadNonConstantTensor: value is a constant tensor: ", tensor_idx));
49 }
50
51 if ((tflite_tensor->type == kTfLiteInt8 ||
52 tflite_tensor->type == kTfLiteUInt8) &&
53 quant_conversion_map) {
54 // Quantized case
55 if (quant_conversion_map->find(tensor_idx) ==
56 quant_conversion_map->end()) {
57 // Since the original tensor is fixed-point, add a new float tensor to
58 // the TFLite graph to represent the dequantized data.
59 int fp_tensor_index = 0;
60 TfLiteTensor* fp_tflite_tensor;
61 if (delegates::CreateNewTensorWithDifferentType(
62 context, tensor_idx, kTfLiteFloat32, &fp_tflite_tensor,
63 &fp_tensor_index) != kTfLiteOk) {
64 return absl::InternalError("Could not add new tensor to graph");
65 }
66 // `tflite_tensor` value could be invalid when the `context->tensors`
67 // is reallocated. Thus reassigning `tflite_tensor` with a fresh value.
68 tflite_tensor = &context->tensors[tensor_idx];
69
70 // Remember this tensor for later.
71 (*quant_conversion_map)[fp_tensor_index] = tensor_idx;
72 (*quant_conversion_map)[tensor_idx] = fp_tensor_index;
73 // Add a new GPU Value for the new dequantized floating-point tensor.
74 Value* value = graph->NewValue();
75 RETURN_IF_ERROR(
76 ConvertTfLiteTensorToTensorRef(*fp_tflite_tensor, &value->tensor));
77 value->tensor.ref = fp_tensor_index;
78 value->tensor.is_variable_input = tflite_tensor->is_variable;
79 value->quant_params.emplace();
80 RETURN_IF_ERROR(
81 PopulateQuantParams(*tflite_tensor, &value->quant_params.value()));
82 (*tensor_to_value)[fp_tensor_index] = value;
83 }
84 // We do not use the original tensor index as reference for the GPU
85 // Value, instead pointing at the corresponding float version.
86 tensor_idx = quant_conversion_map->at(tensor_idx);
87 } else {
88 // Floating-point case.
89 Value* value = graph->NewValue();
90 RETURN_IF_ERROR(
91 ConvertTfLiteTensorToTensorRef(*tflite_tensor, &value->tensor));
92 value->tensor.ref = tensor_idx;
93 value->tensor.is_variable_input = tflite_tensor->is_variable;
94 (*tensor_to_value)[tensor_idx] = value;
95 }
96 }
97
98 if (value) {
99 *value = (*tensor_to_value)[tensor_idx];
100 }
101 return absl::OkStatus();
102 }
103
ReadValue(uint32_t idx,Value ** value)104 absl::Status ObjectReader::ReadValue(uint32_t idx, Value** value) {
105 if (idx >= node_->inputs->size) {
106 return absl::OutOfRangeError(
107 absl::StrCat("ReadValue: input tensor index: ", idx));
108 }
109 return ReadValueByTensorIdx(node_->inputs->data[idx], value);
110 }
111
ReadValueByTensorIdx(uint32_t tensor_idx,Value ** value)112 absl::Status ObjectReader::ReadValueByTensorIdx(uint32_t tensor_idx,
113 Value** value) {
114 // Constant tensors should be handled by ReadTensor.
115 return ReadNonConstantTensor(context_, tensor_to_value_,
116 quant_conversion_map_, graph_, tensor_idx,
117 value);
118 }
119
GetNumberOfRuntimeInputs() const120 int ObjectReader::GetNumberOfRuntimeInputs() const {
121 return GetNumberOfRuntimeInputsForNode(context_, node_);
122 }
123
GetTensorId(uint32_t input_id,int * tensor_id) const124 absl::Status ObjectReader::GetTensorId(uint32_t input_id,
125 int* tensor_id) const {
126 if (input_id >= node_->inputs->size) {
127 return absl::OutOfRangeError(
128 absl::StrCat("Input tensor index: ", input_id));
129 }
130 *tensor_id = node_->inputs->data[input_id];
131 if (*tensor_id < 0 || *tensor_id > context_->tensors_size) {
132 return absl::OutOfRangeError(absl::StrCat("Tensor index: ", *tensor_id));
133 }
134 return absl::OkStatus();
135 }
136
GetTensorDims(uint32_t idx,TfLiteIntArray * dimensions) const137 absl::Status ObjectReader::GetTensorDims(uint32_t idx,
138 TfLiteIntArray* dimensions) const {
139 if (idx >= node_->inputs->size) {
140 return absl::OutOfRangeError(absl::StrCat("Input tensor index: ", idx));
141 }
142 const int tensor_idx = node_->inputs->data[idx];
143 if (tensor_idx < 0 || tensor_idx > context_->tensors_size) {
144 return absl::OutOfRangeError(absl::StrCat("Tensor index: ", tensor_idx));
145 }
146 const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx];
147 *dimensions = *tflite_tensor.dims;
148 return absl::OkStatus();
149 }
150
AddOutput(const Node * node,int id)151 absl::Status ObjectReader::AddOutput(const Node* node, int id) {
152 if (node_->outputs->size <= id) {
153 return absl::InvalidArgumentError(absl::StrCat(
154 "Data id ", id, " must be less than tflite node outputs size ",
155 node_->outputs->size));
156 }
157 int output_tensor_idx = node_->outputs->data[id];
158 Value* value;
159 RETURN_IF_ERROR(ReadValueByTensorIdx(output_tensor_idx, &value));
160 RETURN_IF_ERROR(graph_->SetProducer(node->id, value->id));
161 return absl::OkStatus();
162 }
163
AddOutputs(const Node * node)164 absl::Status ObjectReader::AddOutputs(const Node* node) {
165 for (int i = 0; i < node_->outputs->size; ++i) {
166 RETURN_IF_ERROR(AddOutput(node, i));
167 }
168 return absl::OkStatus();
169 }
170
AddInput(const Node * node,uint32_t idx)171 absl::Status ObjectReader::AddInput(const Node* node, uint32_t idx) {
172 Value* input;
173 RETURN_IF_ERROR(ReadValue(idx, &input));
174 return graph_->AddConsumer(node->id, input->id);
175 }
176
AddUpdate(const Node * node,uint32_t idx)177 absl::Status ObjectReader::AddUpdate(const Node* node, uint32_t idx) {
178 if (node_->inputs->size <= idx) {
179 return absl::InvalidArgumentError(absl::StrCat(
180 "Data id ", idx, " must be less than tflite node inputs size ",
181 node_->inputs->size));
182 }
183
184 int update_tensor_idx = node_->inputs->data[idx];
185 TfLiteTensor* update_tensor = context_->tensors + update_tensor_idx;
186 if (!update_tensor->is_variable) {
187 return absl::InvalidArgumentError(
188 "The tensor must be a variable tensor to update it in place");
189 }
190
191 Value* value;
192 RETURN_IF_ERROR(ReadValueByTensorIdx(update_tensor_idx, &value));
193 if (!value->tensor.is_variable_input) {
194 return absl::InternalError(
195 "Variable input tensor is not marked as variable");
196 }
197
198 // We cannot create a cycle in the graph. The way around this when a node
199 // updates a tensor in place would be to add a new value to the graph that
200 // points to the same tensor.
201 Value* updated_value = graph_->NewValue();
202 updated_value->tensor = value->tensor;
203 updated_value->quant_params = value->quant_params;
204 RETURN_IF_ERROR(graph_->SetProducer(node->id, updated_value->id));
205
206 // We also need to update the tensor_to_value arrays so that the nodes added
207 // after the current node will access the tensor with the updated value rather
208 // than the initial value.
209 if (quant_conversion_map_ != nullptr &&
210 quant_conversion_map_->find(update_tensor_idx) !=
211 quant_conversion_map_->end()) {
212 // If quantization conversion map exists, then the index provided is not the
213 // actual tensor idx. We need to find the float version of the tensor from
214 // the map.
215 tensor_to_value_->at(quant_conversion_map_->at(update_tensor_idx)) =
216 updated_value;
217 } else {
218 tensor_to_value_->at(update_tensor_idx) = updated_value;
219 }
220
221 return absl::OkStatus();
222 }
223
GetInputTensor(int index) const224 TfLiteTensor* ObjectReader::GetInputTensor(int index) const {
225 return index >= 0 && index < node_->inputs->size
226 ? context_->tensors + node_->inputs->data[index]
227 : nullptr;
228 }
229
GetOutputTensor(int index) const230 TfLiteTensor* ObjectReader::GetOutputTensor(int index) const {
231 return index >= 0 && index < node_->outputs->size
232 ? context_->tensors + node_->outputs->data[index]
233 : nullptr;
234 }
235
VerifyInputsConstsOutputs(const TfLiteNode * node,int runtime_inputs,int const_inputs,int outputs)236 absl::Status ObjectReader::VerifyInputsConstsOutputs(const TfLiteNode* node,
237 int runtime_inputs,
238 int const_inputs,
239 int outputs) {
240 return CheckInputsConstsOutputs(context_, node, runtime_inputs, const_inputs,
241 outputs);
242 }
243
244 } // namespace gpu
245 } // namespace tflite
246