1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5
6 http://www.apache.org/licenses/LICENSE-2.0
7
8 Unless required by applicable law or agreed to in writing, software
9 distributed under the License is distributed on an "AS IS" BASIS,
10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 See the License for the specific language governing permissions and
12 limitations under the License.
13 ==============================================================================*/
14 #ifndef TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
15 #define TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
16
17 #if GOOGLE_CUDA
18
19 #define EIGEN_USE_GPU
20
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/framework/types.h"
23
24 namespace tensorflow {
25 namespace internal {
26
27 typedef struct RgbTuple {
28 float r;
29 float g;
30 float b;
31 } RgbTuple;
32
33 typedef struct HsvTuple {
34 float h;
35 float s;
36 float v;
37 } HsvTuple;
38
rgb2hsv_cuda(const float r,const float g,const float b)39 inline __device__ HsvTuple rgb2hsv_cuda(const float r, const float g,
40 const float b) {
41 HsvTuple tuple;
42 const float M = fmaxf(r, fmaxf(g, b));
43 const float m = fminf(r, fminf(g, b));
44 const float chroma = M - m;
45 float h = 0.0f, s = 0.0f;
46 // hue
47 if (chroma > 0.0f) {
48 if (M == r) {
49 const float num = (g - b) / chroma;
50 const float sign = copysignf(1.0f, num);
51 h = ((sign < 0.0f) * 6.0f + sign * fmodf(sign * num, 6.0f)) / 6.0f;
52 } else if (M == g) {
53 h = ((b - r) / chroma + 2.0f) / 6.0f;
54 } else {
55 h = ((r - g) / chroma + 4.0f) / 6.0f;
56 }
57 } else {
58 h = 0.0f;
59 }
60 // saturation
61 if (M > 0.0) {
62 s = chroma / M;
63 } else {
64 s = 0.0f;
65 }
66 tuple.h = h;
67 tuple.s = s;
68 tuple.v = M;
69 return tuple;
70 }
71
hsv2rgb_cuda(const float h,const float s,const float v)72 inline __device__ RgbTuple hsv2rgb_cuda(const float h, const float s,
73 const float v) {
74 RgbTuple tuple;
75 const float new_h = h * 6.0f;
76 const float chroma = v * s;
77 const float x = chroma * (1.0f - fabsf(fmodf(new_h, 2.0f) - 1.0f));
78 const float new_m = v - chroma;
79 const bool between_0_and_1 = new_h >= 0.0f && new_h < 1.0f;
80 const bool between_1_and_2 = new_h >= 1.0f && new_h < 2.0f;
81 const bool between_2_and_3 = new_h >= 2.0f && new_h < 3.0f;
82 const bool between_3_and_4 = new_h >= 3.0f && new_h < 4.0f;
83 const bool between_4_and_5 = new_h >= 4.0f && new_h < 5.0f;
84 const bool between_5_and_6 = new_h >= 5.0f && new_h < 6.0f;
85 tuple.r = chroma * (between_0_and_1 || between_5_and_6) +
86 x * (between_1_and_2 || between_4_and_5) + new_m;
87 tuple.g = chroma * (between_1_and_2 || between_2_and_3) +
88 x * (between_0_and_1 || between_3_and_4) + new_m;
89 tuple.b = chroma * (between_3_and_4 || between_4_and_5) +
90 x * (between_2_and_3 || between_5_and_6) + new_m;
91 return tuple;
92 }
93
94 template <bool AdjustHue, bool AdjustSaturation, bool AdjustV, typename T>
adjust_hsv_nhwc(const int64 number_elements,const T * const __restrict__ input,T * const output,const float * const hue_delta,const float * const saturation_scale,const float * const value_scale)95 __global__ void adjust_hsv_nhwc(const int64 number_elements,
96 const T* const __restrict__ input,
97 T* const output, const float* const hue_delta,
98 const float* const saturation_scale,
99 const float* const value_scale) {
100 // multiply by 3 since we're dealing with contiguous RGB bytes for each pixel
101 // (NHWC)
102 const int64 idx = (blockDim.x * blockIdx.x + threadIdx.x) * 3;
103 // bounds check
104 if (idx > number_elements - 1) {
105 return;
106 }
107 if (!AdjustHue && !AdjustSaturation && !AdjustV) {
108 output[idx] = input[idx];
109 output[idx + 1] = input[idx + 1];
110 output[idx + 2] = input[idx + 2];
111 return;
112 }
113 const HsvTuple hsv = rgb2hsv_cuda(static_cast<float>(input[idx]),
114 static_cast<float>(input[idx + 1]),
115 static_cast<float>(input[idx + 2]));
116 float new_h = hsv.h;
117 float new_s = hsv.s;
118 float new_v = hsv.v;
119 // hue adjustment
120 if (AdjustHue) {
121 const float delta = *hue_delta;
122 new_h = fmodf(hsv.h + delta, 1.0f);
123 if (new_h < 0.0f) {
124 new_h = fmodf(1.0f + new_h, 1.0f);
125 }
126 }
127 // saturation adjustment
128 if (AdjustSaturation && saturation_scale != nullptr) {
129 const float scale = *saturation_scale;
130 new_s = fminf(1.0f, fmaxf(0.0f, hsv.s * scale));
131 }
132 // value adjustment
133 if (AdjustV && value_scale != nullptr) {
134 const float scale = *value_scale;
135 new_v = hsv.v * scale;
136 }
137 const RgbTuple rgb = hsv2rgb_cuda(new_h, new_s, new_v);
138 output[idx] = static_cast<T>(rgb.r);
139 output[idx + 1] = static_cast<T>(rgb.g);
140 output[idx + 2] = static_cast<T>(rgb.b);
141 }
142
143 } // namespace internal
144 } // namespace tensorflow
145
146 #endif // GOOGLE_CUDA
147 #endif // TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
148