1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow; 17 18 import java.lang.reflect.Array; 19 import java.util.ArrayList; 20 import java.util.Collection; 21 22 /** Static utility functions. */ 23 public class TestUtil { 24 25 public static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E> 26 implements AutoCloseable { AutoCloseableList(Collection<? extends E> c)27 public AutoCloseableList(Collection<? extends E> c) { 28 super(c); 29 } 30 31 @Override close()32 public void close() { 33 Exception toThrow = null; 34 for (AutoCloseable c : this) { 35 try { 36 c.close(); 37 } catch (Exception e) { 38 toThrow = e; 39 } 40 } 41 if (toThrow != null) { 42 throw new RuntimeException(toThrow); 43 } 44 } 45 } 46 constantOp(Graph g, String name, Object value)47 public static GraphOperation constantOp(Graph g, String name, Object value) { 48 try (Tensor<?> t = Tensor.create(value)) { 49 return g.opBuilder("Const", name).setAttr("dtype", t.dataType()).setAttr("value", t).build(); 50 } 51 } 52 constant(ExecutionEnvironment env, String name, Object value)53 public static <T> Output<T> constant(ExecutionEnvironment env, String name, Object value) { 54 try (Tensor<?> t = Tensor.create(value)) { 55 return env.opBuilder("Const", name) 56 .setAttr("dtype", t.dataType()) 57 .setAttr("value", t) 58 .build() 59 .<T>output(0); 60 } 61 } 62 placeholder(Graph g, String name, Class<T> type)63 public static <T> Output<T> placeholder(Graph g, String name, Class<T> type) { 64 return g.opBuilder("Placeholder", name) 65 .setAttr("dtype", DataType.fromClass(type)) 66 .build() 67 .<T>output(0); 68 } 69 addN(ExecutionEnvironment env, Output<?>... inputs)70 public static <T> Output<T> addN(ExecutionEnvironment env, Output<?>... inputs) { 71 return env.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0); 72 } 73 matmul( Graph g, String name, Output<T> a, Output<T> b, boolean transposeA, boolean transposeB)74 public static <T> Output<T> matmul( 75 Graph g, String name, Output<T> a, Output<T> b, boolean transposeA, boolean transposeB) { 76 return g.opBuilder("MatMul", name) 77 .addInput(a) 78 .addInput(b) 79 .setAttr("transpose_a", transposeA) 80 .setAttr("transpose_b", transposeB) 81 .build() 82 .<T>output(0); 83 } 84 split(Graph g, String name, int[] values, int numSplit)85 public static Operation split(Graph g, String name, int[] values, int numSplit) { 86 return g.opBuilder("Split", name) 87 .addInput(constant(g, "split_dim", 0)) 88 .addInput(constant(g, "values", values)) 89 .setAttr("num_split", numSplit) 90 .build(); 91 } 92 square(Graph g, String name, Output<T> value)93 public static <T> Output<T> square(Graph g, String name, Output<T> value) { 94 return g.opBuilder("Square", name) 95 .addInput(value) 96 .build() 97 .<T>output(0); 98 } 99 transpose_A_times_X(Graph g, int[][] a)100 public static void transpose_A_times_X(Graph g, int[][] a) { 101 Output<Integer> aa = constant(g, "A", a); 102 matmul(g, "Y", aa, placeholder(g, "X", Integer.class), true, false); 103 } 104 105 /** 106 * Counts the total number of elements in an ND array. 107 * 108 * @param array the array to count the elements of 109 * @return the number of elements 110 */ flattenedNumElements(Object array)111 public static int flattenedNumElements(Object array) { 112 int count = 0; 113 for (int i = 0; i < Array.getLength(array); i++) { 114 Object e = Array.get(array, i); 115 if (!e.getClass().isArray()) { 116 count += 1; 117 } else { 118 count += flattenedNumElements(e); 119 } 120 } 121 return count; 122 } 123 124 /** 125 * Flattens an ND-array into a 1D-array with the same elements. 126 * 127 * @param array the array to flatten 128 * @param elementType the element class (e.g. {@code Integer.TYPE} for an {@code int[]}) 129 * @return a flattened array 130 */ flatten(Object array, Class<?> elementType)131 public static Object flatten(Object array, Class<?> elementType) { 132 Object out = Array.newInstance(elementType, flattenedNumElements(array)); 133 flatten(array, out, 0); 134 return out; 135 } 136 flatten(Object array, Object out, int next)137 private static int flatten(Object array, Object out, int next) { 138 for (int i = 0; i < Array.getLength(array); i++) { 139 Object e = Array.get(array, i); 140 if (!e.getClass().isArray()) { 141 Array.set(out, next++, e); 142 } else { 143 next = flatten(e, out, next); 144 } 145 } 146 return next; 147 } 148 149 /** 150 * Converts a {@code boolean[]} to a {@code byte[]}. 151 * 152 * <p>Suitable for creating tensors of type {@link DataType#BOOL} using {@link 153 * java.nio.ByteBuffer}. 154 */ bool2byte(boolean[] array)155 public static byte[] bool2byte(boolean[] array) { 156 byte[] out = new byte[array.length]; 157 for (int i = 0; i < array.length; i++) { 158 out[i] = array[i] ? (byte) 1 : (byte) 0; 159 } 160 return out; 161 } 162 163 /** 164 * Converts a {@code byte[]} to a {@code boolean[]}. 165 * 166 * <p>Suitable for reading tensors of type {@link DataType#BOOL} using {@link 167 * java.nio.ByteBuffer}. 168 */ byte2bool(byte[] array)169 public static boolean[] byte2bool(byte[] array) { 170 boolean[] out = new boolean[array.length]; 171 for (int i = 0; i < array.length; i++) { 172 out[i] = array[i] != 0; 173 } 174 return out; 175 } 176 TestUtil()177 private TestUtil() {} 178 } 179