1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/object_reader.h"
17
18 #include <cstdint>
19 #include <optional>
20 #include <string>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/delegates/gpu/common/model.h"
26 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
27 #include "tensorflow/lite/delegates/gpu/common/status.h"
28 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
29 #include "tensorflow/lite/delegates/utils.h"
30 #include "tensorflow/lite/kernels/kernel_util.h"
31
32 namespace tflite {
33 namespace gpu {
34
ReadNonConstantTensor(TfLiteContext * context,absl::flat_hash_map<int,Value * > * tensor_to_value,absl::flat_hash_map<int,int> * quant_conversion_map,GraphFloat32 * graph,uint32_t tensor_idx,Value ** value)35 absl::Status ObjectReader::ReadNonConstantTensor(
36 TfLiteContext* context, absl::flat_hash_map<int, Value*>* tensor_to_value,
37 absl::flat_hash_map<int, int>* quant_conversion_map, GraphFloat32* graph,
38 uint32_t tensor_idx, Value** value) {
39 if (tensor_idx >= context->tensors_size) {
40 return absl::OutOfRangeError(
41 absl::StrCat("ReadNonConstTensor: input tensor index: ", tensor_idx));
42 }
43
44 if (tensor_to_value->find(tensor_idx) == tensor_to_value->end()) {
45 TfLiteTensor* tflite_tensor = &context->tensors[tensor_idx];
46 if (tflite::IsConstantTensor(tflite_tensor)) {
47 return absl::InvalidArgumentError(absl::StrCat(
48 "ReadNonConstantTensor: value is a constant tensor: ", tensor_idx));
49 }
50
51 if ((tflite_tensor->type == kTfLiteInt8 ||
52 tflite_tensor->type == kTfLiteUInt8) &&
53 quant_conversion_map) {
54 // Quantized case
55 if (quant_conversion_map->find(tensor_idx) ==
56 quant_conversion_map->end()) {
57 // Since the original tensor is fixed-point, add a new float tensor to
58 // the TFLite graph to represent the dequantized data.
59 int fp_tensor_index = 0;
60 TfLiteTensor* fp_tflite_tensor;
61 if (delegates::CreateNewTensorWithDifferentType(
62 context, tensor_idx, kTfLiteFloat32, &fp_tflite_tensor,
63 &fp_tensor_index) != kTfLiteOk) {
64 return absl::InternalError("Could not add new tensor to graph");
65 }
66 // `tflite_tensor` value could be invalid when the `context->tensors`
67 // is reallocated. Thus reassigning `tflite_tensor` with a fresh value.
68 tflite_tensor = &context->tensors[tensor_idx];
69
70 // Remember this tensor for later.
71 (*quant_conversion_map)[fp_tensor_index] = tensor_idx;
72 (*quant_conversion_map)[tensor_idx] = fp_tensor_index;
73 // Add a new GPU Value for the new dequantized floating-point tensor.
74 Value* value = graph->NewValue();
75 RETURN_IF_ERROR(
76 ConvertTfLiteTensorToTensorRef(*fp_tflite_tensor, &value->tensor));
77 value->tensor.ref = fp_tensor_index;
78 value->tensor.is_variable_input = tflite_tensor->is_variable;
79 value->quant_params.emplace();
80 RETURN_IF_ERROR(
81 PopulateQuantParams(*tflite_tensor, &value->quant_params.value()));
82 (*tensor_to_value)[fp_tensor_index] = value;
83 }
84 // We do not use the original tensor index as reference for the GPU
85 // Value, instead pointing at the corresponding float version.
86 tensor_idx = quant_conversion_map->at(tensor_idx);
87 } else {
88 // Floating-point case.
89 Value* value = graph->NewValue();
90 RETURN_IF_ERROR(
91 ConvertTfLiteTensorToTensorRef(*tflite_tensor, &value->tensor));
92 value->tensor.ref = tensor_idx;
93 value->tensor.is_variable_input = tflite_tensor->is_variable;
94 (*tensor_to_value)[tensor_idx] = value;
95 }
96 }
97
98 if (value) {
99 *value = (*tensor_to_value)[tensor_idx];
100 }
101 return absl::OkStatus();
102 }
103
ReadValue(uint32_t idx,Value ** value)104 absl::Status ObjectReader::ReadValue(uint32_t idx, Value** value) {
105 if (idx >= node_->inputs->size) {
106 return absl::OutOfRangeError(
107 absl::StrCat("ReadValue: input tensor index: ", idx));
108 }
109 return ReadValueByTensorIdx(node_->inputs->data[idx], value);
110 }
111
ReadValueByTensorIdx(uint32_t tensor_idx,Value ** value)112 absl::Status ObjectReader::ReadValueByTensorIdx(uint32_t tensor_idx,
113 Value** value) {
114 // Constant tensors should be handled by ReadTensor.
115 return ReadNonConstantTensor(context_, tensor_to_value_,
116 quant_conversion_map_, graph_, tensor_idx,
117 value);
118 }
119
GetNumberOfRuntimeInputs() const120 int ObjectReader::GetNumberOfRuntimeInputs() const {
121 return GetNumberOfRuntimeInputsForNode(context_, node_);
122 }
123
GetTensorDims(uint32_t idx,TfLiteIntArray * dimensions) const124 absl::Status ObjectReader::GetTensorDims(uint32_t idx,
125 TfLiteIntArray* dimensions) const {
126 if (idx >= node_->inputs->size) {
127 return absl::OutOfRangeError(absl::StrCat("Input tensor index: ", idx));
128 }
129 const int tensor_idx = node_->inputs->data[idx];
130 if (tensor_idx < 0 || tensor_idx > context_->tensors_size) {
131 return absl::OutOfRangeError(absl::StrCat("Tensor index: ", tensor_idx));
132 }
133 const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx];
134 *dimensions = *tflite_tensor.dims;
135 return absl::OkStatus();
136 }
137
AddOutput(const Node * node,int id)138 absl::Status ObjectReader::AddOutput(const Node* node, int id) {
139 if (node_->outputs->size <= id) {
140 return absl::InvalidArgumentError(absl::StrCat(
141 "Data id ", id, " must be less than tflite node outputs size ",
142 node_->outputs->size));
143 }
144 int output_tensor_idx = node_->outputs->data[id];
145 Value* value;
146 RETURN_IF_ERROR(ReadValueByTensorIdx(output_tensor_idx, &value));
147 RETURN_IF_ERROR(graph_->SetProducer(node->id, value->id));
148 return absl::OkStatus();
149 }
150
AddOutputs(const Node * node)151 absl::Status ObjectReader::AddOutputs(const Node* node) {
152 for (int i = 0; i < node_->outputs->size; ++i) {
153 RETURN_IF_ERROR(AddOutput(node, i));
154 }
155 return absl::OkStatus();
156 }
157
AddInput(const Node * node,uint32_t idx)158 absl::Status ObjectReader::AddInput(const Node* node, uint32_t idx) {
159 Value* input;
160 RETURN_IF_ERROR(ReadValue(idx, &input));
161 return graph_->AddConsumer(node->id, input->id);
162 }
163
AddUpdate(const Node * node,uint32_t idx)164 absl::Status ObjectReader::AddUpdate(const Node* node, uint32_t idx) {
165 if (node_->inputs->size <= idx) {
166 return absl::InvalidArgumentError(absl::StrCat(
167 "Data id ", idx, " must be less than tflite node inputs size ",
168 node_->inputs->size));
169 }
170
171 int update_tensor_idx = node_->inputs->data[idx];
172 TfLiteTensor* update_tensor = context_->tensors + update_tensor_idx;
173 if (!update_tensor->is_variable) {
174 return absl::InvalidArgumentError(
175 "The tensor must be a variable tensor to update it in place");
176 }
177
178 Value* value;
179 RETURN_IF_ERROR(ReadValueByTensorIdx(update_tensor_idx, &value));
180 if (!value->tensor.is_variable_input) {
181 return absl::InternalError(
182 "Variable input tensor is not marked as variable");
183 }
184
185 // We cannot create a cycle in the graph. The way around this when a node
186 // updates a tensor in place would be to add a new value to the graph that
187 // points to the same tensor.
188 Value* updated_value = graph_->NewValue();
189 updated_value->tensor = value->tensor;
190 updated_value->quant_params = value->quant_params;
191 RETURN_IF_ERROR(graph_->SetProducer(node->id, updated_value->id));
192
193 // We also need to update the tensor_to_value arrays so that the nodes added
194 // after the current node will access the tensor with the updated value rather
195 // than the initial value.
196 if (quant_conversion_map_ != nullptr &&
197 quant_conversion_map_->find(update_tensor_idx) !=
198 quant_conversion_map_->end()) {
199 // If quantization conversion map exists, then the index provided is not the
200 // actual tensor idx. We need to find the float version of the tensor from
201 // the map.
202 tensor_to_value_->at(quant_conversion_map_->at(update_tensor_idx)) =
203 updated_value;
204 } else {
205 tensor_to_value_->at(update_tensor_idx) = updated_value;
206 }
207
208 return absl::OkStatus();
209 }
210
GetInputTensor(int index) const211 TfLiteTensor* ObjectReader::GetInputTensor(int index) const {
212 return index >= 0 && index < node_->inputs->size
213 ? context_->tensors + node_->inputs->data[index]
214 : nullptr;
215 }
216
GetOutputTensor(int index) const217 TfLiteTensor* ObjectReader::GetOutputTensor(int index) const {
218 return index >= 0 && index < node_->outputs->size
219 ? context_->tensors + node_->outputs->data[index]
220 : nullptr;
221 }
222
VerifyInputsConstsOutputs(const TfLiteNode * node,int runtime_inputs,int const_inputs,int outputs)223 absl::Status ObjectReader::VerifyInputsConstsOutputs(const TfLiteNode* node,
224 int runtime_inputs,
225 int const_inputs,
226 int outputs) {
227 return CheckInputsConstsOutputs(context_, node, runtime_inputs, const_inputs,
228 outputs);
229 }
230
231 } // namespace gpu
232 } // namespace tflite
233