• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/object_reader.h"
17 
18 #include <cstdint>
19 #include <optional>
20 #include <string>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/delegates/gpu/common/model.h"
26 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
27 #include "tensorflow/lite/delegates/gpu/common/status.h"
28 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
29 #include "tensorflow/lite/delegates/utils.h"
30 #include "tensorflow/lite/kernels/kernel_util.h"
31 
32 namespace tflite {
33 namespace gpu {
34 
ReadNonConstantTensor(TfLiteContext * context,absl::flat_hash_map<int,Value * > * tensor_to_value,absl::flat_hash_map<int,int> * quant_conversion_map,GraphFloat32 * graph,uint32_t tensor_idx,Value ** value)35 absl::Status ObjectReader::ReadNonConstantTensor(
36     TfLiteContext* context, absl::flat_hash_map<int, Value*>* tensor_to_value,
37     absl::flat_hash_map<int, int>* quant_conversion_map, GraphFloat32* graph,
38     uint32_t tensor_idx, Value** value) {
39   if (tensor_idx >= context->tensors_size) {
40     return absl::OutOfRangeError(
41         absl::StrCat("ReadNonConstTensor: input tensor index: ", tensor_idx));
42   }
43 
44   if (tensor_to_value->find(tensor_idx) == tensor_to_value->end()) {
45     TfLiteTensor* tflite_tensor = &context->tensors[tensor_idx];
46     if (tflite::IsConstantTensor(tflite_tensor)) {
47       return absl::InvalidArgumentError(absl::StrCat(
48           "ReadNonConstantTensor: value is a constant tensor: ", tensor_idx));
49     }
50 
51     if ((tflite_tensor->type == kTfLiteInt8 ||
52          tflite_tensor->type == kTfLiteUInt8) &&
53         quant_conversion_map) {
54       // Quantized case
55       if (quant_conversion_map->find(tensor_idx) ==
56           quant_conversion_map->end()) {
57         // Since the original tensor is fixed-point, add a new float tensor to
58         // the TFLite graph to represent the dequantized data.
59         int fp_tensor_index = 0;
60         TfLiteTensor* fp_tflite_tensor;
61         if (delegates::CreateNewTensorWithDifferentType(
62                 context, tensor_idx, kTfLiteFloat32, &fp_tflite_tensor,
63                 &fp_tensor_index) != kTfLiteOk) {
64           return absl::InternalError("Could not add new tensor to graph");
65         }
66         // `tflite_tensor` value could be invalid when the `context->tensors`
67         // is reallocated. Thus reassigning `tflite_tensor` with a fresh value.
68         tflite_tensor = &context->tensors[tensor_idx];
69 
70         // Remember this tensor for later.
71         (*quant_conversion_map)[fp_tensor_index] = tensor_idx;
72         (*quant_conversion_map)[tensor_idx] = fp_tensor_index;
73         // Add a new GPU Value for the new dequantized floating-point tensor.
74         Value* value = graph->NewValue();
75         RETURN_IF_ERROR(
76             ConvertTfLiteTensorToTensorRef(*fp_tflite_tensor, &value->tensor));
77         value->tensor.ref = fp_tensor_index;
78         value->tensor.is_variable_input = tflite_tensor->is_variable;
79         value->quant_params.emplace();
80         RETURN_IF_ERROR(
81             PopulateQuantParams(*tflite_tensor, &value->quant_params.value()));
82         (*tensor_to_value)[fp_tensor_index] = value;
83       }
84       // We do not use the original tensor index as reference for the GPU
85       // Value, instead pointing at the corresponding float version.
86       tensor_idx = quant_conversion_map->at(tensor_idx);
87     } else {
88       // Floating-point case.
89       Value* value = graph->NewValue();
90       RETURN_IF_ERROR(
91           ConvertTfLiteTensorToTensorRef(*tflite_tensor, &value->tensor));
92       value->tensor.ref = tensor_idx;
93       value->tensor.is_variable_input = tflite_tensor->is_variable;
94       (*tensor_to_value)[tensor_idx] = value;
95     }
96   }
97 
98   if (value) {
99     *value = (*tensor_to_value)[tensor_idx];
100   }
101   return absl::OkStatus();
102 }
103 
ReadValue(uint32_t idx,Value ** value)104 absl::Status ObjectReader::ReadValue(uint32_t idx, Value** value) {
105   if (idx >= node_->inputs->size) {
106     return absl::OutOfRangeError(
107         absl::StrCat("ReadValue: input tensor index: ", idx));
108   }
109   return ReadValueByTensorIdx(node_->inputs->data[idx], value);
110 }
111 
ReadValueByTensorIdx(uint32_t tensor_idx,Value ** value)112 absl::Status ObjectReader::ReadValueByTensorIdx(uint32_t tensor_idx,
113                                                 Value** value) {
114   // Constant tensors should be handled by ReadTensor.
115   return ReadNonConstantTensor(context_, tensor_to_value_,
116                                quant_conversion_map_, graph_, tensor_idx,
117                                value);
118 }
119 
GetNumberOfRuntimeInputs() const120 int ObjectReader::GetNumberOfRuntimeInputs() const {
121   return GetNumberOfRuntimeInputsForNode(context_, node_);
122 }
123 
GetTensorDims(uint32_t idx,TfLiteIntArray * dimensions) const124 absl::Status ObjectReader::GetTensorDims(uint32_t idx,
125                                          TfLiteIntArray* dimensions) const {
126   if (idx >= node_->inputs->size) {
127     return absl::OutOfRangeError(absl::StrCat("Input tensor index: ", idx));
128   }
129   const int tensor_idx = node_->inputs->data[idx];
130   if (tensor_idx < 0 || tensor_idx > context_->tensors_size) {
131     return absl::OutOfRangeError(absl::StrCat("Tensor index: ", tensor_idx));
132   }
133   const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx];
134   *dimensions = *tflite_tensor.dims;
135   return absl::OkStatus();
136 }
137 
AddOutput(const Node * node,int id)138 absl::Status ObjectReader::AddOutput(const Node* node, int id) {
139   if (node_->outputs->size <= id) {
140     return absl::InvalidArgumentError(absl::StrCat(
141         "Data id ", id, " must be less than tflite node outputs size ",
142         node_->outputs->size));
143   }
144   int output_tensor_idx = node_->outputs->data[id];
145   Value* value;
146   RETURN_IF_ERROR(ReadValueByTensorIdx(output_tensor_idx, &value));
147   RETURN_IF_ERROR(graph_->SetProducer(node->id, value->id));
148   return absl::OkStatus();
149 }
150 
AddOutputs(const Node * node)151 absl::Status ObjectReader::AddOutputs(const Node* node) {
152   for (int i = 0; i < node_->outputs->size; ++i) {
153     RETURN_IF_ERROR(AddOutput(node, i));
154   }
155   return absl::OkStatus();
156 }
157 
AddInput(const Node * node,uint32_t idx)158 absl::Status ObjectReader::AddInput(const Node* node, uint32_t idx) {
159   Value* input;
160   RETURN_IF_ERROR(ReadValue(idx, &input));
161   return graph_->AddConsumer(node->id, input->id);
162 }
163 
AddUpdate(const Node * node,uint32_t idx)164 absl::Status ObjectReader::AddUpdate(const Node* node, uint32_t idx) {
165   if (node_->inputs->size <= idx) {
166     return absl::InvalidArgumentError(absl::StrCat(
167         "Data id ", idx, " must be less than tflite node inputs size ",
168         node_->inputs->size));
169   }
170 
171   int update_tensor_idx = node_->inputs->data[idx];
172   TfLiteTensor* update_tensor = context_->tensors + update_tensor_idx;
173   if (!update_tensor->is_variable) {
174     return absl::InvalidArgumentError(
175         "The tensor must be a variable tensor to update it in place");
176   }
177 
178   Value* value;
179   RETURN_IF_ERROR(ReadValueByTensorIdx(update_tensor_idx, &value));
180   if (!value->tensor.is_variable_input) {
181     return absl::InternalError(
182         "Variable input tensor is not marked as variable");
183   }
184 
185   // We cannot create a cycle in the graph. The way around this when a node
186   // updates a tensor in place would be to add a new value to the graph that
187   // points to the same tensor.
188   Value* updated_value = graph_->NewValue();
189   updated_value->tensor = value->tensor;
190   updated_value->quant_params = value->quant_params;
191   RETURN_IF_ERROR(graph_->SetProducer(node->id, updated_value->id));
192 
193   // We also need to update the tensor_to_value arrays so that the nodes added
194   // after the current node will access the tensor with the updated value rather
195   // than the initial value.
196   if (quant_conversion_map_ != nullptr &&
197       quant_conversion_map_->find(update_tensor_idx) !=
198           quant_conversion_map_->end()) {
199     // If quantization conversion map exists, then the index provided is not the
200     // actual tensor idx. We need to find the float version of the tensor from
201     // the map.
202     tensor_to_value_->at(quant_conversion_map_->at(update_tensor_idx)) =
203         updated_value;
204   } else {
205     tensor_to_value_->at(update_tensor_idx) = updated_value;
206   }
207 
208   return absl::OkStatus();
209 }
210 
GetInputTensor(int index) const211 TfLiteTensor* ObjectReader::GetInputTensor(int index) const {
212   return index >= 0 && index < node_->inputs->size
213              ? context_->tensors + node_->inputs->data[index]
214              : nullptr;
215 }
216 
GetOutputTensor(int index) const217 TfLiteTensor* ObjectReader::GetOutputTensor(int index) const {
218   return index >= 0 && index < node_->outputs->size
219              ? context_->tensors + node_->outputs->data[index]
220              : nullptr;
221 }
222 
VerifyInputsConstsOutputs(const TfLiteNode * node,int runtime_inputs,int const_inputs,int outputs)223 absl::Status ObjectReader::VerifyInputsConstsOutputs(const TfLiteNode* node,
224                                                      int runtime_inputs,
225                                                      int const_inputs,
226                                                      int outputs) {
227   return CheckInputsConstsOutputs(context_, node, runtime_inputs, const_inputs,
228                                   outputs);
229 }
230 
231 }  // namespace gpu
232 }  // namespace tflite
233