1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16
17 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
18 #define EIGEN_USE_GPU
19 #endif
20
21 #include "tensorflow/core/kernels/image/adjust_saturation_op.h"
22
23 #include <memory>
24
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/util/work_sharder.h"
34
35 namespace tensorflow {
36
37 typedef Eigen::ThreadPoolDevice CPUDevice;
38 typedef Eigen::GpuDevice GPUDevice;
39
40 class AdjustSaturationOpBase : public OpKernel {
41 protected:
AdjustSaturationOpBase(OpKernelConstruction * context)42 explicit AdjustSaturationOpBase(OpKernelConstruction* context)
43 : OpKernel(context) {}
44
45 struct ComputeOptions {
46 const Tensor* input;
47 const Tensor* scale;
48 Tensor* output;
49 int64_t channel_count;
50 };
51
52 virtual void DoCompute(OpKernelContext* context,
53 const ComputeOptions& options) = 0;
54
Compute(OpKernelContext * context)55 void Compute(OpKernelContext* context) override {
56 const Tensor& input = context->input(0);
57 const Tensor& scale = context->input(1);
58 OP_REQUIRES(context, input.dims() >= 3,
59 errors::InvalidArgument("input must be at least 3-D, got shape",
60 input.shape().DebugString()));
61 OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale.shape()),
62 errors::InvalidArgument("scale must be scalar: ",
63 scale.shape().DebugString()));
64 auto channels = input.dim_size(input.dims() - 1);
65 OP_REQUIRES(
66 context, channels == 3,
67 errors::InvalidArgument("input must have 3 channels but instead has ",
68 channels, " channels."));
69
70 Tensor* output = nullptr;
71 OP_REQUIRES_OK(context,
72 context->allocate_output(0, input.shape(), &output));
73
74 if (input.NumElements() > 0) {
75 const int64_t channel_count = input.NumElements() / channels;
76 ComputeOptions options;
77 options.input = &input;
78 options.scale = &scale;
79 options.output = output;
80 options.channel_count = channel_count;
81 DoCompute(context, options);
82 }
83 }
84 };
85
86 template <class Device, typename T>
87 class AdjustSaturationOp;
88
89 namespace internal {
rgb_to_hsv(float r,float g,float b,float * h,float * s,float * v)90 static void rgb_to_hsv(float r, float g, float b, float* h, float* s,
91 float* v) {
92 float vv = std::max(r, std::max(g, b));
93 float range = vv - std::min(r, std::min(g, b));
94 if (vv > 0) {
95 *s = range / vv;
96 } else {
97 *s = 0;
98 }
99 float norm = 1.0f / (6.0f * range);
100 float hh;
101 if (r == vv) {
102 hh = norm * (g - b);
103 } else if (g == vv) {
104 hh = norm * (b - r) + 2.0 / 6.0;
105 } else {
106 hh = norm * (r - g) + 4.0 / 6.0;
107 }
108 if (range <= 0.0) {
109 hh = 0;
110 }
111 if (hh < 0.0) {
112 hh = hh + 1;
113 }
114 *v = vv;
115 *h = hh;
116 }
117
118 // Algorithm from wikipedia, https://en.wikipedia.org/wiki/HSL_and_HSV#From_HSV
hsv_to_rgb(float h,float s,float v,float * r,float * g,float * b)119 static void hsv_to_rgb(float h, float s, float v, float* r, float* g,
120 float* b) {
121 float c = s * v;
122 float m = v - c;
123 float dh = h * 6;
124 float rr, gg, bb;
125 int h_category = static_cast<int>(dh);
126 float fmodu = dh;
127 while (fmodu <= 0) {
128 fmodu += 2.0f;
129 }
130 while (fmodu >= 2.0f) {
131 fmodu -= 2.0f;
132 }
133 float x = c * (1 - std::abs(fmodu - 1));
134 switch (h_category) {
135 case 0:
136 rr = c;
137 gg = x;
138 bb = 0;
139 break;
140 case 1:
141 rr = x;
142 gg = c;
143 bb = 0;
144 break;
145 case 2:
146 rr = 0;
147 gg = c;
148 bb = x;
149 break;
150 case 3:
151 rr = 0;
152 gg = x;
153 bb = c;
154 break;
155 case 4:
156 rr = x;
157 gg = 0;
158 bb = c;
159 break;
160 case 5:
161 rr = c;
162 gg = 0;
163 bb = x;
164 break;
165 default:
166 rr = 0;
167 gg = 0;
168 bb = 0;
169 }
170 *r = rr + m;
171 *g = gg + m;
172 *b = bb + m;
173 }
174
175 } // namespace internal
176
177 template <>
178 class AdjustSaturationOp<CPUDevice, float> : public AdjustSaturationOpBase {
179 public:
AdjustSaturationOp(OpKernelConstruction * context)180 explicit AdjustSaturationOp(OpKernelConstruction* context)
181 : AdjustSaturationOpBase(context) {}
182
DoCompute(OpKernelContext * context,const ComputeOptions & options)183 void DoCompute(OpKernelContext* context,
184 const ComputeOptions& options) override {
185 const Tensor* input = options.input;
186 const Tensor* scale = options.scale;
187 Tensor* output = options.output;
188 const int64_t channel_count = options.channel_count;
189 static const int kChannelSize = 3;
190 auto input_data = input->shaped<float, 2>({channel_count, kChannelSize});
191 const float scale_h = scale->scalar<float>()();
192 auto output_data = output->shaped<float, 2>({channel_count, kChannelSize});
193 const int kCostPerChannel = 10;
194 const DeviceBase::CpuWorkerThreads& worker_threads =
195 *context->device()->tensorflow_cpu_worker_threads();
196 Shard(worker_threads.num_threads, worker_threads.workers, channel_count,
197 kCostPerChannel,
198 [&input_data, &output_data, scale_h](int64_t start_channel,
199 int64_t end_channel) {
200 const float* p = input_data.data() + start_channel * kChannelSize;
201 float* q = output_data.data() + start_channel * kChannelSize;
202 for (int i = start_channel; i < end_channel; i++) {
203 float h, s, v;
204 // Convert the RGB color to Hue/V-range.
205 internal::rgb_to_hsv(p[0], p[1], p[2], &h, &s, &v);
206 s = std::min(1.0f, std::max(0.0f, s * scale_h));
207 // Convert the hue and v-range back into RGB.
208 internal::hsv_to_rgb(h, s, v, q, q + 1, q + 2);
209 p += kChannelSize;
210 q += kChannelSize;
211 }
212 });
213 }
214 };
215
216 REGISTER_KERNEL_BUILDER(
217 Name("AdjustSaturation").Device(DEVICE_CPU).TypeConstraint<float>("T"),
218 AdjustSaturationOp<CPUDevice, float>);
219
220 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
221 template <typename T>
222 class AdjustSaturationOp<GPUDevice, T> : public AdjustSaturationOpBase {
223 public:
AdjustSaturationOp(OpKernelConstruction * context)224 explicit AdjustSaturationOp(OpKernelConstruction* context)
225 : AdjustSaturationOpBase(context) {}
226
DoCompute(OpKernelContext * context,const ComputeOptions & options)227 void DoCompute(OpKernelContext* context,
228 const ComputeOptions& options) override {
229 const Tensor* input = options.input;
230 const Tensor* scale = options.scale;
231 Tensor* output = options.output;
232 const int64_t number_of_elements = input->NumElements();
233 GPUDevice device = context->eigen_gpu_device();
234 const auto stream = device.stream();
235 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
236 if (number_of_elements > 0) {
237 const T* input_data = input->flat<T>().data();
238 const float* scale_data = scale->flat<float>().data();
239 T* const output_data = output->flat<T>().data();
240 functor::AdjustSaturationGPU<T>()(&device, number_of_elements, input_data,
241 scale_data, output_data);
242 }
243 }
244 };
245
246 #define REGISTER_GPU(T) \
247 REGISTER_KERNEL_BUILDER( \
248 Name("AdjustSaturation").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
249 AdjustSaturationOp<GPUDevice, T>);
250
251 REGISTER_GPU(float)
252 REGISTER_GPU(Eigen::half)
253
254 #undef REGISTER_GPU
255
256 #endif
257
258 } // namespace tensorflow
259