• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 
6     http://www.apache.org/licenses/LICENSE-2.0
7 
8 Unless required by applicable law or agreed to in writing, software
9 distributed under the License is distributed on an "AS IS" BASIS,
10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 See the License for the specific language governing permissions and
12 limitations under the License.
13 ==============================================================================*/
14 #if GOOGLE_CUDA
15 
16 #define EIGEN_USE_GPU
17 
18 #include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h"
19 #include "tensorflow/core/kernels/gpu_utils.h"
20 #include "tensorflow/core/platform/stream_executor.h"
21 #include "tensorflow/core/util/cuda_kernel_helper.h"
22 
23 namespace tensorflow {
24 
25 namespace internal {
26 
compute_tranformation_matrix_cuda(const float * const delta_h,const float * const scale_s,const float * const scale_v,float * const matrix,const int matrix_size)27 __global__ void compute_tranformation_matrix_cuda(const float* const delta_h,
28                                                   const float* const scale_s,
29                                                   const float* const scale_v,
30                                                   float* const matrix,
31                                                   const int matrix_size) {
32   if (matrix_size == kChannelSize * kChannelSize) {
33     compute_tranformation_matrix<kChannelSize * kChannelSize>(
34         *delta_h, *scale_s, *scale_v, matrix);
35   }
36 }
37 }  // namespace internal
38 
39 namespace functor {
40 
operator ()(OpKernelContext * ctx,int channel_count,const Tensor * const input,const float * const delta_h,const float * const scale_s,const float * const scale_v,Tensor * const output)41 void AdjustHsvInYiqGPU::operator()(OpKernelContext* ctx, int channel_count,
42                                    const Tensor* const input,
43                                    const float* const delta_h,
44                                    const float* const scale_s,
45                                    const float* const scale_v,
46                                    Tensor* const output) {
47   const uint64 m = channel_count;
48   const uint64 k = kChannelSize;
49   const uint64 n = kChannelSize;
50   auto* cu_stream = ctx->eigen_device<GPUDevice>().stream();
51   OP_REQUIRES(ctx, cu_stream, errors::Internal("No GPU stream available."));
52   Tensor tranformation_matrix;
53   OP_REQUIRES_OK(ctx, ctx->allocate_temp(
54                           DT_FLOAT, TensorShape({kChannelSize * kChannelSize}),
55                           &tranformation_matrix));
56   // TODO(huangyp): It takes about 3.5 us to compute tranformation_matrix
57   // with one thread. Improve its performance if necessary.
58   TF_CHECK_OK(CudaLaunchKernel(internal::compute_tranformation_matrix_cuda, 1,
59                                1, 0, cu_stream, delta_h, scale_s, scale_v,
60                                tranformation_matrix.flat<float>().data(),
61                                tranformation_matrix.flat<float>().size()));
62   // Call cuBlas C = A * B directly.
63   auto no_transpose = se::blas::Transpose::kNoTranspose;
64   auto a_ptr =
65       AsDeviceMemory(input->flat<float>().data(), input->flat<float>().size());
66   auto b_ptr = AsDeviceMemory(tranformation_matrix.flat<float>().data(),
67                               tranformation_matrix.flat<float>().size());
68   auto c_ptr = AsDeviceMemory(output->flat<float>().data(),
69                               output->flat<float>().size());
70   auto* stream = ctx->op_device_context()->stream();
71   OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
72   // TODO(huangyp): share/use autotune cublas algorithms in Matmul.op.
73   bool blas_launch_status =
74       stream
75           ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n,
76                          a_ptr, k, 0.0f, &c_ptr, n)
77           .ok();
78   if (!blas_launch_status) {
79     ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
80                                     ", n=", n, ", k=", k));
81   }
82 }
83 }  // namespace functor
84 }  // namespace tensorflow
85 #endif  // GOOGLE_CUDA
86