• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.examples;
17 
18 import java.io.IOException;
19 import java.io.PrintStream;
20 import java.nio.charset.Charset;
21 import java.nio.file.Files;
22 import java.nio.file.Path;
23 import java.nio.file.Paths;
24 import java.util.Arrays;
25 import java.util.List;
26 import org.tensorflow.DataType;
27 import org.tensorflow.Graph;
28 import org.tensorflow.Output;
29 import org.tensorflow.Session;
30 import org.tensorflow.Tensor;
31 import org.tensorflow.TensorFlow;
32 import org.tensorflow.types.UInt8;
33 
34 /** Sample use of the TensorFlow Java API to label images using a pre-trained model. */
35 public class LabelImage {
printUsage(PrintStream s)36   private static void printUsage(PrintStream s) {
37     final String url =
38         "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip";
39     s.println(
40         "Java program that uses a pre-trained Inception model (http://arxiv.org/abs/1512.00567)");
41     s.println("to label JPEG images.");
42     s.println("TensorFlow version: " + TensorFlow.version());
43     s.println();
44     s.println("Usage: label_image <model dir> <image file>");
45     s.println();
46     s.println("Where:");
47     s.println("<model dir> is a directory containing the unzipped contents of the inception model");
48     s.println("            (from " + url + ")");
49     s.println("<image file> is the path to a JPEG image file");
50   }
51 
main(String[] args)52   public static void main(String[] args) {
53     if (args.length != 2) {
54       printUsage(System.err);
55       System.exit(1);
56     }
57     String modelDir = args[0];
58     String imageFile = args[1];
59 
60     byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb"));
61     List<String> labels =
62         readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"));
63     byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));
64 
65     try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
66       float[] labelProbabilities = executeInceptionGraph(graphDef, image);
67       int bestLabelIdx = maxIndex(labelProbabilities);
68       System.out.println(
69           String.format("BEST MATCH: %s (%.2f%% likely)",
70               labels.get(bestLabelIdx),
71               labelProbabilities[bestLabelIdx] * 100f));
72     }
73   }
74 
constructAndExecuteGraphToNormalizeImage(byte[] imageBytes)75   private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
76     try (Graph g = new Graph()) {
77       GraphBuilder b = new GraphBuilder(g);
78       // Some constants specific to the pre-trained model at:
79       // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
80       //
81       // - The model was trained with images scaled to 224x224 pixels.
82       // - The colors, represented as R, G, B in 1-byte each were converted to
83       //   float using (value - Mean)/Scale.
84       final int H = 224;
85       final int W = 224;
86       final float mean = 117f;
87       final float scale = 1f;
88 
89       // Since the graph is being constructed once per execution here, we can use a constant for the
90       // input image. If the graph were to be re-used for multiple input images, a placeholder would
91       // have been more appropriate.
92       final Output<String> input = b.constant("input", imageBytes);
93       final Output<Float> output =
94           b.div(
95               b.sub(
96                   b.resizeBilinear(
97                       b.expandDims(
98                           b.cast(b.decodeJpeg(input, 3), Float.class),
99                           b.constant("make_batch", 0)),
100                       b.constant("size", new int[] {H, W})),
101                   b.constant("mean", mean)),
102               b.constant("scale", scale));
103       try (Session s = new Session(g)) {
104         // Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
105         return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
106       }
107     }
108   }
109 
executeInceptionGraph(byte[] graphDef, Tensor<Float> image)110   private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {
111     try (Graph g = new Graph()) {
112       g.importGraphDef(graphDef);
113       try (Session s = new Session(g);
114           // Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
115           Tensor<Float> result =
116               s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) {
117         final long[] rshape = result.shape();
118         if (result.numDimensions() != 2 || rshape[0] != 1) {
119           throw new RuntimeException(
120               String.format(
121                   "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
122                   Arrays.toString(rshape)));
123         }
124         int nlabels = (int) rshape[1];
125         return result.copyTo(new float[1][nlabels])[0];
126       }
127     }
128   }
129 
maxIndex(float[] probabilities)130   private static int maxIndex(float[] probabilities) {
131     int best = 0;
132     for (int i = 1; i < probabilities.length; ++i) {
133       if (probabilities[i] > probabilities[best]) {
134         best = i;
135       }
136     }
137     return best;
138   }
139 
readAllBytesOrExit(Path path)140   private static byte[] readAllBytesOrExit(Path path) {
141     try {
142       return Files.readAllBytes(path);
143     } catch (IOException e) {
144       System.err.println("Failed to read [" + path + "]: " + e.getMessage());
145       System.exit(1);
146     }
147     return null;
148   }
149 
readAllLinesOrExit(Path path)150   private static List<String> readAllLinesOrExit(Path path) {
151     try {
152       return Files.readAllLines(path, Charset.forName("UTF-8"));
153     } catch (IOException e) {
154       System.err.println("Failed to read [" + path + "]: " + e.getMessage());
155       System.exit(0);
156     }
157     return null;
158   }
159 
160   // In the fullness of time, equivalents of the methods of this class should be auto-generated from
161   // the OpDefs linked into libtensorflow_jni.so. That would match what is done in other languages
162   // like Python, C++ and Go.
163   static class GraphBuilder {
GraphBuilder(Graph g)164     GraphBuilder(Graph g) {
165       this.g = g;
166     }
167 
div(Output<Float> x, Output<Float> y)168     Output<Float> div(Output<Float> x, Output<Float> y) {
169       return binaryOp("Div", x, y);
170     }
171 
sub(Output<T> x, Output<T> y)172     <T> Output<T> sub(Output<T> x, Output<T> y) {
173       return binaryOp("Sub", x, y);
174     }
175 
resizeBilinear(Output<T> images, Output<Integer> size)176     <T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) {
177       return binaryOp3("ResizeBilinear", images, size);
178     }
179 
expandDims(Output<T> input, Output<Integer> dim)180     <T> Output<T> expandDims(Output<T> input, Output<Integer> dim) {
181       return binaryOp3("ExpandDims", input, dim);
182     }
183 
cast(Output<T> value, Class<U> type)184     <T, U> Output<U> cast(Output<T> value, Class<U> type) {
185       DataType dtype = DataType.fromClass(type);
186       return g.opBuilder("Cast", "Cast")
187           .addInput(value)
188           .setAttr("DstT", dtype)
189           .build()
190           .<U>output(0);
191     }
192 
decodeJpeg(Output<String> contents, long channels)193     Output<UInt8> decodeJpeg(Output<String> contents, long channels) {
194       return g.opBuilder("DecodeJpeg", "DecodeJpeg")
195           .addInput(contents)
196           .setAttr("channels", channels)
197           .build()
198           .<UInt8>output(0);
199     }
200 
constant(String name, Object value, Class<T> type)201     <T> Output<T> constant(String name, Object value, Class<T> type) {
202       try (Tensor<T> t = Tensor.<T>create(value, type)) {
203         return g.opBuilder("Const", name)
204             .setAttr("dtype", DataType.fromClass(type))
205             .setAttr("value", t)
206             .build()
207             .<T>output(0);
208       }
209     }
constant(String name, byte[] value)210     Output<String> constant(String name, byte[] value) {
211       return this.constant(name, value, String.class);
212     }
213 
constant(String name, int value)214     Output<Integer> constant(String name, int value) {
215       return this.constant(name, value, Integer.class);
216     }
217 
constant(String name, int[] value)218     Output<Integer> constant(String name, int[] value) {
219       return this.constant(name, value, Integer.class);
220     }
221 
constant(String name, float value)222     Output<Float> constant(String name, float value) {
223       return this.constant(name, value, Float.class);
224     }
225 
binaryOp(String type, Output<T> in1, Output<T> in2)226     private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) {
227       return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
228     }
229 
binaryOp3(String type, Output<U> in1, Output<V> in2)230     private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) {
231       return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
232     }
233     private Graph g;
234   }
235 }
236