• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
17 #define TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 
21 namespace Eigen {
22 
23 // Noise mode used when padding.
24 enum ExtractGlimpsesNoiseMode {
25   UNIFORM = 0,
26   GAUSSIAN = 1,
27   ZERO = 2,
28 };
29 
30 /** ExtractGlimpses
31  * \ingroup CXX11_NeuralNetworks_Module
32  *
33  * \brief Extract glimpses from an input tensor.
34  *
35  * The input parameter is expected to be a col-major tensor with a rank of 4
36  * (depth, x, y, and batch). The width and height parameters specify the
37  * extension of the returned glimpses. The offsets parameter specifies the x, y
38  * locations of the center of the glimpses relative to the center of the input
39  * image. The vector is expected to contain one IndexPair for each image in the
40  * batch dimension. The normalized boolean indicates if incoming coordinates are
41  * normalized so that 0.0 and 1.0 correspond to the minimum and maximum of each
42  * height and width dimension. The centered boolean indicates if incoming
43  * coordinates are centered relative to the image, in which case -1.0 and 1.0
44  * correspond to minimum and maximum of each dimension while 0.0 corresponds to
45  * the center.
46  *
47  * The result can be assigned to a tensor of rank equal to that of the input.
48  * The result will be laid out in col-major order (depth, x, y, batch). The
49  * dimensions of the result will be equal to the dimensions of the input except
50  * for width and height which will be equal to the requested glimpse size.
51  */
52 namespace {
53 
54 template <typename Index>
55 struct GlimpseExtractionOp {
GlimpseExtractionOpGlimpseExtractionOp56   GlimpseExtractionOp(const Index width, const Index height,
57                       const std::vector<IndexPair<float> >& offsets,
58                       const bool normalized, const bool centered,
59                       const ExtractGlimpsesNoiseMode noise, const int version)
60       : width_(width),
61         height_(height),
62         offsets_(offsets),
63         normalized_(normalized),
64         centered_(centered),
65         noise_(noise),
66         version_(version) {}
67 
68   template <typename Input>
dimensionsGlimpseExtractionOp69   DSizes<Index, 4> dimensions(const Input& input) const {
70     typedef typename internal::traits<Input>::Index IndexType;
71     typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4,
72                              internal::traits<Input>::Layout, IndexType> >
73         Ref;
74     Ref in(input);
75 
76     DSizes<Index, 4> dims = in.dimensions();
77 
78     dims[0] = in.dimension(0);
79     dims[1] = width_;
80     dims[2] = height_;
81     dims[3] = in.dimension(3);
82     return dims;
83   }
84 
85   template <typename Input, typename Output, typename Device>
evalGlimpseExtractionOp86   EIGEN_DEVICE_FUNC void eval(const Input& input, Output& output,
87                               const Device& device) const {
88     typedef typename internal::traits<Input>::Index IndexType;
89     typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4,
90                              internal::traits<Input>::Layout, IndexType> >
91         Ref;
92     Ref in(input);
93     const Index num_channels = in.dimension(0);
94     const Index input_width = in.dimension(1);
95     const Index input_height = in.dimension(2);
96     const Index batch_size = in.dimension(3);
97     eigen_assert(input_width > 0);
98     eigen_assert(input_height > 0);
99     internal::NormalRandomGenerator<float> gen;
100     internal::UniformRandomGenerator<float> unigen;
101 
102     for (Index i = 0; i < batch_size; ++i) {
103       float x = offsets_[i].first, y = offsets_[i].second;
104 
105       if (version_ == 1) {
106         // Un-normalize coordinates back to pixel space if normalized.
107         if (normalized_) {
108           x *= input_width;
109           y *= input_height;
110         }
111         // Un-center if coordinates are centered on the image center.
112         if (centered_) {
113           x /= 2.0f;
114           y /= 2.0f;
115           x += input_width / 2.0f;
116           y += input_height / 2.0f;
117         }
118         // Remove half of the glimpse window.
119         x -= width_ / 2.0f;
120         y -= height_ / 2.0f;
121       } else {
122         if (normalized_) {
123           // Un-normalize coordinates back to pixel space if normalized.
124           x *= input_width;
125           y *= input_height;
126           if (centered_) {
127             // Un-center if coordinates are centered on the image center.
128             x /= 2.0f;
129             y /= 2.0f;
130             x += input_width / 2.0f;
131             y += input_height / 2.0f;
132             // Remove half of the glimpse window.
133             x -= width_ / 2.0f;
134             y -= height_ / 2.0f;
135           }
136         } else {
137           if (centered_) {
138             x += input_width / 2.0f;
139             y += input_height / 2.0f;
140           }
141         }
142       }
143 
144       const Index offset_x = (Index)x;
145       const Index offset_y = (Index)y;
146       Index glimpse_width = width_;
147       Index glimpse_height = height_;
148       bool partial_overlap = false;
149       DSizes<Index, 3> slice_offset(0, offset_x, offset_y);
150       DSizes<Index, 3> slice_extent(num_channels, width_, height_);
151       DSizes<Index, 3> base_offset(0, 0, 0);
152 
153       if (offset_x < 0) {
154         slice_offset[1] = 0;
155         glimpse_width = (std::max<Index>)(0, width_ + offset_x);
156         slice_extent[1] = glimpse_width;
157         base_offset[1] = width_ - glimpse_width;
158         partial_overlap = true;
159       } else if (offset_x + width_ >= input_width) {
160         glimpse_width = (std::max<Index>)(0, input_width - offset_x);
161         slice_extent[1] = glimpse_width;
162         partial_overlap = true;
163       }
164       if (offset_y < 0) {
165         slice_offset[2] = 0;
166         glimpse_height = (std::max<Index>)(0, height_ + offset_y);
167         slice_extent[2] = glimpse_height;
168         base_offset[2] = height_ - glimpse_height;
169         partial_overlap = true;
170       } else if (offset_y + height_ >= input_height) {
171         glimpse_height = (std::max<Index>)(0, input_height - offset_y);
172         slice_extent[2] = glimpse_height;
173         partial_overlap = true;
174       }
175       slice_extent[1] = std::min<Index>(input_width, slice_extent[1]);
176       slice_extent[2] = std::min<Index>(input_height, slice_extent[2]);
177 
178       if (partial_overlap) {
179         switch (noise_) {
180           case ZERO: {
181             // Initialize the glimpse with zero noise.
182             output.template chip<3>(i).device(device) =
183                 output.template chip<3>(i).constant(0);
184           } break;
185           case UNIFORM: {
186             // Initialize the glimpse with uniform noise.
187             typedef typename internal::remove_const<
188                 typename internal::traits<Input>::Scalar>::type Scalar;
189             TensorFixedSize<Scalar, Sizes<> > mini;
190             mini.device(device) = input.template chip<3>(i).minimum();
191             TensorFixedSize<float, Sizes<> > range;
192             range.device(device) = (input.template chip<3>(i).maximum() - mini)
193                                        .template cast<float>();
194 
195             DSizes<Index, 3> glimpse_size(num_channels, width_, height_);
196             TensorMap<Tensor<float, 3> > tmp(NULL, glimpse_size);
197             output.template chip<3>(i).device(device) =
198                 mini.reshape(Sizes<1, 1, 1>()).broadcast(glimpse_size) +
199                 (tmp.random(unigen) *
200                  range.reshape(Sizes<1, 1, 1>()).broadcast(glimpse_size))
201                     .template cast<Scalar>();
202           } break;
203           case GAUSSIAN: {
204             // Initialize the glimpse with white noise: compute the mean and
205             // sigma
206             // of each channel, and use them to shape the gaussian.
207             DSizes<Index, 2> glimpse_size(width_, height_);
208             DSizes<Index, 2> input_size(input_width, input_height);
209             typedef typename internal::remove_const<
210                 typename internal::traits<Input>::Scalar>::type Scalar;
211 
212             for (int j = 0; j < num_channels; ++j) {
213               TensorFixedSize<Scalar, Sizes<> > mean;
214               mean.device(device) = input.template chip<3>(i)
215                                         .template chip<0>(j)
216                                         .template cast<float>()
217                                         .mean();
218               TensorFixedSize<float, Sizes<> > sigma;
219               sigma.device(device) =
220                   (input.template chip<3>(i)
221                        .template chip<0>(j)
222                        .template cast<float>() -
223                    mean.reshape(Sizes<1, 1>()).broadcast(input_size))
224                       .square()
225                       .mean()
226                       .sqrt();
227               TensorFixedSize<Scalar, Sizes<> > mini;
228               mini.device(device) =
229                   input.template chip<3>(i).template chip<0>(j).minimum();
230               TensorFixedSize<float, Sizes<> > maxi;
231               maxi.device(device) =
232                   input.template chip<3>(i).template chip<0>(j).maximum();
233 
234               TensorMap<Tensor<float, 2> > tmp(NULL, glimpse_size);
235               output.template chip<3>(i).template chip<0>(j).device(device) =
236                   (mean.reshape(Sizes<1, 1>()).broadcast(glimpse_size) +
237                    (tmp.random(gen) *
238                     sigma.reshape(Sizes<1, 1>()).broadcast(glimpse_size))
239                        .template cast<Scalar>())
240                       .cwiseMin(
241                           maxi.reshape(Sizes<1, 1>()).broadcast(glimpse_size))
242                       .cwiseMax(
243                           mini.reshape(Sizes<1, 1>()).broadcast(glimpse_size));
244             }
245           } break;
246         }
247 
248         // Copy the part of the glimpse that cover the input image if any.
249         if (glimpse_width == 0 || glimpse_height == 0) {
250           continue;
251         }
252         output.template chip<3>(i)
253             .slice(base_offset, slice_extent)
254             .device(device) =
255             input.template chip<3>(i).slice(slice_offset, slice_extent);
256       } else {
257         output.template chip<3>(i).device(device) =
258             input.template chip<3>(i).slice(slice_offset, slice_extent);
259       }
260     }
261   }
262 
263  private:
264   const Index width_;
265   const Index height_;
266   const std::vector<IndexPair<float> > offsets_;
267   const bool normalized_;
268   const bool centered_;
269   const ExtractGlimpsesNoiseMode noise_;
270   const int version_;
271 };
272 }  // namespace
273 
274 template <typename Input>
275 EIGEN_ALWAYS_INLINE static const TensorCustomUnaryOp<
276     const GlimpseExtractionOp<typename internal::traits<Input>::Index>,
277     const Input>
278 ExtractGlimpses(
279     const Input& input, const typename internal::traits<Input>::Index width,
280     const typename internal::traits<Input>::Index height,
281     const std::vector<IndexPair<float> >& offsets, const bool normalized = true,
282     const bool centered = true,
283     const ExtractGlimpsesNoiseMode noise = ExtractGlimpsesNoiseMode::UNIFORM,
284     const int version = 2) {
285   EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == ColMajor,
286                       YOU_MADE_A_PROGRAMMING_MISTAKE);
287   EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 4,
288                       YOU_MADE_A_PROGRAMMING_MISTAKE);
289 
290   typedef typename internal::traits<Input>::Index Index;
291   const GlimpseExtractionOp<Index> op(width, height, offsets, normalized,
292                                       centered, noise, version);
293   return input.customOp(op);
294 }
295 
296 }  // end namespace Eigen
297 
298 #endif  // TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
299