• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #if GOOGLE_CUDA
16 #define EIGEN_USE_GPU
17 #endif
18 
19 #include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h"
20 #include <memory>
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/util/work_sharder.h"
27 
28 namespace tensorflow {
29 
30 typedef Eigen::ThreadPoolDevice CPUDevice;
31 typedef Eigen::GpuDevice GPUDevice;
32 
33 class AdjustHsvInYiqOpBase : public OpKernel {
34  protected:
AdjustHsvInYiqOpBase(OpKernelConstruction * context)35   explicit AdjustHsvInYiqOpBase(OpKernelConstruction* context)
36       : OpKernel(context) {}
37 
38   struct ComputeOptions {
39     const Tensor* input = nullptr;
40     Tensor* output = nullptr;
41     const Tensor* delta_h = nullptr;
42     const Tensor* scale_s = nullptr;
43     const Tensor* scale_v = nullptr;
44     int64 channel_count = 0;
45   };
46 
47   virtual void DoCompute(OpKernelContext* context,
48                          const ComputeOptions& options) = 0;
49 
Compute(OpKernelContext * context)50   void Compute(OpKernelContext* context) override {
51     const Tensor& input = context->input(0);
52     const Tensor& delta_h = context->input(1);
53     const Tensor& scale_s = context->input(2);
54     const Tensor& scale_v = context->input(3);
55     OP_REQUIRES(context, input.dims() >= 3,
56                 errors::InvalidArgument("input must be at least 3-D, got shape",
57                                         input.shape().DebugString()));
58     OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_h.shape()),
59                 errors::InvalidArgument("delta_h must be scalar: ",
60                                         delta_h.shape().DebugString()));
61     OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_s.shape()),
62                 errors::InvalidArgument("scale_s must be scalar: ",
63                                         scale_s.shape().DebugString()));
64     OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_v.shape()),
65                 errors::InvalidArgument("scale_v must be scalar: ",
66                                         scale_v.shape().DebugString()));
67     auto channels = input.dim_size(input.dims() - 1);
68     OP_REQUIRES(
69         context, channels == kChannelSize,
70         errors::InvalidArgument("input must have 3 channels but instead has ",
71                                 channels, " channels."));
72 
73     Tensor* output = nullptr;
74     OP_REQUIRES_OK(context,
75                    context->allocate_output(0, input.shape(), &output));
76 
77     if (input.NumElements() > 0) {
78       const int64 channel_count = input.NumElements() / channels;
79       ComputeOptions options;
80       options.input = &input;
81       options.delta_h = &delta_h;
82       options.scale_s = &scale_s;
83       options.scale_v = &scale_v;
84       options.output = output;
85       options.channel_count = channel_count;
86       DoCompute(context, options);
87     }
88   }
89 };
90 
91 template <class Device>
92 class AdjustHsvInYiqOp;
93 
94 template <>
95 class AdjustHsvInYiqOp<CPUDevice> : public AdjustHsvInYiqOpBase {
96  public:
AdjustHsvInYiqOp(OpKernelConstruction * context)97   explicit AdjustHsvInYiqOp(OpKernelConstruction* context)
98       : AdjustHsvInYiqOpBase(context) {}
99 
DoCompute(OpKernelContext * context,const ComputeOptions & options)100   void DoCompute(OpKernelContext* context,
101                  const ComputeOptions& options) override {
102     const Tensor* input = options.input;
103     Tensor* output = options.output;
104     const int64 channel_count = options.channel_count;
105     auto input_data = input->shaped<float, 2>({channel_count, kChannelSize});
106     const float delta_h = options.delta_h->scalar<float>()();
107     const float scale_s = options.scale_s->scalar<float>()();
108     const float scale_v = options.scale_v->scalar<float>()();
109     auto output_data = output->shaped<float, 2>({channel_count, kChannelSize});
110     float tranformation_matrix[kChannelSize * kChannelSize] = {0};
111     internal::compute_tranformation_matrix<kChannelSize * kChannelSize>(
112         delta_h, scale_s, scale_v, tranformation_matrix);
113     const int kCostPerChannel = 10;
114     const DeviceBase::CpuWorkerThreads& worker_threads =
115         *context->device()->tensorflow_cpu_worker_threads();
116     Shard(worker_threads.num_threads, worker_threads.workers, channel_count,
117           kCostPerChannel,
118           [&input_data, &output_data, &tranformation_matrix](
119               int64 start_channel, int64 end_channel) {
120             // Applying projection matrix to input RGB vectors.
121             const float* p = input_data.data() + start_channel * kChannelSize;
122             float* q = output_data.data() + start_channel * kChannelSize;
123             for (int i = start_channel; i < end_channel; i++) {
124               for (int q_index = 0; q_index < kChannelSize; q_index++) {
125                 q[q_index] = 0;
126                 for (int p_index = 0; p_index < kChannelSize; p_index++) {
127                   q[q_index] +=
128                       p[p_index] *
129                       tranformation_matrix[q_index + kChannelSize * p_index];
130                 }
131               }
132               p += kChannelSize;
133               q += kChannelSize;
134             }
135           });
136   }
137 };
138 
139 REGISTER_KERNEL_BUILDER(
140     Name("AdjustHsvInYiq").Device(DEVICE_CPU).TypeConstraint<float>("T"),
141     AdjustHsvInYiqOp<CPUDevice>);
142 
143 #if GOOGLE_CUDA
144 template <>
145 class AdjustHsvInYiqOp<GPUDevice> : public AdjustHsvInYiqOpBase {
146  public:
AdjustHsvInYiqOp(OpKernelConstruction * context)147   explicit AdjustHsvInYiqOp(OpKernelConstruction* context)
148       : AdjustHsvInYiqOpBase(context) {}
149 
DoCompute(OpKernelContext * ctx,const ComputeOptions & options)150   void DoCompute(OpKernelContext* ctx, const ComputeOptions& options) override {
151     const int64 number_of_elements = options.input->NumElements();
152     if (number_of_elements <= 0) {
153       return;
154     }
155     const float* delta_h = options.delta_h->flat<float>().data();
156     const float* scale_s = options.scale_s->flat<float>().data();
157     const float* scale_v = options.scale_v->flat<float>().data();
158     functor::AdjustHsvInYiqGPU()(ctx, options.channel_count, options.input,
159                                  delta_h, scale_s, scale_v, options.output);
160   }
161 };
162 
163 REGISTER_KERNEL_BUILDER(
164     Name("AdjustHsvInYiq").Device(DEVICE_GPU).TypeConstraint<float>("T"),
165     AdjustHsvInYiqOp<GPUDevice>);
166 #endif
167 
168 }  // namespace tensorflow
169