1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/lite/delegates/gpu/metal/compute_task.h" 17 18 #include <Availability.h> 19 #include <string> 20 #include <tuple> 21 22 #include "tensorflow/lite/delegates/gpu/common/model.h" 23 #include "tensorflow/lite/delegates/gpu/common/shape.h" 24 #include "tensorflow/lite/delegates/gpu/common/status.h" 25 #include "tensorflow/lite/delegates/gpu/common/types.h" 26 #include "tensorflow/lite/delegates/gpu/common/util.h" 27 #include "tensorflow/lite/delegates/gpu/metal/common.h" 28 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" 29 30 using ::tflite::gpu::AlignByN; 31 using ::tflite::gpu::BHWC; 32 using ::tflite::gpu::InternalError; 33 using ::tflite::gpu::InvalidArgumentError; 34 using ::tflite::gpu::HalfBits; 35 using ::tflite::gpu::metal::ComputeTaskDescriptorPtr; 36 using ::tflite::gpu::metal::CreateComputeProgram; 37 using ::tflite::gpu::metal::DispatchParamsFunction; 38 using ::tflite::gpu::metal::OutputDimensions; 39 using ::tflite::gpu::metal::RuntimeOptions; 40 using ::tflite::gpu::metal::UniformsFunction; 41 using ::tflite::gpu::OkStatus; 42 using ::tflite::gpu::Status; 43 using ::tflite::gpu::uint3; 44 using ::tflite::gpu::ValueId; 45 46 @implementation TFLComputeTask { 47 struct InputBuffer { 48 ValueId uid; 49 id<MTLBuffer> metalHandle; 50 }; 51 struct OutputBuffer { 52 ValueId uid; 53 id<MTLBuffer> metalHandle; 54 OutputDimensions dimensionsFunction; 55 std::vector<ValueId> alias; 56 }; 57 struct UniformBuffer { 58 std::vector<uint8_t> data; 59 UniformsFunction dataFunction; 60 }; 61 62 id<MTLComputePipelineState> _program; 63 std::vector<InputBuffer> _inputBuffers; 64 std::vector<OutputBuffer> _outputBuffers; 65 std::vector<id<MTLBuffer>> _immutableBuffers; 66 std::vector<UniformBuffer> _uniformBuffers; 67 uint3 _groupsSize; 68 uint3 _groupsCount; 69 DispatchParamsFunction _resizeFunction; 70 } 71 72 - (Status)compileWithDevice:(id<MTLDevice>)device 73 taskDescriptor:(ComputeTaskDescriptorPtr)desc 74 runtimeOptions:(const RuntimeOptions&)options { 75 NSString* barrier; 76 // simdgroup_barrier is supported on macOS 10.13+ and Metal shading language version 2.0 77 if (@available(macOS 10.13, iOS 10.0, tvOS 10.0, *)) { 78 barrier = @"simdgroup_barrier"; 79 } else { 80 barrier = @"threadgroup_barrier"; 81 } 82 NSString* storageType; 83 NSString* accumulatorType; 84 NSString* toAccumulatorType = @""; 85 NSString* toAccumulatorType2 = @""; 86 NSString* toAccumulatorType3 = @""; 87 NSString* toAccumulatorType4 = @""; 88 if (options.storage_precision == RuntimeOptions::Precision::FP32) { 89 storageType = @"float"; 90 accumulatorType = @"float"; 91 } else { 92 // FP16 93 storageType = @"half"; 94 if (options.accumulator_precision == RuntimeOptions::Precision::FP32) { 95 accumulatorType = @"float"; 96 toAccumulatorType = @"float"; 97 toAccumulatorType2 = @"float2"; 98 toAccumulatorType3 = @"float3"; 99 toAccumulatorType4 = @"float4"; 100 } else { 101 accumulatorType = @"half"; 102 } 103 } 104 NSDictionary<NSString*, NSString*>* macros = @{ 105 @"FLT" : storageType, 106 @"FLT2" : [NSString stringWithFormat:@"%@2", storageType], 107 @"FLT3" : [NSString stringWithFormat:@"%@3", storageType], 108 @"FLT4" : [NSString stringWithFormat:@"%@4", storageType], 109 @"ACCUM_FLT" : accumulatorType, 110 @"ACCUM_FLT2" : [NSString stringWithFormat:@"%@2", accumulatorType], 111 @"ACCUM_FLT3" : [NSString stringWithFormat:@"%@3", accumulatorType], 112 @"ACCUM_FLT4" : [NSString stringWithFormat:@"%@4", accumulatorType], 113 @"TO_ACCUM_TYPE" : toAccumulatorType, 114 @"TO_ACCUM2_TYPE" : toAccumulatorType2, 115 @"TO_ACCUM3_TYPE" : toAccumulatorType3, 116 @"TO_ACCUM4_TYPE" : toAccumulatorType4, 117 @"BARRIER" : barrier, 118 }; 119 120 NSString* code = [NSString stringWithCString:desc->shader_source.c_str() 121 encoding:[NSString defaultCStringEncoding]]; 122 id<MTLComputePipelineState> program; 123 RETURN_IF_ERROR(CreateComputeProgram(device, code, @"ComputeFunction", macros, &program)); 124 if (!program) { 125 return InternalError("Unknown shader compilation error"); 126 } 127 for (auto& buffer : desc->input_buffers) { 128 _inputBuffers.emplace_back(InputBuffer{buffer.id, nil}); 129 } 130 for (auto& uniform : desc->uniform_buffers) { 131 _uniformBuffers.emplace_back(UniformBuffer{{}, uniform.data_function}); 132 } 133 _outputBuffers.emplace_back(OutputBuffer{desc->output_buffer.id, nil, 134 desc->output_buffer.dimensions_function, 135 desc->output_buffer.alias}); 136 for (auto& immutable : desc->immutable_buffers) { 137 int padding = 138 4 * (options.storage_precision == RuntimeOptions::Precision::FP32 ? sizeof(float) 139 : sizeof(HalfBits)); 140 int paddedSize = AlignByN(immutable.data.size(), padding); 141 immutable.data.resize(paddedSize); 142 id<MTLBuffer> metalBuffer = [device newBufferWithBytes:immutable.data.data() 143 length:immutable.data.size() 144 options:MTLResourceStorageModeShared]; 145 _immutableBuffers.emplace_back(metalBuffer); 146 } 147 _resizeFunction = desc->resize_function; 148 _program = program; 149 return OkStatus(); 150 } 151 152 - (Status)setInputDimensionsWithDevice:(id<MTLDevice>)device 153 dimensions: 154 (std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>*)dimensions { 155 // Re-calculate output buffers dimensions 156 for (auto& buffer : _outputBuffers) { 157 auto outputDimensions = buffer.dimensionsFunction(*dimensions); 158 for (ValueId duplicate : buffer.alias) { 159 (*dimensions)[duplicate] = outputDimensions; 160 } 161 // Store buffer dimensions 162 (*dimensions)[buffer.uid] = outputDimensions; 163 } 164 165 for (auto& uniform : _uniformBuffers) { 166 uniform.data = uniform.dataFunction(*dimensions); 167 } 168 169 // Dispatch parameters re-calculation 170 auto workGroups = _resizeFunction(*dimensions); 171 _groupsSize = workGroups.first; 172 MTLSize threadsPerGroup = [device maxThreadsPerThreadgroup]; 173 if (_groupsSize.x > threadsPerGroup.width || _groupsSize.y > threadsPerGroup.height || 174 _groupsSize.z > threadsPerGroup.depth) { 175 std::string error("Threads per working group: "); 176 error += std::to_string(_groupsSize.x) + ", " + std::to_string(_groupsSize.y) + ", " + 177 std::to_string(_groupsSize.z); 178 error += "is larger than the MTLDevice can support: "; 179 error += std::to_string(threadsPerGroup.width) + ", " + std::to_string(threadsPerGroup.height) + 180 ", " + std::to_string(threadsPerGroup.depth); 181 return InvalidArgumentError(error); 182 } 183 _groupsCount = workGroups.second; 184 return OkStatus(); 185 } 186 187 - (Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id<MTLBuffer>>*)buffers 188 outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds 189 usageRecordIds:(const std::map<ValueId, size_t>&)usageRecordIds 190 sharedBufferIds:(const std::vector<size_t>&)sharedBufferIds 191 sharedBuffers:(const std::vector<id<MTLBuffer>>&)sharedBuffers { 192 for (auto& buffer : _outputBuffers) { 193 // If the buffer is intermediate: set its metalHandle from sharedBuffers 194 if (std::find(outputIds.begin(), outputIds.end(), buffer.uid) == outputIds.end()) { 195 auto usageRecordIt = usageRecordIds.find(buffer.uid); 196 if (usageRecordIt == usageRecordIds.end()) { 197 return InternalError("TensorUsageRecord for intermediate tensor is not found."); 198 } 199 buffer.metalHandle = sharedBuffers.at(sharedBufferIds.at(usageRecordIt->second)); 200 (*buffers)[buffer.uid] = buffer.metalHandle; 201 } 202 } 203 204 // Re-assign input buffers 205 for (auto& buffer : _inputBuffers) { 206 buffer.metalHandle = (*buffers)[buffer.uid]; 207 } 208 return OkStatus(); 209 } 210 211 - (void)encodeWithEncoder:(id<MTLComputeCommandEncoder>)encoder 212 inputOutputBuffers:(const std::map<ValueId, id<MTLBuffer>>&)inputOutputBuffers { 213 // The dispatch call is intended to be skipped. 214 if (_groupsCount.x * _groupsCount.y * _groupsCount.z == 0) { 215 return; 216 } 217 218 [encoder setComputePipelineState:_program]; 219 220 int bindIndex = 0; 221 for (auto& buffer : _outputBuffers) { 222 const auto externalBuffer = inputOutputBuffers.find(buffer.uid); 223 if (externalBuffer == inputOutputBuffers.end()) { 224 [encoder setBuffer:buffer.metalHandle offset:0 atIndex:bindIndex]; 225 } else { 226 // the buffer is input or output 227 [encoder setBuffer:externalBuffer->second offset:0 atIndex:bindIndex]; 228 } 229 bindIndex++; 230 } 231 for (auto& buffer : _inputBuffers) { 232 const auto externalBuffer = inputOutputBuffers.find(buffer.uid); 233 if (externalBuffer == inputOutputBuffers.end()) { 234 [encoder setBuffer:buffer.metalHandle offset:0 atIndex:bindIndex]; 235 } else { 236 // the buffer is input or output 237 [encoder setBuffer:externalBuffer->second offset:0 atIndex:bindIndex]; 238 } 239 bindIndex++; 240 } 241 for (auto& immutable : _immutableBuffers) { 242 [encoder setBuffer:immutable offset:0 atIndex:bindIndex]; 243 bindIndex++; 244 } 245 for (auto& uniform : _uniformBuffers) { 246 [encoder setBytes:uniform.data.data() length:uniform.data.size() atIndex:bindIndex]; 247 bindIndex++; 248 } 249 250 MTLSize groupsCount = MTLSizeMake(_groupsCount.x, _groupsCount.y, _groupsCount.z); 251 MTLSize groupsSize = MTLSizeMake(_groupsSize.x, _groupsSize.y, _groupsSize.z); 252 [encoder dispatchThreadgroups:groupsCount threadsPerThreadgroup:groupsSize]; 253 } 254 255 @end 256