• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/concat_z.h"
17 
18 #include <string>
19 
20 #include "tensorflow/lite/delegates/gpu/common/status.h"
21 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
22 #include "tensorflow/lite/delegates/gpu/common/types.h"
23 
24 namespace tflite {
25 namespace gpu {
26 namespace {
27 
IsAllChannelsX4(const std::vector<int> & channels)28 bool IsAllChannelsX4(const std::vector<int>& channels) {
29   for (int channel : channels) {
30     if (channel % 4 != 0) {
31       return false;
32     }
33   }
34   return true;
35 }
36 
GetConcatKernelCode(const OperationDef & op_def,const std::vector<int> & channels)37 std::string GetConcatKernelCode(const OperationDef& op_def,
38                                 const std::vector<int>& channels) {
39   std::vector<std::string> tensor_names(op_def.src_tensors.size());
40   for (int i = 0; i < op_def.src_tensors.size(); ++i) {
41     tensor_names[i] = "src_tensor_" + std::to_string(i);
42   }
43 
44   std::string c;
45   c += "MAIN_FUNCTION($0) {\n";
46   c += "  int X = GLOBAL_ID_0;\n";
47   c += "  int Y = GLOBAL_ID_1;\n";
48   std::string coords = "X, Y";
49   if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
50     c += "  int Z = GLOBAL_ID_2;\n";
51     c += "  if (Z >= args.dst_tensor.Depth()) return;\n";
52     coords = "X, Y, Z";
53   }
54   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) "
55        "return; \n";
56 
57   if (IsAllChannelsX4(channels)) {
58     // When all channels % 4 == 0 we can read/assign/write FLT4 elements easily.
59     // Also it is easy to write a loop in this case, to prevent long kernel
60     // generation.
61     c += "  int S = 0;\n";
62     for (int i = 0; i < channels.size(); ++i) {
63       std::string t_name = "args." + tensor_names[i];
64       const int depth = DivideRoundUp(channels[i], 4);
65       if (depth % 2 == 0) {
66         // We can read more at once inside of loop in case depth % 2 == 0
67         // it should be better for reading latency hiding
68         c += "  for (int i = 0; i < " + t_name + ".Slices(); i += 2) {\n";
69         c += "    FLT4 result0 = " + t_name + ".Read(" + coords + ", i);\n";
70         c += "    FLT4 result1 = " + t_name + ".Read(" + coords + ", i + 1);\n";
71         c += "    args.dst_tensor.Write(result0, " + coords + ", S);\n";
72         c += "    args.dst_tensor.Write(result1, " + coords + ", S + 1);\n";
73         c += "    S += 2;\n";
74         c += "  }\n";
75       } else {
76         c += "  for (int i = 0; i < " + t_name + ".Slices(); ++i) {\n";
77         c += "    FLT4 result = " + t_name + ".Read(" + coords + ", i);\n";
78         c += "    args.dst_tensor.Write(result, " + coords + ", S);\n";
79         c += "    S++;\n";
80         c += "  }\n";
81       }
82     }
83   } else {
84     c += "  FLT4 result = INIT_FLT4(0.0);\n";
85     int out_channel = 0;
86     int read_index = 0;
87     int z = 0;
88     const std::string postfix[] = {".x", ".y", ".z", ".w"};
89     for (int i = 0; i < channels.size(); ++i) {
90       std::string tensor_name = "args." + tensor_names[i];
91       const int depth = DivideRoundUp(channels[i], 4);
92       for (int d = 0; d < depth; ++d) {
93         const int channels_in_group = std::min(4, channels[i] - d * 4);
94         const std::string temp_name = "t" + std::to_string(read_index);
95         c += "  FLT4 " + temp_name + " = " + tensor_name + ".Read(" + coords +
96              ", " + std::to_string(d) + ");\n";
97         for (int ch = 0; ch < channels_in_group; ++ch) {
98           c += "  result" + postfix[out_channel] + " = ";
99           c += temp_name + postfix[ch] + ";\n";
100           out_channel++;
101           if (out_channel == 4) {
102             out_channel = 0;
103             c += "  args.dst_tensor.Write(result, " + coords + ", " +
104                  std::to_string(z) + ");\n";
105             z++;
106           }
107         }
108         read_index++;
109       }
110     }
111     if (out_channel != 0) {
112       c += "  args.dst_tensor.Write(result, " + coords + ", " +
113            std::to_string(z) + ");\n";
114     }
115   }
116   c += "}\n";
117   return c;
118 }
119 
120 }  // namespace
121 
CreateConcatZ(const OperationDef & definition,const std::vector<int> & channels,const GpuInfo & gpu_info)122 GPUOperation CreateConcatZ(const OperationDef& definition,
123                            const std::vector<int>& channels,
124                            const GpuInfo& gpu_info) {
125   GPUOperation op(definition);
126   for (int i = 0; i < definition.src_tensors.size(); ++i) {
127     const std::string name = "src_tensor_" + std::to_string(i);
128     auto src_desc = definition.src_tensors[i];
129     if (definition.IsBatchSupported()) {
130       src_desc.SetStateVar("BatchedWidth", "true");
131     }
132     op.AddSrcTensor(name, src_desc);
133   }
134   auto dst_desc = definition.dst_tensors[0];
135   if (definition.IsBatchSupported()) {
136     dst_desc.SetStateVar("BatchedWidth", "true");
137   }
138   op.AddDstTensor("dst_tensor", dst_desc);
139   op.code_ = GetConcatKernelCode(definition, channels);
140   if (gpu_info.IsPowerVR() &&
141       definition.precision == CalculationsPrecision::F32 &&
142       !IsAllChannelsX4(channels)) {
143     // BUG, some PowerVRs (GE8320) produce incorrect result without it
144     op.compiler_options_.push_back(CompilerOptions::kClDisableOptimizations);
145   }
146   if (gpu_info.IsAMD() && definition.precision != CalculationsPrecision::F32 &&
147       definition.src_tensors[0].storage_type != TensorStorageType::BUFFER &&
148       !IsAllChannelsX4(channels)) {
149     // BUG, some AMD gpus crash without it
150     op.compiler_options_.push_back(CompilerOptions::kClDisableOptimizations);
151   }
152   op.tensor_to_grid_ = TensorToGrid::kWBToX_HToY_DToZ;
153   return op;
154 }
155 
156 }  // namespace gpu
157 }  // namespace tflite
158