• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "LutShader.h"
17 
18 #include <SkM44.h>
19 #include <SkTileMode.h>
20 #include <common/trace.h>
21 #include <cutils/ashmem.h>
22 #include <math/half.h>
23 #include <sys/mman.h>
24 #include <ui/ColorSpace.h>
25 
26 #include "include/core/SkColorSpace.h"
27 
28 using aidl::android::hardware::graphics::composer3::LutProperties;
29 
30 namespace android {
31 namespace renderengine {
32 namespace skia {
33 
34 static const SkString kShader = SkString(R"(
35     uniform shader image;
36     uniform shader lut;
37     uniform int size;
38     uniform int key;
39     uniform int dimension;
40     uniform vec3 luminanceCoefficients; // for CIE_Y
41     // for hlg/pq transfer function, we need normalize it to [0.0, 1.0]
42     // we use `normalizeScalar` to do so
43     uniform float normalizeScalar;
44 
45     vec4 main(vec2 xy) {
46         float4 rgba = image.eval(xy);
47         float3 linear = toLinearSrgb(rgba.rgb) * normalizeScalar;
48         if (dimension == 1) {
49             // RGB
50             if (key == 0) {
51                 float indexR = linear.r * float(size - 1);
52                 float indexG = linear.g * float(size - 1);
53                 float indexB = linear.b * float(size - 1);
54                 float gainR = lut.eval(vec2(indexR, 0.0) + 0.5).r;
55                 float gainG = lut.eval(vec2(indexG, 0.0) + 0.5).r;
56                 float gainB = lut.eval(vec2(indexB, 0.0) + 0.5).r;
57                 linear = float3(linear.r * gainR, linear.g * gainG, linear.b * gainB);
58             // MAX_RGB
59             } else if (key == 1) {
60                 float maxRGB = max(linear.r, max(linear.g, linear.b));
61                 float index = maxRGB * float(size - 1);
62                 float gain = lut.eval(vec2(index, 0.0) + 0.5).r;
63                 linear = linear * gain;
64             // CIE_Y
65             } else if (key == 2) {
66                 float y = dot(linear, luminanceCoefficients) / 3.0;
67                 float index = y * float(size - 1);
68                 float gain = lut.eval(vec2(index, 0.0) + 0.5).r;
69                 linear = linear * gain;
70             }
71         } else if (dimension == 3) {
72             if (key == 0) {
73                 float tx = linear.r * float(size - 1);
74                 float ty = linear.g * float(size - 1);
75                 float tz = linear.b * float(size - 1);
76 
77                 // calculate lower and upper bounds for each dimension
78                 int x = int(tx);
79                 int y = int(ty);
80                 int z = int(tz);
81 
82                 int i000 = x + y * size + z * size * size;
83                 int i100 = i000 + 1;
84                 int i010 = i000 + size;
85                 int i110 = i000 + size + 1;
86                 int i001 = i000 + size * size;
87                 int i101 = i000 + size * size + 1;
88                 int i011 = i000 + size * size + size;
89                 int i111 = i000 + size * size + size + 1;
90 
91                 // get 1d normalized indices
92                 float c000 = float(i000) / float(size * size * size);
93                 float c100 = float(i100) / float(size * size * size);
94                 float c010 = float(i010) / float(size * size * size);
95                 float c110 = float(i110) / float(size * size * size);
96                 float c001 = float(i001) / float(size * size * size);
97                 float c101 = float(i101) / float(size * size * size);
98                 float c011 = float(i011) / float(size * size * size);
99                 float c111 = float(i111) / float(size * size * size);
100 
101                 //TODO(b/377984618): support Tetrahedral interpolation
102                 // perform trilinear interpolation
103                 float3 c00 = mix(lut.eval(vec2(c000, 0.0) + 0.5).rgb,
104                                  lut.eval(vec2(c100, 0.0) + 0.5).rgb, linear.r);
105                 float3 c01 = mix(lut.eval(vec2(c001, 0.0) + 0.5).rgb,
106                                  lut.eval(vec2(c101, 0.0) + 0.5).rgb, linear.r);
107                 float3 c10 = mix(lut.eval(vec2(c010, 0.0) + 0.5).rgb,
108                                  lut.eval(vec2(c110, 0.0) + 0.5).rgb, linear.r);
109                 float3 c11 = mix(lut.eval(vec2(c011, 0.0) + 0.5).rgb,
110                                  lut.eval(vec2(c111, 0.0) + 0.5).rgb, linear.r);
111 
112                 float3 c0 = mix(c00, c10, linear.g);
113                 float3 c1 = mix(c01, c11, linear.g);
114 
115                 linear = mix(c0, c1, linear.b);
116             }
117         }
118         return float4(fromLinearSrgb(linear), rgba.a);
119     })");
120 
121 // same as shader::toColorSpace function
122 // TODO: put this function in a general place
toColorSpace(ui::Dataspace dataspace)123 static ColorSpace toColorSpace(ui::Dataspace dataspace) {
124     switch (dataspace & HAL_DATASPACE_STANDARD_MASK) {
125         case HAL_DATASPACE_STANDARD_BT709:
126             return ColorSpace::sRGB();
127         case HAL_DATASPACE_STANDARD_DCI_P3:
128             return ColorSpace::DisplayP3();
129         case HAL_DATASPACE_STANDARD_BT2020:
130         case HAL_DATASPACE_STANDARD_BT2020_CONSTANT_LUMINANCE:
131             return ColorSpace::BT2020();
132         case HAL_DATASPACE_STANDARD_ADOBE_RGB:
133             return ColorSpace::AdobeRGB();
134         case HAL_DATASPACE_STANDARD_BT601_625:
135         case HAL_DATASPACE_STANDARD_BT601_625_UNADJUSTED:
136         case HAL_DATASPACE_STANDARD_BT601_525:
137         case HAL_DATASPACE_STANDARD_BT601_525_UNADJUSTED:
138         case HAL_DATASPACE_STANDARD_BT470M:
139         case HAL_DATASPACE_STANDARD_FILM:
140         case HAL_DATASPACE_STANDARD_UNSPECIFIED:
141         default:
142             return ColorSpace::sRGB();
143     }
144 }
145 
generateLutShader(sk_sp<SkShader> input,const std::vector<float> & buffers,const int32_t offset,const int32_t length,const int32_t dimension,const int32_t size,const int32_t samplingKey,ui::Dataspace srcDataspace)146 sk_sp<SkShader> LutShader::generateLutShader(sk_sp<SkShader> input,
147                                              const std::vector<float>& buffers,
148                                              const int32_t offset, const int32_t length,
149                                              const int32_t dimension, const int32_t size,
150                                              const int32_t samplingKey,
151                                              ui::Dataspace srcDataspace) {
152     SFTRACE_NAME("lut shader");
153     std::vector<half> buffer(length * 4); // 4 is for RGBA
154     auto d = static_cast<LutProperties::Dimension>(dimension);
155     if (d == LutProperties::Dimension::ONE_D) {
156         auto it = buffers.begin() + offset;
157         std::generate(buffer.begin(), buffer.end(), [it, i = 0]() mutable {
158             float val = (i++ % 4 == 0) ? *it++ : 0.0f;
159             return half(val);
160         });
161     } else {
162         for (int i = 0; i < length; i++) {
163             buffer[i * 4] = half(buffers[offset + i]);
164             buffer[i * 4 + 1] = half(buffers[offset + length + i]);
165             buffer[i * 4 + 2] = half(buffers[offset + length * 2 + i]);
166             buffer[i * 4 + 3] = half(0);
167         }
168     }
169     /**
170      * 1D Lut RGB/MAX_RGB
171      * (R0, 0, 0, 0)
172      * (R1, 0, 0, 0)
173      *
174      * 1D Lut CIE_Y
175      * (Y0, 0, 0, 0)
176      * (Y1, 0, 0, 0)
177      * ...
178      *
179      * 3D Lut MAX_RGB
180      * (R0, G0, B0, 0)
181      * (R1, G1, B1, 0)
182      * ...
183      */
184     SkImageInfo info = SkImageInfo::Make(length /* the number of rgba */, 1, kRGBA_F16_SkColorType,
185                                          kPremul_SkAlphaType);
186     SkBitmap bitmap;
187     bitmap.allocPixels(info);
188     if (!bitmap.installPixels(info, buffer.data(), info.minRowBytes())) {
189         ALOGW("bitmap.installPixels failed, skip this Lut!");
190         return input;
191     }
192 
193     sk_sp<SkImage> lutImage = SkImages::RasterFromBitmap(bitmap);
194     if (!lutImage) {
195         ALOGW("Got a nullptr from SkImages::RasterFromBitmap, skip this Lut!");
196         return input;
197     }
198 
199     mBuilder->child("image") = input;
200     mBuilder->child("lut") =
201             lutImage->makeRawShader(SkTileMode::kClamp, SkTileMode::kClamp,
202                                     d == LutProperties::Dimension::ONE_D
203                                             ? SkSamplingOptions(SkFilterMode::kLinear)
204                                             : SkSamplingOptions());
205 
206     float normalizeScalar = 1.0;
207     switch (srcDataspace & HAL_DATASPACE_TRANSFER_MASK) {
208         case HAL_DATASPACE_TRANSFER_HLG:
209             normalizeScalar = 0.203;
210             break;
211         case HAL_DATASPACE_TRANSFER_ST2084:
212             normalizeScalar = 0.0203;
213             break;
214         default:
215             normalizeScalar = 1.0;
216     }
217     const int uSize = static_cast<int>(size);
218     const int uKey = static_cast<int>(samplingKey);
219     const int uDimension = static_cast<int>(dimension);
220     const float uNormalizeScalar = static_cast<float>(normalizeScalar);
221 
222     if (static_cast<LutProperties::SamplingKey>(samplingKey) == LutProperties::SamplingKey::CIE_Y) {
223         // Use predefined colorspaces of input dataspace so that we can get D65 illuminant
224         mat3 toXYZMatrix(toColorSpace(srcDataspace).getRGBtoXYZ());
225         mBuilder->uniform("luminanceCoefficients") =
226                 SkV3{toXYZMatrix[0][1], toXYZMatrix[1][1], toXYZMatrix[2][1]};
227     } else {
228         mBuilder->uniform("luminanceCoefficients") = SkV3{1.f, 1.f, 1.f};
229     }
230     mBuilder->uniform("size") = uSize;
231     mBuilder->uniform("key") = uKey;
232     mBuilder->uniform("dimension") = uDimension;
233     mBuilder->uniform("normalizeScalar") = uNormalizeScalar;
234     return mBuilder->makeShader();
235 }
236 
lutShader(sk_sp<SkShader> & input,std::shared_ptr<gui::DisplayLuts> displayLuts,ui::Dataspace srcDataspace,sk_sp<SkColorSpace> outColorSpace)237 sk_sp<SkShader> LutShader::lutShader(sk_sp<SkShader>& input,
238                                      std::shared_ptr<gui::DisplayLuts> displayLuts,
239                                      ui::Dataspace srcDataspace,
240                                      sk_sp<SkColorSpace> outColorSpace) {
241     if (mBuilder == nullptr) {
242         const static SkRuntimeEffect::Result instance = SkRuntimeEffect::MakeForShader(kShader);
243         mBuilder = std::make_unique<SkRuntimeShaderBuilder>(instance.effect);
244     }
245 
246     auto& fd = displayLuts->getLutFileDescriptor();
247     if (fd.ok()) {
248         // de-gamma the image without changing the primaries
249         SkImage* baseImage = input->isAImage((SkMatrix*)nullptr, (SkTileMode*)nullptr);
250         sk_sp<SkColorSpace> baseColorSpace = baseImage && baseImage->colorSpace()
251                 ? baseImage->refColorSpace()
252                 : SkColorSpace::MakeSRGB();
253         sk_sp<SkColorSpace> lutMathColorSpace = baseColorSpace->makeLinearGamma();
254         input = input->makeWithWorkingColorSpace(lutMathColorSpace);
255 
256         auto& offsets = displayLuts->offsets;
257         auto& lutProperties = displayLuts->lutProperties;
258         std::vector<float> buffers;
259         int fullLength = offsets[lutProperties.size() - 1];
260         if (lutProperties[lutProperties.size() - 1].dimension == 1) {
261             fullLength += lutProperties[lutProperties.size() - 1].size;
262         } else {
263             fullLength += (lutProperties[lutProperties.size() - 1].size *
264                            lutProperties[lutProperties.size() - 1].size *
265                            lutProperties[lutProperties.size() - 1].size * 3);
266         }
267         size_t bufferSize = fullLength * sizeof(float);
268 
269         // decode the shared memory of luts
270         float* ptr =
271                 (float*)mmap(NULL, bufferSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0);
272         if (ptr == MAP_FAILED) {
273             LOG_ALWAYS_FATAL("mmap failed");
274         }
275         buffers = std::vector<float>(ptr, ptr + fullLength);
276         munmap(ptr, bufferSize);
277 
278         for (size_t i = 0; i < offsets.size(); i++) {
279             int bufferSizePerLut = (i == offsets.size() - 1) ? buffers.size() - offsets[i]
280                                                              : offsets[i + 1] - offsets[i];
281             // divide by 3 for 3d Lut because of 3 (RGB) channels
282             if (static_cast<LutProperties::Dimension>(lutProperties[i].dimension) ==
283                 LutProperties::Dimension::THREE_D) {
284                 bufferSizePerLut /= 3;
285             }
286             input = generateLutShader(input, buffers, offsets[i], bufferSizePerLut,
287                                       lutProperties[i].dimension, lutProperties[i].size,
288                                       lutProperties[i].samplingKey, srcDataspace);
289         }
290 
291         input = input->makeWithWorkingColorSpace(outColorSpace);
292     }
293     return input;
294 }
295 
296 } // namespace skia
297 } // namespace renderengine
298 } // namespace android