• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "tensorflow/lite/delegates/gpu/common/convert.h"
26 #include "tensorflow/lite/delegates/gpu/common/status.h"
27 #include "tensorflow/lite/delegates/gpu/common/types.h"
28 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
29 
30 namespace tflite {
31 namespace gpu {
32 namespace gl {
33 namespace {
34 
35 class FullyConnectedBuffers : public NodeShader {
36  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const37   absl::Status GenerateCode(const GenerationContext& ctx,
38                             GeneratedCode* generated_code) const final {
39     const auto& attr =
40         absl::any_cast<const FullyConnectedAttributes&>(ctx.op_attr);
41 
42     const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
43     const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
44 
45     // This shader can work with any workgroup size, the values below work well
46     // for OpenGL.
47     constexpr int kWorkgroupHintX = 4;
48     constexpr int kWorkgroupHintY = 4;
49 
50     // TODO(akulik): check that input has h,w == 1,1
51     std::vector<Variable> parameters = {
52         {"src_depth", src_depth},
53         {"dst_depth", dst_depth},
54     };
55 
56     // TODO(akulik): refactor indexed access to weights.
57     std::vector<std::pair<std::string, Object>> objects = {
58         {"weights", MakeReadonlyObject(ConvertToPHWO4I4(attr.weights))}};
59 
60     std::string source = R"(
61   const int threads = int(gl_WorkGroupSize.y);
62   const int workers = int(gl_WorkGroupSize.x);
63   ivec3 tid = ivec3(gl_LocalInvocationID);
64 
65   if (gid.x < $dst_depth$) {
66     int offset = 4 * gid.x * $src_depth$ + 4 * tid.y;
67     int iterations = ($src_depth$ + threads-1) / threads;
68     for (int d = 0; d < iterations; d++, offset += 4 * threads) {
69       vec4 src = $input_data_0[0, 0, d * threads + tid.y]$;
70       value_0.x += dot(src, $weights[offset + 0]$);
71       value_0.y += dot(src, $weights[offset + 1]$);
72       value_0.z += dot(src, $weights[offset + 2]$);
73       value_0.w += dot(src, $weights[offset + 3]$);
74     }
75     sh_mem[workers * tid.y + tid.x] = value_0;
76   }
77   memoryBarrierShared();
78   barrier();
79 
80   if (tid.y > 0 || gid.x >= $dst_depth$) {
81     return;
82   }
83 
84   for (int t = 1; t < threads; t++) {
85     value_0 += sh_mem[workers * t + tid.x];
86   }
87 )";
88     if (!attr.bias.data.empty()) {
89       source += "  value_0 += $bias[gid.x]$;\n";
90       objects.push_back({"bias", MakeReadonlyObject(attr.bias.data)});
91     }
92     source += "  $output_data_0[0, 0, gid.x] = value_0$;";
93 
94     std::vector<Variable> shared_variables = {
95 #ifdef __APPLE__
96         // MoltenVK has problems with shared memory sized using the workgroup
97         // size. Fortunately with Metal a fixed workgroup size of 32 seems to
98         // give optimal results.
99         {"sh_mem", std::vector<float4>(32)},
100 #else
101         // The actual size of sh_mem depends on the WorkgroupSize
102         {"sh_mem", std::vector<float4>(0)},
103 #endif
104     };
105 
106     *generated_code = {
107         /*parameters=*/std::move(parameters),
108         /*objects=*/std::move(objects),
109         /*shared_variables=*/std::move(shared_variables),
110         /*workload=*/uint3(dst_depth, kWorkgroupHintY, 1),
111         /*workgroup=*/uint3(kWorkgroupHintX, kWorkgroupHintY, 1),
112         /*source_code=*/std::move(source),
113         /*input=*/IOStructure::ONLY_DEFINITIONS,
114         /*output=*/IOStructure::ONLY_DEFINITIONS,
115     };
116     return absl::OkStatus();
117   }
118 };
119 
120 }  // namespace
121 
NewFullyConnectedNodeShader()122 std::unique_ptr<NodeShader> NewFullyConnectedNodeShader() {
123   return absl::make_unique<FullyConnectedBuffers>();
124 }
125 
126 }  // namespace gl
127 }  // namespace gpu
128 }  // namespace tflite
129