1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
17
18 #include <string>
19
20 #include "absl/types/variant.h"
21 #include "tensorflow/lite/delegates/gpu/common/status.h"
22 #include "tensorflow/lite/delegates/gpu/common/types.h"
23 #include "tensorflow/lite/delegates/gpu/gl/gl_call.h"
24 #include "tensorflow/lite/delegates/gpu/gl/gl_errors.h"
25 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
26
27 namespace tflite {
28 namespace gpu {
29 namespace gl {
30 namespace {
31
CreateNewProgramId(GLuint * program_id)32 Status CreateNewProgramId(GLuint* program_id) {
33 RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glCreateProgram, program_id));
34 if (!*program_id) {
35 return UnknownError("Can't create opengl program: 0 program_id");
36 }
37 return OkStatus();
38 }
39
CheckProgramLinked(GLuint program_id)40 Status CheckProgramLinked(GLuint program_id) {
41 GLint linked;
42 glGetProgramiv(program_id, GL_LINK_STATUS, &linked);
43 if (linked == GL_TRUE) {
44 return OkStatus();
45 }
46 GLint info_size;
47 glGetProgramiv(program_id, GL_INFO_LOG_LENGTH, &info_size);
48 std::string errors;
49 errors.resize(info_size + 1 /* plus \0 */);
50 glGetProgramInfoLog(program_id, info_size + 1, nullptr, &errors[0]);
51 // TODO(akulik): use glValidateProgram to gather more info.
52 return UnavailableError("Program is not properly linked: " + errors);
53 }
54
55 struct ParameterSetter {
operator ()tflite::gpu::gl::__anoncbd5b9f50111::ParameterSetter56 Status operator()(int value) {
57 return TFLITE_GPU_CALL_GL(glProgramUniform1i, program_id, uniform_id,
58 value);
59 }
60
operator ()tflite::gpu::gl::__anoncbd5b9f50111::ParameterSetter61 Status operator()(const int2& value) {
62 return TFLITE_GPU_CALL_GL(glProgramUniform2i, program_id, uniform_id,
63 value.x, value.y);
64 }
65
operator ()tflite::gpu::gl::__anoncbd5b9f50111::ParameterSetter66 Status operator()(const int4& value) {
67 return TFLITE_GPU_CALL_GL(glProgramUniform4i, program_id, uniform_id,
68 value.x, value.y, value.z, value.w);
69 }
70
operator ()tflite::gpu::gl::__anoncbd5b9f50111::ParameterSetter71 Status operator()(const std::vector<int2>& value) {
72 std::vector<GLint> ints(value.size() * 2, 0);
73 for (int i = 0; i < value.size(); ++i) {
74 ints[i * 2] = value[i].x;
75 ints[i * 2 + 1] = value[i].y;
76 }
77 return TFLITE_GPU_CALL_GL(glProgramUniform2iv, program_id, uniform_id,
78 ints.size(), ints.data());
79 }
80
operator ()tflite::gpu::gl::__anoncbd5b9f50111::ParameterSetter81 Status operator()(unsigned int value) {
82 return TFLITE_GPU_CALL_GL(glProgramUniform1ui, program_id, uniform_id,
83 value);
84 }
85
operator ()tflite::gpu::gl::__anoncbd5b9f50111::ParameterSetter86 Status operator()(const uint4& value) {
87 return TFLITE_GPU_CALL_GL(glProgramUniform4ui, program_id, uniform_id,
88 value.x, value.y, value.z, value.w);
89 }
90
operator ()tflite::gpu::gl::__anoncbd5b9f50111::ParameterSetter91 Status operator()(float value) {
92 return TFLITE_GPU_CALL_GL(glProgramUniform1f, program_id, uniform_id,
93 value);
94 }
95
operator ()tflite::gpu::gl::__anoncbd5b9f50111::ParameterSetter96 Status operator()(const float2& value) {
97 return TFLITE_GPU_CALL_GL(glProgramUniform2f, program_id, uniform_id,
98 value.x, value.y);
99 }
100
operator ()tflite::gpu::gl::__anoncbd5b9f50111::ParameterSetter101 Status operator()(const float4& value) {
102 return TFLITE_GPU_CALL_GL(glProgramUniform4f, program_id, uniform_id,
103 value.x, value.y, value.z, value.w);
104 }
105
operator ()tflite::gpu::gl::__anoncbd5b9f50111::ParameterSetter106 Status operator()(const std::vector<float4>& value) {
107 std::vector<GLfloat> floats(value.size() * 4, 0);
108 for (int i = 0; i < value.size(); ++i) {
109 floats[i * 4] = value[i].x;
110 floats[i * 4 + 1] = value[i].y;
111 floats[i * 4 + 2] = value[i].z;
112 floats[i * 4 + 3] = value[i].w;
113 }
114 return TFLITE_GPU_CALL_GL(glProgramUniform4fv, program_id, uniform_id,
115 floats.size(), floats.data());
116 }
117
118 const GLuint program_id;
119 const GLint uniform_id;
120 };
121
122 } // namespace
123
CreateWithShader(const GlShader & shader,GlProgram * gl_program)124 Status GlProgram::CreateWithShader(const GlShader& shader,
125 GlProgram* gl_program) {
126 GLuint program_id;
127 RETURN_IF_ERROR(CreateNewProgramId(&program_id));
128
129 // program_id needs to be properly deleted if there will be an error, hense
130 // wrap program_id into Program.
131 GlProgram program(program_id);
132
133 RETURN_IF_ERROR(
134 TFLITE_GPU_CALL_GL(glAttachShader, program.id(), shader.id()));
135 RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glLinkProgram, program.id()));
136 RETURN_IF_ERROR(CheckProgramLinked(program.id()));
137
138 *gl_program = std::move(program);
139 return OkStatus();
140 }
141
CreateWithBinaryShader(const BinaryShader & shader,GlProgram * gl_program)142 Status GlProgram::CreateWithBinaryShader(const BinaryShader& shader,
143 GlProgram* gl_program) {
144 GLuint program_id;
145 RETURN_IF_ERROR(CreateNewProgramId(&program_id));
146
147 // program_id needs to be properly deleted if there will be an error, hense
148 // wrap program_id into Program.
149 GlProgram program(program_id);
150
151 RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glProgramBinary, program.id(),
152 shader.format(), shader.binary().data(),
153 shader.binary().size()));
154 RETURN_IF_ERROR(CheckProgramLinked(program.id()));
155
156 *gl_program = std::move(program);
157 return OkStatus();
158 }
159
GetBinary(BinaryShader * binary_shader)160 Status GlProgram::GetBinary(BinaryShader* binary_shader) {
161 GLint size = 0;
162 RETURN_IF_ERROR(
163 TFLITE_GPU_CALL_GL(glGetProgramiv, id_, GL_PROGRAM_BINARY_LENGTH, &size));
164 if (!size) {
165 return InternalError("Getting binary size failed.");
166 }
167 // TODO(akulik): call
168 // glProgramParameteri(id_, GL_PROGRAM_BINARY_RETRIEVABLE_HINT, GL_TRUE)
169 // before linking a program to increase chances of retrieving a binary.
170 std::vector<uint8_t> binary(size);
171 GLsizei returned_size;
172 GLenum format;
173 RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glGetProgramBinary, id_, size,
174 &returned_size, &format,
175 reinterpret_cast<void*>(&binary[0])));
176 if (size != returned_size) {
177 return InternalError("Getting binary is failed.");
178 }
179 *binary_shader = BinaryShader(format, std::move(binary));
180 return OkStatus();
181 }
182
GlProgram(GlProgram && program)183 GlProgram::GlProgram(GlProgram&& program) : id_(program.id_) {
184 program.id_ = 0;
185 }
186
Invalidate()187 void GlProgram::Invalidate() {
188 if (id_) {
189 glDeleteProgram(id_);
190 id_ = 0;
191 }
192 }
193
operator =(GlProgram && program)194 GlProgram& GlProgram::operator=(GlProgram&& program) {
195 if (this != &program) {
196 Invalidate();
197 std::swap(id_, program.id_);
198 }
199 return *this;
200 }
201
~GlProgram()202 GlProgram::~GlProgram() { Invalidate(); }
203
SetParameter(const Variable & param)204 Status GlProgram::SetParameter(const Variable& param) {
205 GLint uniform_location;
206 RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glGetUniformLocation, &uniform_location,
207 id_, param.name.c_str()));
208 return absl::visit(ParameterSetter{id_, uniform_location}, param.value);
209 }
210
Dispatch(const uint3 & workgroups) const211 Status GlProgram::Dispatch(const uint3& workgroups) const {
212 if (workgroups.x == 0 || workgroups.y == 0 || workgroups.z == 0) {
213 return InvalidArgumentError("Invalid workgroups");
214 }
215 RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glUseProgram, id_));
216 return TFLITE_GPU_CALL_GL(glDispatchCompute, workgroups.x, workgroups.y,
217 workgroups.z);
218 }
219
220 } // namespace gl
221 } // namespace gpu
222 } // namespace tflite
223