• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
17 
18 #include <Availability.h>
19 #include <string>
20 #include <tuple>
21 
22 #include "tensorflow/lite/delegates/gpu/common/model.h"
23 #include "tensorflow/lite/delegates/gpu/common/shape.h"
24 #include "tensorflow/lite/delegates/gpu/common/status.h"
25 #include "tensorflow/lite/delegates/gpu/common/types.h"
26 #include "tensorflow/lite/delegates/gpu/common/util.h"
27 #include "tensorflow/lite/delegates/gpu/metal/common.h"
28 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
29 
30 using ::tflite::gpu::AlignByN;
31 using ::tflite::gpu::BHWC;
32 using ::tflite::gpu::InternalError;
33 using ::tflite::gpu::InvalidArgumentError;
34 using ::tflite::gpu::HalfBits;
35 using ::tflite::gpu::metal::ComputeTaskDescriptorPtr;
36 using ::tflite::gpu::metal::CreateComputeProgram;
37 using ::tflite::gpu::metal::DispatchParamsFunction;
38 using ::tflite::gpu::metal::OutputDimensions;
39 using ::tflite::gpu::metal::RuntimeOptions;
40 using ::tflite::gpu::metal::UniformsFunction;
41 using ::tflite::gpu::OkStatus;
42 using ::tflite::gpu::Status;
43 using ::tflite::gpu::uint3;
44 using ::tflite::gpu::ValueId;
45 
46 @implementation TFLComputeTask {
47   struct InputBuffer {
48     ValueId uid;
49     id<MTLBuffer> metalHandle;
50   };
51   struct OutputBuffer {
52     ValueId uid;
53     id<MTLBuffer> metalHandle;
54     OutputDimensions dimensionsFunction;
55     std::vector<ValueId> alias;
56   };
57   struct UniformBuffer {
58     std::vector<uint8_t> data;
59     UniformsFunction dataFunction;
60   };
61 
62   id<MTLComputePipelineState> _program;
63   std::vector<InputBuffer> _inputBuffers;
64   std::vector<OutputBuffer> _outputBuffers;
65   std::vector<id<MTLBuffer>> _immutableBuffers;
66   std::vector<UniformBuffer> _uniformBuffers;
67   uint3 _groupsSize;
68   uint3 _groupsCount;
69   DispatchParamsFunction _resizeFunction;
70 }
71 
72 - (Status)compileWithDevice:(id<MTLDevice>)device
73              taskDescriptor:(ComputeTaskDescriptorPtr)desc
74              runtimeOptions:(const RuntimeOptions&)options {
75   NSString* barrier;
76   // simdgroup_barrier is supported on macOS 10.13+ and Metal shading language version 2.0
77   if (@available(macOS 10.13, iOS 10.0, tvOS 10.0, *)) {
78     barrier = @"simdgroup_barrier";
79   } else {
80     barrier = @"threadgroup_barrier";
81   }
82   NSString* storageType;
83   NSString* accumulatorType;
84   NSString* toAccumulatorType = @"";
85   NSString* toAccumulatorType2 = @"";
86   NSString* toAccumulatorType3 = @"";
87   NSString* toAccumulatorType4 = @"";
88   if (options.storage_precision == RuntimeOptions::Precision::FP32) {
89     storageType = @"float";
90     accumulatorType = @"float";
91   } else {
92     // FP16
93     storageType = @"half";
94     if (options.accumulator_precision == RuntimeOptions::Precision::FP32) {
95       accumulatorType = @"float";
96       toAccumulatorType = @"float";
97       toAccumulatorType2 = @"float2";
98       toAccumulatorType3 = @"float3";
99       toAccumulatorType4 = @"float4";
100     } else {
101       accumulatorType = @"half";
102     }
103   }
104   NSDictionary<NSString*, NSString*>* macros = @{
105     @"FLT" : storageType,
106     @"FLT2" : [NSString stringWithFormat:@"%@2", storageType],
107     @"FLT3" : [NSString stringWithFormat:@"%@3", storageType],
108     @"FLT4" : [NSString stringWithFormat:@"%@4", storageType],
109     @"ACCUM_FLT" : accumulatorType,
110     @"ACCUM_FLT2" : [NSString stringWithFormat:@"%@2", accumulatorType],
111     @"ACCUM_FLT3" : [NSString stringWithFormat:@"%@3", accumulatorType],
112     @"ACCUM_FLT4" : [NSString stringWithFormat:@"%@4", accumulatorType],
113     @"TO_ACCUM_TYPE" : toAccumulatorType,
114     @"TO_ACCUM2_TYPE" : toAccumulatorType2,
115     @"TO_ACCUM3_TYPE" : toAccumulatorType3,
116     @"TO_ACCUM4_TYPE" : toAccumulatorType4,
117     @"BARRIER" : barrier,
118   };
119 
120   NSString* code = [NSString stringWithCString:desc->shader_source.c_str()
121                                       encoding:[NSString defaultCStringEncoding]];
122   id<MTLComputePipelineState> program;
123   RETURN_IF_ERROR(CreateComputeProgram(device, code, @"ComputeFunction", macros, &program));
124   if (!program) {
125     return InternalError("Unknown shader compilation error");
126   }
127   for (auto& buffer : desc->input_buffers) {
128     _inputBuffers.emplace_back(InputBuffer{buffer.id, nil});
129   }
130   for (auto& uniform : desc->uniform_buffers) {
131     _uniformBuffers.emplace_back(UniformBuffer{{}, uniform.data_function});
132   }
133   _outputBuffers.emplace_back(OutputBuffer{desc->output_buffer.id, nil,
134                                            desc->output_buffer.dimensions_function,
135                                            desc->output_buffer.alias});
136   for (auto& immutable : desc->immutable_buffers) {
137     int padding =
138         4 * (options.storage_precision == RuntimeOptions::Precision::FP32 ? sizeof(float)
139                                                                           : sizeof(HalfBits));
140     int paddedSize = AlignByN(immutable.data.size(), padding);
141     immutable.data.resize(paddedSize);
142     id<MTLBuffer> metalBuffer = [device newBufferWithBytes:immutable.data.data()
143                                                     length:immutable.data.size()
144                                                    options:MTLResourceStorageModeShared];
145     _immutableBuffers.emplace_back(metalBuffer);
146   }
147   _resizeFunction = desc->resize_function;
148   _program = program;
149   return OkStatus();
150 }
151 
152 - (Status)setInputDimensionsWithDevice:(id<MTLDevice>)device
153                             dimensions:
154                                 (std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>*)dimensions {
155   // Re-calculate output buffers dimensions
156   for (auto& buffer : _outputBuffers) {
157     auto outputDimensions = buffer.dimensionsFunction(*dimensions);
158     for (ValueId duplicate : buffer.alias) {
159       (*dimensions)[duplicate] = outputDimensions;
160     }
161     // Store buffer dimensions
162     (*dimensions)[buffer.uid] = outputDimensions;
163   }
164 
165   for (auto& uniform : _uniformBuffers) {
166     uniform.data = uniform.dataFunction(*dimensions);
167   }
168 
169   // Dispatch parameters re-calculation
170   auto workGroups = _resizeFunction(*dimensions);
171   _groupsSize = workGroups.first;
172   MTLSize threadsPerGroup = [device maxThreadsPerThreadgroup];
173   if (_groupsSize.x > threadsPerGroup.width || _groupsSize.y > threadsPerGroup.height ||
174       _groupsSize.z > threadsPerGroup.depth) {
175     std::string error("Threads per working group: ");
176     error += std::to_string(_groupsSize.x) + ", " + std::to_string(_groupsSize.y) + ", " +
177              std::to_string(_groupsSize.z);
178     error += "is larger than the MTLDevice can support: ";
179     error += std::to_string(threadsPerGroup.width) + ", " + std::to_string(threadsPerGroup.height) +
180              ", " + std::to_string(threadsPerGroup.depth);
181     return InvalidArgumentError(error);
182   }
183   _groupsCount = workGroups.second;
184   return OkStatus();
185 }
186 
187 - (Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id<MTLBuffer>>*)buffers
188               outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds
189          usageRecordIds:(const std::map<ValueId, size_t>&)usageRecordIds
190         sharedBufferIds:(const std::vector<size_t>&)sharedBufferIds
191           sharedBuffers:(const std::vector<id<MTLBuffer>>&)sharedBuffers {
192   for (auto& buffer : _outputBuffers) {
193     // If the buffer is intermediate: set its metalHandle from sharedBuffers
194     if (std::find(outputIds.begin(), outputIds.end(), buffer.uid) == outputIds.end()) {
195       auto usageRecordIt = usageRecordIds.find(buffer.uid);
196       if (usageRecordIt == usageRecordIds.end()) {
197         return InternalError("TensorUsageRecord for intermediate tensor is not found.");
198       }
199       buffer.metalHandle = sharedBuffers.at(sharedBufferIds.at(usageRecordIt->second));
200       (*buffers)[buffer.uid] = buffer.metalHandle;
201     }
202   }
203 
204   // Re-assign input buffers
205   for (auto& buffer : _inputBuffers) {
206     buffer.metalHandle = (*buffers)[buffer.uid];
207   }
208   return OkStatus();
209 }
210 
211 - (void)encodeWithEncoder:(id<MTLComputeCommandEncoder>)encoder
212        inputOutputBuffers:(const std::map<ValueId, id<MTLBuffer>>&)inputOutputBuffers {
213   // The dispatch call is intended to be skipped.
214   if (_groupsCount.x * _groupsCount.y * _groupsCount.z == 0) {
215     return;
216   }
217 
218   [encoder setComputePipelineState:_program];
219 
220   int bindIndex = 0;
221   for (auto& buffer : _outputBuffers) {
222     const auto externalBuffer = inputOutputBuffers.find(buffer.uid);
223     if (externalBuffer == inputOutputBuffers.end()) {
224       [encoder setBuffer:buffer.metalHandle offset:0 atIndex:bindIndex];
225     } else {
226       // the buffer is input or output
227       [encoder setBuffer:externalBuffer->second offset:0 atIndex:bindIndex];
228     }
229     bindIndex++;
230   }
231   for (auto& buffer : _inputBuffers) {
232     const auto externalBuffer = inputOutputBuffers.find(buffer.uid);
233     if (externalBuffer == inputOutputBuffers.end()) {
234       [encoder setBuffer:buffer.metalHandle offset:0 atIndex:bindIndex];
235     } else {
236       // the buffer is input or output
237       [encoder setBuffer:externalBuffer->second offset:0 atIndex:bindIndex];
238     }
239     bindIndex++;
240   }
241   for (auto& immutable : _immutableBuffers) {
242     [encoder setBuffer:immutable offset:0 atIndex:bindIndex];
243     bindIndex++;
244   }
245   for (auto& uniform : _uniformBuffers) {
246     [encoder setBytes:uniform.data.data() length:uniform.data.size() atIndex:bindIndex];
247     bindIndex++;
248   }
249 
250   MTLSize groupsCount = MTLSizeMake(_groupsCount.x, _groupsCount.y, _groupsCount.z);
251   MTLSize groupsSize = MTLSizeMake(_groupsSize.x, _groupsSize.y, _groupsSize.z);
252   [encoder dispatchThreadgroups:groupsCount threadsPerThreadgroup:groupsSize];
253 }
254 
255 @end
256