1 /* 2 * Copyright (C) 2025 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.util.proto; 18 19 import java.io.IOException; 20 import java.io.InputStream; 21 import java.io.OutputStream; 22 import java.util.function.Predicate; 23 24 /** 25 * A utility class that reads raw protobuf data from an InputStream 26 * and copies only those fields for which a given predicate returns true. 27 * 28 * <p> 29 * This is a low-level approach that does not fully decode fields 30 * (unless necessary to determine lengths). It simply: 31 * <ul> 32 * <li>Parses each field's tag (varint for field number & wire type)</li> 33 * <li>If {@code includeFn(fieldNumber) == true}, copies 34 * the tag bytes and the field bytes directly to the output</li> 35 * <li>Otherwise, skips that field in the input</li> 36 * </ul> 37 * </p> 38 * 39 * <p> 40 * Because we do not re-encode, unknown or unrecognized fields are copied 41 * <i>verbatim</i> and remain exactly as in the input (useful for partial 42 * parsing or partial transformations). 43 * </p> 44 * 45 * <p> 46 * Note: This class only filters based on top-level field numbers. For length-delimited 47 * fields (including nested messages), the entire contents are either copied or skipped 48 * as a single unit. The class is not capable of nested filtering. 49 * </p> 50 * 51 * @hide 52 */ 53 @android.ravenwood.annotation.RavenwoodKeepWholeClass 54 public class ProtoFieldFilter { 55 56 private static final int BUFFER_SIZE_BYTES = 4096; 57 58 private final Predicate<Integer> mFieldPredicate; 59 // General-purpose buffer for reading proto fields and their data 60 private final byte[] mBuffer; 61 // Buffer specifically designated to hold varint values (max 10 bytes in protobuf encoding) 62 private final byte[] mVarIntBuffer = new byte[10]; 63 64 /** 65 * Constructs a ProtoFieldFilter with a predicate that considers depth. 66 * 67 * @param fieldPredicate A predicate returning true if the given fieldNumber should be 68 * included in the output. 69 * @param bufferSize The size of the internal buffer used for processing proto fields. 70 * Larger buffers may improve performance when processing large 71 * length-delimited fields. 72 */ ProtoFieldFilter(Predicate<Integer> fieldPredicate, int bufferSize)73 public ProtoFieldFilter(Predicate<Integer> fieldPredicate, int bufferSize) { 74 this.mFieldPredicate = fieldPredicate; 75 this.mBuffer = new byte[bufferSize]; 76 } 77 78 /** 79 * Constructs a ProtoFieldFilter with a predicate that considers depth and 80 * uses a default buffer size. 81 * 82 * @param fieldPredicate A predicate returning true if the given fieldNumber should be 83 * included in the output. 84 */ ProtoFieldFilter(Predicate<Integer> fieldPredicate)85 public ProtoFieldFilter(Predicate<Integer> fieldPredicate) { 86 this(fieldPredicate, BUFFER_SIZE_BYTES); 87 } 88 89 /** 90 * Reads raw protobuf data from {@code in} and writes only those fields 91 * passing {@code includeFn} to {@code out}. The predicate is given 92 * (fieldNumber, wireType) for each encountered field. 93 * 94 * @param in The input stream of protobuf data 95 * @param out The output stream to which we write the filtered protobuf 96 * @throws IOException If reading or writing fails, or if the protobuf data is corrupted 97 */ filter(InputStream in, OutputStream out)98 public void filter(InputStream in, OutputStream out) throws IOException { 99 int tagBytesLength; 100 while ((tagBytesLength = readRawVarint(in)) > 0) { 101 // Parse the varint loaded in mVarIntBuffer, through readRawVarint 102 long tagVal = parseVarint(mVarIntBuffer, tagBytesLength); 103 int fieldNumber = (int) (tagVal >>> ProtoStream.FIELD_ID_SHIFT); 104 int wireType = (int) (tagVal & ProtoStream.WIRE_TYPE_MASK); 105 106 if (fieldNumber == 0) { 107 break; 108 } 109 if (mFieldPredicate.test(fieldNumber)) { 110 out.write(mVarIntBuffer, 0, tagBytesLength); 111 copyFieldData(in, out, wireType); 112 } else { 113 skipFieldData(in, wireType); 114 } 115 } 116 } 117 118 /** 119 * Reads a varint (up to 10 bytes) from the stream as raw bytes 120 * and returns it in a byte array. If the stream is at EOF, returns null. 121 * 122 * @param in The input stream 123 * @return the size of the varint bytes moved to mVarIntBuffer 124 * @throws IOException If an error occurs, or if we detect a malformed varint 125 */ readRawVarint(InputStream in)126 private int readRawVarint(InputStream in) throws IOException { 127 // We attempt to read 1 byte. If none available => null 128 int b = in.read(); 129 if (b < 0) { 130 return 0; 131 } 132 int count = 0; 133 mVarIntBuffer[count++] = (byte) b; 134 // If the continuation bit is set, we continue 135 while ((b & 0x80) != 0) { 136 // read next byte 137 b = in.read(); 138 // EOF 139 if (b < 0) { 140 throw new IOException("Malformed varint: reached EOF mid-varint"); 141 } 142 // max 10 bytes for varint 64 143 if (count >= 10) { 144 throw new IOException("Malformed varint: too many bytes (max 10)"); 145 } 146 mVarIntBuffer[count++] = (byte) b; 147 } 148 return count; 149 } 150 151 /** 152 * Parses a varint from the given raw bytes and returns it as a long. 153 * 154 * @param rawVarint The bytes representing the varint 155 * @param byteLength The number of bytes to read from rawVarint 156 * @return The decoded long value 157 */ parseVarint(byte[] rawVarint, int byteLength)158 private static long parseVarint(byte[] rawVarint, int byteLength) throws IOException { 159 long result = 0; 160 int shift = 0; 161 for (int i = 0; i < byteLength; i++) { 162 result |= ((rawVarint[i] & 0x7F) << shift); 163 shift += 7; 164 if (shift > 63) { 165 throw new IOException("Malformed varint: exceeds 64 bits"); 166 } 167 } 168 return result; 169 } 170 171 /** 172 * Copies the wire data for a single field from {@code in} to {@code out}, 173 * assuming we have already read the field's tag. 174 * 175 * @param in The input stream (protobuf data) 176 * @param out The output stream 177 * @param wireType The wire type (0=varint, 1=fixed64, 2=length-delim, 5=fixed32) 178 * @throws IOException if reading/writing fails or data is malformed 179 */ copyFieldData(InputStream in, OutputStream out, int wireType)180 private void copyFieldData(InputStream in, OutputStream out, int wireType) 181 throws IOException { 182 switch (wireType) { 183 case ProtoStream.WIRE_TYPE_VARINT: 184 copyVarint(in, out); 185 break; 186 case ProtoStream.WIRE_TYPE_FIXED64: 187 copyFixed(in, out, 8); 188 break; 189 case ProtoStream.WIRE_TYPE_LENGTH_DELIMITED: 190 copyLengthDelimited(in, out); 191 break; 192 case ProtoStream.WIRE_TYPE_FIXED32: 193 copyFixed(in, out, 4); 194 break; 195 // case WIRE_TYPE_START_GROUP: 196 // Not Supported 197 // case WIRE_TYPE_END_GROUP: 198 // Not Supported 199 default: 200 // Error or unrecognized wire type 201 throw new IOException("Unknown or unsupported wire type: " + wireType); 202 } 203 } 204 205 /** 206 * Skips the wire data for a single field from {@code in}, 207 * assuming the field's tag was already read. 208 */ skipFieldData(InputStream in, int wireType)209 private void skipFieldData(InputStream in, int wireType) throws IOException { 210 switch (wireType) { 211 case ProtoStream.WIRE_TYPE_VARINT: 212 skipVarint(in); 213 break; 214 case ProtoStream.WIRE_TYPE_FIXED64: 215 skipBytes(in, 8); 216 break; 217 case ProtoStream.WIRE_TYPE_LENGTH_DELIMITED: 218 skipLengthDelimited(in); 219 break; 220 case ProtoStream.WIRE_TYPE_FIXED32: 221 skipBytes(in, 4); 222 break; 223 // case WIRE_TYPE_START_GROUP: 224 // Not Supported 225 // case WIRE_TYPE_END_GROUP: 226 // Not Supported 227 default: 228 throw new IOException("Unknown or unsupported wire type: " + wireType); 229 } 230 } 231 232 /** Copies a varint (the field's value) from in to out. */ copyVarint(InputStream in, OutputStream out)233 private static void copyVarint(InputStream in, OutputStream out) throws IOException { 234 while (true) { 235 int b = in.read(); 236 if (b < 0) { 237 throw new IOException("EOF while copying varint"); 238 } 239 out.write(b); 240 if ((b & 0x80) == 0) { 241 break; 242 } 243 } 244 } 245 246 /** 247 * Copies exactly {@code length} bytes from {@code in} to {@code out}. 248 */ copyFixed(InputStream in, OutputStream out, int length)249 private void copyFixed(InputStream in, OutputStream out, 250 int length) throws IOException { 251 int toRead = length; 252 while (toRead > 0) { 253 int chunk = Math.min(toRead, mBuffer.length); 254 int readCount = in.read(mBuffer, 0, chunk); 255 if (readCount < 0) { 256 throw new IOException("EOF while copying fixed" + (length * 8) + " field"); 257 } 258 out.write(mBuffer, 0, readCount); 259 toRead -= readCount; 260 } 261 } 262 263 /** Copies a length-delimited field */ copyLengthDelimited(InputStream in, OutputStream out)264 private void copyLengthDelimited(InputStream in, 265 OutputStream out) throws IOException { 266 // 1) read length varint (and copy) 267 int lengthVarintLength = readRawVarint(in); 268 if (lengthVarintLength <= 0) { 269 throw new IOException("EOF reading length for length-delimited field"); 270 } 271 out.write(mVarIntBuffer, 0, lengthVarintLength); 272 273 long lengthVal = parseVarint(mVarIntBuffer, lengthVarintLength); 274 if (lengthVal < 0 || lengthVal > Integer.MAX_VALUE) { 275 throw new IOException("Invalid length for length-delimited field: " + lengthVal); 276 } 277 278 // 2) copy that many bytes 279 copyFixed(in, out, (int) lengthVal); 280 } 281 282 /** Skips a varint in the input (does not write anything). */ skipVarint(InputStream in)283 private static void skipVarint(InputStream in) throws IOException { 284 int bytesSkipped = 0; 285 while (true) { 286 int b = in.read(); 287 if (b < 0) { 288 throw new IOException("EOF while skipping varint"); 289 } 290 if ((b & 0x80) == 0) { 291 break; 292 } 293 bytesSkipped++; 294 if (bytesSkipped > 10) { 295 throw new IOException("Malformed varint: exceeds maximum length of 10 bytes"); 296 } 297 } 298 } 299 300 /** Skips exactly n bytes. */ skipBytes(InputStream in, long n)301 private void skipBytes(InputStream in, long n) throws IOException { 302 long skipped = in.skip(n); 303 // If skip fails, fallback to reading the remaining bytes 304 if (skipped < n) { 305 long bytesRemaining = n - skipped; 306 307 while (bytesRemaining > 0) { 308 int bytesToRead = (int) Math.min(bytesRemaining, mBuffer.length); 309 int bytesRead = in.read(mBuffer, 0, bytesToRead); 310 if (bytesRemaining < 0) { 311 throw new IOException("EOF while skipping bytes"); 312 } 313 bytesRemaining -= bytesRead; 314 } 315 } 316 } 317 318 /** 319 * Skips a length-delimited field. 320 * 1) read the length as varint, 321 * 2) skip that many bytes 322 */ skipLengthDelimited(InputStream in)323 private void skipLengthDelimited(InputStream in) throws IOException { 324 int lengthVarintLength = readRawVarint(in); 325 if (lengthVarintLength <= 0) { 326 throw new IOException("EOF reading length for length-delimited field"); 327 } 328 long lengthVal = parseVarint(mVarIntBuffer, lengthVarintLength); 329 if (lengthVal < 0 || lengthVal > Integer.MAX_VALUE) { 330 throw new IOException("Invalid length to skip: " + lengthVal); 331 } 332 skipBytes(in, lengthVal); 333 } 334 335 } 336