1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
17
18 #include <Availability.h>
19
20 #include <map>
21 #include <string>
22 #include <tuple>
23
24 #include "absl/strings/match.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/lite/delegates/gpu/common/kernel_info.h"
27 #include "tensorflow/lite/delegates/gpu/common/shape.h"
28 #include "tensorflow/lite/delegates/gpu/common/status.h"
29 #include "tensorflow/lite/delegates/gpu/common/types.h"
30 #include "tensorflow/lite/delegates/gpu/common/util.h"
31 #include "tensorflow/lite/delegates/gpu/metal/common.h"
32
33 namespace tflite {
34 namespace gpu {
35 namespace metal {
36 namespace {
GetWorkGroupsCount(int grid_dimension,const int3 & grid_size,const int3 & work_group_size,const int3 & work_group_launch_order)37 int3 GetWorkGroupsCount(int grid_dimension, const int3& grid_size,
38 const int3& work_group_size,
39 const int3& work_group_launch_order) {
40 int3 work_groups_count;
41 if (grid_dimension == 1) {
42 work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x);
43 work_groups_count.y = 1;
44 work_groups_count.z = 1;
45 } else if (grid_dimension == 2) {
46 int3 wgs;
47 wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
48 wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
49 work_groups_count.x = wgs[work_group_launch_order[0]];
50 work_groups_count.y = wgs[work_group_launch_order[1]];
51 work_groups_count.z = 1;
52 } else { // grid_dimension == 3
53 int3 wgs;
54 wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
55 wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
56 wgs.z = DivideRoundUp(grid_size.z, work_group_size.z);
57 work_groups_count.x = wgs[work_group_launch_order[0]];
58 work_groups_count.y = wgs[work_group_launch_order[1]];
59 work_groups_count.z = wgs[work_group_launch_order[2]];
60 }
61 return work_groups_count;
62 }
63 } // namespace
64
Init(std::unique_ptr<GPUOperation> && operation)65 void ComputeTask::Init(std::unique_ptr<GPUOperation>&& operation) {
66 operation_ = std::move(operation);
67 }
68
GetDefinition() const69 const OperationDef& ComputeTask::GetDefinition() const {
70 return operation_->definition_;
71 }
72
IsLinkable() const73 bool ComputeTask::IsLinkable() const { return operation_->IsLinkable(); }
74
AddTask(ComputeTask * task)75 absl::Status ComputeTask::AddTask(ComputeTask* task) {
76 return operation_->AddOperation(task->operation_.get());
77 }
78
Compile(MetalDevice * device)79 absl::Status ComputeTask::Compile(MetalDevice* device) {
80 operation_->AssembleCode(device->GetInfo());
81 const std::map<std::string, std::string> linkables = {
82 {operation_->dst_tensors_names_[0], operation_->elementwise_code_}};
83 RETURN_IF_ERROR(metal_args_.Init(linkables, device, &operation_->args_,
84 &operation_->code_));
85
86 operation_->args_.ReleaseCPURepresentation();
87
88 return CompileProgram(device, operation_->definition_.precision,
89 operation_->code_);
90 }
91
CompileProgram(MetalDevice * device,CalculationsPrecision precision,const std::string & kernel_code)92 absl::Status ComputeTask::CompileProgram(MetalDevice* device,
93 CalculationsPrecision precision,
94 const std::string& kernel_code) {
95 NSString* barrier;
96 // simdgroup_barrier is supported since Metal shading language version 2.0
97 if (device->IsLanguageVersion2orHigher()) {
98 barrier = @"simdgroup_barrier";
99 } else {
100 barrier = @"threadgroup_barrier";
101 }
102 NSString* storageType;
103 NSString* accumulatorType;
104 NSString* toAccumulatorType4 = @"";
105 if (precision == CalculationsPrecision::F32) {
106 storageType = @"float";
107 accumulatorType = @"float";
108 } else {
109 // FP16
110 storageType = @"half";
111 if (precision == CalculationsPrecision::F32_F16) {
112 accumulatorType = @"float";
113 toAccumulatorType4 = @"float4";
114 } else {
115 accumulatorType = @"half";
116 }
117 }
118 NSDictionary<NSString*, NSString*>* macros = @{
119 @"float16" : @"float4x4",
120 @"half16" : @"half4x4",
121 @"FLT16_0123(V)" : @"V[0]",
122 @"FLT16_4567(V)" : @"V[1]",
123 @"FLT16_89ab(V)" : @"V[2]",
124 @"FLT16_cdef(V)" : @"V[3]",
125 @"FLT" : storageType,
126 @"FLT2" : [NSString stringWithFormat:@"%@2", storageType],
127 @"FLT3" : [NSString stringWithFormat:@"%@3", storageType],
128 @"FLT4" : [NSString stringWithFormat:@"%@4", storageType],
129 @"ACCUM_FLT" : accumulatorType,
130 @"ACCUM_FLT2" : [NSString stringWithFormat:@"%@2", accumulatorType],
131 @"ACCUM_FLT3" : [NSString stringWithFormat:@"%@3", accumulatorType],
132 @"ACCUM_FLT4" : [NSString stringWithFormat:@"%@4", accumulatorType],
133 @"INIT_ACCUM_FLT4(value)" :
134 [NSString stringWithFormat:@"%@4(value)", accumulatorType],
135 @"TO_ACCUM_TYPE" : toAccumulatorType4,
136 @"TO_ACCUM_FLT" : accumulatorType,
137 @"TO_FLT4" : [NSString stringWithFormat:@"%@4", storageType],
138 @"SIMDGROUP_BARRIER" : barrier,
139 @"SIMD_LOCAL_MEM_BARRIER" : barrier,
140 @"MAIN_FUNCTION" : @"\"kernel void ComputeFunction\"",
141 @"GLOBAL_ID_0" : @"static_cast<int>(reserved_gid.x)",
142 @"GLOBAL_ID_1" : @"static_cast<int>(reserved_gid.y)",
143 @"GLOBAL_ID_2" : @"static_cast<int>(reserved_gid.z)",
144 @"LOCAL_ID_0" : @"static_cast<int>(reserved_lid.x)",
145 @"LOCAL_ID_1" : @"static_cast<int>(reserved_lid.y)",
146 @"LOCAL_ID_2" : @"static_cast<int>(reserved_lid.z)",
147 @"GROUP_ID_0" : @"static_cast<int>(reserved_group_id.x)",
148 @"GROUP_ID_1" : @"static_cast<int>(reserved_group_id.y)",
149 @"GROUP_ID_2" : @"static_cast<int>(reserved_group_id.z)",
150 @"GROUP_SIZE_0" : @"static_cast<int>(reserved_group_size.x)",
151 @"GROUP_SIZE_1" : @"static_cast<int>(reserved_group_size.y)",
152 @"GROUP_SIZE_2" : @"static_cast<int>(reserved_group_size.z)",
153 @"SUB_GROUP_LOCAL_ID" : @"static_cast<int>(reserved_simd_id)",
154 @"\"SUB_GROUP_BROADCAST(V, ID)\"" : @"\"simd_broadcast(V, ID)\"",
155 @"__local" : @"threadgroup",
156 @"__global" : @"device",
157 @"__constant" : @"constant",
158 @"LOCAL_MEM_BARRIER" : @"threadgroup_barrier(mem_flags::mem_threadgroup)",
159 @"INIT_FLT(value)" : [NSString stringWithFormat:@"%@(value)", storageType],
160 @"INIT_FLT4(value)" :
161 [NSString stringWithFormat:@"%@4(value)", storageType],
162 @"\"INIT_FLT4v4(v0, v1, v2, v3)\"" :
163 [NSString stringWithFormat:@"\"%@4(v0, v1, v2, v3)\"", storageType],
164 @"INIT_FLOAT(value)" : @"float(value)",
165 @"INIT_FLOAT2(value)" : @"float2(value)",
166 @"\"INIT_FLOAT2v2(v0, v1)\"" : @"\"float2(v0, v1)\"",
167 @"INIT_FLOAT3(value)" : @"float3(value)",
168 @"\"INIT_FLOAT3v3(v0, v1, v2)\"" : @"\"float3(v0, v1, v2)\"",
169 @"INIT_FLOAT4(value)" : @"float4(value)",
170 @"\"INIT_FLOAT4v4(v0, v1, v2, v3)\"" : @"\"float4(v0, v1, v2, v3)\"",
171 @"INIT_INT(value)" : @"int(value)",
172 @"\"INIT_INT2v2(v0, v1)\"" : @"\"int2(v0, v1)\"",
173 @"\"INIT_INT4v4(v0, v1, v2, v3)\"" : @"\"int4(v0, v1, v2, v3)\"",
174 @"CONVERT_TO_INT4(value)" : @"int4(value)",
175 };
176
177 NSString* code =
178 [NSString stringWithCString:kernel_code.c_str()
179 encoding:[NSString defaultCStringEncoding]];
180 id<MTLComputePipelineState> program;
181 RETURN_IF_ERROR(CreateComputeProgram(device->device(), code,
182 @"ComputeFunction", macros, &program));
183 if (!program) {
184 return absl::InternalError("Unknown shader compilation error");
185 }
186 program_ = program;
187 return absl::OkStatus();
188 }
189
UpdateParams()190 absl::Status ComputeTask::UpdateParams() {
191 for (int i = 0; i < operation_->src_tensors_names_.size(); ++i) {
192 const auto* metal_spatial_tensor =
193 dynamic_cast<const MetalSpatialTensor*>(operation_->src_[i]);
194 if (!metal_spatial_tensor) {
195 return absl::InvalidArgumentError("Expected MetalSpatialTensor.");
196 }
197 RETURN_IF_ERROR(metal_args_.SetObjectRef(operation_->src_tensors_names_[i],
198 *metal_spatial_tensor));
199 }
200 for (int i = 0; i < operation_->dst_tensors_names_.size(); ++i) {
201 const auto* metal_spatial_tensor =
202 dynamic_cast<const MetalSpatialTensor*>(operation_->dst_[i]);
203 if (!metal_spatial_tensor) {
204 return absl::InvalidArgumentError("Expected MetalSpatialTensor.");
205 }
206 RETURN_IF_ERROR(metal_args_.SetObjectRef(operation_->dst_tensors_names_[i],
207 *metal_spatial_tensor));
208 }
209 RETURN_IF_ERROR(operation_->BindArguments(&metal_args_));
210 operation_->grid_size_ = operation_->GetGridSize();
211 operation_->work_groups_count_ = GetWorkGroupsCount(
212 operation_->grid_dimension_, operation_->grid_size_,
213 operation_->work_group_size_, operation_->work_group_launch_order_);
214 return absl::OkStatus();
215 }
216
Encode(id<MTLComputeCommandEncoder> encoder)217 void ComputeTask::Encode(id<MTLComputeCommandEncoder> encoder) {
218 [encoder setComputePipelineState:program_];
219 metal_args_.Encode(encoder, 0);
220 MTLSize groupsCount, groupsSize;
221 groupsCount.width = operation_->work_groups_count_.x;
222 groupsCount.height = operation_->work_groups_count_.y;
223 groupsCount.depth = operation_->work_groups_count_.z;
224 groupsSize.width = operation_->work_group_size_.x;
225 groupsSize.height = operation_->work_group_size_.y;
226 groupsSize.depth = operation_->work_group_size_.z;
227 [encoder dispatchThreadgroups:groupsCount threadsPerThreadgroup:groupsSize];
228 }
229
SetSrcTensor(MetalSpatialTensor * tensor,int index)230 void ComputeTask::SetSrcTensor(MetalSpatialTensor* tensor, int index) {
231 operation_->SetSrc(tensor, index);
232 auto status =
233 metal_args_.SetObjectRef(operation_->src_tensors_names_[index], *tensor);
234 }
235
SetDstTensor(MetalSpatialTensor * tensor,int index)236 void ComputeTask::SetDstTensor(MetalSpatialTensor* tensor, int index) {
237 operation_->SetDst(tensor, index);
238 auto status =
239 metal_args_.SetObjectRef(operation_->dst_tensors_names_[index], *tensor);
240 }
241
Tune(TuningType tuning_type,MetalDevice * device)242 absl::Status ComputeTask::Tune(TuningType tuning_type, MetalDevice* device) {
243 std::vector<int3> possible_work_groups;
244 KernelInfo kernel_info;
245 kernel_info.max_work_group_size = [program_ maxTotalThreadsPerThreadgroup];
246 kernel_info.private_memory_size = 0;
247 operation_->GetPossibleKernelWorkGroups(tuning_type, device->GetInfo(),
248 kernel_info, &possible_work_groups);
249 if (possible_work_groups.empty()) {
250 return absl::NotFoundError(
251 "Can not found work_group size to launch kernel");
252 }
253 operation_->work_group_size_ = possible_work_groups[0];
254 operation_->work_groups_count_ = GetWorkGroupsCount(
255 operation_->grid_dimension_, operation_->grid_size_,
256 operation_->work_group_size_, operation_->work_group_launch_order_);
257 return absl::OkStatus();
258 }
259
260 } // namespace metal
261 } // namespace gpu
262 } // namespace tflite
263