• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#import <ATen/native/metal/MetalCommandBuffer.h>
2#import <ATen/native/metal/MetalContext.h>
3#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
4
5NSString* thread_local_storage_key = @"PTMetalCommandBuffer";
6@implementation MetalCommandBuffer {
7  NSMutableArray* _images;
8  NSMutableSet<id<PTMetalCommandBuffer>>* _delegates;
9}
10
11+ (MetalCommandBuffer*)newBuffer {
12  MetalCommandBuffer* cb = [MetalCommandBuffer new];
13  cb->_buffer = [[MetalContext sharedInstance].commandQueue commandBuffer];
14  cb->_images = [NSMutableArray new];
15  cb->_delegates = [NSMutableSet new];
16  return cb;
17}
18
19+ (MetalCommandBuffer*)currentBuffer {
20  NSThread* thd = [NSThread currentThread];
21  thd.name = thread_local_storage_key;
22  NSMutableDictionary* dict = [thd threadDictionary];
23  MetalCommandBuffer* cb = dict[thread_local_storage_key];
24  if (!cb || !cb.valid) {
25    cb = [MetalCommandBuffer newBuffer];
26    // The command buffer should only be retained by the thread-local storage.
27    dict[thread_local_storage_key] = cb;
28  }
29  return cb;
30}
31
32- (BOOL)valid {
33  return _buffer != nil && _buffer.status == 0;
34}
35
36- (void)addSubscriber:(id<PTMetalCommandBuffer>)subscriber {
37  if (subscriber) {
38    [_delegates addObject:subscriber];
39  }
40}
41- (void)removeSubscriber:(id<PTMetalCommandBuffer>)subscriber {
42  if (subscriber) {
43    [_delegates removeObject:subscriber];
44  }
45}
46
47- (void)add:(MPSTemporaryImage*)image {
48  if (![image isTemporaryImage]) {
49    return;
50  }
51  [_images addObject:image];
52}
53
54- (void)remove:(MPSTemporaryImage*)image {
55  if (![image isTemporaryImage]) {
56    return;
57  }
58  [_images removeObject:image];
59}
60
61- (void)commit {
62  if (_buffer.status == 0) {
63    [self beginSynchronization];
64    [_buffer commit];
65    [_buffer waitUntilCompleted];
66    [self endSynchronization];
67  }
68}
69
70- (void)beginSynchronization {
71  for (id<PTMetalCommandBuffer> delegate in _delegates) {
72    if ([delegate respondsToSelector:@selector(beginSynchronization)]) {
73      [delegate beginSynchronization];
74    };
75  }
76  // recycle all temporary images manually before flushing the command buffer
77  for (MPSTemporaryImage* image in _images) {
78    [image recycle];
79  }
80}
81
82- (void)endSynchronization {
83  for (id<PTMetalCommandBuffer> delegate in _delegates) {
84    if ([delegate respondsToSelector:@selector(endSynchronization:)]) {
85      [delegate endSynchronization:_buffer.error];
86    };
87  }
88  [_delegates removeAllObjects];
89  [_images removeAllObjects];
90  _buffer = nil;
91  [[NSThread currentThread].threadDictionary
92      removeObjectForKey:thread_local_storage_key];
93}
94
95- (BOOL)isEqual:(id)object {
96  if (![object isKindOfClass:[MetalCommandBuffer class]]) {
97    return NO;
98  }
99  MetalCommandBuffer* mc = (MetalCommandBuffer*)object;
100  return _buffer == mc.buffer;
101}
102
103@end
104