1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
17
18 #include <map>
19 #include <string>
20 #include <vector>
21
22 #include "absl/strings/substitute.h"
23 #include "absl/time/clock.h"
24 #include "tensorflow/lite/delegates/gpu/common/memory_management.h"
25 #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
26 #include "tensorflow/lite/delegates/gpu/common/model.h"
27 #include "tensorflow/lite/delegates/gpu/common/operations.h"
28 #include "tensorflow/lite/delegates/gpu/common/precision.h"
29 #include "tensorflow/lite/delegates/gpu/common/selectors/operation_selector.h"
30 #include "tensorflow/lite/delegates/gpu/common/selectors/special_selector.h"
31 #include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h"
32 #include "tensorflow/lite/delegates/gpu/common/shape.h"
33 #include "tensorflow/lite/delegates/gpu/common/status.h"
34 #include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h"
35 #include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h"
36 #include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
37 #include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
38 #include "tensorflow/lite/delegates/gpu/common/util.h"
39 #include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
40 #include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
41
42 namespace tflite {
43 namespace gpu {
44 namespace metal {
45 namespace {
46
47 // returns true if actual memory for this storage type is buffer
IsBufferBased(const TensorStorageType & type)48 bool IsBufferBased(const TensorStorageType& type) {
49 return type == TensorStorageType::BUFFER ||
50 type == TensorStorageType::IMAGE_BUFFER;
51 }
52
HasIntersection(const std::vector<ValueId> & vec_ids,const std::set<ValueId> & ids)53 bool HasIntersection(const std::vector<ValueId>& vec_ids,
54 const std::set<ValueId>& ids) {
55 for (ValueId id : vec_ids) {
56 if (ids.find(id) != ids.end()) {
57 return true;
58 }
59 }
60 return false;
61 }
62
IsReady(const std::set<ValueId> & ready_tensors,const MetalNode & node)63 bool IsReady(const std::set<ValueId>& ready_tensors, const MetalNode& node) {
64 for (const ValueId in_id : node.inputs) {
65 if (ready_tensors.find(in_id) == ready_tensors.end()) {
66 return false;
67 }
68 }
69 return true;
70 }
71
AddUsage(ValueId id,int task_index,std::map<ValueId,int2> * usage_records)72 void AddUsage(ValueId id, int task_index,
73 std::map<ValueId, int2>* usage_records) {
74 auto it = usage_records->find(id);
75 if (it == usage_records->end()) {
76 // initializing start index(.x) and end index(.y)
77 (*usage_records)[id].x = task_index;
78 (*usage_records)[id].y = task_index;
79 } else {
80 // updating end index(.y)
81 (*usage_records)[id].y = task_index;
82 }
83 }
84
85 // Generic add is add that have several runtime inputs and they are not
86 // broadcasted, i.e. pointwise add for N tensors where N > 1.
IsGenericAdd(const Node & node,const std::vector<Value * > & inputs,const std::vector<Value * > & outputs)87 bool IsGenericAdd(const Node& node, const std::vector<Value*>& inputs,
88 const std::vector<Value*>& outputs) {
89 if (inputs.size() == 1) {
90 return false;
91 }
92 const OperationType op_type = OperationTypeFromString(node.operation.type);
93 if (op_type != OperationType::ADD) {
94 return false;
95 }
96
97 const auto dst_shape = outputs[0]->tensor.shape;
98 for (int i = 0; i < inputs.size(); ++i) {
99 const auto src_shape = inputs[i]->tensor.shape;
100 if (dst_shape.b != src_shape.b && src_shape.b == 1) {
101 return false;
102 }
103 if (dst_shape.h != src_shape.h && src_shape.h == 1) {
104 return false;
105 }
106 if (dst_shape.w != src_shape.w && src_shape.w == 1) {
107 return false;
108 }
109 if (dst_shape.c != src_shape.c && src_shape.c == 1) {
110 return false;
111 }
112 }
113 return true;
114 }
115
MergeNodes(MetalNode * src,MetalNode * dst)116 absl::Status MergeNodes(MetalNode* src, MetalNode* dst) {
117 for (int j = 1; j < src->inputs.size(); ++j) {
118 dst->inputs.push_back(src->inputs[j]);
119 }
120 dst->outputs[0] = src->outputs[0];
121 dst->name += " linked : " + src->name;
122 return dst->task.AddTask(&src->task);
123 }
124 } // namespace
125
InitFromGraphWithTransforms(const CreateInferenceInfo & create_info,GraphFloat32 * graph,id<MTLDevice> device_id)126 absl::Status InferenceContext::InitFromGraphWithTransforms(
127 const CreateInferenceInfo& create_info, GraphFloat32* graph,
128 id<MTLDevice> device_id) {
129 RETURN_IF_ERROR(RunGraphTransforms(graph));
130 RETURN_IF_ERROR(InitFromGraph(create_info, *graph, device_id));
131 return absl::OkStatus();
132 }
133
InitFromGraph(const CreateInferenceInfo & create_info,const GraphFloat32 & graph,id<MTLDevice> device_id)134 absl::Status InferenceContext::InitFromGraph(
135 const CreateInferenceInfo& create_info, const GraphFloat32& graph,
136 id<MTLDevice> device_id) {
137 std::set<ValueId> preallocated_ids;
138 const auto inputs = graph.inputs();
139 for (const auto& input : inputs) {
140 input_ids_.push_back(input->id);
141 preallocated_ids.insert(input->id);
142 }
143
144 const auto outputs = graph.outputs();
145 for (const auto& output : outputs) {
146 output_ids_.push_back(output->id);
147 preallocated_ids.insert(output->id);
148 }
149 precision_ = create_info.precision;
150
151 MetalDevice metal_device(device_id);
152 ReserveGraphTensors(create_info, metal_device.GetInfo(), graph,
153 preallocated_ids);
154 RETURN_IF_ERROR(Compile(graph, metal_device.GetInfo(), create_info.hints));
155 RETURN_IF_ERROR(Merge());
156 RETURN_IF_ERROR(CompileOperations(&metal_device));
157 RETURN_IF_ERROR(AllocateTensors(&metal_device, preallocated_ids));
158 BindTensorsToOperations();
159 RETURN_IF_ERROR(UpdateParams(metal_device.GetInfo()));
160 RETURN_IF_ERROR(Tune(TuningType::kFast, &metal_device));
161 return absl::OkStatus();
162 }
163
ReserveGraphTensors(const CreateInferenceInfo & create_info,const GpuInfo & gpu_info,const GraphFloat32 & graph,const std::set<ValueId> & preallocated_ids)164 void InferenceContext::ReserveGraphTensors(
165 const CreateInferenceInfo& create_info, const GpuInfo& gpu_info,
166 const GraphFloat32& graph, const std::set<ValueId>& preallocated_ids) {
167 ValueId max_id = 0;
168 auto tensors = graph.values();
169 auto data_type = DeduceDataTypeFromPrecision(create_info.precision);
170 for (auto& t : tensors) {
171 TensorStorageType storage_type = create_info.storage_type;
172 if (preallocated_ids.find(t->id) != preallocated_ids.end()) {
173 storage_type = TensorStorageType::BUFFER;
174 }
175 const auto shape = graph.GetValue(t->id)->tensor.shape;
176 Layout layout = shape.b == 1 ? Layout::HWC : Layout::BHWC;
177 // Temporary disabled because no support of SINGLE_TEXTURE_2D in Metal
178 // Metal supports only BUFFER storage type currently
179 // if (graph.IsGraphInput(t->id) || graph.IsGraphOutput(t->id)) {
180 // if (false && shape.c < 4 &&
181 // CanCreateTensorWithShape(
182 // gpu_info, shape,
183 // TensorDescriptor{data_type,
184 // TensorStorageType::SINGLE_TEXTURE_2D,
185 // layout})) {
186 // storage_type = TensorStorageType::SINGLE_TEXTURE_2D;
187 // }
188 // }
189 storage_type =
190 SelectBestStorageType(gpu_info, shape, storage_type, data_type, layout);
191 tensor_reserver_.Add(
192 t->id, {shape, TensorDescriptor{data_type, storage_type, layout}});
193 max_id = std::max(max_id, t->id);
194 }
195 tensor_reserver_.SetNext(max_id + 1);
196 }
197
Compile(const GraphFloat32 & graph,const GpuInfo & gpu_info,ModelHints hints)198 absl::Status InferenceContext::Compile(const GraphFloat32& graph,
199 const GpuInfo& gpu_info,
200 ModelHints hints) {
201 if (!IsBatchMatchesForAllValues(graph)) {
202 return absl::InvalidArgumentError(
203 "Only identical batch dimension is supported");
204 }
205 std::map<ValueId, TensorDescriptor> tensor_descriptors;
206 const auto values = graph.values();
207 for (auto value : values) {
208 tensor_descriptors[value->id] = tensor_reserver_.Get(value->id).descriptor;
209 }
210 std::set<NodeId> consumed_nodes;
211 std::map<ValueId, int>
212 tensor_usages; // keeps latest index of operation that updated tensor
213 for (const auto& input_id : input_ids_) {
214 tensor_usages[input_id] = -1; // so as inputs "updated" before operation 0,
215 // we will mark them with -1
216 }
217 std::vector<Node*> graph_nodes = graph.nodes();
218 for (int i = 0; i < graph_nodes.size(); ++i) {
219 const Node& node = *graph_nodes[i];
220 auto op_type = OperationTypeFromString(node.operation.type);
221 if (op_type == OperationType::CONSTANT) {
222 auto attr =
223 absl::any_cast<ConstTensorAttributes>(node.operation.attributes);
224 auto outputs = graph.FindOutputs(node.id);
225 const_tensors_descs_[outputs[0]->id] =
226 tensor_reserver_.Get(outputs[0]->id).descriptor;
227 const_tensors_descs_[outputs[0]->id].UploadData(attr.tensor);
228 continue;
229 }
230 std::string op_name = node.operation.type + " " + std::to_string(node.id);
231 GPUOperationsSubgraph gpu_subgraph;
232 if (hints.Check(ModelHints::kAllowSpecialKernels) &&
233 GPUSubgraphFromGraph(gpu_info, precision_, graph, node.id,
234 tensor_descriptors, &consumed_nodes, &gpu_subgraph,
235 &op_name)
236 .ok()) {
237 // Mapping of subgraph (set of nodes) to GPU operations. Should happen
238 // before straigtforward mapping.
239 } else {
240 // Straigtforward mapping of one graph node to GPU operations.
241 auto inputs = graph.FindInputs(node.id);
242 auto outputs = graph.FindOutputs(node.id);
243 // Reordering of input ids and updating of temporary tensors_usage struct.
244 // This stage is necessary because we are building OperationDef that rely
245 // on order of input ids. But we also should have input id on first
246 // position that potentially can be "linking" tensor and as result
247 // eliminated(unused) We apply it only for ADD operation, because of ADD
248 // associativity and ADD can be linked. In current approach "linking"
249 // tensor can be only latest written tensor(during linear order of
250 // execution) among input tensors.
251 if (IsGenericAdd(node, inputs, outputs)) {
252 int latest_written_tensor_index = 0;
253 int last_usage = tensor_usages[inputs[0]->id];
254 for (int j = 1; j < inputs.size(); ++j) {
255 if (tensor_usages[inputs[j]->id] > last_usage) {
256 last_usage = tensor_usages[inputs[j]->id];
257 latest_written_tensor_index = j;
258 }
259 }
260 std::swap(inputs[0], inputs[latest_written_tensor_index]);
261 }
262 consumed_nodes.insert(node.id);
263 OperationDef op_def;
264 op_def.precision = precision_;
265 for (int j = 0; j < inputs.size(); ++j) {
266 op_def.src_tensors.push_back(
267 tensor_reserver_.Get(inputs[j]->id).descriptor);
268 }
269 for (int j = 0; j < outputs.size(); ++j) {
270 op_def.dst_tensors.push_back(
271 tensor_reserver_.Get(outputs[j]->id).descriptor);
272 }
273 RETURN_IF_ERROR(GPUOperationFromNode(gpu_info, op_def, hints, inputs,
274 outputs, node, &gpu_subgraph));
275 }
276 std::map<int, ValueId> mapping_to_global_ids;
277 for (int j = 0; j < gpu_subgraph.new_tensors.size(); ++j) {
278 const auto& t = gpu_subgraph.new_tensors[j];
279 auto global_id = tensor_reserver_.Add({t.first, t.second});
280 mapping_to_global_ids[j] = global_id;
281 }
282 for (auto& gpu_op : gpu_subgraph.operations) {
283 MetalNode metal_node;
284 metal_node.task.Init(std::move(gpu_op.operation));
285 metal_node.inputs.resize(gpu_op.input_ids.size());
286 for (int j = 0; j < gpu_op.input_ids.size(); ++j) {
287 int id = gpu_op.input_ids[j];
288 if (id >= 0) {
289 metal_node.inputs[j] = id;
290 } else {
291 metal_node.inputs[j] = mapping_to_global_ids[-(id + 1)];
292 }
293 }
294 metal_node.outputs.resize(gpu_op.output_ids.size());
295 for (int j = 0; j < gpu_op.output_ids.size(); ++j) {
296 int id = gpu_op.output_ids[j];
297 if (id >= 0) {
298 metal_node.outputs[j] = id;
299 tensor_usages[id] = i;
300 } else {
301 metal_node.outputs[j] = mapping_to_global_ids[-(id + 1)];
302 }
303 }
304 metal_node.name = op_name;
305 nodes_.push_back(std::move(metal_node));
306 }
307 }
308 return absl::OkStatus();
309 }
310
Merge()311 absl::Status InferenceContext::Merge() {
312 std::set<ValueId> ready_tensors;
313 for (const auto& input_id : input_ids_) {
314 ready_tensors.insert(input_id);
315 }
316 for (int i = 0; i < nodes_.size(); ++i) {
317 auto& node = nodes_[i];
318 for (const auto& out_id : node.outputs) {
319 ready_tensors.insert(out_id);
320 }
321 if (node.outputs.size() != 1) {
322 continue;
323 }
324 std::vector<int> next_nodes;
325 int link_index = 0;
326 for (int j = i + 1; j < nodes_.size(); ++j) {
327 for (int k = 0; k < nodes_[j].inputs.size(); ++k) {
328 if (nodes_[j].inputs[k] == node.outputs[0]) {
329 next_nodes.push_back(j);
330 link_index = k;
331 }
332 }
333 }
334 if (next_nodes.size() != 1 || link_index != 0) {
335 continue;
336 }
337 auto& linkable_node = nodes_[next_nodes[0]];
338 if (!linkable_node.task.IsLinkable() || linkable_node.outputs.size() != 1 ||
339 !IsReady(ready_tensors, linkable_node)) {
340 continue;
341 }
342 const auto& original_dst_def = node.task.GetDefinition().dst_tensors[0];
343 const auto& link_dst_def =
344 linkable_node.task.GetDefinition().dst_tensors[0];
345 if (original_dst_def != link_dst_def) {
346 continue;
347 }
348 RETURN_IF_ERROR(MergeNodes(&linkable_node, &node));
349 nodes_.erase(nodes_.begin() + next_nodes[0]);
350 i -= 1;
351 }
352 return absl::OkStatus();
353 }
354
CompileOperations(MetalDevice * device)355 absl::Status InferenceContext::CompileOperations(MetalDevice* device) {
356 for (auto& node : nodes_) {
357 RETURN_IF_ERROR(node.task.Compile(device));
358 }
359 return absl::OkStatus();
360 }
361
AllocateTensors(MetalDevice * device,const std::set<ValueId> & preallocated_ids)362 absl::Status InferenceContext::AllocateTensors(
363 MetalDevice* device, const std::set<ValueId>& preallocated_ids) {
364 for (int i = 0; i < nodes_.size(); ++i) {
365 auto& node = nodes_[i];
366 if (HasIntersection(node.inputs, preallocated_ids) ||
367 HasIntersection(node.outputs, preallocated_ids)) {
368 task_ids_with_preallocated_tensors_.push_back(i);
369 }
370 }
371
372 for (auto& tensor_id : preallocated_ids) {
373 const auto& t = tensor_reserver_.Get(tensor_id);
374 RETURN_IF_ERROR(CreateSharedBufferTensor(
375 nil, t.shape, t.descriptor, &preallocated_tensors_[tensor_id]));
376 }
377
378 RETURN_IF_ERROR(AllocateMemoryForConstTensors(device));
379 RETURN_IF_ERROR(AllocateMemoryForBuffers(device));
380 RETURN_IF_ERROR(AllocateMemoryForStrongShapes(device));
381 return absl::OkStatus();
382 }
383
GetTensor(ValueId tensor_id)384 MetalSpatialTensor* InferenceContext::GetTensor(ValueId tensor_id) {
385 if (preallocated_tensors_.find(tensor_id) != preallocated_tensors_.end()) {
386 return &preallocated_tensors_[tensor_id];
387 } else if (const_tensors_.find(tensor_id) != const_tensors_.end()) {
388 return &const_tensors_[tensor_id];
389 } else if (graph_ids_to_shared_buffer_tensors_.find(tensor_id) !=
390 graph_ids_to_shared_buffer_tensors_.end()) {
391 return &shared_buffer_tensors_
392 [graph_ids_to_shared_buffer_tensors_[tensor_id]];
393 } else if (graph_ids_to_strong_shape_tensors_.find(tensor_id) !=
394 graph_ids_to_strong_shape_tensors_.end()) {
395 return &strong_shape_tensors_
396 [graph_ids_to_strong_shape_tensors_[tensor_id]];
397 }
398 return nullptr;
399 }
400
BindTensorsToOperations()401 void InferenceContext::BindTensorsToOperations() {
402 for (auto& node : nodes_) {
403 const auto& src_ids = node.inputs;
404 for (int i = 0; i < src_ids.size(); ++i) {
405 node.task.SetSrcTensor(GetTensor(src_ids[i]), i);
406 }
407 const auto& dst_ids = node.outputs;
408 for (int i = 0; i < dst_ids.size(); ++i) {
409 node.task.SetDstTensor(GetTensor(dst_ids[i]), i);
410 }
411 }
412 }
413
UpdateParams(const GpuInfo & gpu_info)414 absl::Status InferenceContext::UpdateParams(const GpuInfo& gpu_info) {
415 for (auto& node : nodes_) {
416 std::vector<BHWC> src_shapes;
417 std::vector<BHWC> dst_shapes;
418 for (const auto& in_id : node.inputs) {
419 src_shapes.push_back(tensor_reserver_.Get(in_id).shape);
420 }
421 for (const auto& out_id : node.outputs) {
422 dst_shapes.push_back(tensor_reserver_.Get(out_id).shape);
423 }
424 RETURN_IF_ERROR(node.task.UpdateParams());
425 }
426 return absl::OkStatus();
427 }
428
GetTensorMemoryType(ValueId id)429 InferenceContext::TensorMemoryType InferenceContext::GetTensorMemoryType(
430 ValueId id) {
431 if (preallocated_tensors_.find(id) != preallocated_tensors_.end()) {
432 return TensorMemoryType::kPreallocated;
433 } else if (const_tensors_.find(id) != const_tensors_.end()) {
434 return TensorMemoryType::kConst;
435 } else if (IsBufferBased(tensor_reserver_.Get(id).descriptor.storage_type)) {
436 return TensorMemoryType::kBuffer;
437 } else {
438 return TensorMemoryType::kStrongShape;
439 }
440 }
441
GetUsages(const std::function<bool (ValueId)> & functor,std::map<ValueId,int2> * usages)442 void InferenceContext::GetUsages(const std::function<bool(ValueId)>& functor,
443 std::map<ValueId, int2>* usages) {
444 for (ValueId in_id : input_ids_) {
445 if (functor(in_id)) {
446 AddUsage(in_id, 0, usages);
447 }
448 }
449 for (int op_index = 0; op_index < nodes_.size(); ++op_index) {
450 for (auto& tensor_id : nodes_[op_index].inputs) {
451 if (functor(tensor_id)) {
452 AddUsage(tensor_id, op_index, usages);
453 }
454 }
455 for (auto& tensor_id : nodes_[op_index].outputs) {
456 if (functor(tensor_id)) {
457 AddUsage(tensor_id, op_index, usages);
458 }
459 }
460 }
461 for (ValueId out_id : output_ids_) {
462 if (functor(out_id)) {
463 AddUsage(out_id, nodes_.size(), usages);
464 }
465 }
466 }
467
AllocateMemoryForConstTensors(MetalDevice * device)468 absl::Status InferenceContext::AllocateMemoryForConstTensors(
469 MetalDevice* device) {
470 for (auto& description : const_tensors_descs_) {
471 RETURN_IF_ERROR(const_tensors_[description.first].CreateFromDescriptor(
472 description.second, device->device()));
473 }
474 const_tensors_descs_.clear();
475 return absl::OkStatus();
476 }
477
AllocateMemoryForBuffers(MetalDevice * device)478 absl::Status InferenceContext::AllocateMemoryForBuffers(MetalDevice* device) {
479 std::map<ValueId, int2> buffer_usages;
480 GetUsages(
481 [this](ValueId id) {
482 return GetTensorMemoryType(id) == TensorMemoryType::kBuffer;
483 },
484 &buffer_usages);
485
486 std::vector<TensorUsageRecord<size_t>> buffer_usage_records;
487 for (auto& usage : buffer_usages) {
488 const auto& shape = tensor_reserver_.Get(usage.first).shape;
489 const size_t buffer_size =
490 shape.b * shape.w * shape.h * AlignByN(shape.c, 4);
491 graph_ids_to_shared_buffer_tensors_[usage.first] =
492 buffer_usage_records.size();
493 buffer_usage_records.push_back({buffer_size,
494 static_cast<TaskId>(usage.second.x),
495 static_cast<TaskId>(usage.second.y)});
496 }
497
498 ObjectsAssignment<size_t> buffer_assignment;
499 RETURN_IF_ERROR(AssignObjectsToTensors(
500 buffer_usage_records, MemoryStrategy::GREEDY_BEST, &buffer_assignment));
501
502 const bool f32_storage = precision_ == CalculationsPrecision::F32;
503 size_t dataTypeSize = f32_storage ? sizeof(float) : sizeof(HalfBits);
504 shared_buffers_.resize(buffer_assignment.object_sizes.size());
505 for (int i = 0; i < buffer_assignment.object_sizes.size(); ++i) {
506 // Initialize metal buffer
507 NSUInteger bufferSize = dataTypeSize * buffer_assignment.object_sizes[i];
508
509 if (bufferSize > device->GetInfo().GetMaxBufferSize()) {
510 std::string error("Tensor id: ");
511 error += std::to_string(buffer_assignment.object_ids[i]) +
512 " with size: " + std::to_string(bufferSize) +
513 " exceeds MTLDevice maxBufferLength: " +
514 std::to_string(device->GetInfo().GetMaxBufferSize());
515 return absl::ResourceExhaustedError(error);
516 }
517
518 shared_buffers_[i] =
519 [device->device() newBufferWithLength:bufferSize
520 options:MTLResourceStorageModeShared];
521 }
522
523 std::vector<bool> created_tensors(buffer_usage_records.size(), false);
524 shared_buffer_tensors_.resize(buffer_usage_records.size());
525 for (auto& node : nodes_) {
526 std::vector<ValueId> all_ids = node.inputs;
527 all_ids.insert(all_ids.end(), node.outputs.begin(), node.outputs.end());
528 for (auto& tensor_id : all_ids) {
529 if (preallocated_tensors_.find(tensor_id) != preallocated_tensors_.end())
530 continue;
531 const int tensor_index = graph_ids_to_shared_buffer_tensors_[tensor_id];
532 if (created_tensors[tensor_index]) continue;
533 const auto& tensor_dummy = tensor_reserver_.Get(tensor_id);
534 const int buffer_index = buffer_assignment.object_ids[tensor_index];
535 RETURN_IF_ERROR(CreateSharedBufferTensor(
536 shared_buffers_[buffer_index], tensor_dummy.shape,
537 tensor_dummy.descriptor, &shared_buffer_tensors_[tensor_index]));
538 created_tensors[tensor_index] = true;
539 }
540 }
541 return absl::OkStatus();
542 }
543
AllocateMemoryForStrongShapes(MetalDevice * device)544 absl::Status InferenceContext::AllocateMemoryForStrongShapes(
545 MetalDevice* device) {
546 std::map<ValueId, int2> usages;
547 GetUsages(
548 [this](ValueId id) {
549 return GetTensorMemoryType(id) == TensorMemoryType::kStrongShape;
550 },
551 &usages);
552
553 std::vector<TensorUsageRecord<DummyTensor>> usage_records;
554 std::map<ValueId, ValueId> remap_from_graph_ids;
555 for (auto& usage : usages) {
556 remap_from_graph_ids[usage.first] = usage_records.size();
557 usage_records.push_back({tensor_reserver_.Get(usage.first),
558 static_cast<TaskId>(usage.second.x),
559 static_cast<TaskId>(usage.second.y)});
560 }
561
562 ObjectsAssignment<DummyTensor> assignment;
563 RETURN_IF_ERROR(AssignObjectsToTensors(
564 usage_records, MemoryStrategy::EQUALITY, &assignment));
565
566 for (auto& node : nodes_) {
567 std::vector<ValueId> all_ids = node.inputs;
568 all_ids.insert(all_ids.end(), node.outputs.begin(), node.outputs.end());
569 for (auto& tensor_id : all_ids) {
570 const auto& tensor_dummy = tensor_reserver_.Get(tensor_id);
571 if (GetTensorMemoryType(tensor_id) != TensorMemoryType::kStrongShape) {
572 continue;
573 }
574 const auto id = assignment.object_ids[remap_from_graph_ids[tensor_id]];
575 graph_ids_to_strong_shape_tensors_[tensor_id] = id;
576 const auto& it = strong_shape_tensors_.find(id);
577 if (it == strong_shape_tensors_.end()) {
578 RETURN_IF_ERROR(CreateTensor(device->device(), tensor_dummy.shape,
579 tensor_dummy.descriptor,
580 &strong_shape_tensors_[id]));
581 }
582 }
583 }
584 return absl::OkStatus();
585 }
586
Tune(TuningType tuning_type,MetalDevice * device)587 absl::Status InferenceContext::Tune(TuningType tuning_type,
588 MetalDevice* device) {
589 for (auto& node : nodes_) {
590 RETURN_IF_ERROR(node.task.Tune(tuning_type, device));
591 }
592 return absl::OkStatus();
593 }
594
EncodeWithEncoder(id<MTLComputeCommandEncoder> command_encoder)595 void InferenceContext::EncodeWithEncoder(
596 id<MTLComputeCommandEncoder> command_encoder) {
597 for (int i = 0; i < nodes_.size(); ++i) {
598 auto& task = nodes_[i].task;
599 task.Encode(command_encoder);
600 }
601 }
602
Profile(id<MTLDevice> device,ProfilingInfo * result)603 void InferenceContext::Profile(id<MTLDevice> device, ProfilingInfo* result) {
604 result->dispatches.resize(nodes_.size());
605 id<MTLCommandQueue> command_queue = [device newCommandQueue];
606 for (int k = 0; k < nodes_.size(); ++k) {
607 @autoreleasepool {
608 id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
609 id<MTLComputeCommandEncoder> encoder =
610 [command_buffer computeCommandEncoder];
611 auto& task = nodes_[k].task;
612 const int kRuns = 500;
613 for (int i = 0; i < kRuns; ++i) {
614 task.Encode(encoder);
615 }
616 [encoder endEncoding];
617 auto start = absl::Now();
618 [command_buffer commit];
619 [command_buffer waitUntilCompleted];
620 auto end = absl::Now();
621 auto& dispatch_info = result->dispatches[k];
622 dispatch_info.label = nodes_[k].name;
623 dispatch_info.duration = (end - start) / static_cast<float>(kRuns);
624 }
625 }
626 }
627
EncodeWithCommandBuffer(id<MTLCommandBuffer> command_buffer)628 void InferenceContext::EncodeWithCommandBuffer(
629 id<MTLCommandBuffer> command_buffer) {
630 for (int i = 0; i < nodes_.size(); ++i) {
631 id<MTLComputeCommandEncoder> encoder =
632 [command_buffer computeCommandEncoder];
633 auto& task = nodes_[i].task;
634 task.Encode(encoder);
635 [encoder endEncoding];
636 }
637 }
638
EncodeWithCommandQueue(id<MTLCommandQueue> command_queue,int flush_period)639 void InferenceContext::EncodeWithCommandQueue(id<MTLCommandQueue> command_queue,
640 int flush_period) {
641 id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
642 for (int i = 0; i < nodes_.size(); ++i) {
643 id<MTLComputeCommandEncoder> encoder =
644 [command_buffer computeCommandEncoder];
645 auto& task = nodes_[i].task;
646 task.Encode(encoder);
647 [encoder endEncoding];
648 if (i % flush_period == (flush_period - 1)) {
649 [command_buffer commit];
650 command_buffer = [command_queue commandBuffer];
651 }
652 }
653 [command_buffer commit];
654 }
655
UpdatePreallocatedTensors(const std::map<ValueId,id<MTLBuffer>> & preallocated)656 void InferenceContext::UpdatePreallocatedTensors(
657 const std::map<ValueId, id<MTLBuffer>>& preallocated) {
658 for (const auto& it : preallocated) {
659 auto status = preallocated_tensors_[it.first].SetBufferHandle(it.second);
660 }
661 for (auto& task_index : task_ids_with_preallocated_tensors_) {
662 auto& task = nodes_[task_index].task;
663 const auto& src_ids = nodes_[task_index].inputs;
664 for (int i = 0; i < src_ids.size(); ++i) {
665 const auto& it = preallocated_tensors_.find(src_ids[i]);
666 if (it != preallocated_tensors_.end()) {
667 task.SetSrcTensor(&it->second, i);
668 }
669 }
670 const auto& dst_ids = nodes_[task_index].outputs;
671 for (int i = 0; i < dst_ids.size(); ++i) {
672 const auto& it = preallocated_tensors_.find(dst_ids[i]);
673 if (it != preallocated_tensors_.end()) {
674 task.SetDstTensor(&it->second, i);
675 }
676 }
677 }
678 }
679
RunGraphTransforms(GraphFloat32 * graph)680 absl::Status RunGraphTransforms(GraphFloat32* graph) {
681 auto merge_padding_transform = NewMergePaddingWithAdd();
682 auto add_bias_transform = NewAddBias();
683 auto pooling_to_reduce_op = NewGlobalPoolingToReduceOp();
684 ModelTransformer transformer(graph, /*reporter=*/nullptr);
685 if (!transformer.Apply("add_bias", add_bias_transform.get())) {
686 return absl::InternalError("Invalid add_bias transform");
687 }
688 if (!transformer.Apply("merge_padding", merge_padding_transform.get())) {
689 return absl::InternalError("Invalid merge_padding transform");
690 }
691 if (!transformer.Apply("global pooling to mean",
692 pooling_to_reduce_op.get())) {
693 return absl::InternalError("Invalid global pooling to mean transform");
694 }
695 return absl::OkStatus();
696 }
697
698 } // namespace metal
699 } // namespace gpu
700 } // namespace tflite
701