1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #if GOOGLE_CUDA 16 #define EIGEN_USE_GPU 17 #endif 18 19 #include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" 20 #include <memory> 21 #include "tensorflow/core/framework/register_types.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/framework/tensor_shape.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/platform/logging.h" 26 #include "tensorflow/core/util/work_sharder.h" 27 28 namespace tensorflow { 29 30 typedef Eigen::ThreadPoolDevice CPUDevice; 31 typedef Eigen::GpuDevice GPUDevice; 32 33 class AdjustHsvInYiqOpBase : public OpKernel { 34 protected: AdjustHsvInYiqOpBase(OpKernelConstruction * context)35 explicit AdjustHsvInYiqOpBase(OpKernelConstruction* context) 36 : OpKernel(context) {} 37 38 struct ComputeOptions { 39 const Tensor* input = nullptr; 40 Tensor* output = nullptr; 41 const Tensor* delta_h = nullptr; 42 const Tensor* scale_s = nullptr; 43 const Tensor* scale_v = nullptr; 44 int64 channel_count = 0; 45 }; 46 47 virtual void DoCompute(OpKernelContext* context, 48 const ComputeOptions& options) = 0; 49 Compute(OpKernelContext * context)50 void Compute(OpKernelContext* context) override { 51 const Tensor& input = context->input(0); 52 const Tensor& delta_h = context->input(1); 53 const Tensor& scale_s = context->input(2); 54 const Tensor& scale_v = context->input(3); 55 OP_REQUIRES(context, input.dims() >= 3, 56 errors::InvalidArgument("input must be at least 3-D, got shape", 57 input.shape().DebugString())); 58 OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_h.shape()), 59 errors::InvalidArgument("delta_h must be scalar: ", 60 delta_h.shape().DebugString())); 61 OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_s.shape()), 62 errors::InvalidArgument("scale_s must be scalar: ", 63 scale_s.shape().DebugString())); 64 OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_v.shape()), 65 errors::InvalidArgument("scale_v must be scalar: ", 66 scale_v.shape().DebugString())); 67 auto channels = input.dim_size(input.dims() - 1); 68 OP_REQUIRES( 69 context, channels == kChannelSize, 70 errors::InvalidArgument("input must have 3 channels but instead has ", 71 channels, " channels.")); 72 73 Tensor* output = nullptr; 74 OP_REQUIRES_OK(context, 75 context->allocate_output(0, input.shape(), &output)); 76 77 if (input.NumElements() > 0) { 78 const int64 channel_count = input.NumElements() / channels; 79 ComputeOptions options; 80 options.input = &input; 81 options.delta_h = &delta_h; 82 options.scale_s = &scale_s; 83 options.scale_v = &scale_v; 84 options.output = output; 85 options.channel_count = channel_count; 86 DoCompute(context, options); 87 } 88 } 89 }; 90 91 template <class Device> 92 class AdjustHsvInYiqOp; 93 94 template <> 95 class AdjustHsvInYiqOp<CPUDevice> : public AdjustHsvInYiqOpBase { 96 public: AdjustHsvInYiqOp(OpKernelConstruction * context)97 explicit AdjustHsvInYiqOp(OpKernelConstruction* context) 98 : AdjustHsvInYiqOpBase(context) {} 99 DoCompute(OpKernelContext * context,const ComputeOptions & options)100 void DoCompute(OpKernelContext* context, 101 const ComputeOptions& options) override { 102 const Tensor* input = options.input; 103 Tensor* output = options.output; 104 const int64 channel_count = options.channel_count; 105 auto input_data = input->shaped<float, 2>({channel_count, kChannelSize}); 106 const float delta_h = options.delta_h->scalar<float>()(); 107 const float scale_s = options.scale_s->scalar<float>()(); 108 const float scale_v = options.scale_v->scalar<float>()(); 109 auto output_data = output->shaped<float, 2>({channel_count, kChannelSize}); 110 float tranformation_matrix[kChannelSize * kChannelSize] = {0}; 111 internal::compute_tranformation_matrix<kChannelSize * kChannelSize>( 112 delta_h, scale_s, scale_v, tranformation_matrix); 113 const int kCostPerChannel = 10; 114 const DeviceBase::CpuWorkerThreads& worker_threads = 115 *context->device()->tensorflow_cpu_worker_threads(); 116 Shard(worker_threads.num_threads, worker_threads.workers, channel_count, 117 kCostPerChannel, 118 [&input_data, &output_data, &tranformation_matrix]( 119 int64 start_channel, int64 end_channel) { 120 // Applying projection matrix to input RGB vectors. 121 const float* p = input_data.data() + start_channel * kChannelSize; 122 float* q = output_data.data() + start_channel * kChannelSize; 123 for (int i = start_channel; i < end_channel; i++) { 124 for (int q_index = 0; q_index < kChannelSize; q_index++) { 125 q[q_index] = 0; 126 for (int p_index = 0; p_index < kChannelSize; p_index++) { 127 q[q_index] += 128 p[p_index] * 129 tranformation_matrix[q_index + kChannelSize * p_index]; 130 } 131 } 132 p += kChannelSize; 133 q += kChannelSize; 134 } 135 }); 136 } 137 }; 138 139 REGISTER_KERNEL_BUILDER( 140 Name("AdjustHsvInYiq").Device(DEVICE_CPU).TypeConstraint<float>("T"), 141 AdjustHsvInYiqOp<CPUDevice>); 142 143 #if GOOGLE_CUDA 144 template <> 145 class AdjustHsvInYiqOp<GPUDevice> : public AdjustHsvInYiqOpBase { 146 public: AdjustHsvInYiqOp(OpKernelConstruction * context)147 explicit AdjustHsvInYiqOp(OpKernelConstruction* context) 148 : AdjustHsvInYiqOpBase(context) {} 149 DoCompute(OpKernelContext * ctx,const ComputeOptions & options)150 void DoCompute(OpKernelContext* ctx, const ComputeOptions& options) override { 151 const int64 number_of_elements = options.input->NumElements(); 152 if (number_of_elements <= 0) { 153 return; 154 } 155 const float* delta_h = options.delta_h->flat<float>().data(); 156 const float* scale_s = options.scale_s->flat<float>().data(); 157 const float* scale_v = options.scale_v->flat<float>().data(); 158 functor::AdjustHsvInYiqGPU()(ctx, options.channel_count, options.input, 159 delta_h, scale_s, scale_v, options.output); 160 } 161 }; 162 163 REGISTER_KERNEL_BUILDER( 164 Name("AdjustHsvInYiq").Device(DEVICE_GPU).TypeConstraint<float>("T"), 165 AdjustHsvInYiqOp<GPUDevice>); 166 #endif 167 168 } // namespace tensorflow 169