• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
17 // inputs or outputs in various ways.
18 
19 // See docs in ../ops/summary_ops.cc.
20 
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/summary.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/png/png_io.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace tensorflow {
28 
29 class SummaryImageOp : public OpKernel {
30  public:
31   typedef Eigen::Tensor<uint8, 2, Eigen::RowMajor> Uint8Image;
32 
SummaryImageOp(OpKernelConstruction * context)33   explicit SummaryImageOp(OpKernelConstruction* context) : OpKernel(context) {
34     int64 max_images_tmp;
35     OP_REQUIRES_OK(context, context->GetAttr("max_images", &max_images_tmp));
36     OP_REQUIRES(context, max_images_tmp < (1LL << 31),
37                 errors::InvalidArgument("max_images must be < 2^31"));
38     max_images_ = static_cast<int32>(max_images_tmp);
39     const TensorProto* proto;
40     OP_REQUIRES_OK(context, context->GetAttr("bad_color", &proto));
41     OP_REQUIRES_OK(context, context->device()->MakeTensorFromProto(
42                                 *proto, AllocatorAttributes(), &bad_color_));
43     OP_REQUIRES(context, bad_color_.dtype() == DT_UINT8,
44                 errors::InvalidArgument("bad_color must be uint8, got ",
45                                         DataTypeString(bad_color_.dtype())));
46     OP_REQUIRES(
47         context, TensorShapeUtils::IsVector(bad_color_.shape()),
48         errors::InvalidArgument("bad_color must be a vector, got shape ",
49                                 bad_color_.shape().DebugString()));
50   }
51 
Compute(OpKernelContext * c)52   void Compute(OpKernelContext* c) override {
53     const Tensor& tags = c->input(0);
54     const Tensor& tensor = c->input(1);
55     OP_REQUIRES(c, TensorShapeUtils::IsScalar(tags.shape()),
56                 errors::InvalidArgument("Tags must be a scalar"));
57     OP_REQUIRES(c,
58                 tensor.dims() == 4 &&
59                     (tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 ||
60                      tensor.dim_size(3) == 4),
61                 errors::InvalidArgument(
62                     "Tensor must be 4-D with last dim 1, 3, or 4, not ",
63                     tensor.shape().DebugString()));
64     const string& base_tag = tags.scalar<tstring>()();
65 
66     OP_REQUIRES(c,
67                 tensor.dim_size(0) < (1LL << 31) &&
68                     tensor.dim_size(1) < (1LL << 31) &&
69                     tensor.dim_size(2) < (1LL << 31) &&
70                     (tensor.dim_size(1) * tensor.dim_size(2)) < (1LL << 29),
71                 errors::InvalidArgument("Tensor too large for summary ",
72                                         tensor.shape().DebugString()));
73 
74     // The casts and h * w cannot overflow because of the limits above.
75     const int batch_size = static_cast<int>(tensor.dim_size(0));
76     const int h = static_cast<int>(tensor.dim_size(1));
77     const int w = static_cast<int>(tensor.dim_size(2));
78     const int hw = h * w;  // Compact these two dims for simplicity
79     const int depth = static_cast<int>(tensor.dim_size(3));
80 
81     OP_REQUIRES(c, hw > 0 && depth > 0,
82                 errors::InvalidArgument(
83                     "input tensor must have non-zero dims. Found: [",
84                     batch_size, ", ", h, ", ", w, ", ", depth, "]."));
85 
86     Summary s;
87     if (tensor.dtype() == DT_UINT8) {
88       // For uint8 input, no normalization is necessary
89       auto ith_image = [&tensor, batch_size, hw, depth](int i) {
90         auto values = tensor.shaped<uint8, 3>({batch_size, hw, depth});
91         return typename TTypes<uint8>::ConstMatrix(
92             &values(i, 0, 0), Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
93       };
94       OP_REQUIRES_OK(
95           c, AddImages(base_tag, batch_size, w, h, depth, ith_image, &s));
96     } else if (tensor.dtype() == DT_HALF) {
97       NormalizeAndAddImages<Eigen::half>(c, tensor, h, w, hw, depth, batch_size,
98                                          base_tag, &s);
99     } else if (tensor.dtype() == DT_FLOAT) {
100       NormalizeAndAddImages<float>(c, tensor, h, w, hw, depth, batch_size,
101                                    base_tag, &s);
102     } else {  // tensor.dtype() = DT_DOUBLE
103       NormalizeAndAddImages<double>(c, tensor, h, w, hw, depth, batch_size,
104                                     base_tag, &s);
105     }
106 
107     Tensor* summary_tensor = nullptr;
108     OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
109     CHECK(SerializeToTString(s, &summary_tensor->scalar<tstring>()()));
110   }
111 
112   template <class T>
NormalizeAndAddImages(OpKernelContext * c,const Tensor & tensor,int h,int w,int hw,int depth,int batch_size,const string & base_tag,Summary * s)113   void NormalizeAndAddImages(OpKernelContext* c, const Tensor& tensor, int h,
114                              int w, int hw, int depth, int batch_size,
115                              const string& base_tag, Summary* s) {
116     // For float and half images, nans and infs are replaced with bad_color.
117     OP_REQUIRES(c, bad_color_.dim_size(0) >= depth,
118                 errors::InvalidArgument(
119                     "expected depth <= bad_color.size, got depth = ", depth,
120                     ", bad_color.size = ", bad_color_.dim_size(0)));
121     auto bad_color_full = bad_color_.vec<uint8>();
122     typename TTypes<uint8>::ConstVec bad_color(bad_color_full.data(), depth);
123 
124     // Float images must be scaled and translated.
125     Uint8Image image(hw, depth);
126     auto ith_image = [&tensor, &image, bad_color, batch_size, hw,
127                       depth](int i) {
128       auto tensor_eigen = tensor.template shaped<T, 3>({batch_size, hw, depth});
129       typename TTypes<T>::ConstMatrix values(
130           &tensor_eigen(i, 0, 0),
131           Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
132       NormalizeFloatImage<T>(hw, depth, values, bad_color, &image);
133       return image;
134     };
135     OP_REQUIRES_OK(c,
136                    AddImages(base_tag, batch_size, w, h, depth, ith_image, s));
137   }
138 
139   // Add the sequence of images specified by ith_image to the summary.
140   //
141   // Factoring this loop out into a helper function lets ith_image behave
142   // differently in the float and uint8 cases: the float case needs a temporary
143   // buffer which can be shared across calls to ith_image, but the uint8 case
144   // does not.
AddImages(const string & tag,int batch_size,int w,int h,int depth,const std::function<Uint8Image (int)> & ith_image,Summary * s)145   Status AddImages(const string& tag, int batch_size, int w, int h, int depth,
146                    const std::function<Uint8Image(int)>& ith_image,
147                    Summary* s) {
148     const int N = std::min<int>(max_images_, batch_size);
149     for (int i = 0; i < N; ++i) {
150       Summary::Value* v = s->add_value();
151       // The tag depends on the number of requested images (not the number
152       // produced.)
153       //
154       // Note that later on avisu uses "/" to figure out a consistent naming
155       // convention for display, so we append "/image" to guarantee that the
156       // image(s) won't be displayed in the global scope with no name.
157       if (max_images_ > 1) {
158         v->set_tag(strings::StrCat(tag, "/image/", i));
159       } else {
160         v->set_tag(strings::StrCat(tag, "/image"));
161       }
162 
163       auto image = ith_image(i);
164       Summary::Image* si = v->mutable_image();
165       si->set_height(h);
166       si->set_width(w);
167       si->set_colorspace(depth);
168       const int channel_bits = 8;
169       const int compression = -1;  // Use zlib default
170       if (!png::WriteImageToBuffer(
171               image.data(), w, h, w * depth, depth, channel_bits, compression,
172               si->mutable_encoded_image_string(), nullptr)) {
173         return errors::Internal("PNG encoding failed");
174       }
175     }
176     return Status::OK();
177   }
178 
179   template <class T>
NormalizeFloatImage(int hw,int depth,typename TTypes<T>::ConstMatrix values,typename TTypes<uint8>::ConstVec bad_color,Uint8Image * image)180   static void NormalizeFloatImage(int hw, int depth,
181                                   typename TTypes<T>::ConstMatrix values,
182                                   typename TTypes<uint8>::ConstVec bad_color,
183                                   Uint8Image* image) {
184     if (!image->size()) return;  // Nothing to do for empty images
185 
186     // Rescale the image to uint8 range.
187     //
188     // We are trying to generate an RGB image from a float/half tensor.  We do
189     // not have any info about the expected range of values in the tensor
190     // but the generated image needs to have all RGB values within [0, 255].
191     //
192     // We use two different algorithms to generate these values.  If the
193     // tensor has only positive values we scale them all by 255/max(values).
194     // If the tensor has both negative and positive values we scale them by
195     // the max of their absolute values and center them around 127.
196     //
197     // This works for most cases, but does not respect the relative dynamic
198     // range across different instances of the tensor.
199 
200     // Compute min and max ignoring nonfinite pixels
201     float image_min = std::numeric_limits<float>::infinity();
202     float image_max = -image_min;
203     for (int i = 0; i < hw; i++) {
204       bool finite = true;
205       for (int j = 0; j < depth; j++) {
206         if (!Eigen::numext::isfinite(values(i, j))) {
207           finite = false;
208           break;
209         }
210       }
211       if (finite) {
212         for (int j = 0; j < depth; j++) {
213           float value(values(i, j));
214           image_min = std::min(image_min, value);
215           image_max = std::max(image_max, value);
216         }
217       }
218     }
219 
220     // Pick an affine transform into uint8
221     const float kZeroThreshold = 1e-6;
222     T scale, offset;
223     if (image_min < 0) {
224       float max_val = std::max(std::abs(image_min), std::abs(image_max));
225       scale = T(max_val < kZeroThreshold ? 0.0f : 127.0f / max_val);
226       offset = T(128.0f);
227     } else {
228       scale = T(image_max < kZeroThreshold ? 0.0f : 255.0f / image_max);
229       offset = T(0.0f);
230     }
231 
232     // Transform image, turning nonfinite values to bad_color
233     for (int i = 0; i < hw; i++) {
234       bool finite = true;
235       for (int j = 0; j < depth; j++) {
236         if (!Eigen::numext::isfinite(values(i, j))) {
237           finite = false;
238           break;
239         }
240       }
241       if (finite) {
242         image->chip<0>(i) = (values.template chip<0>(i) * scale + offset)
243                                 .template cast<uint8>();
244       } else {
245         image->chip<0>(i) = bad_color;
246       }
247     }
248   }
249 
250  private:
251   int32 max_images_;
252   Tensor bad_color_;
253 };
254 
255 REGISTER_KERNEL_BUILDER(Name("ImageSummary").Device(DEVICE_CPU),
256                         SummaryImageOp);
257 
258 }  // namespace tensorflow
259