• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
17
18#include <Availability.h>
19#include <string>
20#include <tuple>
21
22#include "tensorflow/lite/delegates/gpu/common/model.h"
23#include "tensorflow/lite/delegates/gpu/common/shape.h"
24#include "tensorflow/lite/delegates/gpu/common/status.h"
25#include "tensorflow/lite/delegates/gpu/common/types.h"
26#include "tensorflow/lite/delegates/gpu/common/util.h"
27#include "tensorflow/lite/delegates/gpu/metal/common.h"
28#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
29
30using ::tflite::gpu::AlignByN;
31using ::tflite::gpu::BHWC;
32using ::tflite::gpu::InternalError;
33using ::tflite::gpu::InvalidArgumentError;
34using ::tflite::gpu::HalfBits;
35using ::tflite::gpu::metal::ComputeTaskDescriptorPtr;
36using ::tflite::gpu::metal::CreateComputeProgram;
37using ::tflite::gpu::metal::DispatchParamsFunction;
38using ::tflite::gpu::metal::OutputDimensions;
39using ::tflite::gpu::metal::RuntimeOptions;
40using ::tflite::gpu::metal::UniformsFunction;
41using ::tflite::gpu::OkStatus;
42using ::tflite::gpu::Status;
43using ::tflite::gpu::uint3;
44using ::tflite::gpu::ValueId;
45
46@implementation TFLComputeTask {
47  struct InputBuffer {
48    ValueId uid;
49    id<MTLBuffer> metalHandle;
50  };
51  struct OutputBuffer {
52    ValueId uid;
53    id<MTLBuffer> metalHandle;
54    OutputDimensions dimensionsFunction;
55    std::vector<ValueId> alias;
56  };
57  struct UniformBuffer {
58    std::vector<uint8_t> data;
59    UniformsFunction dataFunction;
60  };
61
62  id<MTLComputePipelineState> _program;
63  std::vector<InputBuffer> _inputBuffers;
64  std::vector<OutputBuffer> _outputBuffers;
65  std::vector<id<MTLBuffer>> _immutableBuffers;
66  std::vector<UniformBuffer> _uniformBuffers;
67  uint3 _groupsSize;
68  uint3 _groupsCount;
69  DispatchParamsFunction _resizeFunction;
70}
71
72- (Status)compileWithDevice:(id<MTLDevice>)device
73             taskDescriptor:(ComputeTaskDescriptorPtr)desc
74             runtimeOptions:(const RuntimeOptions&)options {
75  NSString* barrier;
76  // simdgroup_barrier is supported on macOS 10.13+ and Metal shading language version 2.0
77  if (@available(macOS 10.13, iOS 10.0, tvOS 10.0, *)) {
78    barrier = @"simdgroup_barrier";
79  } else {
80    barrier = @"threadgroup_barrier";
81  }
82  NSString* storageType;
83  NSString* accumulatorType;
84  NSString* toAccumulatorType = @"";
85  NSString* toAccumulatorType2 = @"";
86  NSString* toAccumulatorType3 = @"";
87  NSString* toAccumulatorType4 = @"";
88  if (options.storage_precision == RuntimeOptions::Precision::FP32) {
89    storageType = @"float";
90    accumulatorType = @"float";
91  } else {
92    // FP16
93    storageType = @"half";
94    if (options.accumulator_precision == RuntimeOptions::Precision::FP32) {
95      accumulatorType = @"float";
96      toAccumulatorType = @"float";
97      toAccumulatorType2 = @"float2";
98      toAccumulatorType3 = @"float3";
99      toAccumulatorType4 = @"float4";
100    } else {
101      accumulatorType = @"half";
102    }
103  }
104  NSDictionary<NSString*, NSString*>* macros = @{
105    @"FLT" : storageType,
106    @"FLT2" : [NSString stringWithFormat:@"%@2", storageType],
107    @"FLT3" : [NSString stringWithFormat:@"%@3", storageType],
108    @"FLT4" : [NSString stringWithFormat:@"%@4", storageType],
109    @"ACCUM_FLT" : accumulatorType,
110    @"ACCUM_FLT2" : [NSString stringWithFormat:@"%@2", accumulatorType],
111    @"ACCUM_FLT3" : [NSString stringWithFormat:@"%@3", accumulatorType],
112    @"ACCUM_FLT4" : [NSString stringWithFormat:@"%@4", accumulatorType],
113    @"TO_ACCUM_TYPE" : toAccumulatorType,
114    @"TO_ACCUM2_TYPE" : toAccumulatorType2,
115    @"TO_ACCUM3_TYPE" : toAccumulatorType3,
116    @"TO_ACCUM4_TYPE" : toAccumulatorType4,
117    @"BARRIER" : barrier,
118  };
119
120  NSString* code = [NSString stringWithCString:desc->shader_source.c_str()
121                                      encoding:[NSString defaultCStringEncoding]];
122  id<MTLComputePipelineState> program;
123  RETURN_IF_ERROR(CreateComputeProgram(device, code, @"ComputeFunction", macros, &program));
124  if (!program) {
125    return InternalError("Unknown shader compilation error");
126  }
127  for (auto& buffer : desc->input_buffers) {
128    _inputBuffers.emplace_back(InputBuffer{buffer.id, nil});
129  }
130  for (auto& uniform : desc->uniform_buffers) {
131    _uniformBuffers.emplace_back(UniformBuffer{{}, uniform.data_function});
132  }
133  _outputBuffers.emplace_back(OutputBuffer{desc->output_buffer.id, nil,
134                                           desc->output_buffer.dimensions_function,
135                                           desc->output_buffer.alias});
136  for (auto& immutable : desc->immutable_buffers) {
137    int padding =
138        4 * (options.storage_precision == RuntimeOptions::Precision::FP32 ? sizeof(float)
139                                                                          : sizeof(HalfBits));
140    int paddedSize = AlignByN(immutable.data.size(), padding);
141    immutable.data.resize(paddedSize);
142    id<MTLBuffer> metalBuffer = [device newBufferWithBytes:immutable.data.data()
143                                                    length:immutable.data.size()
144                                                   options:MTLResourceStorageModeShared];
145    _immutableBuffers.emplace_back(metalBuffer);
146  }
147  _resizeFunction = desc->resize_function;
148  _program = program;
149  return OkStatus();
150}
151
152- (Status)setInputDimensionsWithDevice:(id<MTLDevice>)device
153                            dimensions:
154                                (std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>*)dimensions {
155  // Re-calculate output buffers dimensions
156  for (auto& buffer : _outputBuffers) {
157    auto outputDimensions = buffer.dimensionsFunction(*dimensions);
158    for (ValueId duplicate : buffer.alias) {
159      (*dimensions)[duplicate] = outputDimensions;
160    }
161    // Store buffer dimensions
162    (*dimensions)[buffer.uid] = outputDimensions;
163  }
164
165  for (auto& uniform : _uniformBuffers) {
166    uniform.data = uniform.dataFunction(*dimensions);
167  }
168
169  // Dispatch parameters re-calculation
170  auto workGroups = _resizeFunction(*dimensions);
171  _groupsSize = workGroups.first;
172  MTLSize threadsPerGroup = [device maxThreadsPerThreadgroup];
173  if (_groupsSize.x > threadsPerGroup.width || _groupsSize.y > threadsPerGroup.height ||
174      _groupsSize.z > threadsPerGroup.depth) {
175    std::string error("Threads per working group: ");
176    error += std::to_string(_groupsSize.x) + ", " + std::to_string(_groupsSize.y) + ", " +
177             std::to_string(_groupsSize.z);
178    error += "is larger than the MTLDevice can support: ";
179    error += std::to_string(threadsPerGroup.width) + ", " + std::to_string(threadsPerGroup.height) +
180             ", " + std::to_string(threadsPerGroup.depth);
181    return InvalidArgumentError(error);
182  }
183  _groupsCount = workGroups.second;
184  return OkStatus();
185}
186
187- (Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id<MTLBuffer>>*)buffers
188              outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds
189         usageRecordIds:(const std::map<ValueId, size_t>&)usageRecordIds
190        sharedBufferIds:(const std::vector<size_t>&)sharedBufferIds
191          sharedBuffers:(const std::vector<id<MTLBuffer>>&)sharedBuffers {
192  for (auto& buffer : _outputBuffers) {
193    // If the buffer is intermediate: set its metalHandle from sharedBuffers
194    if (std::find(outputIds.begin(), outputIds.end(), buffer.uid) == outputIds.end()) {
195      auto usageRecordIt = usageRecordIds.find(buffer.uid);
196      if (usageRecordIt == usageRecordIds.end()) {
197        return InternalError("TensorUsageRecord for intermediate tensor is not found.");
198      }
199      buffer.metalHandle = sharedBuffers.at(sharedBufferIds.at(usageRecordIt->second));
200      (*buffers)[buffer.uid] = buffer.metalHandle;
201    }
202  }
203
204  // Re-assign input buffers
205  for (auto& buffer : _inputBuffers) {
206    buffer.metalHandle = (*buffers)[buffer.uid];
207  }
208  return OkStatus();
209}
210
211- (void)encodeWithEncoder:(id<MTLComputeCommandEncoder>)encoder
212       inputOutputBuffers:(const std::map<ValueId, id<MTLBuffer>>&)inputOutputBuffers {
213  // The dispatch call is intended to be skipped.
214  if (_groupsCount.x * _groupsCount.y * _groupsCount.z == 0) {
215    return;
216  }
217
218  [encoder setComputePipelineState:_program];
219
220  int bindIndex = 0;
221  for (auto& buffer : _outputBuffers) {
222    const auto externalBuffer = inputOutputBuffers.find(buffer.uid);
223    if (externalBuffer == inputOutputBuffers.end()) {
224      [encoder setBuffer:buffer.metalHandle offset:0 atIndex:bindIndex];
225    } else {
226      // the buffer is input or output
227      [encoder setBuffer:externalBuffer->second offset:0 atIndex:bindIndex];
228    }
229    bindIndex++;
230  }
231  for (auto& buffer : _inputBuffers) {
232    const auto externalBuffer = inputOutputBuffers.find(buffer.uid);
233    if (externalBuffer == inputOutputBuffers.end()) {
234      [encoder setBuffer:buffer.metalHandle offset:0 atIndex:bindIndex];
235    } else {
236      // the buffer is input or output
237      [encoder setBuffer:externalBuffer->second offset:0 atIndex:bindIndex];
238    }
239    bindIndex++;
240  }
241  for (auto& immutable : _immutableBuffers) {
242    [encoder setBuffer:immutable offset:0 atIndex:bindIndex];
243    bindIndex++;
244  }
245  for (auto& uniform : _uniformBuffers) {
246    [encoder setBytes:uniform.data.data() length:uniform.data.size() atIndex:bindIndex];
247    bindIndex++;
248  }
249
250  MTLSize groupsCount = MTLSizeMake(_groupsCount.x, _groupsCount.y, _groupsCount.z);
251  MTLSize groupsSize = MTLSizeMake(_groupsSize.x, _groupsSize.y, _groupsSize.z);
252  [encoder dispatchThreadgroups:groupsCount threadsPerThreadgroup:groupsSize];
253}
254
255@end
256