• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/image_ops.cc
17 
18 #include <memory>
19 #include "tensorflow/core/framework/bounds_check.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/jpeg/jpeg_mem.h"
27 #include "tensorflow/core/platform/logging.h"
28 
29 namespace tensorflow {
30 
31 // Encode an image to a JPEG stream
32 class EncodeJpegOp : public OpKernel {
33  public:
EncodeJpegOp(OpKernelConstruction * context)34   explicit EncodeJpegOp(OpKernelConstruction* context) : OpKernel(context) {
35     OP_REQUIRES_OK(context, context->GetAttr("format", &format_));
36     if (format_.empty()) {
37       flags_.format = static_cast<jpeg::Format>(0);
38     } else if (format_ == "grayscale") {
39       flags_.format = jpeg::FORMAT_GRAYSCALE;
40     } else if (format_ == "rgb") {
41       flags_.format = jpeg::FORMAT_RGB;
42     } else {
43       OP_REQUIRES(context, false,
44                   errors::InvalidArgument(
45                       "format must be '', grayscale or rgb, got ", format_));
46     }
47 
48     OP_REQUIRES_OK(context, context->GetAttr("quality", &flags_.quality));
49     OP_REQUIRES(context, 0 <= flags_.quality && flags_.quality <= 100,
50                 errors::InvalidArgument("quality must be in [0,100], got ",
51                                         flags_.quality));
52     OP_REQUIRES_OK(context,
53                    context->GetAttr("progressive", &flags_.progressive));
54     OP_REQUIRES_OK(
55         context, context->GetAttr("optimize_size", &flags_.optimize_jpeg_size));
56     OP_REQUIRES_OK(context, context->GetAttr("chroma_downsampling",
57                                              &flags_.chroma_downsampling));
58 
59     string density_unit;
60     OP_REQUIRES_OK(context, context->GetAttr("density_unit", &density_unit));
61     if (density_unit == "in") {
62       flags_.density_unit = 1;
63     } else if (density_unit == "cm") {
64       flags_.density_unit = 2;
65     } else {
66       OP_REQUIRES(context, false,
67                   errors::InvalidArgument("density_unit must be 'in' or 'cm'",
68                                           density_unit));
69     }
70 
71     OP_REQUIRES_OK(context, context->GetAttr("x_density", &flags_.x_density));
72     OP_REQUIRES_OK(context, context->GetAttr("y_density", &flags_.y_density));
73     OP_REQUIRES_OK(context, context->GetAttr("xmp_metadata", &xmp_metadata_));
74     flags_.xmp_metadata = xmp_metadata_;  // StringPiece doesn't own data
75   }
76 
Compute(OpKernelContext * context)77   void Compute(OpKernelContext* context) override {
78     const Tensor& image = context->input(0);
79     OP_REQUIRES(context, image.dims() == 3,
80                 errors::InvalidArgument("image must be 3-dimensional",
81                                         image.shape().DebugString()));
82 
83     OP_REQUIRES(
84         context,
85         FastBoundsCheck(image.NumElements(), std::numeric_limits<int32>::max()),
86         errors::InvalidArgument(
87             "Cannot encode images with >= max int32 elements"));
88 
89     const int32 dim_size0 = static_cast<int32>(image.dim_size(0));
90     const int32 dim_size1 = static_cast<int32>(image.dim_size(1));
91     const int32 dim_size2 = static_cast<int32>(image.dim_size(2));
92 
93     // Autodetect format if desired, otherwise make sure format and
94     // image channels are consistent.
95     int channels;
96     jpeg::CompressFlags adjusted_flags = flags_;
97     if (flags_.format == 0) {
98       channels = dim_size2;
99       if (channels == 1) {
100         adjusted_flags.format = jpeg::FORMAT_GRAYSCALE;
101       } else if (channels == 3) {
102         adjusted_flags.format = jpeg::FORMAT_RGB;
103       } else {
104         OP_REQUIRES(
105             context, false,
106             errors::InvalidArgument("image must have 1 or 3 channels, got ",
107                                     image.shape().DebugString()));
108       }
109     } else {
110       if (flags_.format == jpeg::FORMAT_GRAYSCALE) {
111         channels = 1;
112       } else {  // RGB
113         channels = 3;
114       }
115       OP_REQUIRES(context, channels == dim_size2,
116                   errors::InvalidArgument("format ", format_, " expects ",
117                                           channels, " channels, got ",
118                                           image.shape().DebugString()));
119     }
120 
121     // Encode image to jpeg string
122     Tensor* output = nullptr;
123     OP_REQUIRES_OK(context,
124                    context->allocate_output(0, TensorShape({}), &output));
125     OP_REQUIRES(context,
126                 jpeg::Compress(image.flat<uint8>().data(), dim_size1, dim_size0,
127                                adjusted_flags, &output->scalar<tstring>()()),
128                 errors::Internal("JPEG encoding failed"));
129   }
130 
131  private:
132   string format_;
133   string xmp_metadata_;  // Owns data referenced by flags_
134   jpeg::CompressFlags flags_;
135 };
136 REGISTER_KERNEL_BUILDER(Name("EncodeJpeg").Device(DEVICE_CPU), EncodeJpegOp);
137 
138 class EncodeJpegVariableQualityOp : public OpKernel {
139  public:
EncodeJpegVariableQualityOp(OpKernelConstruction * context)140   explicit EncodeJpegVariableQualityOp(OpKernelConstruction* context)
141       : OpKernel(context) {}
142 
Compute(OpKernelContext * context)143   void Compute(OpKernelContext* context) override {
144     const Tensor& image = context->input(0);
145     OP_REQUIRES(context, image.dims() == 3,
146                 errors::InvalidArgument("image must be 3-dimensional",
147                                         image.shape().DebugString()));
148 
149     OP_REQUIRES(
150         context,
151         FastBoundsCheck(image.NumElements(), std::numeric_limits<int32>::max()),
152         errors::InvalidArgument(
153             "Cannot encode images with >= max int32 elements"));
154 
155     const int32 dim_size0 = static_cast<int32>(image.dim_size(0));
156     const int32 dim_size1 = static_cast<int32>(image.dim_size(1));
157     const int32 dim_size2 = static_cast<int32>(image.dim_size(2));
158 
159     // Use default jpeg compression flags except for format and quality.
160     jpeg::CompressFlags adjusted_flags;
161 
162     // Get jpeg encoding quality.
163     const Tensor& quality = context->input(1);
164     OP_REQUIRES(context, TensorShapeUtils::IsScalar(quality.shape()),
165                 errors::InvalidArgument("quality must be scalar: ",
166                                         quality.shape().DebugString()));
167     adjusted_flags.quality = quality.scalar<int>()();
168     OP_REQUIRES(context,
169                 0 <= adjusted_flags.quality && adjusted_flags.quality <= 100,
170                 errors::InvalidArgument("quality must be in [0,100], got ",
171                                         adjusted_flags.quality));
172 
173     // Autodetect format.
174     int channels;
175     channels = dim_size2;
176     if (channels == 1) {
177       adjusted_flags.format = jpeg::FORMAT_GRAYSCALE;
178     } else if (channels == 3) {
179       adjusted_flags.format = jpeg::FORMAT_RGB;
180     } else {
181       OP_REQUIRES(
182           context, false,
183           errors::InvalidArgument("image must have 1 or 3 channels, got ",
184                                   image.shape().DebugString()));
185     }
186 
187     // Encode image to jpeg string
188     Tensor* output = nullptr;
189     OP_REQUIRES_OK(context,
190                    context->allocate_output(0, TensorShape({}), &output));
191     OP_REQUIRES(context,
192                 jpeg::Compress(image.flat<uint8>().data(), dim_size1, dim_size0,
193                                adjusted_flags, &output->scalar<tstring>()()),
194                 errors::Internal("JPEG encoding failed"));
195   }
196 };
197 REGISTER_KERNEL_BUILDER(Name("EncodeJpegVariableQuality").Device(DEVICE_CPU),
198                         EncodeJpegVariableQualityOp);
199 
200 }  // namespace tensorflow
201