1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/concat_z.h"
17
18 #include <string>
19
20 #include "tensorflow/lite/delegates/gpu/common/status.h"
21 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
22 #include "tensorflow/lite/delegates/gpu/common/types.h"
23
24 namespace tflite {
25 namespace gpu {
26 namespace {
27
IsAllChannelsX4(const std::vector<int> & channels)28 bool IsAllChannelsX4(const std::vector<int>& channels) {
29 for (int channel : channels) {
30 if (channel % 4 != 0) {
31 return false;
32 }
33 }
34 return true;
35 }
36
GetConcatKernelCode(const OperationDef & op_def,const std::vector<int> & channels)37 std::string GetConcatKernelCode(const OperationDef& op_def,
38 const std::vector<int>& channels) {
39 std::vector<std::string> tensor_names(op_def.src_tensors.size());
40 for (int i = 0; i < op_def.src_tensors.size(); ++i) {
41 tensor_names[i] = "src_tensor_" + std::to_string(i);
42 }
43
44 std::string c;
45 c += "MAIN_FUNCTION($0) {\n";
46 c += " int X = GLOBAL_ID_0;\n";
47 c += " int Y = GLOBAL_ID_1;\n";
48 std::string coords = "X, Y";
49 if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
50 c += " int Z = GLOBAL_ID_2;\n";
51 c += " if (Z >= args.dst_tensor.Depth()) return;\n";
52 coords = "X, Y, Z";
53 }
54 c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) "
55 "return; \n";
56
57 if (IsAllChannelsX4(channels)) {
58 // When all channels % 4 == 0 we can read/assign/write FLT4 elements easily.
59 // Also it is easy to write a loop in this case, to prevent long kernel
60 // generation.
61 c += " int S = 0;\n";
62 for (int i = 0; i < channels.size(); ++i) {
63 std::string t_name = "args." + tensor_names[i];
64 const int depth = DivideRoundUp(channels[i], 4);
65 if (depth % 2 == 0) {
66 // We can read more at once inside of loop in case depth % 2 == 0
67 // it should be better for reading latency hiding
68 c += " for (int i = 0; i < " + t_name + ".Slices(); i += 2) {\n";
69 c += " FLT4 result0 = " + t_name + ".Read(" + coords + ", i);\n";
70 c += " FLT4 result1 = " + t_name + ".Read(" + coords + ", i + 1);\n";
71 c += " args.dst_tensor.Write(result0, " + coords + ", S);\n";
72 c += " args.dst_tensor.Write(result1, " + coords + ", S + 1);\n";
73 c += " S += 2;\n";
74 c += " }\n";
75 } else {
76 c += " for (int i = 0; i < " + t_name + ".Slices(); ++i) {\n";
77 c += " FLT4 result = " + t_name + ".Read(" + coords + ", i);\n";
78 c += " args.dst_tensor.Write(result, " + coords + ", S);\n";
79 c += " S++;\n";
80 c += " }\n";
81 }
82 }
83 } else {
84 c += " FLT4 result = INIT_FLT4(0.0);\n";
85 int out_channel = 0;
86 int read_index = 0;
87 int z = 0;
88 const std::string postfix[] = {".x", ".y", ".z", ".w"};
89 for (int i = 0; i < channels.size(); ++i) {
90 std::string tensor_name = "args." + tensor_names[i];
91 const int depth = DivideRoundUp(channels[i], 4);
92 for (int d = 0; d < depth; ++d) {
93 const int channels_in_group = std::min(4, channels[i] - d * 4);
94 const std::string temp_name = "t" + std::to_string(read_index);
95 c += " FLT4 " + temp_name + " = " + tensor_name + ".Read(" + coords +
96 ", " + std::to_string(d) + ");\n";
97 for (int ch = 0; ch < channels_in_group; ++ch) {
98 c += " result" + postfix[out_channel] + " = ";
99 c += temp_name + postfix[ch] + ";\n";
100 out_channel++;
101 if (out_channel == 4) {
102 out_channel = 0;
103 c += " args.dst_tensor.Write(result, " + coords + ", " +
104 std::to_string(z) + ");\n";
105 z++;
106 }
107 }
108 read_index++;
109 }
110 }
111 if (out_channel != 0) {
112 c += " args.dst_tensor.Write(result, " + coords + ", " +
113 std::to_string(z) + ");\n";
114 }
115 }
116 c += "}\n";
117 return c;
118 }
119
120 } // namespace
121
CreateConcatZ(const OperationDef & definition,const std::vector<int> & channels,const GpuInfo & gpu_info)122 GPUOperation CreateConcatZ(const OperationDef& definition,
123 const std::vector<int>& channels,
124 const GpuInfo& gpu_info) {
125 GPUOperation op(definition);
126 for (int i = 0; i < definition.src_tensors.size(); ++i) {
127 const std::string name = "src_tensor_" + std::to_string(i);
128 auto src_desc = definition.src_tensors[i];
129 if (definition.IsBatchSupported()) {
130 src_desc.SetStateVar("BatchedWidth", "true");
131 }
132 op.AddSrcTensor(name, src_desc);
133 }
134 auto dst_desc = definition.dst_tensors[0];
135 if (definition.IsBatchSupported()) {
136 dst_desc.SetStateVar("BatchedWidth", "true");
137 }
138 op.AddDstTensor("dst_tensor", dst_desc);
139 op.code_ = GetConcatKernelCode(definition, channels);
140 if (gpu_info.IsPowerVR() &&
141 definition.precision == CalculationsPrecision::F32 &&
142 !IsAllChannelsX4(channels)) {
143 // BUG, some PowerVRs (GE8320) produce incorrect result without it
144 op.compiler_options_.push_back(CompilerOptions::kClDisableOptimizations);
145 }
146 if (gpu_info.IsAMD() && definition.precision != CalculationsPrecision::F32 &&
147 definition.src_tensors[0].storage_type != TensorStorageType::BUFFER &&
148 !IsAllChannelsX4(channels)) {
149 // BUG, some AMD gpus crash without it
150 op.compiler_options_.push_back(CompilerOptions::kClDisableOptimizations);
151 }
152 op.tensor_to_grid_ = TensorToGrid::kWBToX_HToY_DToZ;
153 return op;
154 }
155
156 } // namespace gpu
157 } // namespace tflite
158