1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as 17 // inputs or outputs in various ways. 18 19 // See docs in ../ops/summary_ops.cc. 20 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/summary.pb.h" 23 #include "tensorflow/core/lib/core/errors.h" 24 #include "tensorflow/core/lib/png/png_io.h" 25 #include "tensorflow/core/platform/logging.h" 26 27 namespace tensorflow { 28 29 class SummaryImageOp : public OpKernel { 30 public: 31 typedef Eigen::Tensor<uint8, 2, Eigen::RowMajor> Uint8Image; 32 SummaryImageOp(OpKernelConstruction * context)33 explicit SummaryImageOp(OpKernelConstruction* context) : OpKernel(context) { 34 int64 max_images_tmp; 35 OP_REQUIRES_OK(context, context->GetAttr("max_images", &max_images_tmp)); 36 OP_REQUIRES(context, max_images_tmp < (1LL << 31), 37 errors::InvalidArgument("max_images must be < 2^31")); 38 max_images_ = static_cast<int32>(max_images_tmp); 39 const TensorProto* proto; 40 OP_REQUIRES_OK(context, context->GetAttr("bad_color", &proto)); 41 OP_REQUIRES_OK(context, context->device()->MakeTensorFromProto( 42 *proto, AllocatorAttributes(), &bad_color_)); 43 OP_REQUIRES(context, bad_color_.dtype() == DT_UINT8, 44 errors::InvalidArgument("bad_color must be uint8, got ", 45 DataTypeString(bad_color_.dtype()))); 46 OP_REQUIRES( 47 context, TensorShapeUtils::IsVector(bad_color_.shape()), 48 errors::InvalidArgument("bad_color must be a vector, got shape ", 49 bad_color_.shape().DebugString())); 50 } 51 Compute(OpKernelContext * c)52 void Compute(OpKernelContext* c) override { 53 const Tensor& tags = c->input(0); 54 const Tensor& tensor = c->input(1); 55 OP_REQUIRES(c, TensorShapeUtils::IsScalar(tags.shape()), 56 errors::InvalidArgument("Tags must be a scalar")); 57 OP_REQUIRES(c, 58 tensor.dims() == 4 && 59 (tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 || 60 tensor.dim_size(3) == 4), 61 errors::InvalidArgument( 62 "Tensor must be 4-D with last dim 1, 3, or 4, not ", 63 tensor.shape().DebugString())); 64 const string& base_tag = tags.scalar<tstring>()(); 65 66 OP_REQUIRES(c, 67 tensor.dim_size(0) < (1LL << 31) && 68 tensor.dim_size(1) < (1LL << 31) && 69 tensor.dim_size(2) < (1LL << 31) && 70 (tensor.dim_size(1) * tensor.dim_size(2)) < (1LL << 29), 71 errors::InvalidArgument("Tensor too large for summary ", 72 tensor.shape().DebugString())); 73 74 // The casts and h * w cannot overflow because of the limits above. 75 const int batch_size = static_cast<int>(tensor.dim_size(0)); 76 const int h = static_cast<int>(tensor.dim_size(1)); 77 const int w = static_cast<int>(tensor.dim_size(2)); 78 const int hw = h * w; // Compact these two dims for simplicity 79 const int depth = static_cast<int>(tensor.dim_size(3)); 80 81 OP_REQUIRES(c, hw > 0 && depth > 0, 82 errors::InvalidArgument( 83 "input tensor must have non-zero dims. Found: [", 84 batch_size, ", ", h, ", ", w, ", ", depth, "].")); 85 86 Summary s; 87 if (tensor.dtype() == DT_UINT8) { 88 // For uint8 input, no normalization is necessary 89 auto ith_image = [&tensor, batch_size, hw, depth](int i) { 90 auto values = tensor.shaped<uint8, 3>({batch_size, hw, depth}); 91 return typename TTypes<uint8>::ConstMatrix( 92 &values(i, 0, 0), Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth)); 93 }; 94 OP_REQUIRES_OK( 95 c, AddImages(base_tag, batch_size, w, h, depth, ith_image, &s)); 96 } else if (tensor.dtype() == DT_HALF) { 97 NormalizeAndAddImages<Eigen::half>(c, tensor, h, w, hw, depth, batch_size, 98 base_tag, &s); 99 } else if (tensor.dtype() == DT_FLOAT) { 100 NormalizeAndAddImages<float>(c, tensor, h, w, hw, depth, batch_size, 101 base_tag, &s); 102 } else { // tensor.dtype() = DT_DOUBLE 103 NormalizeAndAddImages<double>(c, tensor, h, w, hw, depth, batch_size, 104 base_tag, &s); 105 } 106 107 Tensor* summary_tensor = nullptr; 108 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor)); 109 CHECK(SerializeToTString(s, &summary_tensor->scalar<tstring>()())); 110 } 111 112 template <class T> NormalizeAndAddImages(OpKernelContext * c,const Tensor & tensor,int h,int w,int hw,int depth,int batch_size,const string & base_tag,Summary * s)113 void NormalizeAndAddImages(OpKernelContext* c, const Tensor& tensor, int h, 114 int w, int hw, int depth, int batch_size, 115 const string& base_tag, Summary* s) { 116 // For float and half images, nans and infs are replaced with bad_color. 117 OP_REQUIRES(c, bad_color_.dim_size(0) >= depth, 118 errors::InvalidArgument( 119 "expected depth <= bad_color.size, got depth = ", depth, 120 ", bad_color.size = ", bad_color_.dim_size(0))); 121 auto bad_color_full = bad_color_.vec<uint8>(); 122 typename TTypes<uint8>::ConstVec bad_color(bad_color_full.data(), depth); 123 124 // Float images must be scaled and translated. 125 Uint8Image image(hw, depth); 126 auto ith_image = [&tensor, &image, bad_color, batch_size, hw, 127 depth](int i) { 128 auto tensor_eigen = tensor.template shaped<T, 3>({batch_size, hw, depth}); 129 typename TTypes<T>::ConstMatrix values( 130 &tensor_eigen(i, 0, 0), 131 Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth)); 132 NormalizeFloatImage<T>(hw, depth, values, bad_color, &image); 133 return image; 134 }; 135 OP_REQUIRES_OK(c, 136 AddImages(base_tag, batch_size, w, h, depth, ith_image, s)); 137 } 138 139 // Add the sequence of images specified by ith_image to the summary. 140 // 141 // Factoring this loop out into a helper function lets ith_image behave 142 // differently in the float and uint8 cases: the float case needs a temporary 143 // buffer which can be shared across calls to ith_image, but the uint8 case 144 // does not. AddImages(const string & tag,int batch_size,int w,int h,int depth,const std::function<Uint8Image (int)> & ith_image,Summary * s)145 Status AddImages(const string& tag, int batch_size, int w, int h, int depth, 146 const std::function<Uint8Image(int)>& ith_image, 147 Summary* s) { 148 const int N = std::min<int>(max_images_, batch_size); 149 for (int i = 0; i < N; ++i) { 150 Summary::Value* v = s->add_value(); 151 // The tag depends on the number of requested images (not the number 152 // produced.) 153 // 154 // Note that later on avisu uses "/" to figure out a consistent naming 155 // convention for display, so we append "/image" to guarantee that the 156 // image(s) won't be displayed in the global scope with no name. 157 if (max_images_ > 1) { 158 v->set_tag(strings::StrCat(tag, "/image/", i)); 159 } else { 160 v->set_tag(strings::StrCat(tag, "/image")); 161 } 162 163 auto image = ith_image(i); 164 Summary::Image* si = v->mutable_image(); 165 si->set_height(h); 166 si->set_width(w); 167 si->set_colorspace(depth); 168 const int channel_bits = 8; 169 const int compression = -1; // Use zlib default 170 if (!png::WriteImageToBuffer( 171 image.data(), w, h, w * depth, depth, channel_bits, compression, 172 si->mutable_encoded_image_string(), nullptr)) { 173 return errors::Internal("PNG encoding failed"); 174 } 175 } 176 return Status::OK(); 177 } 178 179 template <class T> NormalizeFloatImage(int hw,int depth,typename TTypes<T>::ConstMatrix values,typename TTypes<uint8>::ConstVec bad_color,Uint8Image * image)180 static void NormalizeFloatImage(int hw, int depth, 181 typename TTypes<T>::ConstMatrix values, 182 typename TTypes<uint8>::ConstVec bad_color, 183 Uint8Image* image) { 184 if (!image->size()) return; // Nothing to do for empty images 185 186 // Rescale the image to uint8 range. 187 // 188 // We are trying to generate an RGB image from a float/half tensor. We do 189 // not have any info about the expected range of values in the tensor 190 // but the generated image needs to have all RGB values within [0, 255]. 191 // 192 // We use two different algorithms to generate these values. If the 193 // tensor has only positive values we scale them all by 255/max(values). 194 // If the tensor has both negative and positive values we scale them by 195 // the max of their absolute values and center them around 127. 196 // 197 // This works for most cases, but does not respect the relative dynamic 198 // range across different instances of the tensor. 199 200 // Compute min and max ignoring nonfinite pixels 201 float image_min = std::numeric_limits<float>::infinity(); 202 float image_max = -image_min; 203 for (int i = 0; i < hw; i++) { 204 bool finite = true; 205 for (int j = 0; j < depth; j++) { 206 if (!Eigen::numext::isfinite(values(i, j))) { 207 finite = false; 208 break; 209 } 210 } 211 if (finite) { 212 for (int j = 0; j < depth; j++) { 213 float value(values(i, j)); 214 image_min = std::min(image_min, value); 215 image_max = std::max(image_max, value); 216 } 217 } 218 } 219 220 // Pick an affine transform into uint8 221 const float kZeroThreshold = 1e-6; 222 T scale, offset; 223 if (image_min < 0) { 224 float max_val = std::max(std::abs(image_min), std::abs(image_max)); 225 scale = T(max_val < kZeroThreshold ? 0.0f : 127.0f / max_val); 226 offset = T(128.0f); 227 } else { 228 scale = T(image_max < kZeroThreshold ? 0.0f : 255.0f / image_max); 229 offset = T(0.0f); 230 } 231 232 // Transform image, turning nonfinite values to bad_color 233 for (int i = 0; i < hw; i++) { 234 bool finite = true; 235 for (int j = 0; j < depth; j++) { 236 if (!Eigen::numext::isfinite(values(i, j))) { 237 finite = false; 238 break; 239 } 240 } 241 if (finite) { 242 image->chip<0>(i) = (values.template chip<0>(i) * scale + offset) 243 .template cast<uint8>(); 244 } else { 245 image->chip<0>(i) = bad_color; 246 } 247 } 248 } 249 250 private: 251 int32 max_images_; 252 Tensor bad_color_; 253 }; 254 255 REGISTER_KERNEL_BUILDER(Name("ImageSummary").Device(DEVICE_CPU), 256 SummaryImageOp); 257 258 } // namespace tensorflow 259