1#import <ATen/native/metal/MetalCommandBuffer.h> 2#import <ATen/native/metal/MetalContext.h> 3#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 4 5NSString* thread_local_storage_key = @"PTMetalCommandBuffer"; 6@implementation MetalCommandBuffer { 7 NSMutableArray* _images; 8 NSMutableSet<id<PTMetalCommandBuffer>>* _delegates; 9} 10 11+ (MetalCommandBuffer*)newBuffer { 12 MetalCommandBuffer* cb = [MetalCommandBuffer new]; 13 cb->_buffer = [[MetalContext sharedInstance].commandQueue commandBuffer]; 14 cb->_images = [NSMutableArray new]; 15 cb->_delegates = [NSMutableSet new]; 16 return cb; 17} 18 19+ (MetalCommandBuffer*)currentBuffer { 20 NSThread* thd = [NSThread currentThread]; 21 thd.name = thread_local_storage_key; 22 NSMutableDictionary* dict = [thd threadDictionary]; 23 MetalCommandBuffer* cb = dict[thread_local_storage_key]; 24 if (!cb || !cb.valid) { 25 cb = [MetalCommandBuffer newBuffer]; 26 // The command buffer should only be retained by the thread-local storage. 27 dict[thread_local_storage_key] = cb; 28 } 29 return cb; 30} 31 32- (BOOL)valid { 33 return _buffer != nil && _buffer.status == 0; 34} 35 36- (void)addSubscriber:(id<PTMetalCommandBuffer>)subscriber { 37 if (subscriber) { 38 [_delegates addObject:subscriber]; 39 } 40} 41- (void)removeSubscriber:(id<PTMetalCommandBuffer>)subscriber { 42 if (subscriber) { 43 [_delegates removeObject:subscriber]; 44 } 45} 46 47- (void)add:(MPSTemporaryImage*)image { 48 if (![image isTemporaryImage]) { 49 return; 50 } 51 [_images addObject:image]; 52} 53 54- (void)remove:(MPSTemporaryImage*)image { 55 if (![image isTemporaryImage]) { 56 return; 57 } 58 [_images removeObject:image]; 59} 60 61- (void)commit { 62 if (_buffer.status == 0) { 63 [self beginSynchronization]; 64 [_buffer commit]; 65 [_buffer waitUntilCompleted]; 66 [self endSynchronization]; 67 } 68} 69 70- (void)beginSynchronization { 71 for (id<PTMetalCommandBuffer> delegate in _delegates) { 72 if ([delegate respondsToSelector:@selector(beginSynchronization)]) { 73 [delegate beginSynchronization]; 74 }; 75 } 76 // recycle all temporary images manually before flushing the command buffer 77 for (MPSTemporaryImage* image in _images) { 78 [image recycle]; 79 } 80} 81 82- (void)endSynchronization { 83 for (id<PTMetalCommandBuffer> delegate in _delegates) { 84 if ([delegate respondsToSelector:@selector(endSynchronization:)]) { 85 [delegate endSynchronization:_buffer.error]; 86 }; 87 } 88 [_delegates removeAllObjects]; 89 [_images removeAllObjects]; 90 _buffer = nil; 91 [[NSThread currentThread].threadDictionary 92 removeObjectForKey:thread_local_storage_key]; 93} 94 95- (BOOL)isEqual:(id)object { 96 if (![object isKindOfClass:[MetalCommandBuffer class]]) { 97 return NO; 98 } 99 MetalCommandBuffer* mc = (MetalCommandBuffer*)object; 100 return _buffer == mc.buffer; 101} 102 103@end 104