• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/protozero/filtering/message_filter.h"
18 
19 #include "perfetto/base/logging.h"
20 #include "perfetto/protozero/proto_utils.h"
21 #include "src/protozero/filtering/string_filter.h"
22 
23 namespace protozero {
24 
25 namespace {
26 
27 // Inline helpers to append proto fields in output. They are the equivalent of
28 // the protozero::Message::AppendXXX() fields but don't require building and
29 // maintaining a full protozero::Message object or dealing with scattered
30 // output slices.
31 // All these functions assume there is enough space in the output buffer, which
32 // should be always the case assuming that we don't end up generating more
33 // output than input.
34 
AppendVarInt(uint32_t field_id,uint64_t value,uint8_t ** out)35 inline void AppendVarInt(uint32_t field_id, uint64_t value, uint8_t** out) {
36   *out = proto_utils::WriteVarInt(proto_utils::MakeTagVarInt(field_id), *out);
37   *out = proto_utils::WriteVarInt(value, *out);
38 }
39 
40 // For fixed32 / fixed64.
41 template <typename INT_T /* uint32_t | uint64_t*/>
AppendFixed(uint32_t field_id,INT_T value,uint8_t ** out)42 inline void AppendFixed(uint32_t field_id, INT_T value, uint8_t** out) {
43   *out = proto_utils::WriteVarInt(proto_utils::MakeTagFixed<INT_T>(field_id),
44                                   *out);
45   memcpy(*out, &value, sizeof(value));
46   *out += sizeof(value);
47 }
48 
49 // For length-delimited (string, bytes) fields. Note: this function appends only
50 // the proto preamble and the varint field that states the length of the payload
51 // not the payload itself.
52 // In the case of submessages, the caller needs to re-write the length at the
53 // end in the in the returned memory area.
54 // The problem here is that, because of filtering, the length of a submessage
55 // might be < original length (the original length is still an upper-bound).
56 // Returns a pair with: (1) the pointer where the final length should be written
57 // into, (2) the length of the size field.
58 // The caller must write a redundant varint to match the original size (i.e.
59 // needs to use WriteRedundantVarInt()).
AppendLenDelim(uint32_t field_id,uint32_t len,uint8_t ** out)60 inline std::pair<uint8_t*, uint32_t> AppendLenDelim(uint32_t field_id,
61                                                     uint32_t len,
62                                                     uint8_t** out) {
63   *out = proto_utils::WriteVarInt(proto_utils::MakeTagLengthDelimited(field_id),
64                                   *out);
65   uint8_t* size_field_start = *out;
66   *out = proto_utils::WriteVarInt(len, *out);
67   const size_t size_field_len = static_cast<size_t>(*out - size_field_start);
68   return std::make_pair(size_field_start, size_field_len);
69 }
70 }  // namespace
71 
MessageFilter()72 MessageFilter::MessageFilter() {
73   // Push a state on the stack for the implicit root message.
74   stack_.emplace_back();
75 }
76 
MessageFilter(const MessageFilter & other)77 MessageFilter::MessageFilter(const MessageFilter& other)
78     : root_msg_index_(other.root_msg_index_),
79       filter_(other.filter_),
80       string_filter_(other.string_filter_) {
81   stack_.emplace_back();
82 }
83 
84 MessageFilter::~MessageFilter() = default;
85 
LoadFilterBytecode(const void * filter_data,size_t len)86 bool MessageFilter::LoadFilterBytecode(const void* filter_data, size_t len) {
87   return filter_.Load(filter_data, len);
88 }
89 
SetFilterRoot(const uint32_t * field_ids,size_t num_fields)90 bool MessageFilter::SetFilterRoot(const uint32_t* field_ids,
91                                   size_t num_fields) {
92   uint32_t root_msg_idx = 0;
93   for (const uint32_t* it = field_ids; it < field_ids + num_fields; ++it) {
94     uint32_t field_id = *it;
95     auto res = filter_.Query(root_msg_idx, field_id);
96     if (!res.allowed || !res.nested_msg_field())
97       return false;
98     root_msg_idx = res.nested_msg_index;
99   }
100   root_msg_index_ = root_msg_idx;
101   return true;
102 }
103 
FilterMessageFragments(const InputSlice * slices,size_t num_slices)104 MessageFilter::FilteredMessage MessageFilter::FilterMessageFragments(
105     const InputSlice* slices,
106     size_t num_slices) {
107   // First compute the upper bound for the output. The filtered message cannot
108   // be > the original message.
109   uint32_t total_len = 0;
110   for (size_t i = 0; i < num_slices; ++i)
111     total_len += slices[i].len;
112   out_buf_.reset(new uint8_t[total_len]);
113   out_ = out_buf_.get();
114   out_end_ = out_ + total_len;
115 
116   // Reset the parser state.
117   tokenizer_ = MessageTokenizer();
118   error_ = false;
119   stack_.clear();
120   stack_.resize(2);
121   // stack_[0] is a sentinel and should never be hit in nominal cases. If we
122   // end up there we will just keep consuming the input stream and detecting
123   // at the end, without hurting the fastpath.
124   stack_[0].in_bytes_limit = UINT32_MAX;
125   stack_[0].eat_next_bytes = UINT32_MAX;
126   // stack_[1] is the actual root message.
127   stack_[1].in_bytes_limit = total_len;
128   stack_[1].msg_index = root_msg_index_;
129 
130   // Process the input data and write the output.
131   for (size_t slice_idx = 0; slice_idx < num_slices; ++slice_idx) {
132     const InputSlice& slice = slices[slice_idx];
133     const uint8_t* data = static_cast<const uint8_t*>(slice.data);
134     for (size_t i = 0; i < slice.len; ++i)
135       FilterOneByte(data[i]);
136   }
137 
138   // Construct the output object.
139   PERFETTO_CHECK(out_ >= out_buf_.get() && out_ <= out_end_);
140   auto used_size = static_cast<size_t>(out_ - out_buf_.get());
141   FilteredMessage res{std::move(out_buf_), used_size};
142   res.error = error_;
143   if (stack_.size() != 1 || !tokenizer_.idle() ||
144       stack_[0].in_bytes != total_len) {
145     res.error = true;
146   }
147   return res;
148 }
149 
FilterOneByte(uint8_t octet)150 void MessageFilter::FilterOneByte(uint8_t octet) {
151   PERFETTO_DCHECK(!stack_.empty());
152 
153   auto* state = &stack_.back();
154   StackState next_state{};
155   bool push_next_state = false;
156 
157   if (state->eat_next_bytes > 0) {
158     // This is the case where the previous tokenizer_.Push() call returned a
159     // length delimited message which is NOT a submessage (a string or a bytes
160     // field). We just want to consume it, and pass it through/filter strings
161     // if the field was allowed.
162     --state->eat_next_bytes;
163     if (state->action == StackState::kPassthrough) {
164       *(out_++) = octet;
165     } else if (state->action == StackState::kFilterString) {
166       *(out_++) = octet;
167       if (state->eat_next_bytes == 0) {
168         string_filter_.MaybeFilter(
169             reinterpret_cast<char*>(state->filter_string_ptr),
170             static_cast<size_t>(out_ - state->filter_string_ptr));
171       }
172     }
173   } else {
174     MessageTokenizer::Token token = tokenizer_.Push(octet);
175     // |token| will not be valid() in most cases and this is WAI. When pushing
176     // a varint field, only the last byte yields a token, all the other bytes
177     // return an invalid token, they just update the internal tokenizer state.
178     if (token.valid()) {
179       auto filter = filter_.Query(state->msg_index, token.field_id);
180       switch (token.type) {
181         case proto_utils::ProtoWireType::kVarInt:
182           if (filter.allowed && filter.simple_field())
183             AppendVarInt(token.field_id, token.value, &out_);
184           break;
185         case proto_utils::ProtoWireType::kFixed32:
186           if (filter.allowed && filter.simple_field())
187             AppendFixed(token.field_id, static_cast<uint32_t>(token.value),
188                         &out_);
189           break;
190         case proto_utils::ProtoWireType::kFixed64:
191           if (filter.allowed && filter.simple_field())
192             AppendFixed(token.field_id, static_cast<uint64_t>(token.value),
193                         &out_);
194           break;
195         case proto_utils::ProtoWireType::kLengthDelimited:
196           // Here we have two cases:
197           // A. A simple string/bytes field: we just want to consume the next
198           //    bytes (the string payload), optionally passing them through in
199           //    output if the field is allowed.
200           // B. This is a nested submessage. In this case we want to recurse and
201           //    push a new state on the stack.
202           // Note that we can't tell the difference between a
203           // "non-allowed string" and a "non-allowed submessage". But it doesn't
204           // matter because in both cases we just want to skip the next N bytes.
205           const auto submessage_len = static_cast<uint32_t>(token.value);
206           auto in_bytes_left = state->in_bytes_limit - state->in_bytes - 1;
207           if (PERFETTO_UNLIKELY(submessage_len > in_bytes_left)) {
208             // This is a malicious / malformed string/bytes/submessage that
209             // claims to be larger than the outer message that contains it.
210             return SetUnrecoverableErrorState();
211           }
212 
213           if (filter.allowed && filter.nested_msg_field() &&
214               submessage_len > 0) {
215             // submessage_len == 0 is the edge case of a message with a 0-len
216             // (but present) submessage. In this case, if allowed, we don't want
217             // to push any further state (doing so would desync the FSM) but we
218             // still want to emit it.
219             // At this point |submessage_len| is only an upper bound. The
220             // final message written in output can be <= the one in input,
221             // only some of its fields might be allowed (also remember that
222             // this class implicitly removes redundancy varint encoding of
223             // len-delimited field lengths). The final length varint (the
224             // return value of AppendLenDelim()) will be filled when popping
225             // from |stack_|.
226             auto size_field =
227                 AppendLenDelim(token.field_id, submessage_len, &out_);
228             push_next_state = true;
229             next_state.field_id = token.field_id;
230             next_state.msg_index = filter.nested_msg_index;
231             next_state.in_bytes_limit = submessage_len;
232             next_state.size_field = size_field.first;
233             next_state.size_field_len = size_field.second;
234             next_state.out_bytes_written_at_start = out_written();
235           } else {
236             // A string or bytes field, or a 0 length submessage.
237             state->eat_next_bytes = submessage_len;
238             if (filter.allowed && filter.filter_string_field()) {
239               state->action = StackState::kFilterString;
240               AppendLenDelim(token.field_id, submessage_len, &out_);
241               state->filter_string_ptr = out_;
242             } else if (filter.allowed) {
243               state->action = StackState::kPassthrough;
244               AppendLenDelim(token.field_id, submessage_len, &out_);
245             } else {
246               state->action = StackState::kDrop;
247             }
248           }
249           break;
250       }  // switch(type)
251 
252       if (PERFETTO_UNLIKELY(track_field_usage_)) {
253         IncrementCurrentFieldUsage(token.field_id, filter.allowed);
254       }
255     }  // if (token.valid)
256   }    // if (eat_next_bytes == 0)
257 
258   ++state->in_bytes;
259   while (state->in_bytes >= state->in_bytes_limit) {
260     PERFETTO_DCHECK(state->in_bytes == state->in_bytes_limit);
261     push_next_state = false;
262 
263     // We can't possibly write more than we read.
264     const uint32_t msg_bytes_written = static_cast<uint32_t>(
265         out_written() - state->out_bytes_written_at_start);
266     PERFETTO_DCHECK(msg_bytes_written <= state->in_bytes_limit);
267 
268     // Backfill the length field of the
269     proto_utils::WriteRedundantVarInt(msg_bytes_written, state->size_field,
270                                       state->size_field_len);
271 
272     const uint32_t in_bytes_processes_for_last_msg = state->in_bytes;
273     stack_.pop_back();
274     PERFETTO_CHECK(!stack_.empty());
275     state = &stack_.back();
276     state->in_bytes += in_bytes_processes_for_last_msg;
277     if (PERFETTO_UNLIKELY(!tokenizer_.idle())) {
278       // If we hit this case, it means that we got to the end of a submessage
279       // while decoding a field. We can't recover from this and we don't want to
280       // propagate a broken sub-message.
281       return SetUnrecoverableErrorState();
282     }
283   }
284 
285   if (push_next_state) {
286     PERFETTO_DCHECK(tokenizer_.idle());
287     stack_.emplace_back(std::move(next_state));
288     state = &stack_.back();
289   }
290 }
291 
SetUnrecoverableErrorState()292 void MessageFilter::SetUnrecoverableErrorState() {
293   error_ = true;
294   stack_.clear();
295   stack_.resize(1);
296   auto& state = stack_[0];
297   state.eat_next_bytes = UINT32_MAX;
298   state.in_bytes_limit = UINT32_MAX;
299   state.action = StackState::kDrop;
300   out_ = out_buf_.get();  // Reset the write pointer.
301 }
302 
IncrementCurrentFieldUsage(uint32_t field_id,bool allowed)303 void MessageFilter::IncrementCurrentFieldUsage(uint32_t field_id,
304                                                bool allowed) {
305   // Slowpath. Used mainly in offline tools and tests to workout used fields in
306   // a proto.
307   PERFETTO_DCHECK(track_field_usage_);
308 
309   // Field path contains a concatenation of varints, one for each nesting level.
310   // e.g. y in message Root { Sub x = 2; }; message Sub { SubSub y = 7; }
311   // is encoded as [varint(2) + varint(7)].
312   // We use varint to take the most out of SSO (small string opt). In most cases
313   // the path will fit in the on-stack 22 bytes, requiring no heap.
314   std::string field_path;
315 
316   auto append_field_id = [&field_path](uint32_t id) {
317     uint8_t buf[10];
318     uint8_t* end = proto_utils::WriteVarInt(id, buf);
319     field_path.append(reinterpret_cast<char*>(buf),
320                       static_cast<size_t>(end - buf));
321   };
322 
323   // Append all the ancestors IDs from the state stack.
324   // The first entry of the stack has always ID 0 and we skip it (we don't know
325   // the ID of the root message itself).
326   PERFETTO_DCHECK(stack_.size() >= 2 && stack_[1].field_id == 0);
327   for (size_t i = 2; i < stack_.size(); ++i)
328     append_field_id(stack_[i].field_id);
329   // Append the id of the field in the current message.
330   append_field_id(field_id);
331   field_usage_[field_path] += allowed ? 1 : -1;
332 }
333 
334 }  // namespace protozero
335