• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 
6     http://www.apache.org/licenses/LICENSE-2.0
7 
8 Unless required by applicable law or agreed to in writing, software
9 distributed under the License is distributed on an "AS IS" BASIS,
10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 See the License for the specific language governing permissions and
12 limitations under the License.
13 ==============================================================================*/
14 #define EIGEN_USE_THREADS
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 #define EIGEN_USE_GPU
18 #endif
19 
20 #include "tensorflow/core/kernels/image/adjust_hue_op.h"
21 
22 #include <memory>
23 
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/tensor_types.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/util/work_sharder.h"
34 
35 namespace tensorflow {
36 
37 typedef Eigen::ThreadPoolDevice CPUDevice;
38 typedef Eigen::GpuDevice GPUDevice;
39 
40 class AdjustHueOpBase : public OpKernel {
41  protected:
AdjustHueOpBase(OpKernelConstruction * context)42   explicit AdjustHueOpBase(OpKernelConstruction* context) : OpKernel(context) {}
43 
44   struct ComputeOptions {
45     const Tensor* input;
46     const Tensor* delta;
47     Tensor* output;
48     int64 channel_count;
49   };
50 
51   virtual void DoCompute(OpKernelContext* context,
52                          const ComputeOptions& options) = 0;
53 
Compute(OpKernelContext * context)54   void Compute(OpKernelContext* context) override {
55     const Tensor& input = context->input(0);
56     const Tensor& delta = context->input(1);
57     OP_REQUIRES(context, input.dims() >= 3,
58                 errors::InvalidArgument("input must be at least 3-D, got shape",
59                                         input.shape().DebugString()));
60     OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta.shape()),
61                 errors::InvalidArgument("delta must be scalar: ",
62                                         delta.shape().DebugString()));
63     auto channels = input.dim_size(input.dims() - 1);
64     OP_REQUIRES(
65         context, channels == 3,
66         errors::InvalidArgument("input must have 3 channels but instead has ",
67                                 channels, " channels."));
68 
69     Tensor* output = nullptr;
70     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
71                                 {0}, 0, input.shape(), &output));
72 
73     if (input.NumElements() > 0) {
74       const int64 channel_count = input.NumElements() / channels;
75       ComputeOptions options;
76       options.input = &input;
77       options.delta = &delta;
78       options.output = output;
79       options.channel_count = channel_count;
80       DoCompute(context, options);
81     }
82   }
83 };
84 
85 template <class Device, typename T>
86 class AdjustHueOp;
87 
88 namespace internal {
89 
90 // Helper function to convert a RGB color to H-and-V-range. H is in the range
91 // of [0, 6] instead of the normal [0, 1]
rgb_to_hv_range(float r,float g,float b,float * h,float * v_min,float * v_max)92 static void rgb_to_hv_range(float r, float g, float b, float* h, float* v_min,
93                             float* v_max) {
94   float v_mid;
95   int h_category;
96   // According to the figures in:
97   // https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma
98   // For the conditions, we don't care about the case where two components are
99   // equal. It is okay to count it in either side in that case.
100   if (r < g) {
101     if (b < r) {
102       // b < r < g
103       *v_max = g;
104       v_mid = r;
105       *v_min = b;
106       h_category = 1;
107     } else if (b > g) {
108       // r < g < b
109       *v_max = b;
110       v_mid = g;
111       *v_min = r;
112       h_category = 3;
113     } else {
114       // r < b < g
115       *v_max = g;
116       v_mid = b;
117       *v_min = r;
118       h_category = 2;
119     }
120   } else {
121     // g < r
122     if (b < g) {
123       // b < g < r
124       *v_max = r;
125       v_mid = g;
126       *v_min = b;
127       h_category = 0;
128     } else if (b > r) {
129       // g < r < b
130       *v_max = b;
131       v_mid = r;
132       *v_min = g;
133       h_category = 4;
134     } else {
135       // g < b < r
136       *v_max = r;
137       v_mid = b;
138       *v_min = g;
139       h_category = 5;
140     }
141   }
142   if (*v_max == *v_min) {
143     *h = 0;
144     return;
145   }
146   auto ratio = (v_mid - *v_min) / (*v_max - *v_min);
147   bool increase = ((h_category & 0x1) == 0);
148   *h = h_category + (increase ? ratio : (1 - ratio));
149 }
150 
151 // Helper function to convert from H-and-V-range to RGB.
hv_range_to_rgb(float h,float v_min,float v_max,float * r,float * g,float * b)152 static void hv_range_to_rgb(float h, float v_min, float v_max, float* r,
153                             float* g, float* b) {
154   int h_category = static_cast<int>(h);
155   float ratio = h - h_category;
156   bool increase = ((h_category & 0x1) == 0);
157   if (!increase) {
158     ratio = 1 - ratio;
159   }
160   float v_mid = v_min + ratio * (v_max - v_min);
161   // According to the figures in:
162   // https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma
163   switch (h_category) {
164     case 0:
165       *r = v_max;
166       *g = v_mid;
167       *b = v_min;
168       break;
169     case 1:
170       *r = v_mid;
171       *g = v_max;
172       *b = v_min;
173       break;
174     case 2:
175       *r = v_min;
176       *g = v_max;
177       *b = v_mid;
178       break;
179     case 3:
180       *r = v_min;
181       *g = v_mid;
182       *b = v_max;
183       break;
184     case 4:
185       *r = v_mid;
186       *g = v_min;
187       *b = v_max;
188       break;
189     case 5:
190     default:
191       *r = v_max;
192       *g = v_min;
193       *b = v_mid;
194   }
195 }
196 }  // namespace internal
197 
198 template <>
199 class AdjustHueOp<CPUDevice, float> : public AdjustHueOpBase {
200  public:
AdjustHueOp(OpKernelConstruction * context)201   explicit AdjustHueOp(OpKernelConstruction* context)
202       : AdjustHueOpBase(context) {}
203 
DoCompute(OpKernelContext * context,const ComputeOptions & options)204   void DoCompute(OpKernelContext* context,
205                  const ComputeOptions& options) override {
206     const Tensor* input = options.input;
207     const Tensor* delta = options.delta;
208     Tensor* output = options.output;
209     const int64 channel_count = options.channel_count;
210     static const int kChannelSize = 3;
211     auto input_data = input->shaped<float, 2>({channel_count, kChannelSize});
212     const float delta_h = delta->scalar<float>()();
213     auto output_data = output->shaped<float, 2>({channel_count, kChannelSize});
214     const int kCostPerChannel = 10;
215     const DeviceBase::CpuWorkerThreads& worker_threads =
216         *context->device()->tensorflow_cpu_worker_threads();
217     Shard(worker_threads.num_threads, worker_threads.workers, channel_count,
218           kCostPerChannel,
219           [&input_data, &output_data, delta_h](int64 start_channel,
220                                                int64 end_channel) {
221             const float* p = input_data.data() + start_channel * kChannelSize;
222             float* q = output_data.data() + start_channel * kChannelSize;
223             for (int i = start_channel; i < end_channel; i++) {
224               float h, v_min, v_max;
225               // Convert the RGB color to Hue/V-range.
226               internal::rgb_to_hv_range(p[0], p[1], p[2], &h, &v_min, &v_max);
227               static const int kChannelRange = 6;
228               // Adjust the hue value. And adjust the hue back into the valid
229               // range of [0, 6). It is faster than a fmod by avoiding
230               // a float-point division since h is often very close to this
231               // range.
232               h += delta_h * kChannelRange;
233               while (h < 0) {
234                 h += kChannelRange;
235               }
236               while (h >= kChannelRange) {
237                 h -= kChannelRange;
238               }
239               // Convert the hue and v-range back into RGB.
240               internal::hv_range_to_rgb(h, v_min, v_max, q, q + 1, q + 2);
241               p += kChannelSize;
242               q += kChannelSize;
243             }
244           });
245   }
246 };
247 
248 REGISTER_KERNEL_BUILDER(
249     Name("AdjustHue").Device(DEVICE_CPU).TypeConstraint<float>("T"),
250     AdjustHueOp<CPUDevice, float>);
251 
252 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
253 template <typename T>
254 class AdjustHueOp<GPUDevice, T> : public AdjustHueOpBase {
255  public:
AdjustHueOp(OpKernelConstruction * context)256   explicit AdjustHueOp(OpKernelConstruction* context)
257       : AdjustHueOpBase(context) {}
258 
DoCompute(OpKernelContext * context,const ComputeOptions & options)259   void DoCompute(OpKernelContext* context,
260                  const ComputeOptions& options) override {
261     const Tensor* input = options.input;
262     const Tensor* delta = options.delta;
263     Tensor* output = options.output;
264     const int64 number_of_elements = input->NumElements();
265     GPUDevice device = context->eigen_gpu_device();
266     const auto stream = device.stream();
267     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
268     if (number_of_elements > 0) {
269       const T* input_data = input->flat<T>().data();
270       const float* delta_h = delta->flat<float>().data();
271       T* const output_data = output->flat<T>().data();
272       functor::AdjustHueGPU<T>()(&device, number_of_elements, input_data,
273                                  delta_h, output_data);
274     }
275   }
276 };
277 
278 #define REGISTER_GPU(T)                                            \
279   REGISTER_KERNEL_BUILDER(                                         \
280       Name("AdjustHue").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
281       AdjustHueOp<GPUDevice, T>);
282 
283 REGISTER_GPU(float)
284 REGISTER_GPU(Eigen::half)
285 
286 #undef REGISTER_GPU
287 
288 #endif
289 
290 //} // namespace functor
291 }  // namespace tensorflow
292