1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#import "tensorflow/lite/delegates/gpu/metal/buffer_convert.h" 17 18#import <Metal/Metal.h> 19 20#include <string> 21 22#include "tensorflow/lite/delegates/gpu/common/shape.h" 23#include "tensorflow/lite/delegates/gpu/common/util.h" 24#include "tensorflow/lite/delegates/gpu/metal/common.h" 25 26using ::tflite::gpu::BHWC; 27using ::tflite::gpu::DivideRoundUp; 28using ::tflite::gpu::metal::CreateComputeProgram; 29 30@implementation TFLBufferConvert { 31 id<MTLComputePipelineState> _program; 32} 33 34- (id)initWithDevice:(id<MTLDevice>)device 35 isFloat16:(bool)isFloat16 36 convertToPBHWC4:(bool)convertToPBHWC4 { 37 if (self = [super init]) { 38 std::string shaderSource; 39 if (convertToPBHWC4) { 40 shaderSource = R"( 41 #include <metal_stdlib> 42 using namespace metal; 43 kernel void ComputeFunction(device float* const input_buffer [[buffer(0)]], 44 device FLT4* output_buffer [[buffer(1)]], 45 constant int4& size [[buffer(2)]], 46 uint3 gid[[thread_position_in_grid]]) { 47 int linear_id = static_cast<int>(gid.x); 48 int X = linear_id / size.w; 49 int B = linear_id % size.w; 50 int Y = static_cast<int>(gid.y); 51 int S = static_cast<int>(gid.z); 52 if (X >= size.x || Y >= size.y) { 53 return; 54 } 55 FLT4 value = FLT4(0.0); 56 for (int i = 0; i < 4; i++) { 57 int channel = S * 4 + i; 58 if (channel >= size.z) break; 59 const int bhwc_index = ((B * size.y + Y) * size.x + X) * size.z + channel; 60 value[i] = input_buffer[bhwc_index]; 61 } 62 const int shwbc4_index = ((S * size.y + Y) * size.x + X) * size.w + B; 63 output_buffer[shwbc4_index] = value; 64 } 65 )"; 66 } else { 67 shaderSource = R"( 68 #include <metal_stdlib> 69 using namespace metal; 70 kernel void ComputeFunction(device FLT4* const input_buffer [[buffer(0)]], 71 device float* output_buffer [[buffer(1)]], 72 constant int4& size [[buffer(2)]], 73 uint3 gid[[thread_position_in_grid]]) { 74 int linear_id = static_cast<int>(gid.x); 75 int X = linear_id / size.w; 76 int B = linear_id % size.w; 77 int Y = static_cast<int>(gid.y); 78 int S = static_cast<int>(gid.z); 79 if (X >= size.x || Y >= size.y) { 80 return; 81 } 82 const int shwbc4_index = ((S * size.y + Y) * size.x + X) * size.w + B; 83 FLT4 value = input_buffer[shwbc4_index]; 84 for (int i = 0; i < 4; i++) { 85 int channel = S * 4 + i; 86 if (channel >= size.z) break; 87 const int bhwc_index = ((B * size.y + Y) * size.x + X) * size.z + channel; 88 output_buffer[bhwc_index] = value[i]; 89 } 90 } 91 )"; 92 } 93 const std::map<std::string, std::string> macros = {{"FLT4", isFloat16 ? "half4" : "float4"}}; 94 id<MTLComputePipelineState> program; 95 if (CreateComputeProgram(device, shaderSource, "ComputeFunction", macros, &program).ok()) { 96 _program = program; 97 return self; 98 } 99 } 100 return nil; 101} 102 103- (void)convertWithEncoder:(id<MTLComputeCommandEncoder>)encoder 104 shape:(const BHWC&)shape 105 sourceBuffer:(id<MTLBuffer>)sourceBuffer 106 convertedBuffer:(id<MTLBuffer>)convertedBuffer { 107 [encoder setComputePipelineState:_program]; 108 [encoder setBuffer:sourceBuffer offset:0 atIndex:0]; 109 [encoder setBuffer:convertedBuffer offset:0 atIndex:1]; 110 111 std::vector<int> uniforms = {shape.w, shape.h, shape.c, shape.b}; 112 [encoder setBytes:uniforms.data() length:uniforms.size() * sizeof(int) atIndex:2]; 113 114 MTLSize group_size = MTLSizeMake(16, 8, 1); 115 int slices = DivideRoundUp(shape.c, 4); 116 int groups_x = DivideRoundUp(shape.w * shape.b, group_size.width); 117 int groups_y = DivideRoundUp(shape.h, group_size.height); 118 int groups_z = DivideRoundUp(slices, group_size.depth); 119 MTLSize groups_count = MTLSizeMake(groups_x, groups_y, groups_z); 120 [encoder dispatchThreadgroups:groups_count threadsPerThreadgroup:group_size]; 121} 122 123@end 124