• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA
17 
18 #define EIGEN_USE_GPU
19 
20 #include "tensorflow/contrib/seq2seq/kernels/beam_search_ops.h"
21 #include "tensorflow/core/util/cuda_kernel_helper.h"
22 
23 namespace tensorflow {
24 namespace functor {
25 
26 typedef Eigen::GpuDevice GPUDevice;
27 
28 template <typename T>
GatherTreeOpKernel(const int32 batch_size,const int32 max_time,const int32 beam_width,const T * step_ids,const T * parent_ids,const int32 * max_sequence_lengths,const T end_token,T * beams)29 __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
30                                    const int32 beam_width, const T* step_ids,
31                                    const T* parent_ids,
32                                    const int32* max_sequence_lengths,
33                                    const T end_token, T* beams) {
34   CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) {
35     const int32 batch = i / beam_width;
36     const int32 beam = i % beam_width;
37 
38     const int32 max_seq_len_b =
39         Eigen::numext::mini(max_time, ldg(max_sequence_lengths + batch));
40     if (max_seq_len_b <= 0) {
41       continue;
42     }
43 
44 #define GET_IX(time_ix, beam_ix) \
45   (batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix))
46     const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam);
47     beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix);
48     int32 parent = ldg(parent_ids + initial_beam_ix);
49     bool found_bad = false;
50     for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
51       const int32 level_beam_ix = GET_IX(level, beam);
52       const int32 level_parent_ix = GET_IX(level, parent);
53       if (parent < 0 || parent > beam_width) {
54         beams[level_beam_ix] = -1;
55         parent = -1;
56         found_bad = true;
57       } else {
58         beams[level_beam_ix] = ldg(step_ids + level_parent_ix);
59         parent = ldg(parent_ids + level_parent_ix);
60       }
61     }
62     // Not necessary when using a BeamSearchDecoder, but necessary
63     // when a user feeds in possibly broken trajectory (i.e., non-eos
64     // entries in a beam following eos entries).
65     if (!found_bad) {
66       bool finished = false;
67       for (int32 time = 0; time < max_seq_len_b; ++time) {
68         const int32 level_beam_ix = GET_IX(time, beam);
69         if (finished) {
70           beams[level_beam_ix] = end_token;
71         } else if (beams[level_beam_ix] == end_token) {
72           finished = true;
73         }
74       }
75     }
76 #undef GET_IX
77   }
78 }
79 
80 template <typename T>
81 struct GatherTree<GPUDevice, T> {
operator ()tensorflow::functor::GatherTree82   void operator()(OpKernelContext* ctx, const GPUDevice& d,
83                   typename TTypes<T, 3>::ConstTensor step_ids,
84                   typename TTypes<T, 3>::ConstTensor parent_ids,
85                   TTypes<int32>::ConstVec max_sequence_length,
86                   const T end_token, typename TTypes<T, 3>::Tensor beams) {
87     const int32 max_time = parent_ids.dimension(0);
88     const int32 batch_size = parent_ids.dimension(1);
89     const int32 beam_width = parent_ids.dimension(2);
90     // First kernel launch to "zero" things out
91     beams.device(d) = beams.constant(end_token);
92 
93     CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d);
94     TF_CHECK_OK(CudaLaunchKernel(
95         GatherTreeOpKernel<T>, config.block_count, config.thread_per_block, 0,
96         d.stream(), batch_size, max_time, beam_width, step_ids.data(),
97         parent_ids.data(), max_sequence_length.data(), end_token,
98         beams.data()));
99   }
100 };
101 
102 #define DEFINE_GPU_SPECS(T) template struct GatherTree<GPUDevice, T>;
103 
104 DEFINE_GPU_SPECS(int32);
105 #undef DEFINE_GPU_SPECS
106 
107 }  // end namespace functor
108 }  // end namespace tensorflow
109 #endif  // GOOGLE_CUDA
110