1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.examples; 17 18 import java.io.IOException; 19 import java.io.PrintStream; 20 import java.nio.charset.Charset; 21 import java.nio.file.Files; 22 import java.nio.file.Path; 23 import java.nio.file.Paths; 24 import java.util.Arrays; 25 import java.util.List; 26 import org.tensorflow.DataType; 27 import org.tensorflow.Graph; 28 import org.tensorflow.Output; 29 import org.tensorflow.Session; 30 import org.tensorflow.Tensor; 31 import org.tensorflow.TensorFlow; 32 import org.tensorflow.types.UInt8; 33 34 /** Sample use of the TensorFlow Java API to label images using a pre-trained model. */ 35 public class LabelImage { printUsage(PrintStream s)36 private static void printUsage(PrintStream s) { 37 final String url = 38 "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip"; 39 s.println( 40 "Java program that uses a pre-trained Inception model (http://arxiv.org/abs/1512.00567)"); 41 s.println("to label JPEG images."); 42 s.println("TensorFlow version: " + TensorFlow.version()); 43 s.println(); 44 s.println("Usage: label_image <model dir> <image file>"); 45 s.println(); 46 s.println("Where:"); 47 s.println("<model dir> is a directory containing the unzipped contents of the inception model"); 48 s.println(" (from " + url + ")"); 49 s.println("<image file> is the path to a JPEG image file"); 50 } 51 main(String[] args)52 public static void main(String[] args) { 53 if (args.length != 2) { 54 printUsage(System.err); 55 System.exit(1); 56 } 57 String modelDir = args[0]; 58 String imageFile = args[1]; 59 60 byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb")); 61 List<String> labels = 62 readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt")); 63 byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile)); 64 65 try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) { 66 float[] labelProbabilities = executeInceptionGraph(graphDef, image); 67 int bestLabelIdx = maxIndex(labelProbabilities); 68 System.out.println( 69 String.format("BEST MATCH: %s (%.2f%% likely)", 70 labels.get(bestLabelIdx), 71 labelProbabilities[bestLabelIdx] * 100f)); 72 } 73 } 74 constructAndExecuteGraphToNormalizeImage(byte[] imageBytes)75 private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) { 76 try (Graph g = new Graph()) { 77 GraphBuilder b = new GraphBuilder(g); 78 // Some constants specific to the pre-trained model at: 79 // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip 80 // 81 // - The model was trained with images scaled to 224x224 pixels. 82 // - The colors, represented as R, G, B in 1-byte each were converted to 83 // float using (value - Mean)/Scale. 84 final int H = 224; 85 final int W = 224; 86 final float mean = 117f; 87 final float scale = 1f; 88 89 // Since the graph is being constructed once per execution here, we can use a constant for the 90 // input image. If the graph were to be re-used for multiple input images, a placeholder would 91 // have been more appropriate. 92 final Output<String> input = b.constant("input", imageBytes); 93 final Output<Float> output = 94 b.div( 95 b.sub( 96 b.resizeBilinear( 97 b.expandDims( 98 b.cast(b.decodeJpeg(input, 3), Float.class), 99 b.constant("make_batch", 0)), 100 b.constant("size", new int[] {H, W})), 101 b.constant("mean", mean)), 102 b.constant("scale", scale)); 103 try (Session s = new Session(g)) { 104 // Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks. 105 return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class); 106 } 107 } 108 } 109 executeInceptionGraph(byte[] graphDef, Tensor<Float> image)110 private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) { 111 try (Graph g = new Graph()) { 112 g.importGraphDef(graphDef); 113 try (Session s = new Session(g); 114 // Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks. 115 Tensor<Float> result = 116 s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) { 117 final long[] rshape = result.shape(); 118 if (result.numDimensions() != 2 || rshape[0] != 1) { 119 throw new RuntimeException( 120 String.format( 121 "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", 122 Arrays.toString(rshape))); 123 } 124 int nlabels = (int) rshape[1]; 125 return result.copyTo(new float[1][nlabels])[0]; 126 } 127 } 128 } 129 maxIndex(float[] probabilities)130 private static int maxIndex(float[] probabilities) { 131 int best = 0; 132 for (int i = 1; i < probabilities.length; ++i) { 133 if (probabilities[i] > probabilities[best]) { 134 best = i; 135 } 136 } 137 return best; 138 } 139 readAllBytesOrExit(Path path)140 private static byte[] readAllBytesOrExit(Path path) { 141 try { 142 return Files.readAllBytes(path); 143 } catch (IOException e) { 144 System.err.println("Failed to read [" + path + "]: " + e.getMessage()); 145 System.exit(1); 146 } 147 return null; 148 } 149 readAllLinesOrExit(Path path)150 private static List<String> readAllLinesOrExit(Path path) { 151 try { 152 return Files.readAllLines(path, Charset.forName("UTF-8")); 153 } catch (IOException e) { 154 System.err.println("Failed to read [" + path + "]: " + e.getMessage()); 155 System.exit(0); 156 } 157 return null; 158 } 159 160 // In the fullness of time, equivalents of the methods of this class should be auto-generated from 161 // the OpDefs linked into libtensorflow_jni.so. That would match what is done in other languages 162 // like Python, C++ and Go. 163 static class GraphBuilder { GraphBuilder(Graph g)164 GraphBuilder(Graph g) { 165 this.g = g; 166 } 167 div(Output<Float> x, Output<Float> y)168 Output<Float> div(Output<Float> x, Output<Float> y) { 169 return binaryOp("Div", x, y); 170 } 171 sub(Output<T> x, Output<T> y)172 <T> Output<T> sub(Output<T> x, Output<T> y) { 173 return binaryOp("Sub", x, y); 174 } 175 resizeBilinear(Output<T> images, Output<Integer> size)176 <T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) { 177 return binaryOp3("ResizeBilinear", images, size); 178 } 179 expandDims(Output<T> input, Output<Integer> dim)180 <T> Output<T> expandDims(Output<T> input, Output<Integer> dim) { 181 return binaryOp3("ExpandDims", input, dim); 182 } 183 cast(Output<T> value, Class<U> type)184 <T, U> Output<U> cast(Output<T> value, Class<U> type) { 185 DataType dtype = DataType.fromClass(type); 186 return g.opBuilder("Cast", "Cast") 187 .addInput(value) 188 .setAttr("DstT", dtype) 189 .build() 190 .<U>output(0); 191 } 192 decodeJpeg(Output<String> contents, long channels)193 Output<UInt8> decodeJpeg(Output<String> contents, long channels) { 194 return g.opBuilder("DecodeJpeg", "DecodeJpeg") 195 .addInput(contents) 196 .setAttr("channels", channels) 197 .build() 198 .<UInt8>output(0); 199 } 200 constant(String name, Object value, Class<T> type)201 <T> Output<T> constant(String name, Object value, Class<T> type) { 202 try (Tensor<T> t = Tensor.<T>create(value, type)) { 203 return g.opBuilder("Const", name) 204 .setAttr("dtype", DataType.fromClass(type)) 205 .setAttr("value", t) 206 .build() 207 .<T>output(0); 208 } 209 } constant(String name, byte[] value)210 Output<String> constant(String name, byte[] value) { 211 return this.constant(name, value, String.class); 212 } 213 constant(String name, int value)214 Output<Integer> constant(String name, int value) { 215 return this.constant(name, value, Integer.class); 216 } 217 constant(String name, int[] value)218 Output<Integer> constant(String name, int[] value) { 219 return this.constant(name, value, Integer.class); 220 } 221 constant(String name, float value)222 Output<Float> constant(String name, float value) { 223 return this.constant(name, value, Float.class); 224 } 225 binaryOp(String type, Output<T> in1, Output<T> in2)226 private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) { 227 return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0); 228 } 229 binaryOp3(String type, Output<U> in1, Output<V> in2)230 private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) { 231 return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0); 232 } 233 private Graph g; 234 } 235 } 236