1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ 17 #define TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 21 namespace Eigen { 22 23 // Noise mode used when padding. 24 enum ExtractGlimpsesNoiseMode { 25 UNIFORM = 0, 26 GAUSSIAN = 1, 27 ZERO = 2, 28 }; 29 30 /** ExtractGlimpses 31 * \ingroup CXX11_NeuralNetworks_Module 32 * 33 * \brief Extract glimpses from an input tensor. 34 * 35 * The input parameter is expected to be a col-major tensor with a rank of 4 36 * (depth, x, y, and batch). The width and height parameters specify the 37 * extension of the returned glimpses. The offsets parameter specifies the x, y 38 * locations of the center of the glimpses relative to the center of the input 39 * image. The vector is expected to contain one IndexPair for each image in the 40 * batch dimension. The normalized boolean indicates if incoming coordinates are 41 * normalized so that 0.0 and 1.0 correspond to the minimum and maximum of each 42 * height and width dimension. The centered boolean indicates if incoming 43 * coordinates are centered relative to the image, in which case -1.0 and 1.0 44 * correspond to minimum and maximum of each dimension while 0.0 corresponds to 45 * the center. 46 * 47 * The result can be assigned to a tensor of rank equal to that of the input. 48 * The result will be laid out in col-major order (depth, x, y, batch). The 49 * dimensions of the result will be equal to the dimensions of the input except 50 * for width and height which will be equal to the requested glimpse size. 51 */ 52 namespace { 53 54 template <typename Index> 55 struct GlimpseExtractionOp { GlimpseExtractionOpGlimpseExtractionOp56 GlimpseExtractionOp(const Index width, const Index height, 57 const std::vector<IndexPair<float> >& offsets, 58 const bool normalized, const bool centered, 59 const ExtractGlimpsesNoiseMode noise, const int version) 60 : width_(width), 61 height_(height), 62 offsets_(offsets), 63 normalized_(normalized), 64 centered_(centered), 65 noise_(noise), 66 version_(version) {} 67 68 template <typename Input> dimensionsGlimpseExtractionOp69 DSizes<Index, 4> dimensions(const Input& input) const { 70 typedef typename internal::traits<Input>::Index IndexType; 71 typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4, 72 internal::traits<Input>::Layout, IndexType> > 73 Ref; 74 Ref in(input); 75 76 DSizes<Index, 4> dims = in.dimensions(); 77 78 dims[0] = in.dimension(0); 79 dims[1] = width_; 80 dims[2] = height_; 81 dims[3] = in.dimension(3); 82 return dims; 83 } 84 85 template <typename Input, typename Output, typename Device> evalGlimpseExtractionOp86 EIGEN_DEVICE_FUNC void eval(const Input& input, Output& output, 87 const Device& device) const { 88 typedef typename internal::traits<Input>::Index IndexType; 89 typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4, 90 internal::traits<Input>::Layout, IndexType> > 91 Ref; 92 Ref in(input); 93 const Index num_channels = in.dimension(0); 94 const Index input_width = in.dimension(1); 95 const Index input_height = in.dimension(2); 96 const Index batch_size = in.dimension(3); 97 eigen_assert(input_width > 0); 98 eigen_assert(input_height > 0); 99 internal::NormalRandomGenerator<float> gen; 100 internal::UniformRandomGenerator<float> unigen; 101 102 for (Index i = 0; i < batch_size; ++i) { 103 float x = offsets_[i].first, y = offsets_[i].second; 104 105 if (version_ == 1) { 106 // Un-normalize coordinates back to pixel space if normalized. 107 if (normalized_) { 108 x *= input_width; 109 y *= input_height; 110 } 111 // Un-center if coordinates are centered on the image center. 112 if (centered_) { 113 x /= 2.0f; 114 y /= 2.0f; 115 x += input_width / 2.0f; 116 y += input_height / 2.0f; 117 } 118 // Remove half of the glimpse window. 119 x -= width_ / 2.0f; 120 y -= height_ / 2.0f; 121 } else { 122 if (normalized_) { 123 // Un-normalize coordinates back to pixel space if normalized. 124 x *= input_width; 125 y *= input_height; 126 if (centered_) { 127 // Un-center if coordinates are centered on the image center. 128 x /= 2.0f; 129 y /= 2.0f; 130 x += input_width / 2.0f; 131 y += input_height / 2.0f; 132 // Remove half of the glimpse window. 133 x -= width_ / 2.0f; 134 y -= height_ / 2.0f; 135 } 136 } else { 137 if (centered_) { 138 x += input_width / 2.0f; 139 y += input_height / 2.0f; 140 } 141 } 142 } 143 144 const Index offset_x = (Index)x; 145 const Index offset_y = (Index)y; 146 Index glimpse_width = width_; 147 Index glimpse_height = height_; 148 bool partial_overlap = false; 149 DSizes<Index, 3> slice_offset(0, offset_x, offset_y); 150 DSizes<Index, 3> slice_extent(num_channels, width_, height_); 151 DSizes<Index, 3> base_offset(0, 0, 0); 152 153 if (offset_x < 0) { 154 slice_offset[1] = 0; 155 glimpse_width = (std::max<Index>)(0, width_ + offset_x); 156 slice_extent[1] = glimpse_width; 157 base_offset[1] = width_ - glimpse_width; 158 partial_overlap = true; 159 } else if (offset_x + width_ >= input_width) { 160 glimpse_width = (std::max<Index>)(0, input_width - offset_x); 161 slice_extent[1] = glimpse_width; 162 partial_overlap = true; 163 } 164 if (offset_y < 0) { 165 slice_offset[2] = 0; 166 glimpse_height = (std::max<Index>)(0, height_ + offset_y); 167 slice_extent[2] = glimpse_height; 168 base_offset[2] = height_ - glimpse_height; 169 partial_overlap = true; 170 } else if (offset_y + height_ >= input_height) { 171 glimpse_height = (std::max<Index>)(0, input_height - offset_y); 172 slice_extent[2] = glimpse_height; 173 partial_overlap = true; 174 } 175 slice_extent[1] = std::min<Index>(input_width, slice_extent[1]); 176 slice_extent[2] = std::min<Index>(input_height, slice_extent[2]); 177 178 if (partial_overlap) { 179 switch (noise_) { 180 case ZERO: { 181 // Initialize the glimpse with zero noise. 182 output.template chip<3>(i).device(device) = 183 output.template chip<3>(i).constant(0); 184 } break; 185 case UNIFORM: { 186 // Initialize the glimpse with uniform noise. 187 typedef typename internal::remove_const< 188 typename internal::traits<Input>::Scalar>::type Scalar; 189 TensorFixedSize<Scalar, Sizes<> > mini; 190 mini.device(device) = input.template chip<3>(i).minimum(); 191 TensorFixedSize<float, Sizes<> > range; 192 range.device(device) = (input.template chip<3>(i).maximum() - mini) 193 .template cast<float>(); 194 195 DSizes<Index, 3> glimpse_size(num_channels, width_, height_); 196 TensorMap<Tensor<float, 3> > tmp(NULL, glimpse_size); 197 output.template chip<3>(i).device(device) = 198 mini.reshape(Sizes<1, 1, 1>()).broadcast(glimpse_size) + 199 (tmp.random(unigen) * 200 range.reshape(Sizes<1, 1, 1>()).broadcast(glimpse_size)) 201 .template cast<Scalar>(); 202 } break; 203 case GAUSSIAN: { 204 // Initialize the glimpse with white noise: compute the mean and 205 // sigma 206 // of each channel, and use them to shape the gaussian. 207 DSizes<Index, 2> glimpse_size(width_, height_); 208 DSizes<Index, 2> input_size(input_width, input_height); 209 typedef typename internal::remove_const< 210 typename internal::traits<Input>::Scalar>::type Scalar; 211 212 for (int j = 0; j < num_channels; ++j) { 213 TensorFixedSize<Scalar, Sizes<> > mean; 214 mean.device(device) = input.template chip<3>(i) 215 .template chip<0>(j) 216 .template cast<float>() 217 .mean(); 218 TensorFixedSize<float, Sizes<> > sigma; 219 sigma.device(device) = 220 (input.template chip<3>(i) 221 .template chip<0>(j) 222 .template cast<float>() - 223 mean.reshape(Sizes<1, 1>()).broadcast(input_size)) 224 .square() 225 .mean() 226 .sqrt(); 227 TensorFixedSize<Scalar, Sizes<> > mini; 228 mini.device(device) = 229 input.template chip<3>(i).template chip<0>(j).minimum(); 230 TensorFixedSize<float, Sizes<> > maxi; 231 maxi.device(device) = 232 input.template chip<3>(i).template chip<0>(j).maximum(); 233 234 TensorMap<Tensor<float, 2> > tmp(NULL, glimpse_size); 235 output.template chip<3>(i).template chip<0>(j).device(device) = 236 (mean.reshape(Sizes<1, 1>()).broadcast(glimpse_size) + 237 (tmp.random(gen) * 238 sigma.reshape(Sizes<1, 1>()).broadcast(glimpse_size)) 239 .template cast<Scalar>()) 240 .cwiseMin( 241 maxi.reshape(Sizes<1, 1>()).broadcast(glimpse_size)) 242 .cwiseMax( 243 mini.reshape(Sizes<1, 1>()).broadcast(glimpse_size)); 244 } 245 } break; 246 } 247 248 // Copy the part of the glimpse that cover the input image if any. 249 if (glimpse_width == 0 || glimpse_height == 0) { 250 continue; 251 } 252 output.template chip<3>(i) 253 .slice(base_offset, slice_extent) 254 .device(device) = 255 input.template chip<3>(i).slice(slice_offset, slice_extent); 256 } else { 257 output.template chip<3>(i).device(device) = 258 input.template chip<3>(i).slice(slice_offset, slice_extent); 259 } 260 } 261 } 262 263 private: 264 const Index width_; 265 const Index height_; 266 const std::vector<IndexPair<float> > offsets_; 267 const bool normalized_; 268 const bool centered_; 269 const ExtractGlimpsesNoiseMode noise_; 270 const int version_; 271 }; 272 } // namespace 273 274 template <typename Input> 275 EIGEN_ALWAYS_INLINE static const TensorCustomUnaryOp< 276 const GlimpseExtractionOp<typename internal::traits<Input>::Index>, 277 const Input> 278 ExtractGlimpses( 279 const Input& input, const typename internal::traits<Input>::Index width, 280 const typename internal::traits<Input>::Index height, 281 const std::vector<IndexPair<float> >& offsets, const bool normalized = true, 282 const bool centered = true, 283 const ExtractGlimpsesNoiseMode noise = ExtractGlimpsesNoiseMode::UNIFORM, 284 const int version = 2) { 285 EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == ColMajor, 286 YOU_MADE_A_PROGRAMMING_MISTAKE); 287 EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 4, 288 YOU_MADE_A_PROGRAMMING_MISTAKE); 289 290 typedef typename internal::traits<Input>::Index Index; 291 const GlimpseExtractionOp<Index> op(width, height, offsets, normalized, 292 centered, noise, version); 293 return input.customOp(op); 294 } 295 296 } // end namespace Eigen 297 298 #endif // TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ 299