• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/lite/delegates/gpu/metal/common.h"
17#include "absl/strings/match.h"
18
19#import <Metal/Metal.h>
20
21#include <Availability.h>
22#include <map>
23#include <string>
24#include <utility>
25#include <vector>
26
27#include "tensorflow/lite/delegates/gpu/common/status.h"
28
29// Compile-time message: print define name and value.
30#define VALUE_TO_STRING(x) #x
31#define VALUE(x) VALUE_TO_STRING(x)
32#define VAR_NAME_VALUE(var) #var "=" VALUE(var)
33
34namespace tflite {
35namespace gpu {
36namespace metal {
37
38id<MTLDevice> GetBestSupportedMetalDevice() { return MTLCreateSystemDefaultDevice(); }
39
40absl::Status CreateComputeProgram(id<MTLDevice> device, const std::string& code,
41                                  const std::string& function_name,
42                                  const std::map<std::string, std::string>& macros,
43                                  id<MTLComputePipelineState>* program) {
44  id<MTLFunction> function;
45  RETURN_IF_ERROR(CreateFunction(device, code, function_name, macros, &function));
46
47  NSError* error = nil;
48  *program = [device newComputePipelineStateWithFunction:function error:&error];
49  if (!*program) {
50    NSString* errorString =
51        [NSString stringWithFormat:@"newComputePipelineStateWithFunction error: %@",
52                                   [error localizedDescription]];
53    return absl::InternalError([errorString UTF8String]);
54  }
55  return absl::OkStatus();
56}
57
58absl::Status CreateComputeProgramWithArgumentBuffer(
59    id<MTLDevice> device, const std::string& code, const std::string& function_name,
60    const std::map<std::string, std::string>& macros, id<MTLComputePipelineState>* program,
61    id<MTLArgumentEncoder>* arguments_encoder) {
62  if (@available(macOS 10.13, iOS 11.0, tvOS 11.0, *)) {
63    id<MTLFunction> function;
64    RETURN_IF_ERROR(CreateFunction(device, code, "ComputeFunction", macros, &function));
65    *arguments_encoder = [function newArgumentEncoderWithBufferIndex:0];
66    if (!*arguments_encoder) {
67      return absl::InternalError("Failed to get MTLArgumentEncoder.");
68    }
69    MTLComputePipelineDescriptor* pipeline_desc = [[MTLComputePipelineDescriptor alloc] init];
70    pipeline_desc.computeFunction = function;
71    NSError* error = nil;
72    *program = [device newComputePipelineStateWithDescriptor:pipeline_desc
73                                                     options:MTLPipelineOptionNone
74                                                  reflection:nullptr
75                                                       error:&error];
76    if (!*program) {
77      NSString* error_string =
78          [NSString stringWithFormat:@"newComputePipelineStateWithDescriptor: %@",
79                                     [error localizedDescription]];
80      return absl::InternalError([error_string UTF8String]);
81    }
82    return absl::OkStatus();
83  } else {
84    return absl::InternalError("Metal argument buffers available since ios 11, tvos 11 or macos "
85                               "10.13.");
86  }
87}
88
89absl::Status CreateComputeProgramWithICBSupport(id<MTLDevice> device, const std::string& code,
90                                                const std::string& function_name,
91                                                const std::map<std::string, std::string>& macros,
92                                                id<MTLComputePipelineState>* program,
93                                                id<MTLArgumentEncoder>* arguments_encoder) {
94  if (@available(macOS 11.00, iOS 13.0, tvOS 13.0, *)) {
95    id<MTLFunction> function;
96    RETURN_IF_ERROR(CreateFunction(device, code, "ComputeFunction", macros, &function));
97    *arguments_encoder = [function newArgumentEncoderWithBufferIndex:0];
98    if (!*arguments_encoder) {
99      return absl::InternalError("Failed to get MTLArgumentEncoder.");
100    }
101    MTLComputePipelineDescriptor* pipeline_desc = [[MTLComputePipelineDescriptor alloc] init];
102    pipeline_desc.computeFunction = function;
103    pipeline_desc.supportIndirectCommandBuffers = TRUE;
104    NSError* error = nil;
105    *program = [device newComputePipelineStateWithDescriptor:pipeline_desc
106                                                     options:MTLPipelineOptionNone
107                                                  reflection:nullptr
108                                                       error:&error];
109    if (!*program) {
110      NSString* error_string =
111          [NSString stringWithFormat:@"newComputePipelineStateWithDescriptor: %@",
112                                     [error localizedDescription]];
113      return absl::InternalError([error_string UTF8String]);
114    }
115    return absl::OkStatus();
116  } else {
117    return absl::InternalError("Indirect compute command buffer available since ios 13, tvos 13 "
118                               "or macos 11.00");
119  }
120}
121
122absl::Status CreateFunction(id<MTLDevice> device, const std::string& code,
123                            const std::string& function_name,
124                            const std::map<std::string, std::string>& macros,
125                            id<MTLFunction>* function) {
126  MTLCompileOptions* options = [[MTLCompileOptions alloc] init];
127
128  // Runtime checks for the iOS version independently of minimum target iOS.
129  if (@available(macOS 11.0, iOS 14.0, tvOS 14.0, *)) {
130    [options setLanguageVersion:MTLLanguageVersion2_3];
131  } else if (@available(macOS 10.15, iOS 13.0, tvOS 13.0, *)) {
132    [options setLanguageVersion:MTLLanguageVersion2_2];
133  } else if (@available(macOS 10.14, iOS 12.0, tvOS 12.0, *)) {
134    [options setLanguageVersion:MTLLanguageVersion2_1];
135  } else if (@available(macOS 10.13, iOS 11.0, tvOS 11.0, *)) {
136    [options setLanguageVersion:MTLLanguageVersion2_0];
137  } else if (@available(macOS 10.12, iOS 10.0, tvOS 10.0, *)) {
138    [options setLanguageVersion:MTLLanguageVersion1_2];
139  } else if (@available(macOS 10.11, iOS 9.0, tvOS 9.0, *)) {
140    [options setLanguageVersion:MTLLanguageVersion1_1];
141  }
142#if (defined(__MAC_10_11) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_11) ||    \
143    (defined(__IPHONE_9_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_9_0) || \
144    (defined(__TVOS_9_0) && __TV_OS_VERSION_MIN_REQUIRED >= __TVOS_9_0)
145  // Minimum target OS version is able to support Metal.
146#else
147#pragma message(VAR_NAME_VALUE(__MAC_OS_X_VERSION_MIN_REQUIRED))
148#pragma message(VAR_NAME_VALUE(__IPHONE_OS_VERSION_MIN_REQUIRED))
149#pragma message(VAR_NAME_VALUE(__TV_OS_VERSION_MIN_REQUIRED))
150// NOLINTBEGIN
151#error \
152    "The Metal delegate is not supported on current target SDK. Minimum supported os: iOS/tvOS 9.0, macOS 10.11"
153// NOLINTEND
154#endif
155
156  NSMutableDictionary<NSString*, NSString*>* macros_dict = [NSMutableDictionary dictionary];
157  for (const auto& pair : macros) {
158    std::string key = pair.first;
159    std::string value = pair.second;
160    if (absl::StrContains(key, ' ')) {
161      key = "\"" + key + "\"";
162    }
163    if (absl::StrContains(value, ' ')) {
164      value = "\"" + value + "\"";
165    }
166    [macros_dict setObject:[NSString stringWithCString:value.c_str()
167                                              encoding:[NSString defaultCStringEncoding]]
168                    forKey:[NSString stringWithCString:key.c_str()
169                                              encoding:[NSString defaultCStringEncoding]]];
170  }
171
172  [options setFastMathEnabled:YES];
173  [options setPreprocessorMacros:macros_dict];
174  NSError* error = nil;
175  NSString* code_ns = [NSString stringWithCString:code.c_str()
176                                         encoding:[NSString defaultCStringEncoding]];
177  id<MTLLibrary> library = [device newLibraryWithSource:code_ns options:options error:&error];
178  if (!library) {
179    NSString* errorString =
180        [NSString stringWithFormat:@"newLibraryWithSource: %@", [error localizedDescription]];
181    return absl::InternalError([errorString UTF8String]);
182  }
183
184  NSString* function_name_ns = [NSString stringWithCString:function_name.c_str()
185                                                  encoding:[NSString defaultCStringEncoding]];
186  *function = [library newFunctionWithName:function_name_ns];
187  if (!*function) {
188    NSString* errorString =
189        [NSString stringWithFormat:@"newFunctionWithName: %@", [error localizedDescription]];
190    return absl::InternalError([errorString UTF8String]);
191  }
192  return absl::OkStatus();
193}
194
195int PixelFormatToSizeInBytes(MTLPixelFormat pixel_format) {
196  if (pixel_format == MTLPixelFormatRGBA32Uint ||
197      pixel_format == MTLPixelFormatRGBA32Sint ||
198      pixel_format == MTLPixelFormatRGBA32Float) {
199    return 16;
200  } else if (pixel_format == MTLPixelFormatRGBA16Unorm ||
201             pixel_format == MTLPixelFormatRGBA16Snorm ||
202             pixel_format == MTLPixelFormatRGBA16Uint ||
203             pixel_format == MTLPixelFormatRGBA16Sint ||
204             pixel_format == MTLPixelFormatRGBA16Float) {
205    return 8;
206  } else if (pixel_format == MTLPixelFormatRGBA8Unorm ||
207             pixel_format == MTLPixelFormatRGBA8Snorm ||
208             pixel_format == MTLPixelFormatRGBA8Uint ||
209             pixel_format == MTLPixelFormatRGBA8Sint) {
210    return 4;
211  }
212  return -1;
213}
214
215MTLPixelFormat DataTypeToRGBAPixelFormat(DataType type, bool normalized) {
216  switch (type) {
217    case DataType::FLOAT32:
218      return MTLPixelFormatRGBA32Float;
219    case DataType::FLOAT16:
220      return MTLPixelFormatRGBA16Float;
221    case DataType::INT8:
222      return normalized ? MTLPixelFormatRGBA8Snorm : MTLPixelFormatRGBA8Sint;
223    case DataType::UINT8:
224      return normalized ? MTLPixelFormatRGBA8Unorm : MTLPixelFormatRGBA8Uint;
225    case DataType::INT16:
226      return normalized ? MTLPixelFormatRGBA16Snorm : MTLPixelFormatRGBA16Sint;
227    case DataType::UINT16:
228      return normalized ? MTLPixelFormatRGBA16Unorm : MTLPixelFormatRGBA16Uint;
229    case DataType::INT32:
230      return MTLPixelFormatRGBA32Sint;
231    case DataType::UINT32:
232      return MTLPixelFormatRGBA32Uint;
233    case DataType::BOOL:
234      return MTLPixelFormatRGBA8Uint;
235    default:
236      return MTLPixelFormatInvalid;
237  }
238}
239
240void WriteDataToTexture2D(id<MTLTexture> texture, id<MTLDevice> device, const void* data) {
241  const int pixel_size = PixelFormatToSizeInBytes(texture.pixelFormat);
242  id<MTLBuffer> temp_buffer = [device newBufferWithBytes:data
243                                                  length:pixel_size * texture.width * texture.height
244                                                 options:MTLResourceStorageModeShared];
245
246  id<MTLCommandQueue> command_queue = [device newCommandQueue];
247  id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
248
249  id<MTLBlitCommandEncoder> blitCommandEncoder = [command_buffer blitCommandEncoder];
250  [blitCommandEncoder copyFromBuffer:temp_buffer
251                        sourceOffset:0
252                   sourceBytesPerRow:pixel_size * texture.width
253                 sourceBytesPerImage:pixel_size * texture.width * texture.height
254                          sourceSize:MTLSizeMake(texture.width, texture.height, 1)
255                           toTexture:texture
256                    destinationSlice:0
257                    destinationLevel:0
258                   destinationOrigin:MTLOriginMake(0, 0, 0)];
259  [blitCommandEncoder endEncoding];
260
261  [command_buffer commit];
262  [command_buffer waitUntilCompleted];
263}
264
265void ReadDataFromTexture2D(id<MTLTexture> texture, id<MTLDevice> device, void* data) {
266  const int pixel_size = PixelFormatToSizeInBytes(texture.pixelFormat);
267  const int buffer_size = pixel_size * texture.width * texture.height;
268  id<MTLBuffer> temp_buffer = [device newBufferWithLength:buffer_size
269                                                  options:MTLResourceStorageModeShared];
270
271  id<MTLCommandQueue> command_queue = [device newCommandQueue];
272  id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
273
274  id<MTLBlitCommandEncoder> blitCommandEncoder = [command_buffer blitCommandEncoder];
275  [blitCommandEncoder copyFromTexture:texture
276                          sourceSlice:0
277                          sourceLevel:0
278                         sourceOrigin:MTLOriginMake(0, 0, 0)
279                           sourceSize:MTLSizeMake(texture.width, texture.height, 1)
280                             toBuffer:temp_buffer
281                    destinationOffset:0
282               destinationBytesPerRow:pixel_size * texture.width
283             destinationBytesPerImage:pixel_size * texture.width * texture.height];
284  [blitCommandEncoder endEncoding];
285
286  [command_buffer commit];
287  [command_buffer waitUntilCompleted];
288  std::memcpy(data, [temp_buffer contents], buffer_size);
289}
290
291void WriteDataToTexture3D(id<MTLTexture> texture, id<MTLDevice> device, const void* data) {
292  const int pixel_size = PixelFormatToSizeInBytes(texture.pixelFormat);
293  id<MTLBuffer> temp_buffer =
294      [device newBufferWithBytes:data
295                          length:pixel_size * texture.width * texture.height * texture.depth
296                         options:MTLResourceStorageModeShared];
297
298  id<MTLCommandQueue> command_queue = [device newCommandQueue];
299  id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
300
301  id<MTLBlitCommandEncoder> blitCommandEncoder = [command_buffer blitCommandEncoder];
302  [blitCommandEncoder copyFromBuffer:temp_buffer
303                        sourceOffset:0
304                   sourceBytesPerRow:pixel_size * texture.width
305                 sourceBytesPerImage:pixel_size * texture.width * texture.height
306                          sourceSize:MTLSizeMake(texture.width, texture.height, texture.depth)
307                           toTexture:texture
308                    destinationSlice:0
309                    destinationLevel:0
310                   destinationOrigin:MTLOriginMake(0, 0, 0)];
311  [blitCommandEncoder endEncoding];
312
313  [command_buffer commit];
314  [command_buffer waitUntilCompleted];
315}
316
317void ReadDataFromTexture3D(id<MTLTexture> texture, id<MTLDevice> device, void* data) {
318  const int pixel_size = PixelFormatToSizeInBytes(texture.pixelFormat);
319  const int buffer_size = pixel_size * texture.width * texture.height * texture.depth;
320  id<MTLBuffer> temp_buffer = [device newBufferWithLength:buffer_size
321                                                  options:MTLResourceStorageModeShared];
322
323  id<MTLCommandQueue> command_queue = [device newCommandQueue];
324  id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
325
326  id<MTLBlitCommandEncoder> blitCommandEncoder = [command_buffer blitCommandEncoder];
327  [blitCommandEncoder copyFromTexture:texture
328                          sourceSlice:0
329                          sourceLevel:0
330                         sourceOrigin:MTLOriginMake(0, 0, 0)
331                           sourceSize:MTLSizeMake(texture.width, texture.height, texture.depth)
332                             toBuffer:temp_buffer
333                    destinationOffset:0
334               destinationBytesPerRow:pixel_size * texture.width
335             destinationBytesPerImage:pixel_size * texture.width * texture.height];
336  [blitCommandEncoder endEncoding];
337
338  [command_buffer commit];
339  [command_buffer waitUntilCompleted];
340  std::memcpy(data, [temp_buffer contents], buffer_size);
341}
342
343void WriteDataToTexture2DArray(id<MTLTexture> texture, id<MTLDevice> device, const void* data) {
344  const int pixel_size = PixelFormatToSizeInBytes(texture.pixelFormat);
345  id<MTLBuffer> temp_buffer =
346      [device newBufferWithBytes:data
347                          length:pixel_size * texture.width * texture.height * texture.arrayLength
348                         options:MTLResourceStorageModeShared];
349
350  id<MTLCommandQueue> command_queue = [device newCommandQueue];
351  id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
352
353  for (int i = 0; i < texture.arrayLength; ++i) {
354    id<MTLBlitCommandEncoder> blitCommandEncoder = [command_buffer blitCommandEncoder];
355    [blitCommandEncoder copyFromBuffer:temp_buffer
356                          sourceOffset:pixel_size * texture.width * texture.height * i
357                     sourceBytesPerRow:pixel_size * texture.width
358                   sourceBytesPerImage:pixel_size * texture.width * texture.height
359                            sourceSize:MTLSizeMake(texture.width, texture.height, 1)
360                             toTexture:texture
361                      destinationSlice:i
362                      destinationLevel:0
363                     destinationOrigin:MTLOriginMake(0, 0, 0)];
364    [blitCommandEncoder endEncoding];
365  }
366
367  [command_buffer commit];
368  [command_buffer waitUntilCompleted];
369}
370
371void ReadDataFromTexture2DArray(id<MTLTexture> texture, id<MTLDevice> device, void* data) {
372  const int pixel_size = PixelFormatToSizeInBytes(texture.pixelFormat);
373  const int buffer_size = pixel_size * texture.width * texture.height * texture.arrayLength;
374  id<MTLBuffer> temp_buffer = [device newBufferWithLength:buffer_size
375                                                  options:MTLResourceStorageModeShared];
376
377  id<MTLCommandQueue> command_queue = [device newCommandQueue];
378  id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
379
380  for (int i = 0; i < texture.arrayLength; ++i) {
381    id<MTLBlitCommandEncoder> blitCommandEncoder = [command_buffer blitCommandEncoder];
382    [blitCommandEncoder copyFromTexture:texture
383                            sourceSlice:i
384                            sourceLevel:0
385                           sourceOrigin:MTLOriginMake(0, 0, 0)
386                             sourceSize:MTLSizeMake(texture.width, texture.height, 1)
387                               toBuffer:temp_buffer
388                      destinationOffset:pixel_size * texture.width * texture.height * i
389                 destinationBytesPerRow:pixel_size * texture.width
390               destinationBytesPerImage:pixel_size * texture.width * texture.height];
391    [blitCommandEncoder endEncoding];
392  }
393
394  [command_buffer commit];
395  [command_buffer waitUntilCompleted];
396  std::memcpy(data, [temp_buffer contents], buffer_size);
397}
398
399}  // namespace metal
400}  // namespace gpu
401}  // namespace tflite
402