1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/lite/delegates/gpu/metal/common.h" 17#include "absl/strings/match.h" 18 19#import <Metal/Metal.h> 20 21#include <Availability.h> 22#include <map> 23#include <string> 24#include <utility> 25#include <vector> 26 27#include "tensorflow/lite/delegates/gpu/common/status.h" 28 29// Compile-time message: print define name and value. 30#define VALUE_TO_STRING(x) #x 31#define VALUE(x) VALUE_TO_STRING(x) 32#define VAR_NAME_VALUE(var) #var "=" VALUE(var) 33 34namespace tflite { 35namespace gpu { 36namespace metal { 37 38id<MTLDevice> GetBestSupportedMetalDevice() { return MTLCreateSystemDefaultDevice(); } 39 40absl::Status CreateComputeProgram(id<MTLDevice> device, const std::string& code, 41 const std::string& function_name, 42 const std::map<std::string, std::string>& macros, 43 id<MTLComputePipelineState>* program) { 44 id<MTLFunction> function; 45 RETURN_IF_ERROR(CreateFunction(device, code, function_name, macros, &function)); 46 47 NSError* error = nil; 48 *program = [device newComputePipelineStateWithFunction:function error:&error]; 49 if (!*program) { 50 NSString* errorString = 51 [NSString stringWithFormat:@"newComputePipelineStateWithFunction error: %@", 52 [error localizedDescription]]; 53 return absl::InternalError([errorString UTF8String]); 54 } 55 return absl::OkStatus(); 56} 57 58absl::Status CreateComputeProgramWithArgumentBuffer( 59 id<MTLDevice> device, const std::string& code, const std::string& function_name, 60 const std::map<std::string, std::string>& macros, id<MTLComputePipelineState>* program, 61 id<MTLArgumentEncoder>* arguments_encoder) { 62 if (@available(macOS 10.13, iOS 11.0, tvOS 11.0, *)) { 63 id<MTLFunction> function; 64 RETURN_IF_ERROR(CreateFunction(device, code, "ComputeFunction", macros, &function)); 65 *arguments_encoder = [function newArgumentEncoderWithBufferIndex:0]; 66 if (!*arguments_encoder) { 67 return absl::InternalError("Failed to get MTLArgumentEncoder."); 68 } 69 MTLComputePipelineDescriptor* pipeline_desc = [[MTLComputePipelineDescriptor alloc] init]; 70 pipeline_desc.computeFunction = function; 71 NSError* error = nil; 72 *program = [device newComputePipelineStateWithDescriptor:pipeline_desc 73 options:MTLPipelineOptionNone 74 reflection:nullptr 75 error:&error]; 76 if (!*program) { 77 NSString* error_string = 78 [NSString stringWithFormat:@"newComputePipelineStateWithDescriptor: %@", 79 [error localizedDescription]]; 80 return absl::InternalError([error_string UTF8String]); 81 } 82 return absl::OkStatus(); 83 } else { 84 return absl::InternalError("Metal argument buffers available since ios 11, tvos 11 or macos " 85 "10.13."); 86 } 87} 88 89absl::Status CreateComputeProgramWithICBSupport(id<MTLDevice> device, const std::string& code, 90 const std::string& function_name, 91 const std::map<std::string, std::string>& macros, 92 id<MTLComputePipelineState>* program, 93 id<MTLArgumentEncoder>* arguments_encoder) { 94 if (@available(macOS 11.00, iOS 13.0, tvOS 13.0, *)) { 95 id<MTLFunction> function; 96 RETURN_IF_ERROR(CreateFunction(device, code, "ComputeFunction", macros, &function)); 97 *arguments_encoder = [function newArgumentEncoderWithBufferIndex:0]; 98 if (!*arguments_encoder) { 99 return absl::InternalError("Failed to get MTLArgumentEncoder."); 100 } 101 MTLComputePipelineDescriptor* pipeline_desc = [[MTLComputePipelineDescriptor alloc] init]; 102 pipeline_desc.computeFunction = function; 103 pipeline_desc.supportIndirectCommandBuffers = TRUE; 104 NSError* error = nil; 105 *program = [device newComputePipelineStateWithDescriptor:pipeline_desc 106 options:MTLPipelineOptionNone 107 reflection:nullptr 108 error:&error]; 109 if (!*program) { 110 NSString* error_string = 111 [NSString stringWithFormat:@"newComputePipelineStateWithDescriptor: %@", 112 [error localizedDescription]]; 113 return absl::InternalError([error_string UTF8String]); 114 } 115 return absl::OkStatus(); 116 } else { 117 return absl::InternalError("Indirect compute command buffer available since ios 13, tvos 13 " 118 "or macos 11.00"); 119 } 120} 121 122absl::Status CreateFunction(id<MTLDevice> device, const std::string& code, 123 const std::string& function_name, 124 const std::map<std::string, std::string>& macros, 125 id<MTLFunction>* function) { 126 MTLCompileOptions* options = [[MTLCompileOptions alloc] init]; 127 128 // Runtime checks for the iOS version independently of minimum target iOS. 129 if (@available(macOS 11.0, iOS 14.0, tvOS 14.0, *)) { 130 [options setLanguageVersion:MTLLanguageVersion2_3]; 131 } else if (@available(macOS 10.15, iOS 13.0, tvOS 13.0, *)) { 132 [options setLanguageVersion:MTLLanguageVersion2_2]; 133 } else if (@available(macOS 10.14, iOS 12.0, tvOS 12.0, *)) { 134 [options setLanguageVersion:MTLLanguageVersion2_1]; 135 } else if (@available(macOS 10.13, iOS 11.0, tvOS 11.0, *)) { 136 [options setLanguageVersion:MTLLanguageVersion2_0]; 137 } else if (@available(macOS 10.12, iOS 10.0, tvOS 10.0, *)) { 138 [options setLanguageVersion:MTLLanguageVersion1_2]; 139 } else if (@available(macOS 10.11, iOS 9.0, tvOS 9.0, *)) { 140 [options setLanguageVersion:MTLLanguageVersion1_1]; 141 } 142#if (defined(__MAC_10_11) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_11) || \ 143 (defined(__IPHONE_9_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_9_0) || \ 144 (defined(__TVOS_9_0) && __TV_OS_VERSION_MIN_REQUIRED >= __TVOS_9_0) 145 // Minimum target OS version is able to support Metal. 146#else 147#pragma message(VAR_NAME_VALUE(__MAC_OS_X_VERSION_MIN_REQUIRED)) 148#pragma message(VAR_NAME_VALUE(__IPHONE_OS_VERSION_MIN_REQUIRED)) 149#pragma message(VAR_NAME_VALUE(__TV_OS_VERSION_MIN_REQUIRED)) 150// NOLINTBEGIN 151#error \ 152 "The Metal delegate is not supported on current target SDK. Minimum supported os: iOS/tvOS 9.0, macOS 10.11" 153// NOLINTEND 154#endif 155 156 NSMutableDictionary<NSString*, NSString*>* macros_dict = [NSMutableDictionary dictionary]; 157 for (const auto& pair : macros) { 158 std::string key = pair.first; 159 std::string value = pair.second; 160 if (absl::StrContains(key, ' ')) { 161 key = "\"" + key + "\""; 162 } 163 if (absl::StrContains(value, ' ')) { 164 value = "\"" + value + "\""; 165 } 166 [macros_dict setObject:[NSString stringWithCString:value.c_str() 167 encoding:[NSString defaultCStringEncoding]] 168 forKey:[NSString stringWithCString:key.c_str() 169 encoding:[NSString defaultCStringEncoding]]]; 170 } 171 172 [options setFastMathEnabled:YES]; 173 [options setPreprocessorMacros:macros_dict]; 174 NSError* error = nil; 175 NSString* code_ns = [NSString stringWithCString:code.c_str() 176 encoding:[NSString defaultCStringEncoding]]; 177 id<MTLLibrary> library = [device newLibraryWithSource:code_ns options:options error:&error]; 178 if (!library) { 179 NSString* errorString = 180 [NSString stringWithFormat:@"newLibraryWithSource: %@", [error localizedDescription]]; 181 return absl::InternalError([errorString UTF8String]); 182 } 183 184 NSString* function_name_ns = [NSString stringWithCString:function_name.c_str() 185 encoding:[NSString defaultCStringEncoding]]; 186 *function = [library newFunctionWithName:function_name_ns]; 187 if (!*function) { 188 NSString* errorString = 189 [NSString stringWithFormat:@"newFunctionWithName: %@", [error localizedDescription]]; 190 return absl::InternalError([errorString UTF8String]); 191 } 192 return absl::OkStatus(); 193} 194 195int PixelFormatToSizeInBytes(MTLPixelFormat pixel_format) { 196 if (pixel_format == MTLPixelFormatRGBA32Uint || 197 pixel_format == MTLPixelFormatRGBA32Sint || 198 pixel_format == MTLPixelFormatRGBA32Float) { 199 return 16; 200 } else if (pixel_format == MTLPixelFormatRGBA16Unorm || 201 pixel_format == MTLPixelFormatRGBA16Snorm || 202 pixel_format == MTLPixelFormatRGBA16Uint || 203 pixel_format == MTLPixelFormatRGBA16Sint || 204 pixel_format == MTLPixelFormatRGBA16Float) { 205 return 8; 206 } else if (pixel_format == MTLPixelFormatRGBA8Unorm || 207 pixel_format == MTLPixelFormatRGBA8Snorm || 208 pixel_format == MTLPixelFormatRGBA8Uint || 209 pixel_format == MTLPixelFormatRGBA8Sint) { 210 return 4; 211 } 212 return -1; 213} 214 215MTLPixelFormat DataTypeToRGBAPixelFormat(DataType type, bool normalized) { 216 switch (type) { 217 case DataType::FLOAT32: 218 return MTLPixelFormatRGBA32Float; 219 case DataType::FLOAT16: 220 return MTLPixelFormatRGBA16Float; 221 case DataType::INT8: 222 return normalized ? MTLPixelFormatRGBA8Snorm : MTLPixelFormatRGBA8Sint; 223 case DataType::UINT8: 224 return normalized ? MTLPixelFormatRGBA8Unorm : MTLPixelFormatRGBA8Uint; 225 case DataType::INT16: 226 return normalized ? MTLPixelFormatRGBA16Snorm : MTLPixelFormatRGBA16Sint; 227 case DataType::UINT16: 228 return normalized ? MTLPixelFormatRGBA16Unorm : MTLPixelFormatRGBA16Uint; 229 case DataType::INT32: 230 return MTLPixelFormatRGBA32Sint; 231 case DataType::UINT32: 232 return MTLPixelFormatRGBA32Uint; 233 case DataType::BOOL: 234 return MTLPixelFormatRGBA8Uint; 235 default: 236 return MTLPixelFormatInvalid; 237 } 238} 239 240void WriteDataToTexture2D(id<MTLTexture> texture, id<MTLDevice> device, const void* data) { 241 const int pixel_size = PixelFormatToSizeInBytes(texture.pixelFormat); 242 id<MTLBuffer> temp_buffer = [device newBufferWithBytes:data 243 length:pixel_size * texture.width * texture.height 244 options:MTLResourceStorageModeShared]; 245 246 id<MTLCommandQueue> command_queue = [device newCommandQueue]; 247 id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer]; 248 249 id<MTLBlitCommandEncoder> blitCommandEncoder = [command_buffer blitCommandEncoder]; 250 [blitCommandEncoder copyFromBuffer:temp_buffer 251 sourceOffset:0 252 sourceBytesPerRow:pixel_size * texture.width 253 sourceBytesPerImage:pixel_size * texture.width * texture.height 254 sourceSize:MTLSizeMake(texture.width, texture.height, 1) 255 toTexture:texture 256 destinationSlice:0 257 destinationLevel:0 258 destinationOrigin:MTLOriginMake(0, 0, 0)]; 259 [blitCommandEncoder endEncoding]; 260 261 [command_buffer commit]; 262 [command_buffer waitUntilCompleted]; 263} 264 265void ReadDataFromTexture2D(id<MTLTexture> texture, id<MTLDevice> device, void* data) { 266 const int pixel_size = PixelFormatToSizeInBytes(texture.pixelFormat); 267 const int buffer_size = pixel_size * texture.width * texture.height; 268 id<MTLBuffer> temp_buffer = [device newBufferWithLength:buffer_size 269 options:MTLResourceStorageModeShared]; 270 271 id<MTLCommandQueue> command_queue = [device newCommandQueue]; 272 id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer]; 273 274 id<MTLBlitCommandEncoder> blitCommandEncoder = [command_buffer blitCommandEncoder]; 275 [blitCommandEncoder copyFromTexture:texture 276 sourceSlice:0 277 sourceLevel:0 278 sourceOrigin:MTLOriginMake(0, 0, 0) 279 sourceSize:MTLSizeMake(texture.width, texture.height, 1) 280 toBuffer:temp_buffer 281 destinationOffset:0 282 destinationBytesPerRow:pixel_size * texture.width 283 destinationBytesPerImage:pixel_size * texture.width * texture.height]; 284 [blitCommandEncoder endEncoding]; 285 286 [command_buffer commit]; 287 [command_buffer waitUntilCompleted]; 288 std::memcpy(data, [temp_buffer contents], buffer_size); 289} 290 291void WriteDataToTexture3D(id<MTLTexture> texture, id<MTLDevice> device, const void* data) { 292 const int pixel_size = PixelFormatToSizeInBytes(texture.pixelFormat); 293 id<MTLBuffer> temp_buffer = 294 [device newBufferWithBytes:data 295 length:pixel_size * texture.width * texture.height * texture.depth 296 options:MTLResourceStorageModeShared]; 297 298 id<MTLCommandQueue> command_queue = [device newCommandQueue]; 299 id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer]; 300 301 id<MTLBlitCommandEncoder> blitCommandEncoder = [command_buffer blitCommandEncoder]; 302 [blitCommandEncoder copyFromBuffer:temp_buffer 303 sourceOffset:0 304 sourceBytesPerRow:pixel_size * texture.width 305 sourceBytesPerImage:pixel_size * texture.width * texture.height 306 sourceSize:MTLSizeMake(texture.width, texture.height, texture.depth) 307 toTexture:texture 308 destinationSlice:0 309 destinationLevel:0 310 destinationOrigin:MTLOriginMake(0, 0, 0)]; 311 [blitCommandEncoder endEncoding]; 312 313 [command_buffer commit]; 314 [command_buffer waitUntilCompleted]; 315} 316 317void ReadDataFromTexture3D(id<MTLTexture> texture, id<MTLDevice> device, void* data) { 318 const int pixel_size = PixelFormatToSizeInBytes(texture.pixelFormat); 319 const int buffer_size = pixel_size * texture.width * texture.height * texture.depth; 320 id<MTLBuffer> temp_buffer = [device newBufferWithLength:buffer_size 321 options:MTLResourceStorageModeShared]; 322 323 id<MTLCommandQueue> command_queue = [device newCommandQueue]; 324 id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer]; 325 326 id<MTLBlitCommandEncoder> blitCommandEncoder = [command_buffer blitCommandEncoder]; 327 [blitCommandEncoder copyFromTexture:texture 328 sourceSlice:0 329 sourceLevel:0 330 sourceOrigin:MTLOriginMake(0, 0, 0) 331 sourceSize:MTLSizeMake(texture.width, texture.height, texture.depth) 332 toBuffer:temp_buffer 333 destinationOffset:0 334 destinationBytesPerRow:pixel_size * texture.width 335 destinationBytesPerImage:pixel_size * texture.width * texture.height]; 336 [blitCommandEncoder endEncoding]; 337 338 [command_buffer commit]; 339 [command_buffer waitUntilCompleted]; 340 std::memcpy(data, [temp_buffer contents], buffer_size); 341} 342 343void WriteDataToTexture2DArray(id<MTLTexture> texture, id<MTLDevice> device, const void* data) { 344 const int pixel_size = PixelFormatToSizeInBytes(texture.pixelFormat); 345 id<MTLBuffer> temp_buffer = 346 [device newBufferWithBytes:data 347 length:pixel_size * texture.width * texture.height * texture.arrayLength 348 options:MTLResourceStorageModeShared]; 349 350 id<MTLCommandQueue> command_queue = [device newCommandQueue]; 351 id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer]; 352 353 for (int i = 0; i < texture.arrayLength; ++i) { 354 id<MTLBlitCommandEncoder> blitCommandEncoder = [command_buffer blitCommandEncoder]; 355 [blitCommandEncoder copyFromBuffer:temp_buffer 356 sourceOffset:pixel_size * texture.width * texture.height * i 357 sourceBytesPerRow:pixel_size * texture.width 358 sourceBytesPerImage:pixel_size * texture.width * texture.height 359 sourceSize:MTLSizeMake(texture.width, texture.height, 1) 360 toTexture:texture 361 destinationSlice:i 362 destinationLevel:0 363 destinationOrigin:MTLOriginMake(0, 0, 0)]; 364 [blitCommandEncoder endEncoding]; 365 } 366 367 [command_buffer commit]; 368 [command_buffer waitUntilCompleted]; 369} 370 371void ReadDataFromTexture2DArray(id<MTLTexture> texture, id<MTLDevice> device, void* data) { 372 const int pixel_size = PixelFormatToSizeInBytes(texture.pixelFormat); 373 const int buffer_size = pixel_size * texture.width * texture.height * texture.arrayLength; 374 id<MTLBuffer> temp_buffer = [device newBufferWithLength:buffer_size 375 options:MTLResourceStorageModeShared]; 376 377 id<MTLCommandQueue> command_queue = [device newCommandQueue]; 378 id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer]; 379 380 for (int i = 0; i < texture.arrayLength; ++i) { 381 id<MTLBlitCommandEncoder> blitCommandEncoder = [command_buffer blitCommandEncoder]; 382 [blitCommandEncoder copyFromTexture:texture 383 sourceSlice:i 384 sourceLevel:0 385 sourceOrigin:MTLOriginMake(0, 0, 0) 386 sourceSize:MTLSizeMake(texture.width, texture.height, 1) 387 toBuffer:temp_buffer 388 destinationOffset:pixel_size * texture.width * texture.height * i 389 destinationBytesPerRow:pixel_size * texture.width 390 destinationBytesPerImage:pixel_size * texture.width * texture.height]; 391 [blitCommandEncoder endEncoding]; 392 } 393 394 [command_buffer commit]; 395 [command_buffer waitUntilCompleted]; 396 std::memcpy(data, [temp_buffer contents], buffer_size); 397} 398 399} // namespace metal 400} // namespace gpu 401} // namespace tflite 402