• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/cl/inference_context.h"
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <cstdint>
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
29 #include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
30 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
31 #include "tensorflow/lite/delegates/gpu/common/memory_management.h"
32 #include "tensorflow/lite/delegates/gpu/common/model.h"
33 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
34 #include "tensorflow/lite/delegates/gpu/common/operations.h"
35 #include "tensorflow/lite/delegates/gpu/common/precision.h"
36 #include "tensorflow/lite/delegates/gpu/common/selectors/operation_selector.h"
37 #include "tensorflow/lite/delegates/gpu/common/selectors/special_selector.h"
38 #include "tensorflow/lite/delegates/gpu/common/shape.h"
39 #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
40 #include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h"
41 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
42 #include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h"
43 #include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
44 #include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
45 #include "tensorflow/lite/delegates/gpu/common/types.h"
46 #include "tensorflow/lite/delegates/gpu/common/util.h"
47 
48 namespace tflite {
49 namespace gpu {
50 namespace cl {
51 
52 namespace {
IsReady(const absl::flat_hash_set<ValueId> & ready_tensors,const CLNode & node)53 bool IsReady(const absl::flat_hash_set<ValueId>& ready_tensors,
54              const CLNode& node) {
55   for (const ValueId in_id : node.inputs) {
56     if (ready_tensors.find(in_id) == ready_tensors.end()) {
57       return false;
58     }
59   }
60   return true;
61 }
62 
GetCLNodeTensors(const CLNode & node)63 std::vector<std::pair<ValueId, TensorDescriptor>> GetCLNodeTensors(
64     const CLNode& node) {
65   std::vector<std::pair<ValueId, TensorDescriptor>> result;
66   result.reserve(node.inputs.size() + node.outputs.size());
67   const OperationDef op_def = node.cl_operation.GetDefinition();
68   for (int j = 0; j < node.inputs.size(); ++j) {
69     result.push_back({node.inputs[j], op_def.src_tensors[j]});
70   }
71   for (int j = 0; j < node.outputs.size(); ++j) {
72     result.push_back({node.outputs[j], op_def.dst_tensors[j]});
73   }
74 
75   return result;
76 }
77 
MergeCLNodes(CLNode * src,CLNode * dst)78 absl::Status MergeCLNodes(CLNode* src, CLNode* dst) {
79   for (int j = 1; j < src->inputs.size(); ++j) {
80     dst->inputs.push_back(src->inputs[j]);
81   }
82   dst->outputs[0] = src->outputs[0];
83   dst->name += " linked : " + src->name;
84   return dst->cl_operation.AddOperation(&src->cl_operation);
85 }
86 
AddUsage(ValueId id,int task_index,std::map<ValueId,int2> * usage_records)87 void AddUsage(ValueId id, int task_index,
88               std::map<ValueId, int2>* usage_records) {
89   auto it = usage_records->find(id);
90   if (it == usage_records->end()) {
91     (*usage_records)[id].x = task_index;
92     (*usage_records)[id].y = task_index;
93   } else {
94     (*usage_records)[id].y = task_index;
95   }
96 }
97 
98 // returns true if actual memory for this storage type will be allocated with
99 // clCreateBuffer.
IsBufferBased(const TensorStorageType & type)100 bool IsBufferBased(const TensorStorageType& type) {
101   return type == TensorStorageType::BUFFER ||
102          type == TensorStorageType::IMAGE_BUFFER;
103 }
104 
105 // Generic add is add that have several runtime inputs and they are not
106 // broadcasted, i.e. pointwise add for N tensors where N > 1.
IsGenericAdd(const Node & node,const std::vector<Value * > & inputs,const std::vector<Value * > & outputs)107 bool IsGenericAdd(const Node& node, const std::vector<Value*>& inputs,
108                   const std::vector<Value*>& outputs) {
109   if (inputs.size() == 1) {
110     return false;
111   }
112   const OperationType op_type = OperationTypeFromString(node.operation.type);
113   if (op_type != OperationType::ADD) {
114     return false;
115   }
116 
117   const auto dst_shape = outputs[0]->tensor.shape;
118   for (int i = 0; i < inputs.size(); ++i) {
119     const auto src_shape = inputs[i]->tensor.shape;
120     if (dst_shape.b != src_shape.b && src_shape.b == 1) {
121       return false;
122     }
123     if (dst_shape.h != src_shape.h && src_shape.h == 1) {
124       return false;
125     }
126     if (dst_shape.w != src_shape.w && src_shape.w == 1) {
127       return false;
128     }
129     if (dst_shape.c != src_shape.c && src_shape.c == 1) {
130       return false;
131     }
132   }
133   return true;
134 }
135 
136 }  // namespace
137 
InitFromGraph(const CreateInferenceInfo & create_info,const GraphFloat32 & graph,Environment * env,std::vector<uint8_t> * serialized_model)138 absl::Status InferenceContext::InitFromGraph(
139     const CreateInferenceInfo& create_info, const GraphFloat32& graph,
140     Environment* env, std::vector<uint8_t>* serialized_model) {
141   CreationContext creation_context;
142   creation_context.device = env->GetDevicePtr();
143   creation_context.context = &env->context();
144   creation_context.queue = env->queue();
145   creation_context.cache = env->program_cache();
146 
147   ReserveGraphTensors(create_info, creation_context.GetGpuInfo(), graph);
148   precision_ = create_info.precision;
149   storage_type_ = create_info.storage_type;
150   if (env->device().GetInfo().IsMali()) {
151     need_flush_ = true;
152     need_manual_release_ = true;
153 
154     flush_periodically_ = true;
155     flush_period_ = 24;
156   }
157   if (env->device().GetInfo().IsPowerVR()) {
158     need_flush_ = true;
159   }
160   CopyInAndOutIds(graph);
161   RETURN_IF_ERROR(ConvertOperations(creation_context.GetGpuInfo(), graph,
162                                     create_info.hints));
163   RETURN_IF_ERROR(Merge());
164   RETURN_IF_ERROR(AllocateMemory(creation_context.context));
165   BindMemoryToOperations();
166   RETURN_IF_ERROR(Compile(creation_context));
167   RETURN_IF_ERROR(UpdateParams());
168 
169   TuningType tuning_type = TuningType::kExhaustive;
170   if (create_info.hints.Check(ModelHints::kFastTuning)) {
171     tuning_type = TuningType::kFast;
172   }
173   if (env->device().GetInfo().IsMali()) {
174     const MaliInfo& info = env->device().GetInfo().mali_info;
175     if (info.IsMaliT6xx()) {
176       // Mali T628 hangs forever in clFinish when used profiling queue
177       // TuningType::FAST does not use profiling queue.
178       tuning_type = TuningType::kFast;
179     }
180   }
181   RETURN_IF_ERROR(
182       Tune(tuning_type, env->device().GetInfo(), env->profiling_queue()));
183 
184   if (serialized_model) {
185     for (auto& node : nodes_) {
186       node.cl_operation.MoveObjectRefsFromCLToGeneric();
187       node.cl_operation.SyncScalarValues();
188     }
189     flatbuffers::FlatBufferBuilder builder;
190     auto encoded_fb = Encode(*this, &builder);
191     data::FinishInferenceContextBuffer(builder, encoded_fb);
192     serialized_model->resize(builder.GetSize());
193     std::memcpy(serialized_model->data(), builder.GetBufferPointer(),
194                 builder.GetSize());
195     for (auto& node : nodes_) {
196       node.cl_operation.MoveObjectRefsFromGenericToCL();
197     }
198   }
199   ReleaseCPURepresentation();
200   return absl::OkStatus();
201 }
202 
RestoreDeserialized(const absl::Span<const uint8_t> serialized_model,Environment * env)203 absl::Status InferenceContext::RestoreDeserialized(
204     const absl::Span<const uint8_t> serialized_model, Environment* env) {
205   flatbuffers::Verifier verifier(serialized_model.data(),
206                                  serialized_model.size());
207   if (!data::VerifyInferenceContextBuffer(verifier)) {
208     return absl::DataLossError("Deserialization failed.");
209   }
210   auto decoded_fb = data::GetInferenceContext(serialized_model.data());
211   RETURN_IF_ERROR(Decode(decoded_fb, this));
212 
213   CreationContext creation_context;
214   creation_context.device = env->GetDevicePtr();
215   creation_context.context = &env->context();
216   creation_context.queue = env->queue();
217   creation_context.cache = env->program_cache();
218 
219   RETURN_IF_ERROR(AllocateMemory(creation_context.context));
220   BindMemoryToOperations();
221   for (auto& node : nodes_) {
222     RETURN_IF_ERROR(node.cl_operation.CompileDeserialized(creation_context));
223   }
224   RETURN_IF_ERROR(UpdateParams());
225   ReleaseCPURepresentation();
226   return absl::OkStatus();
227 }
228 
InitFromGraphWithTransforms(const CreateInferenceInfo & create_info,GraphFloat32 * graph,Environment * env,std::vector<uint8_t> * serialized_model)229 absl::Status InferenceContext::InitFromGraphWithTransforms(
230     const CreateInferenceInfo& create_info, GraphFloat32* graph,
231     Environment* env, std::vector<uint8_t>* serialized_model) {
232   RETURN_IF_ERROR(RunGraphTransforms(graph));
233   RETURN_IF_ERROR(InitFromGraph(create_info, *graph, env, serialized_model));
234   return absl::OkStatus();
235 }
236 
CopyInAndOutIds(const GraphFloat32 & graph)237 void InferenceContext::CopyInAndOutIds(const GraphFloat32& graph) {
238   const auto inputs = graph.inputs();
239   for (const auto& input : inputs) {
240     input_ids_.push_back(input->id);
241   }
242 
243   const auto variable_inputs = graph.variable_inputs();
244   for (const auto& variable_input : variable_inputs) {
245     variable_ids_and_refs_[variable_input->id] = variable_input->tensor.ref;
246   }
247 
248   const auto outputs = graph.outputs();
249   for (const auto& output : outputs) {
250     output_ids_.push_back(output->id);
251   }
252 
253   in_refs_.resize(inputs.size());
254   out_refs_.resize(outputs.size());
255   for (int i = 0; i < inputs.size(); ++i) {
256     in_refs_[i] = inputs[i]->tensor.ref;
257   }
258   for (int i = 0; i < outputs.size(); ++i) {
259     out_refs_[i] = outputs[i]->tensor.ref;
260   }
261 }
262 
ReserveGraphTensors(const CreateInferenceInfo & create_info,const GpuInfo & gpu_info,const GraphFloat32 & graph)263 void InferenceContext::ReserveGraphTensors(
264     const CreateInferenceInfo& create_info, const GpuInfo& gpu_info,
265     const GraphFloat32& graph) {
266   ValueId max_id = 0;
267   auto tensors = graph.values();
268   auto data_type = DeduceDataTypeFromPrecision(create_info.precision);
269   for (auto& t : tensors) {
270     TensorStorageType storage_type = create_info.storage_type;
271     const auto shape = graph.GetValue(t->id)->tensor.shape;
272     Layout layout = shape.b == 1 ? Layout::HWC : Layout::BHWC;
273     if (graph.IsGraphInput(t->id) || graph.IsGraphOutput(t->id)) {
274       if (shape.c < 4 &&
275           CanCreateTensorWithShape(
276               gpu_info, shape,
277               TensorDescriptor{data_type, TensorStorageType::SINGLE_TEXTURE_2D,
278                                layout})) {
279         storage_type = TensorStorageType::SINGLE_TEXTURE_2D;
280       }
281     }
282     storage_type =
283         SelectBestStorageType(gpu_info, shape, storage_type, data_type, layout);
284     tensor_reserver_.Add(
285         t->id, {shape, TensorDescriptor{data_type, storage_type, layout}});
286     max_id = std::max(max_id, t->id);
287   }
288   tensor_reserver_.SetNext(max_id + 1);
289 }
290 
ConvertOperations(const GpuInfo & gpu_info,const GraphFloat32 & graph,ModelHints hints)291 absl::Status InferenceContext::ConvertOperations(const GpuInfo& gpu_info,
292                                                  const GraphFloat32& graph,
293                                                  ModelHints hints) {
294   std::map<ValueId, TensorDescriptor> tensor_descriptors;
295   const auto values = graph.values();
296   for (auto value : values) {
297     tensor_descriptors[value->id] = tensor_reserver_.Get(value->id).descriptor;
298   }
299   std::set<NodeId> consumed_nodes;
300   std::vector<Node*> graph_nodes = graph.nodes();
301   std::map<ValueId, int>
302       tensor_usages;  // keeps latest index of operation that updated tensor
303   for (const auto& input_id : input_ids_) {
304     tensor_usages[input_id] = -1;  // so as inputs "updated" before operation 0,
305                                    // we will mark them with -1
306   }
307   for (int i = 0; i < graph_nodes.size(); ++i) {
308     const Node& node = *graph_nodes[i];
309     if (consumed_nodes.find(node.id) != consumed_nodes.end()) {
310       continue;
311     }
312     auto op_type = OperationTypeFromString(node.operation.type);
313     if (op_type == OperationType::CONSTANT) {
314       auto attr =
315           absl::any_cast<ConstTensorAttributes>(node.operation.attributes);
316       auto outputs = graph.FindOutputs(node.id);
317       const_tensors_descs_[outputs[0]->id] =
318           tensor_reserver_.Get(outputs[0]->id).descriptor;
319       const_tensors_descs_[outputs[0]->id].UploadData(attr.tensor);
320       continue;
321     }
322     std::string op_name = node.operation.type + " " + std::to_string(node.id);
323     GPUOperationsSubgraph gpu_subgraph;
324     if (hints.Check(ModelHints::kAllowSpecialKernels) &&
325         GPUSubgraphFromGraph(gpu_info, precision_, graph, node.id,
326                              tensor_descriptors, &consumed_nodes, &gpu_subgraph,
327                              &op_name)
328             .ok()) {
329       // Mapping of subgraph (set of nodes) to GPU operations. Should happen
330       // before straigtforward mapping.
331     } else {
332       // Straigtforward mapping of one graph node to GPU operations.
333       auto inputs = graph.FindInputs(node.id);
334       auto outputs = graph.FindOutputs(node.id);
335       // Reordering of input ids and updating of temporary tensors_usage struct.
336       // This stage is necessary because we are building OperationDef that rely
337       // on order of input ids. But we also should have input id on first
338       // position that potentially can be "linking" tensor and as result
339       // eliminated(unused) We apply it only for ADD operation, because of ADD
340       // associativity and ADD can be linked. In current approach "linking"
341       // tensor can be only latest written tensor(during linear order of
342       // execution) among input tensors.
343       if (IsGenericAdd(node, inputs, outputs)) {
344         int latest_written_tensor_index = 0;
345         int last_usage = tensor_usages[inputs[0]->id];
346         for (int j = 1; j < inputs.size(); ++j) {
347           if (tensor_usages[inputs[j]->id] > last_usage) {
348             last_usage = tensor_usages[inputs[j]->id];
349             latest_written_tensor_index = j;
350           }
351         }
352         std::swap(inputs[0], inputs[latest_written_tensor_index]);
353       }
354       consumed_nodes.insert(node.id);
355       OperationDef op_def;
356       op_def.precision = precision_;
357       for (int j = 0; j < inputs.size(); ++j) {
358         op_def.src_tensors.push_back(
359             tensor_reserver_.Get(inputs[j]->id).descriptor);
360       }
361       for (int j = 0; j < outputs.size(); ++j) {
362         op_def.dst_tensors.push_back(
363             tensor_reserver_.Get(outputs[j]->id).descriptor);
364       }
365       RETURN_IF_ERROR(GPUOperationFromNode(gpu_info, op_def, hints, inputs,
366                                            outputs, node, &gpu_subgraph));
367     }
368     absl::flat_hash_map<int, ValueId> mapping_to_global_ids;
369     for (int j = 0; j < gpu_subgraph.new_tensors.size(); ++j) {
370       const auto& t = gpu_subgraph.new_tensors[j];
371       auto global_id = tensor_reserver_.Add({t.first, t.second});
372       mapping_to_global_ids[j] = global_id;
373     }
374     for (auto& gpu_op : gpu_subgraph.operations) {
375       CLNode cl_node;
376       cl_node.cl_operation.Init(std::move(gpu_op.operation));
377       cl_node.inputs.resize(gpu_op.input_ids.size());
378       for (int j = 0; j < gpu_op.input_ids.size(); ++j) {
379         int id = gpu_op.input_ids[j];
380         if (id >= 0) {
381           cl_node.inputs[j] = id;
382         } else {
383           cl_node.inputs[j] = mapping_to_global_ids[-(id + 1)];
384         }
385       }
386       cl_node.outputs.resize(gpu_op.output_ids.size());
387       for (int j = 0; j < gpu_op.output_ids.size(); ++j) {
388         int id = gpu_op.output_ids[j];
389         if (id >= 0) {
390           cl_node.outputs[j] = id;
391           tensor_usages[id] = i;
392         } else {
393           cl_node.outputs[j] = mapping_to_global_ids[-(id + 1)];
394         }
395       }
396       cl_node.name = op_name;
397       nodes_.push_back(std::move(cl_node));
398     }
399   }
400 
401   return absl::OkStatus();
402 }
403 
Merge()404 absl::Status InferenceContext::Merge() {
405   absl::flat_hash_set<ValueId> ready_tensors;
406   for (const auto& input_id : input_ids_) {
407     ready_tensors.insert(input_id);
408   }
409   for (int i = 0; i < nodes_.size(); ++i) {
410     auto& node = nodes_[i];
411     for (const auto& out_id : node.outputs) {
412       ready_tensors.insert(out_id);
413     }
414     if (node.outputs.size() != 1) {
415       continue;
416     }
417     std::vector<int> next_nodes;
418     int link_index = 0;
419     for (int j = i + 1; j < nodes_.size(); ++j) {
420       for (int k = 0; k < nodes_[j].inputs.size(); ++k) {
421         if (nodes_[j].inputs[k] == node.outputs[0]) {
422           next_nodes.push_back(j);
423           link_index = k;
424         }
425       }
426     }
427     if (next_nodes.size() != 1 || link_index != 0) {
428       continue;
429     }
430     auto& linkable_node = nodes_[next_nodes[0]];
431     if (!linkable_node.cl_operation.GetGpuOperation().IsLinkable() ||
432         linkable_node.outputs.size() != 1 ||
433         !IsReady(ready_tensors, linkable_node)) {
434       continue;
435     }
436     const auto& original_dst_def =
437         node.cl_operation.GetDefinition().dst_tensors[0];
438     const auto& link_dst_def =
439         linkable_node.cl_operation.GetDefinition().dst_tensors[0];
440     if (original_dst_def != link_dst_def) {
441       continue;
442     }
443     RETURN_IF_ERROR(MergeCLNodes(&linkable_node, &node));
444     nodes_.erase(nodes_.begin() + next_nodes[0]);
445     i -= 1;
446   }
447   return absl::OkStatus();
448 }
449 
GetUsages(const std::function<bool (ValueId)> & functor,std::map<ValueId,int2> * usages)450 void InferenceContext::GetUsages(const std::function<bool(ValueId)>& functor,
451                                  std::map<ValueId, int2>* usages) {
452   for (ValueId in_id : input_ids_) {
453     if (functor(in_id)) {
454       AddUsage(in_id, 0, usages);
455     }
456   }
457   for (int op_index = 0; op_index < nodes_.size(); ++op_index) {
458     auto tensors = GetCLNodeTensors(nodes_[op_index]);
459     for (auto& tensor : tensors) {
460       if (functor(tensor.first)) {
461         AddUsage(tensor.first, op_index, usages);
462       }
463     }
464   }
465   for (ValueId out_id : output_ids_) {
466     if (functor(out_id)) {
467       AddUsage(out_id, nodes_.size(), usages);
468     }
469   }
470 }
471 
GetTensorMemoryType(ValueId id)472 InferenceContext::TensorMemoryType InferenceContext::GetTensorMemoryType(
473     ValueId id) {
474   if (const_tensors_.find(id) != const_tensors_.end()) {
475     return TensorMemoryType::kConst;
476   } else if (variable_ids_and_refs_.find(id) != variable_ids_and_refs_.end()) {
477     return TensorMemoryType::kVariable;
478   } else if (IsBufferBased(tensor_reserver_.Get(id).descriptor.storage_type)) {
479     return TensorMemoryType::kBuffer;
480   } else {
481     return TensorMemoryType::kStrongShape;
482   }
483 }
484 
AllocateMemory(CLContext * context)485 absl::Status InferenceContext::AllocateMemory(CLContext* context) {
486   RETURN_IF_ERROR(AllocateMemoryForConstTensors(context));
487   RETURN_IF_ERROR(AllocateMemoryForVariableTensors(context));
488   RETURN_IF_ERROR(AllocateMemoryForBuffers(context));
489   RETURN_IF_ERROR(AllocateMemoryForStrongShapes(context));
490   return absl::OkStatus();
491 }
492 
AllocateMemoryForConstTensors(CLContext * context)493 absl::Status InferenceContext::AllocateMemoryForConstTensors(
494     CLContext* context) {
495   for (auto& description : const_tensors_descs_) {
496     RETURN_IF_ERROR(const_tensors_[description.first].CreateFromDescriptor(
497         description.second, context));
498   }
499   return absl::OkStatus();
500 }
501 
AllocateMemoryForVariableTensors(CLContext * context)502 absl::Status InferenceContext::AllocateMemoryForVariableTensors(
503     CLContext* context) {
504   std::map<ValueId, int> ref_value_to_tensor_index;
505 
506   for (auto value_and_ref_value : variable_ids_and_refs_) {
507     if (ref_value_to_tensor_index.find(value_and_ref_value.second) ==
508         ref_value_to_tensor_index.end()) {
509       const auto& t = tensor_reserver_.Get(value_and_ref_value.first);
510       const auto& shape = t.shape;
511       const auto& descriptor = t.descriptor;
512 
513       RETURN_IF_ERROR(
514           CreateTensor(*context, shape, descriptor,
515                        &variable_tensors_[value_and_ref_value.second]));
516     }
517   }
518   return absl::OkStatus();
519 }
520 
AllocateMemoryForBuffers(CLContext * context)521 absl::Status InferenceContext::AllocateMemoryForBuffers(CLContext* context) {
522   std::map<ValueId, int2> buffer_usages;
523   GetUsages(
524       [this](ValueId id) {
525         return GetTensorMemoryType(id) == TensorMemoryType::kBuffer;
526       },
527       &buffer_usages);
528 
529   std::vector<TensorUsageRecord<size_t>> buffer_usage_records;
530   for (auto& usage : buffer_usages) {
531     const auto& t = tensor_reserver_.Get(usage.first);
532     const auto& shape = t.shape;
533     const auto& descriptor = t.descriptor;
534     const size_t element_size =
535         descriptor.data_type == DataType::FLOAT32 ? 4 : 2;
536     const size_t buffer_size =
537         shape.b * shape.w * shape.h * AlignByN(shape.c, 4) * element_size;
538     graph_ids_to_shared_buffer_tensors_[usage.first] =
539         buffer_usage_records.size();
540     buffer_usage_records.push_back({buffer_size,
541                                     static_cast<TaskId>(usage.second.x),
542                                     static_cast<TaskId>(usage.second.y)});
543   }
544 
545   ObjectsAssignment<size_t> buffer_assignment;
546   RETURN_IF_ERROR(AssignObjectsToTensors(
547       buffer_usage_records, MemoryStrategy::GREEDY_BEST, &buffer_assignment));
548 
549   shared_buffers_.resize(buffer_assignment.object_sizes.size());
550   for (int i = 0; i < buffer_assignment.object_sizes.size(); ++i) {
551     RETURN_IF_ERROR(CreateReadWriteBuffer(buffer_assignment.object_sizes[i],
552                                           context, &shared_buffers_[i]));
553   }
554 
555   std::vector<bool> created_tensors(buffer_usage_records.size(), false);
556   shared_buffer_tensors_.resize(buffer_usage_records.size());
557   for (auto& node : nodes_) {
558     auto tensors = GetCLNodeTensors(node);
559     for (auto& t : tensors) {
560       if (GetTensorMemoryType(t.first) != TensorMemoryType::kBuffer) continue;
561       const int tensor_index = graph_ids_to_shared_buffer_tensors_[t.first];
562       if (created_tensors[tensor_index]) continue;
563       const auto& shape = tensor_reserver_.Get(t.first).shape;
564       const int buffer_index = buffer_assignment.object_ids[tensor_index];
565       RETURN_IF_ERROR(CreateSharedTensor(
566           *context, shared_buffers_[buffer_index].GetMemoryPtr(), shape,
567           t.second, &shared_buffer_tensors_[tensor_index]));
568       created_tensors[tensor_index] = true;
569     }
570   }
571   return absl::OkStatus();
572 }
573 
AllocateMemoryForStrongShapes(CLContext * context)574 absl::Status InferenceContext::AllocateMemoryForStrongShapes(
575     CLContext* context) {
576   std::map<ValueId, int2> usages;
577   GetUsages(
578       [this](ValueId id) {
579         return GetTensorMemoryType(id) == TensorMemoryType::kStrongShape;
580       },
581       &usages);
582 
583   std::vector<TensorUsageRecord<DummyTensor>> usage_records;
584   std::map<ValueId, ValueId> remap_from_graph_ids;
585   for (auto& usage : usages) {
586     remap_from_graph_ids[usage.first] = usage_records.size();
587     usage_records.push_back({tensor_reserver_.Get(usage.first),
588                              static_cast<TaskId>(usage.second.x),
589                              static_cast<TaskId>(usage.second.y)});
590   }
591 
592   ObjectsAssignment<DummyTensor> assignment;
593   RETURN_IF_ERROR(AssignObjectsToTensors(
594       usage_records, MemoryStrategy::EQUALITY, &assignment));
595 
596   for (auto& node : nodes_) {
597     auto tensors = GetCLNodeTensors(node);
598     for (auto& t : tensors) {
599       if (GetTensorMemoryType(t.first) != TensorMemoryType::kStrongShape) {
600         continue;
601       }
602       const auto& shape = tensor_reserver_.Get(t.first).shape;
603       const auto id = assignment.object_ids[remap_from_graph_ids[t.first]];
604       graph_ids_to_strong_shape_tensors_[t.first] = id;
605       const auto& it = strong_shape_tensors_.find(id);
606       if (it == strong_shape_tensors_.end()) {
607         RETURN_IF_ERROR(CreateTensor(*context, shape, t.second,
608                                      &strong_shape_tensors_[id]));
609       }
610     }
611   }
612   return absl::OkStatus();
613 }
614 
BindMemoryToOperations()615 void InferenceContext::BindMemoryToOperations() {
616   for (auto& node : nodes_) {
617     for (int i = 0; i < node.inputs.size(); ++i) {
618       node.cl_operation.GetGpuOperation().SetSrc(GetTensor(node.inputs[i]), i);
619     }
620     for (int i = 0; i < node.outputs.size(); ++i) {
621       node.cl_operation.GetGpuOperation().SetDst(GetTensor(node.outputs[i]), i);
622     }
623   }
624 }
625 
Compile(const CreationContext & creation_context)626 absl::Status InferenceContext::Compile(
627     const CreationContext& creation_context) {
628   for (auto& node : nodes_) {
629     RETURN_IF_ERROR(node.cl_operation.Compile(creation_context));
630   }
631   return absl::OkStatus();
632 }
633 
Tune(TuningType tuning_type,const GpuInfo & gpu_info,ProfilingCommandQueue * profiling_queue)634 absl::Status InferenceContext::Tune(TuningType tuning_type,
635                                     const GpuInfo& gpu_info,
636                                     ProfilingCommandQueue* profiling_queue) {
637   for (auto& node : nodes_) {
638     RETURN_IF_ERROR(
639         node.cl_operation.Tune(tuning_type, gpu_info, profiling_queue));
640   }
641   return absl::OkStatus();
642 }
643 
UpdateParams()644 absl::Status InferenceContext::UpdateParams() {
645   for (auto& node : nodes_) {
646     RETURN_IF_ERROR(node.cl_operation.UpdateParams());
647   }
648   return absl::OkStatus();
649 }
650 
AddToQueue(CLCommandQueue * queue)651 absl::Status InferenceContext::AddToQueue(CLCommandQueue* queue) {
652   if (need_manual_release_) {
653     if (prev_enqueue_start_point_.is_valid()) {
654       prev_enqueue_start_point_.Wait();
655     }
656     RETURN_IF_ERROR(queue->EnqueueEvent(&prev_enqueue_start_point_));
657   }
658   int counter = 0;
659   for (auto& node : nodes_) {
660     RETURN_IF_ERROR(node.cl_operation.AddToQueue(queue));
661     counter++;
662     if (flush_periodically_ && counter % flush_period_ == 0) {
663       clFlush(queue->queue());
664     }
665   }
666   if (need_flush_) {
667     clFlush(queue->queue());
668   }
669   return absl::OkStatus();
670 }
671 
Profile(ProfilingCommandQueue * queue,ProfilingInfo * result)672 absl::Status InferenceContext::Profile(ProfilingCommandQueue* queue,
673                                        ProfilingInfo* result) {
674   queue->ResetMeasurements();
675   for (auto& node : nodes_) {
676     queue->SetEventsLabel(node.name);
677     RETURN_IF_ERROR(node.cl_operation.AddToQueue(queue));
678   }
679   RETURN_IF_ERROR(queue->WaitForCompletion());
680   *result = queue->GetProfilingInfo();
681   return absl::OkStatus();
682 }
683 
GetSizeOfMemoryAllocatedForIntermediateTensors() const684 uint64_t InferenceContext::GetSizeOfMemoryAllocatedForIntermediateTensors()
685     const {
686   uint64_t total_memory = 0;
687   for (const auto& t : strong_shape_tensors_) {
688     total_memory += t.second.GetMemorySizeInBytes();
689   }
690   for (const auto& b : shared_buffers_) {
691     total_memory += b.GetMemorySizeInBytes();
692   }
693   for (const auto& t : variable_tensors_) {
694     total_memory += t.second.GetMemorySizeInBytes();
695   }
696 
697   return total_memory;
698 }
699 
GetTensor(ValueId id)700 Tensor* InferenceContext::GetTensor(ValueId id) {
701   if (const_tensors_.find(id) != const_tensors_.end()) {
702     return &const_tensors_[id];
703   } else if (variable_ids_and_refs_.find(id) != variable_ids_and_refs_.end()) {
704     return &variable_tensors_[variable_ids_and_refs_[id]];
705   } else if (graph_ids_to_shared_buffer_tensors_.find(id) !=
706              graph_ids_to_shared_buffer_tensors_.end()) {
707     return &shared_buffer_tensors_[graph_ids_to_shared_buffer_tensors_[id]];
708   } else {
709     return &strong_shape_tensors_[graph_ids_to_strong_shape_tensors_[id]];
710   }
711 }
712 
SetInputTensor(ValueId id,const TensorFloat32 & tensor,CLCommandQueue * queue)713 absl::Status InferenceContext::SetInputTensor(ValueId id,
714                                               const TensorFloat32& tensor,
715                                               CLCommandQueue* queue) {
716   return GetTensor(id)->WriteData(queue, tensor);
717 }
718 
GetOutputTensor(ValueId id,CLCommandQueue * queue,TensorFloat32 * result)719 absl::Status InferenceContext::GetOutputTensor(ValueId id,
720                                                CLCommandQueue* queue,
721                                                TensorFloat32* result) {
722   const auto& gpu_tensor = *GetTensor(id);
723   const auto dst_shape = BHWC(gpu_tensor.Batch(), gpu_tensor.Height(),
724                               gpu_tensor.Width(), gpu_tensor.Channels());
725   result->id = id;
726   result->shape = dst_shape;
727   result->data.resize(dst_shape.DimensionsProduct());
728   return gpu_tensor.ReadData(queue, result);
729 }
730 
ReleaseCPURepresentation()731 void InferenceContext::ReleaseCPURepresentation() {
732   for (auto& node : nodes_) {
733     node.cl_operation.GetGpuOperation().args_.ReleaseCPURepresentation();
734   }
735   const_tensors_descs_.clear();
736 }
737 
RunGraphTransforms(GraphFloat32 * graph)738 absl::Status RunGraphTransforms(GraphFloat32* graph) {
739   auto merge_padding_transform = NewMergePaddingWithAdd();
740   auto add_bias_transform = NewAddBias();
741   auto pooling_to_reduce_op = NewGlobalPoolingToReduceOp();
742   ModelTransformer transformer(graph, /*reporter=*/nullptr);
743   if (!transformer.Apply("add_bias", add_bias_transform.get())) {
744     return absl::InternalError("Invalid add_bias transform");
745   }
746   if (!transformer.Apply("merge_padding", merge_padding_transform.get())) {
747     return absl::InternalError("Invalid merge_padding transform");
748   }
749   if (!transformer.Apply("global pooling to mean",
750                          pooling_to_reduce_op.get())) {
751     return absl::InternalError("Invalid global pooling to mean transform");
752   }
753   return absl::OkStatus();
754 }
755 
756 }  // namespace cl
757 }  // namespace gpu
758 }  // namespace tflite
759