• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16 
17 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
18 #define EIGEN_USE_GPU
19 #endif
20 
21 #include "tensorflow/core/kernels/image/adjust_saturation_op.h"
22 
23 #include <memory>
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/util/work_sharder.h"
34 
35 namespace tensorflow {
36 
37 typedef Eigen::ThreadPoolDevice CPUDevice;
38 typedef Eigen::GpuDevice GPUDevice;
39 
40 class AdjustSaturationOpBase : public OpKernel {
41  protected:
AdjustSaturationOpBase(OpKernelConstruction * context)42   explicit AdjustSaturationOpBase(OpKernelConstruction* context)
43       : OpKernel(context) {}
44 
45   struct ComputeOptions {
46     const Tensor* input;
47     const Tensor* scale;
48     Tensor* output;
49     int64_t channel_count;
50   };
51 
52   virtual void DoCompute(OpKernelContext* context,
53                          const ComputeOptions& options) = 0;
54 
Compute(OpKernelContext * context)55   void Compute(OpKernelContext* context) override {
56     const Tensor& input = context->input(0);
57     const Tensor& scale = context->input(1);
58     OP_REQUIRES(context, input.dims() >= 3,
59                 errors::InvalidArgument("input must be at least 3-D, got shape",
60                                         input.shape().DebugString()));
61     OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale.shape()),
62                 errors::InvalidArgument("scale must be scalar: ",
63                                         scale.shape().DebugString()));
64     auto channels = input.dim_size(input.dims() - 1);
65     OP_REQUIRES(
66         context, channels == 3,
67         errors::InvalidArgument("input must have 3 channels but instead has ",
68                                 channels, " channels."));
69 
70     Tensor* output = nullptr;
71     OP_REQUIRES_OK(context,
72                    context->allocate_output(0, input.shape(), &output));
73 
74     if (input.NumElements() > 0) {
75       const int64_t channel_count = input.NumElements() / channels;
76       ComputeOptions options;
77       options.input = &input;
78       options.scale = &scale;
79       options.output = output;
80       options.channel_count = channel_count;
81       DoCompute(context, options);
82     }
83   }
84 };
85 
86 template <class Device, typename T>
87 class AdjustSaturationOp;
88 
89 namespace internal {
rgb_to_hsv(float r,float g,float b,float * h,float * s,float * v)90 static void rgb_to_hsv(float r, float g, float b, float* h, float* s,
91                        float* v) {
92   float vv = std::max(r, std::max(g, b));
93   float range = vv - std::min(r, std::min(g, b));
94   if (vv > 0) {
95     *s = range / vv;
96   } else {
97     *s = 0;
98   }
99   float norm = 1.0f / (6.0f * range);
100   float hh;
101   if (r == vv) {
102     hh = norm * (g - b);
103   } else if (g == vv) {
104     hh = norm * (b - r) + 2.0 / 6.0;
105   } else {
106     hh = norm * (r - g) + 4.0 / 6.0;
107   }
108   if (range <= 0.0) {
109     hh = 0;
110   }
111   if (hh < 0.0) {
112     hh = hh + 1;
113   }
114   *v = vv;
115   *h = hh;
116 }
117 
118 // Algorithm from wikipedia, https://en.wikipedia.org/wiki/HSL_and_HSV#From_HSV
hsv_to_rgb(float h,float s,float v,float * r,float * g,float * b)119 static void hsv_to_rgb(float h, float s, float v, float* r, float* g,
120                        float* b) {
121   float c = s * v;
122   float m = v - c;
123   float dh = h * 6;
124   float rr, gg, bb;
125   int h_category = static_cast<int>(dh);
126   float fmodu = dh;
127   while (fmodu <= 0) {
128     fmodu += 2.0f;
129   }
130   while (fmodu >= 2.0f) {
131     fmodu -= 2.0f;
132   }
133   float x = c * (1 - std::abs(fmodu - 1));
134   switch (h_category) {
135     case 0:
136       rr = c;
137       gg = x;
138       bb = 0;
139       break;
140     case 1:
141       rr = x;
142       gg = c;
143       bb = 0;
144       break;
145     case 2:
146       rr = 0;
147       gg = c;
148       bb = x;
149       break;
150     case 3:
151       rr = 0;
152       gg = x;
153       bb = c;
154       break;
155     case 4:
156       rr = x;
157       gg = 0;
158       bb = c;
159       break;
160     case 5:
161       rr = c;
162       gg = 0;
163       bb = x;
164       break;
165     default:
166       rr = 0;
167       gg = 0;
168       bb = 0;
169   }
170   *r = rr + m;
171   *g = gg + m;
172   *b = bb + m;
173 }
174 
175 }  // namespace internal
176 
177 template <>
178 class AdjustSaturationOp<CPUDevice, float> : public AdjustSaturationOpBase {
179  public:
AdjustSaturationOp(OpKernelConstruction * context)180   explicit AdjustSaturationOp(OpKernelConstruction* context)
181       : AdjustSaturationOpBase(context) {}
182 
DoCompute(OpKernelContext * context,const ComputeOptions & options)183   void DoCompute(OpKernelContext* context,
184                  const ComputeOptions& options) override {
185     const Tensor* input = options.input;
186     const Tensor* scale = options.scale;
187     Tensor* output = options.output;
188     const int64_t channel_count = options.channel_count;
189     static const int kChannelSize = 3;
190     auto input_data = input->shaped<float, 2>({channel_count, kChannelSize});
191     const float scale_h = scale->scalar<float>()();
192     auto output_data = output->shaped<float, 2>({channel_count, kChannelSize});
193     const int kCostPerChannel = 10;
194     const DeviceBase::CpuWorkerThreads& worker_threads =
195         *context->device()->tensorflow_cpu_worker_threads();
196     Shard(worker_threads.num_threads, worker_threads.workers, channel_count,
197           kCostPerChannel,
198           [&input_data, &output_data, scale_h](int64_t start_channel,
199                                                int64_t end_channel) {
200             const float* p = input_data.data() + start_channel * kChannelSize;
201             float* q = output_data.data() + start_channel * kChannelSize;
202             for (int i = start_channel; i < end_channel; i++) {
203               float h, s, v;
204               // Convert the RGB color to Hue/V-range.
205               internal::rgb_to_hsv(p[0], p[1], p[2], &h, &s, &v);
206               s = std::min(1.0f, std::max(0.0f, s * scale_h));
207               // Convert the hue and v-range back into RGB.
208               internal::hsv_to_rgb(h, s, v, q, q + 1, q + 2);
209               p += kChannelSize;
210               q += kChannelSize;
211             }
212           });
213   }
214 };
215 
216 REGISTER_KERNEL_BUILDER(
217     Name("AdjustSaturation").Device(DEVICE_CPU).TypeConstraint<float>("T"),
218     AdjustSaturationOp<CPUDevice, float>);
219 
220 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
221 template <typename T>
222 class AdjustSaturationOp<GPUDevice, T> : public AdjustSaturationOpBase {
223  public:
AdjustSaturationOp(OpKernelConstruction * context)224   explicit AdjustSaturationOp(OpKernelConstruction* context)
225       : AdjustSaturationOpBase(context) {}
226 
DoCompute(OpKernelContext * context,const ComputeOptions & options)227   void DoCompute(OpKernelContext* context,
228                  const ComputeOptions& options) override {
229     const Tensor* input = options.input;
230     const Tensor* scale = options.scale;
231     Tensor* output = options.output;
232     const int64_t number_of_elements = input->NumElements();
233     GPUDevice device = context->eigen_gpu_device();
234     const auto stream = device.stream();
235     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
236     if (number_of_elements > 0) {
237       const T* input_data = input->flat<T>().data();
238       const float* scale_data = scale->flat<float>().data();
239       T* const output_data = output->flat<T>().data();
240       functor::AdjustSaturationGPU<T>()(&device, number_of_elements, input_data,
241                                         scale_data, output_data);
242     }
243   }
244 };
245 
246 #define REGISTER_GPU(T)                                                   \
247   REGISTER_KERNEL_BUILDER(                                                \
248       Name("AdjustSaturation").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
249       AdjustSaturationOp<GPUDevice, T>);
250 
251 REGISTER_GPU(float)
252 REGISTER_GPU(Eigen::half)
253 
254 #undef REGISTER_GPU
255 
256 #endif
257 
258 }  // namespace tensorflow
259