• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_ILSVRC_INCEPTION_PREPROCESSING_H_
17 #define TENSORFLOW_LITE_TOOLS_ACCURACY_ILSVRC_INCEPTION_PREPROCESSING_H_
18 
19 #include <utility>
20 
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/lite/tools/accuracy/stage.h"
24 
25 namespace tensorflow {
26 namespace metrics {
27 
28 // A stage that does inception preprocessing.
29 // Inputs: A tensor containing bytes of a JPEG image.
30 // Outputs: A tensor containing rescaled and preprocessed image that has
31 // shape {1, image_height, image_width, 3}, where 3 is the number of channels.
32 class InceptionPreprocessingStage : public Stage {
33  public:
34   // Preprocessing params that govern scaling and normalization of channels of
35   // the image.
36   struct Params {
37     // Input means are subtracted from each channel.
38     // In case of an empty vector this is skipped.
39     std::vector<float> input_means;
40     // Scale is used to divide the input.
41     // A scale of 0 means divison is skipped.
42     float scale;
43     double cropping_fraction;
44   };
45 
46   // Default preprocessing for inception stage based on |output_type|
DefaultParamsForType(DataType output_type)47   static Params DefaultParamsForType(DataType output_type) {
48     const float kCroppingFraction = 0.875;
49     Params params = {};
50     params.cropping_fraction = kCroppingFraction;
51     if (output_type == DT_UINT8) {
52     } else if (output_type == DT_INT8) {
53       params.input_means = {128.0, 128.0, 128.0};
54     } else {
55       // Assume floating point preprocessing.
56       params.input_means = {127.5, 127.5, 127.5};
57       params.scale = 127.5;
58     }
59     return params;
60   }
61 
62   // Creates a new preprocessing stage object with provided |image_width|
63   // |image_height| as the size of output image.
64   // |output_datatype| is the datatype of output of the stage.
InceptionPreprocessingStage(int image_width,int image_height,DataType output_datatype)65   InceptionPreprocessingStage(int image_width, int image_height,
66                               DataType output_datatype)
67       : output_datatype_(output_datatype),
68         image_width_(image_width),
69         image_height_(image_height) {
70     params_ = DefaultParamsForType(output_datatype);
71   }
72 
73   // Creates a new preprocessing stage object with provided |image_width|
74   // |image_height| as the size of output image.
75   // |output_datatype| is the datatype of output of the stage.
InceptionPreprocessingStage(int image_width,int image_height,DataType output_datatype,Params params)76   InceptionPreprocessingStage(int image_width, int image_height,
77                               DataType output_datatype, Params params)
78       : output_datatype_(output_datatype),
79         image_width_(image_width),
80         image_height_(image_height),
81         params_(std::move(params)) {}
82 
name()83   string name() const override { return "stage_inception_preprocess"; }
output_name()84   string output_name() const override {
85     return "stage_inception_preprocess_output";
86   }
87 
88   void AddToGraph(const Scope& scope, const Input& input) override;
89 
90  private:
91   DataType output_datatype_;
92   int image_width_;
93   int image_height_;
94   bool is_quantized_;
95   Params params_;
96 };
97 
98 }  // namespace metrics
99 }  // namespace tensorflow
100 
101 #endif  // TENSORFLOW_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_
102