• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/object_reader.h"
17 
18 #include <cstdint>
19 #include <optional>
20 #include <string>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/delegates/gpu/common/model.h"
26 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
27 #include "tensorflow/lite/delegates/gpu/common/status.h"
28 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
29 #include "tensorflow/lite/delegates/utils.h"
30 #include "tensorflow/lite/kernels/kernel_util.h"
31 
32 namespace tflite {
33 namespace gpu {
34 
ReadNonConstantTensor(TfLiteContext * context,absl::flat_hash_map<int,Value * > * tensor_to_value,absl::flat_hash_map<int,int> * quant_conversion_map,GraphFloat32 * graph,uint32_t tensor_idx,Value ** value)35 absl::Status ObjectReader::ReadNonConstantTensor(
36     TfLiteContext* context, absl::flat_hash_map<int, Value*>* tensor_to_value,
37     absl::flat_hash_map<int, int>* quant_conversion_map, GraphFloat32* graph,
38     uint32_t tensor_idx, Value** value) {
39   if (tensor_idx >= context->tensors_size) {
40     return absl::OutOfRangeError(
41         absl::StrCat("ReadNonConstTensor: input tensor index: ", tensor_idx));
42   }
43 
44   if (tensor_to_value->find(tensor_idx) == tensor_to_value->end()) {
45     TfLiteTensor* tflite_tensor = &context->tensors[tensor_idx];
46     if (tflite::IsConstantTensor(tflite_tensor)) {
47       return absl::InvalidArgumentError(absl::StrCat(
48           "ReadNonConstantTensor: value is a constant tensor: ", tensor_idx));
49     }
50 
51     if ((tflite_tensor->type == kTfLiteInt8 ||
52          tflite_tensor->type == kTfLiteUInt8) &&
53         quant_conversion_map) {
54       // Quantized case
55       if (quant_conversion_map->find(tensor_idx) ==
56           quant_conversion_map->end()) {
57         // Since the original tensor is fixed-point, add a new float tensor to
58         // the TFLite graph to represent the dequantized data.
59         int fp_tensor_index = 0;
60         TfLiteTensor* fp_tflite_tensor;
61         if (delegates::CreateNewTensorWithDifferentType(
62                 context, tensor_idx, kTfLiteFloat32, &fp_tflite_tensor,
63                 &fp_tensor_index) != kTfLiteOk) {
64           return absl::InternalError("Could not add new tensor to graph");
65         }
66         // `tflite_tensor` value could be invalid when the `context->tensors`
67         // is reallocated. Thus reassigning `tflite_tensor` with a fresh value.
68         tflite_tensor = &context->tensors[tensor_idx];
69 
70         // Remember this tensor for later.
71         (*quant_conversion_map)[fp_tensor_index] = tensor_idx;
72         (*quant_conversion_map)[tensor_idx] = fp_tensor_index;
73         // Add a new GPU Value for the new dequantized floating-point tensor.
74         Value* value = graph->NewValue();
75         RETURN_IF_ERROR(
76             ConvertTfLiteTensorToTensorRef(*fp_tflite_tensor, &value->tensor));
77         value->tensor.ref = fp_tensor_index;
78         value->tensor.is_variable_input = tflite_tensor->is_variable;
79         value->quant_params.emplace();
80         RETURN_IF_ERROR(
81             PopulateQuantParams(*tflite_tensor, &value->quant_params.value()));
82         (*tensor_to_value)[fp_tensor_index] = value;
83       }
84       // We do not use the original tensor index as reference for the GPU
85       // Value, instead pointing at the corresponding float version.
86       tensor_idx = quant_conversion_map->at(tensor_idx);
87     } else {
88       // Floating-point case.
89       Value* value = graph->NewValue();
90       RETURN_IF_ERROR(
91           ConvertTfLiteTensorToTensorRef(*tflite_tensor, &value->tensor));
92       value->tensor.ref = tensor_idx;
93       value->tensor.is_variable_input = tflite_tensor->is_variable;
94       (*tensor_to_value)[tensor_idx] = value;
95     }
96   }
97 
98   if (value) {
99     *value = (*tensor_to_value)[tensor_idx];
100   }
101   return absl::OkStatus();
102 }
103 
ReadValue(uint32_t idx,Value ** value)104 absl::Status ObjectReader::ReadValue(uint32_t idx, Value** value) {
105   if (idx >= node_->inputs->size) {
106     return absl::OutOfRangeError(
107         absl::StrCat("ReadValue: input tensor index: ", idx));
108   }
109   return ReadValueByTensorIdx(node_->inputs->data[idx], value);
110 }
111 
ReadValueByTensorIdx(uint32_t tensor_idx,Value ** value)112 absl::Status ObjectReader::ReadValueByTensorIdx(uint32_t tensor_idx,
113                                                 Value** value) {
114   // Constant tensors should be handled by ReadTensor.
115   return ReadNonConstantTensor(context_, tensor_to_value_,
116                                quant_conversion_map_, graph_, tensor_idx,
117                                value);
118 }
119 
GetNumberOfRuntimeInputs() const120 int ObjectReader::GetNumberOfRuntimeInputs() const {
121   return GetNumberOfRuntimeInputsForNode(context_, node_);
122 }
123 
GetTensorId(uint32_t input_id,int * tensor_id) const124 absl::Status ObjectReader::GetTensorId(uint32_t input_id,
125                                        int* tensor_id) const {
126   if (input_id >= node_->inputs->size) {
127     return absl::OutOfRangeError(
128         absl::StrCat("Input tensor index: ", input_id));
129   }
130   *tensor_id = node_->inputs->data[input_id];
131   if (*tensor_id < 0 || *tensor_id > context_->tensors_size) {
132     return absl::OutOfRangeError(absl::StrCat("Tensor index: ", *tensor_id));
133   }
134   return absl::OkStatus();
135 }
136 
GetTensorDims(uint32_t idx,TfLiteIntArray * dimensions) const137 absl::Status ObjectReader::GetTensorDims(uint32_t idx,
138                                          TfLiteIntArray* dimensions) const {
139   if (idx >= node_->inputs->size) {
140     return absl::OutOfRangeError(absl::StrCat("Input tensor index: ", idx));
141   }
142   const int tensor_idx = node_->inputs->data[idx];
143   if (tensor_idx < 0 || tensor_idx > context_->tensors_size) {
144     return absl::OutOfRangeError(absl::StrCat("Tensor index: ", tensor_idx));
145   }
146   const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx];
147   *dimensions = *tflite_tensor.dims;
148   return absl::OkStatus();
149 }
150 
AddOutput(const Node * node,int id)151 absl::Status ObjectReader::AddOutput(const Node* node, int id) {
152   if (node_->outputs->size <= id) {
153     return absl::InvalidArgumentError(absl::StrCat(
154         "Data id ", id, " must be less than tflite node outputs size ",
155         node_->outputs->size));
156   }
157   int output_tensor_idx = node_->outputs->data[id];
158   Value* value;
159   RETURN_IF_ERROR(ReadValueByTensorIdx(output_tensor_idx, &value));
160   RETURN_IF_ERROR(graph_->SetProducer(node->id, value->id));
161   return absl::OkStatus();
162 }
163 
AddOutputs(const Node * node)164 absl::Status ObjectReader::AddOutputs(const Node* node) {
165   for (int i = 0; i < node_->outputs->size; ++i) {
166     RETURN_IF_ERROR(AddOutput(node, i));
167   }
168   return absl::OkStatus();
169 }
170 
AddInput(const Node * node,uint32_t idx)171 absl::Status ObjectReader::AddInput(const Node* node, uint32_t idx) {
172   Value* input;
173   RETURN_IF_ERROR(ReadValue(idx, &input));
174   return graph_->AddConsumer(node->id, input->id);
175 }
176 
AddUpdate(const Node * node,uint32_t idx)177 absl::Status ObjectReader::AddUpdate(const Node* node, uint32_t idx) {
178   if (node_->inputs->size <= idx) {
179     return absl::InvalidArgumentError(absl::StrCat(
180         "Data id ", idx, " must be less than tflite node inputs size ",
181         node_->inputs->size));
182   }
183 
184   int update_tensor_idx = node_->inputs->data[idx];
185   TfLiteTensor* update_tensor = context_->tensors + update_tensor_idx;
186   if (!update_tensor->is_variable) {
187     return absl::InvalidArgumentError(
188         "The tensor must be a variable tensor to update it in place");
189   }
190 
191   Value* value;
192   RETURN_IF_ERROR(ReadValueByTensorIdx(update_tensor_idx, &value));
193   if (!value->tensor.is_variable_input) {
194     return absl::InternalError(
195         "Variable input tensor is not marked as variable");
196   }
197 
198   // We cannot create a cycle in the graph. The way around this when a node
199   // updates a tensor in place would be to add a new value to the graph that
200   // points to the same tensor.
201   Value* updated_value = graph_->NewValue();
202   updated_value->tensor = value->tensor;
203   updated_value->quant_params = value->quant_params;
204   RETURN_IF_ERROR(graph_->SetProducer(node->id, updated_value->id));
205 
206   // We also need to update the tensor_to_value arrays so that the nodes added
207   // after the current node will access the tensor with the updated value rather
208   // than the initial value.
209   if (quant_conversion_map_ != nullptr &&
210       quant_conversion_map_->find(update_tensor_idx) !=
211           quant_conversion_map_->end()) {
212     // If quantization conversion map exists, then the index provided is not the
213     // actual tensor idx. We need to find the float version of the tensor from
214     // the map.
215     tensor_to_value_->at(quant_conversion_map_->at(update_tensor_idx)) =
216         updated_value;
217   } else {
218     tensor_to_value_->at(update_tensor_idx) = updated_value;
219   }
220 
221   return absl::OkStatus();
222 }
223 
GetInputTensor(int index) const224 TfLiteTensor* ObjectReader::GetInputTensor(int index) const {
225   return index >= 0 && index < node_->inputs->size
226              ? context_->tensors + node_->inputs->data[index]
227              : nullptr;
228 }
229 
GetOutputTensor(int index) const230 TfLiteTensor* ObjectReader::GetOutputTensor(int index) const {
231   return index >= 0 && index < node_->outputs->size
232              ? context_->tensors + node_->outputs->data[index]
233              : nullptr;
234 }
235 
VerifyInputsConstsOutputs(const TfLiteNode * node,int runtime_inputs,int const_inputs,int outputs)236 absl::Status ObjectReader::VerifyInputsConstsOutputs(const TfLiteNode* node,
237                                                      int runtime_inputs,
238                                                      int const_inputs,
239                                                      int outputs) {
240   return CheckInputsConstsOutputs(context_, node, runtime_inputs, const_inputs,
241                                   outputs);
242 }
243 
244 }  // namespace gpu
245 }  // namespace tflite
246