• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2025 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.util.proto;
18 
19 import java.io.IOException;
20 import java.io.InputStream;
21 import java.io.OutputStream;
22 import java.util.function.Predicate;
23 
24 /**
25  * A utility class that reads raw protobuf data from an InputStream
26  * and copies only those fields for which a given predicate returns true.
27  *
28  * <p>
29  * This is a low-level approach that does not fully decode fields
30  * (unless necessary to determine lengths). It simply:
31  * <ul>
32  *   <li>Parses each field's tag (varint for field number & wire type)</li>
33  *   <li>If {@code includeFn(fieldNumber) == true}, copies
34  *       the tag bytes and the field bytes directly to the output</li>
35  *   <li>Otherwise, skips that field in the input</li>
36  * </ul>
37  * </p>
38  *
39  * <p>
40  * Because we do not re-encode, unknown or unrecognized fields are copied
41  * <i>verbatim</i> and remain exactly as in the input (useful for partial
42  * parsing or partial transformations).
43  * </p>
44  *
45  * <p>
46  * Note: This class only filters based on top-level field numbers. For length-delimited
47  * fields (including nested messages), the entire contents are either copied or skipped
48  * as a single unit. The class is not capable of nested filtering.
49  * </p>
50  *
51  * @hide
52  */
53 @android.ravenwood.annotation.RavenwoodKeepWholeClass
54 public class ProtoFieldFilter {
55 
56     private static final int BUFFER_SIZE_BYTES = 4096;
57 
58     private final Predicate<Integer> mFieldPredicate;
59     // General-purpose buffer for reading proto fields and their data
60     private final byte[] mBuffer;
61     // Buffer specifically designated to hold varint values (max 10 bytes in protobuf encoding)
62     private final byte[] mVarIntBuffer = new byte[10];
63 
64     /**
65     * Constructs a ProtoFieldFilter with a predicate that considers depth.
66     *
67     * @param fieldPredicate A predicate returning true if the given fieldNumber should be
68     *                       included in the output.
69     * @param bufferSize The size of the internal buffer used for processing proto fields.
70     *                   Larger buffers may improve performance when processing large
71     *                   length-delimited fields.
72     */
ProtoFieldFilter(Predicate<Integer> fieldPredicate, int bufferSize)73     public ProtoFieldFilter(Predicate<Integer> fieldPredicate, int bufferSize) {
74         this.mFieldPredicate = fieldPredicate;
75         this.mBuffer = new byte[bufferSize];
76     }
77 
78     /**
79     * Constructs a ProtoFieldFilter with a predicate that considers depth and
80     * uses a default buffer size.
81     *
82     * @param fieldPredicate A predicate returning true if the given fieldNumber should be
83     *                       included in the output.
84     */
ProtoFieldFilter(Predicate<Integer> fieldPredicate)85     public ProtoFieldFilter(Predicate<Integer> fieldPredicate) {
86         this(fieldPredicate, BUFFER_SIZE_BYTES);
87     }
88 
89     /**
90      * Reads raw protobuf data from {@code in} and writes only those fields
91      * passing {@code includeFn} to {@code out}. The predicate is given
92      * (fieldNumber, wireType) for each encountered field.
93      *
94      * @param in        The input stream of protobuf data
95      * @param out       The output stream to which we write the filtered protobuf
96      * @throws IOException If reading or writing fails, or if the protobuf data is corrupted
97      */
filter(InputStream in, OutputStream out)98     public void filter(InputStream in, OutputStream out) throws IOException {
99         int tagBytesLength;
100         while ((tagBytesLength = readRawVarint(in)) > 0) {
101             // Parse the varint loaded in mVarIntBuffer, through readRawVarint
102             long tagVal = parseVarint(mVarIntBuffer, tagBytesLength);
103             int fieldNumber = (int) (tagVal >>> ProtoStream.FIELD_ID_SHIFT);
104             int wireType   = (int) (tagVal & ProtoStream.WIRE_TYPE_MASK);
105 
106             if (fieldNumber == 0) {
107                 break;
108             }
109             if (mFieldPredicate.test(fieldNumber)) {
110                 out.write(mVarIntBuffer, 0, tagBytesLength);
111                 copyFieldData(in, out, wireType);
112             } else {
113                 skipFieldData(in, wireType);
114             }
115         }
116     }
117 
118     /**
119      * Reads a varint (up to 10 bytes) from the stream as raw bytes
120      * and returns it in a byte array. If the stream is at EOF, returns null.
121      *
122      * @param in The input stream
123      * @return the size of the varint bytes moved to mVarIntBuffer
124      * @throws IOException If an error occurs, or if we detect a malformed varint
125      */
readRawVarint(InputStream in)126     private int readRawVarint(InputStream in) throws IOException {
127         // We attempt to read 1 byte. If none available => null
128         int b = in.read();
129         if (b < 0) {
130             return 0;
131         }
132         int count = 0;
133         mVarIntBuffer[count++] = (byte) b;
134         // If the continuation bit is set, we continue
135         while ((b & 0x80) != 0) {
136             // read next byte
137             b = in.read();
138             // EOF
139             if (b < 0) {
140                 throw new IOException("Malformed varint: reached EOF mid-varint");
141             }
142             // max 10 bytes for varint 64
143             if (count >= 10) {
144                 throw new IOException("Malformed varint: too many bytes (max 10)");
145             }
146             mVarIntBuffer[count++] = (byte) b;
147         }
148         return count;
149     }
150 
151     /**
152      * Parses a varint from the given raw bytes and returns it as a long.
153      *
154      * @param rawVarint The bytes representing the varint
155      * @param byteLength The number of bytes to read from rawVarint
156      * @return The decoded long value
157      */
parseVarint(byte[] rawVarint, int byteLength)158     private static long parseVarint(byte[] rawVarint, int byteLength) throws IOException {
159         long result = 0;
160         int shift = 0;
161         for (int i = 0; i < byteLength; i++) {
162             result |= ((rawVarint[i] & 0x7F) << shift);
163             shift += 7;
164             if (shift > 63) {
165                 throw new IOException("Malformed varint: exceeds 64 bits");
166             }
167         }
168         return result;
169     }
170 
171     /**
172      * Copies the wire data for a single field from {@code in} to {@code out},
173      * assuming we have already read the field's tag.
174      *
175      * @param in       The input stream (protobuf data)
176      * @param out      The output stream
177      * @param wireType The wire type (0=varint, 1=fixed64, 2=length-delim, 5=fixed32)
178      * @throws IOException if reading/writing fails or data is malformed
179      */
copyFieldData(InputStream in, OutputStream out, int wireType)180     private void copyFieldData(InputStream in, OutputStream out, int wireType)
181             throws IOException {
182         switch (wireType) {
183             case ProtoStream.WIRE_TYPE_VARINT:
184                 copyVarint(in, out);
185                 break;
186             case ProtoStream.WIRE_TYPE_FIXED64:
187                 copyFixed(in, out, 8);
188                 break;
189             case ProtoStream.WIRE_TYPE_LENGTH_DELIMITED:
190                 copyLengthDelimited(in, out);
191                 break;
192             case ProtoStream.WIRE_TYPE_FIXED32:
193                 copyFixed(in, out, 4);
194                 break;
195             // case WIRE_TYPE_START_GROUP:
196                 // Not Supported
197             // case WIRE_TYPE_END_GROUP:
198                 // Not Supported
199             default:
200                 // Error or unrecognized wire type
201                 throw new IOException("Unknown or unsupported wire type: " + wireType);
202         }
203     }
204 
205     /**
206      * Skips the wire data for a single field from {@code in},
207      * assuming the field's tag was already read.
208      */
skipFieldData(InputStream in, int wireType)209     private void skipFieldData(InputStream in, int wireType) throws IOException {
210         switch (wireType) {
211             case ProtoStream.WIRE_TYPE_VARINT:
212                 skipVarint(in);
213                 break;
214             case ProtoStream.WIRE_TYPE_FIXED64:
215                 skipBytes(in, 8);
216                 break;
217             case ProtoStream.WIRE_TYPE_LENGTH_DELIMITED:
218                 skipLengthDelimited(in);
219                 break;
220             case ProtoStream.WIRE_TYPE_FIXED32:
221                 skipBytes(in, 4);
222                 break;
223              // case WIRE_TYPE_START_GROUP:
224                 // Not Supported
225             // case WIRE_TYPE_END_GROUP:
226                 // Not Supported
227             default:
228                 throw new IOException("Unknown or unsupported wire type: " + wireType);
229         }
230     }
231 
232     /** Copies a varint (the field's value) from in to out. */
copyVarint(InputStream in, OutputStream out)233     private static void copyVarint(InputStream in, OutputStream out) throws IOException {
234         while (true) {
235             int b = in.read();
236             if (b < 0) {
237                 throw new IOException("EOF while copying varint");
238             }
239             out.write(b);
240             if ((b & 0x80) == 0) {
241                 break;
242             }
243         }
244     }
245 
246     /**
247      * Copies exactly {@code length} bytes from {@code in} to {@code out}.
248      */
copyFixed(InputStream in, OutputStream out, int length)249     private void copyFixed(InputStream in, OutputStream out,
250                 int length) throws IOException {
251         int toRead = length;
252         while (toRead > 0) {
253             int chunk = Math.min(toRead, mBuffer.length);
254             int readCount = in.read(mBuffer, 0, chunk);
255             if (readCount < 0) {
256                 throw new IOException("EOF while copying fixed" + (length * 8) + " field");
257             }
258             out.write(mBuffer, 0, readCount);
259             toRead -= readCount;
260         }
261     }
262 
263     /** Copies a length-delimited field */
copyLengthDelimited(InputStream in, OutputStream out)264     private void copyLengthDelimited(InputStream in,
265                     OutputStream out) throws IOException {
266         // 1) read length varint (and copy)
267         int lengthVarintLength = readRawVarint(in);
268         if (lengthVarintLength <= 0) {
269             throw new IOException("EOF reading length for length-delimited field");
270         }
271         out.write(mVarIntBuffer, 0, lengthVarintLength);
272 
273         long lengthVal = parseVarint(mVarIntBuffer, lengthVarintLength);
274         if (lengthVal < 0 || lengthVal > Integer.MAX_VALUE) {
275             throw new IOException("Invalid length for length-delimited field: " + lengthVal);
276         }
277 
278         // 2) copy that many bytes
279         copyFixed(in, out, (int) lengthVal);
280     }
281 
282     /** Skips a varint in the input (does not write anything). */
skipVarint(InputStream in)283     private static void skipVarint(InputStream in) throws IOException {
284         int bytesSkipped = 0;
285         while (true) {
286             int b = in.read();
287             if (b < 0) {
288                 throw new IOException("EOF while skipping varint");
289             }
290             if ((b & 0x80) == 0) {
291                 break;
292             }
293             bytesSkipped++;
294             if (bytesSkipped > 10) {
295                 throw new IOException("Malformed varint: exceeds maximum length of 10 bytes");
296             }
297         }
298     }
299 
300     /** Skips exactly n bytes. */
skipBytes(InputStream in, long n)301     private void skipBytes(InputStream in, long n) throws IOException {
302         long skipped = in.skip(n);
303         // If skip fails, fallback to reading the remaining bytes
304         if (skipped < n) {
305             long bytesRemaining = n - skipped;
306 
307             while (bytesRemaining > 0) {
308                 int bytesToRead = (int) Math.min(bytesRemaining, mBuffer.length);
309                 int bytesRead = in.read(mBuffer, 0, bytesToRead);
310                 if (bytesRemaining < 0) {
311                     throw new IOException("EOF while skipping bytes");
312                 }
313                 bytesRemaining -= bytesRead;
314             }
315         }
316     }
317 
318     /**
319      * Skips a length-delimited field.
320      * 1) read the length as varint,
321      * 2) skip that many bytes
322      */
skipLengthDelimited(InputStream in)323     private void skipLengthDelimited(InputStream in) throws IOException {
324         int lengthVarintLength = readRawVarint(in);
325         if (lengthVarintLength <= 0) {
326             throw new IOException("EOF reading length for length-delimited field");
327         }
328         long lengthVal = parseVarint(mVarIntBuffer, lengthVarintLength);
329         if (lengthVal < 0 || lengthVal > Integer.MAX_VALUE) {
330             throw new IOException("Invalid length to skip: " + lengthVal);
331         }
332         skipBytes(in, lengthVal);
333     }
334 
335 }
336