1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chrome/browser/extensions/api/declarative_webrequest/webrequest_condition_attribute.h"
6
7 #include <algorithm>
8
9 #include "base/lazy_instance.h"
10 #include "base/logging.h"
11 #include "base/strings/string_util.h"
12 #include "base/strings/stringprintf.h"
13 #include "base/values.h"
14 #include "chrome/browser/extensions/api/declarative/deduping_factory.h"
15 #include "chrome/browser/extensions/api/declarative_webrequest/request_stage.h"
16 #include "chrome/browser/extensions/api/declarative_webrequest/webrequest_condition.h"
17 #include "chrome/browser/extensions/api/declarative_webrequest/webrequest_constants.h"
18 #include "chrome/browser/extensions/api/web_request/web_request_api_helpers.h"
19 #include "content/public/browser/resource_request_info.h"
20 #include "extensions/common/error_utils.h"
21 #include "net/base/net_errors.h"
22 #include "net/base/registry_controlled_domains/registry_controlled_domain.h"
23 #include "net/base/static_cookie_policy.h"
24 #include "net/http/http_request_headers.h"
25 #include "net/http/http_util.h"
26 #include "net/url_request/url_request.h"
27
28 using base::CaseInsensitiveCompareASCII;
29 using base::DictionaryValue;
30 using base::ListValue;
31 using base::StringValue;
32 using base::Value;
33
34 namespace helpers = extension_web_request_api_helpers;
35 namespace keys = extensions::declarative_webrequest_constants;
36
37 namespace extensions {
38
39 namespace {
40 // Error messages.
41 const char kInvalidValue[] = "Condition '*' has an invalid value";
42
43 struct WebRequestConditionAttributeFactory {
44 DedupingFactory<WebRequestConditionAttribute> factory;
45
WebRequestConditionAttributeFactoryextensions::__anon2a2920e40111::WebRequestConditionAttributeFactory46 WebRequestConditionAttributeFactory() : factory(5) {
47 factory.RegisterFactoryMethod(
48 keys::kResourceTypeKey,
49 DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
50 &WebRequestConditionAttributeResourceType::Create);
51
52 factory.RegisterFactoryMethod(
53 keys::kContentTypeKey,
54 DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
55 &WebRequestConditionAttributeContentType::Create);
56 factory.RegisterFactoryMethod(
57 keys::kExcludeContentTypeKey,
58 DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
59 &WebRequestConditionAttributeContentType::Create);
60
61 factory.RegisterFactoryMethod(
62 keys::kRequestHeadersKey,
63 DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
64 &WebRequestConditionAttributeRequestHeaders::Create);
65 factory.RegisterFactoryMethod(
66 keys::kExcludeRequestHeadersKey,
67 DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
68 &WebRequestConditionAttributeRequestHeaders::Create);
69
70 factory.RegisterFactoryMethod(
71 keys::kResponseHeadersKey,
72 DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
73 &WebRequestConditionAttributeResponseHeaders::Create);
74 factory.RegisterFactoryMethod(
75 keys::kExcludeResponseHeadersKey,
76 DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
77 &WebRequestConditionAttributeResponseHeaders::Create);
78
79 factory.RegisterFactoryMethod(
80 keys::kThirdPartyKey,
81 DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
82 &WebRequestConditionAttributeThirdParty::Create);
83
84 factory.RegisterFactoryMethod(
85 keys::kStagesKey,
86 DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
87 &WebRequestConditionAttributeStages::Create);
88 }
89 };
90
91 base::LazyInstance<WebRequestConditionAttributeFactory>::Leaky
92 g_web_request_condition_attribute_factory = LAZY_INSTANCE_INITIALIZER;
93
94 } // namespace
95
96 //
97 // WebRequestConditionAttribute
98 //
99
WebRequestConditionAttribute()100 WebRequestConditionAttribute::WebRequestConditionAttribute() {}
101
~WebRequestConditionAttribute()102 WebRequestConditionAttribute::~WebRequestConditionAttribute() {}
103
Equals(const WebRequestConditionAttribute * other) const104 bool WebRequestConditionAttribute::Equals(
105 const WebRequestConditionAttribute* other) const {
106 return GetType() == other->GetType();
107 }
108
109 // static
110 scoped_refptr<const WebRequestConditionAttribute>
Create(const std::string & name,const base::Value * value,std::string * error)111 WebRequestConditionAttribute::Create(
112 const std::string& name,
113 const base::Value* value,
114 std::string* error) {
115 CHECK(value != NULL && error != NULL);
116 bool bad_message = false;
117 return g_web_request_condition_attribute_factory.Get().factory.Instantiate(
118 name, value, error, &bad_message);
119 }
120
121 //
122 // WebRequestConditionAttributeResourceType
123 //
124
125 WebRequestConditionAttributeResourceType::
WebRequestConditionAttributeResourceType(const std::vector<ResourceType::Type> & types)126 WebRequestConditionAttributeResourceType(
127 const std::vector<ResourceType::Type>& types)
128 : types_(types) {}
129
130 WebRequestConditionAttributeResourceType::
~WebRequestConditionAttributeResourceType()131 ~WebRequestConditionAttributeResourceType() {}
132
133 // static
134 scoped_refptr<const WebRequestConditionAttribute>
Create(const std::string & instance_type,const base::Value * value,std::string * error,bool * bad_message)135 WebRequestConditionAttributeResourceType::Create(
136 const std::string& instance_type,
137 const base::Value* value,
138 std::string* error,
139 bool* bad_message) {
140 DCHECK(instance_type == keys::kResourceTypeKey);
141 const base::ListValue* value_as_list = NULL;
142 if (!value->GetAsList(&value_as_list)) {
143 *error = ErrorUtils::FormatErrorMessage(kInvalidValue,
144 keys::kResourceTypeKey);
145 return scoped_refptr<const WebRequestConditionAttribute>(NULL);
146 }
147
148 size_t number_types = value_as_list->GetSize();
149
150 std::vector<ResourceType::Type> passed_types;
151 passed_types.reserve(number_types);
152 for (size_t i = 0; i < number_types; ++i) {
153 std::string resource_type_string;
154 ResourceType::Type type = ResourceType::LAST_TYPE;
155 if (!value_as_list->GetString(i, &resource_type_string) ||
156 !helpers::ParseResourceType(resource_type_string, &type)) {
157 *error = ErrorUtils::FormatErrorMessage(kInvalidValue,
158 keys::kResourceTypeKey);
159 return scoped_refptr<const WebRequestConditionAttribute>(NULL);
160 }
161 passed_types.push_back(type);
162 }
163
164 return scoped_refptr<const WebRequestConditionAttribute>(
165 new WebRequestConditionAttributeResourceType(passed_types));
166 }
167
GetStages() const168 int WebRequestConditionAttributeResourceType::GetStages() const {
169 return ON_BEFORE_REQUEST | ON_BEFORE_SEND_HEADERS | ON_SEND_HEADERS |
170 ON_HEADERS_RECEIVED | ON_AUTH_REQUIRED | ON_BEFORE_REDIRECT |
171 ON_RESPONSE_STARTED | ON_COMPLETED | ON_ERROR;
172 }
173
IsFulfilled(const WebRequestData & request_data) const174 bool WebRequestConditionAttributeResourceType::IsFulfilled(
175 const WebRequestData& request_data) const {
176 if (!(request_data.stage & GetStages()))
177 return false;
178 const content::ResourceRequestInfo* info =
179 content::ResourceRequestInfo::ForRequest(request_data.request);
180 if (!info)
181 return false;
182 return std::find(types_.begin(), types_.end(), info->GetResourceType()) !=
183 types_.end();
184 }
185
186 WebRequestConditionAttribute::Type
GetType() const187 WebRequestConditionAttributeResourceType::GetType() const {
188 return CONDITION_RESOURCE_TYPE;
189 }
190
GetName() const191 std::string WebRequestConditionAttributeResourceType::GetName() const {
192 return keys::kResourceTypeKey;
193 }
194
Equals(const WebRequestConditionAttribute * other) const195 bool WebRequestConditionAttributeResourceType::Equals(
196 const WebRequestConditionAttribute* other) const {
197 if (!WebRequestConditionAttribute::Equals(other))
198 return false;
199 const WebRequestConditionAttributeResourceType* casted_other =
200 static_cast<const WebRequestConditionAttributeResourceType*>(other);
201 return types_ == casted_other->types_;
202 }
203
204 //
205 // WebRequestConditionAttributeContentType
206 //
207
208 WebRequestConditionAttributeContentType::
WebRequestConditionAttributeContentType(const std::vector<std::string> & content_types,bool inclusive)209 WebRequestConditionAttributeContentType(
210 const std::vector<std::string>& content_types,
211 bool inclusive)
212 : content_types_(content_types),
213 inclusive_(inclusive) {}
214
215 WebRequestConditionAttributeContentType::
~WebRequestConditionAttributeContentType()216 ~WebRequestConditionAttributeContentType() {}
217
218 // static
219 scoped_refptr<const WebRequestConditionAttribute>
Create(const std::string & name,const base::Value * value,std::string * error,bool * bad_message)220 WebRequestConditionAttributeContentType::Create(
221 const std::string& name,
222 const base::Value* value,
223 std::string* error,
224 bool* bad_message) {
225 DCHECK(name == keys::kContentTypeKey || name == keys::kExcludeContentTypeKey);
226
227 const base::ListValue* value_as_list = NULL;
228 if (!value->GetAsList(&value_as_list)) {
229 *error = ErrorUtils::FormatErrorMessage(kInvalidValue, name);
230 return scoped_refptr<const WebRequestConditionAttribute>(NULL);
231 }
232 std::vector<std::string> content_types;
233 for (base::ListValue::const_iterator it = value_as_list->begin();
234 it != value_as_list->end(); ++it) {
235 std::string content_type;
236 if (!(*it)->GetAsString(&content_type)) {
237 *error = ErrorUtils::FormatErrorMessage(kInvalidValue, name);
238 return scoped_refptr<const WebRequestConditionAttribute>(NULL);
239 }
240 content_types.push_back(content_type);
241 }
242
243 return scoped_refptr<const WebRequestConditionAttribute>(
244 new WebRequestConditionAttributeContentType(
245 content_types, name == keys::kContentTypeKey));
246 }
247
GetStages() const248 int WebRequestConditionAttributeContentType::GetStages() const {
249 return ON_HEADERS_RECEIVED;
250 }
251
IsFulfilled(const WebRequestData & request_data) const252 bool WebRequestConditionAttributeContentType::IsFulfilled(
253 const WebRequestData& request_data) const {
254 if (!(request_data.stage & GetStages()))
255 return false;
256 std::string content_type;
257 request_data.original_response_headers->GetNormalizedHeader(
258 net::HttpRequestHeaders::kContentType, &content_type);
259 std::string mime_type;
260 std::string charset;
261 bool had_charset = false;
262 net::HttpUtil::ParseContentType(
263 content_type, &mime_type, &charset, &had_charset, NULL);
264
265 if (inclusive_) {
266 return std::find(content_types_.begin(), content_types_.end(),
267 mime_type) != content_types_.end();
268 } else {
269 return std::find(content_types_.begin(), content_types_.end(),
270 mime_type) == content_types_.end();
271 }
272 }
273
274 WebRequestConditionAttribute::Type
GetType() const275 WebRequestConditionAttributeContentType::GetType() const {
276 return CONDITION_CONTENT_TYPE;
277 }
278
GetName() const279 std::string WebRequestConditionAttributeContentType::GetName() const {
280 return (inclusive_ ? keys::kContentTypeKey : keys::kExcludeContentTypeKey);
281 }
282
Equals(const WebRequestConditionAttribute * other) const283 bool WebRequestConditionAttributeContentType::Equals(
284 const WebRequestConditionAttribute* other) const {
285 if (!WebRequestConditionAttribute::Equals(other))
286 return false;
287 const WebRequestConditionAttributeContentType* casted_other =
288 static_cast<const WebRequestConditionAttributeContentType*>(other);
289 return content_types_ == casted_other->content_types_ &&
290 inclusive_ == casted_other->inclusive_;
291 }
292
293 // Manages a set of tests to be applied to name-value pairs representing
294 // headers. This is a helper class to header-related condition attributes.
295 // It contains a set of test groups. A name-value pair satisfies the whole
296 // set of test groups iff it passes at least one test group.
297 class HeaderMatcher {
298 public:
299 ~HeaderMatcher();
300
301 // Creates an instance based on a list |tests| of test groups, encoded as
302 // dictionaries of the type declarativeWebRequest.HeaderFilter (see
303 // declarative_web_request.json).
304 static scoped_ptr<const HeaderMatcher> Create(const base::ListValue* tests);
305
306 // Does |this| match the header "|name|: |value|"?
307 bool TestNameValue(const std::string& name, const std::string& value) const;
308
309 private:
310 // Represents a single string-matching test.
311 class StringMatchTest {
312 public:
313 enum MatchType { kPrefix, kSuffix, kEquals, kContains };
314
315 // |data| is the pattern to be matched in the position given by |type|.
316 // Note that |data| must point to a StringValue object.
317 static scoped_ptr<StringMatchTest> Create(const base::Value* data,
318 MatchType type,
319 bool case_sensitive);
320 ~StringMatchTest();
321
322 // Does |str| pass |this| StringMatchTest?
323 bool Matches(const std::string& str) const;
324
325 private:
326 StringMatchTest(const std::string& data,
327 MatchType type,
328 bool case_sensitive);
329
330 const std::string data_;
331 const MatchType type_;
332 const bool case_sensitive_;
333 DISALLOW_COPY_AND_ASSIGN(StringMatchTest);
334 };
335
336 // Represents a test group -- a set of string matching tests to be applied to
337 // both the header name and value.
338 class HeaderMatchTest {
339 public:
340 ~HeaderMatchTest();
341
342 // Gets the test group description in |tests| and creates the corresponding
343 // HeaderMatchTest. On failure returns NULL.
344 static scoped_ptr<const HeaderMatchTest> Create(
345 const base::DictionaryValue* tests);
346
347 // Does the header "|name|: |value|" match all tests in |this|?
348 bool Matches(const std::string& name, const std::string& value) const;
349
350 private:
351 // Takes ownership of the content of both |name_match| and |value_match|.
352 HeaderMatchTest(ScopedVector<const StringMatchTest>* name_match,
353 ScopedVector<const StringMatchTest>* value_match);
354
355 // Tests to be passed by a header's name.
356 const ScopedVector<const StringMatchTest> name_match_;
357 // Tests to be passed by a header's value.
358 const ScopedVector<const StringMatchTest> value_match_;
359 DISALLOW_COPY_AND_ASSIGN(HeaderMatchTest);
360 };
361
362 explicit HeaderMatcher(ScopedVector<const HeaderMatchTest>* tests);
363
364 const ScopedVector<const HeaderMatchTest> tests_;
365
366 DISALLOW_COPY_AND_ASSIGN(HeaderMatcher);
367 };
368
369 // HeaderMatcher implementation.
370
~HeaderMatcher()371 HeaderMatcher::~HeaderMatcher() {}
372
373 // static
Create(const base::ListValue * tests)374 scoped_ptr<const HeaderMatcher> HeaderMatcher::Create(
375 const base::ListValue* tests) {
376 ScopedVector<const HeaderMatchTest> header_tests;
377 for (base::ListValue::const_iterator it = tests->begin();
378 it != tests->end(); ++it) {
379 const base::DictionaryValue* tests = NULL;
380 if (!(*it)->GetAsDictionary(&tests))
381 return scoped_ptr<const HeaderMatcher>();
382
383 scoped_ptr<const HeaderMatchTest> header_test(
384 HeaderMatchTest::Create(tests));
385 if (header_test.get() == NULL)
386 return scoped_ptr<const HeaderMatcher>();
387 header_tests.push_back(header_test.release());
388 }
389
390 return scoped_ptr<const HeaderMatcher>(new HeaderMatcher(&header_tests));
391 }
392
TestNameValue(const std::string & name,const std::string & value) const393 bool HeaderMatcher::TestNameValue(const std::string& name,
394 const std::string& value) const {
395 for (size_t i = 0; i < tests_.size(); ++i) {
396 if (tests_[i]->Matches(name, value))
397 return true;
398 }
399 return false;
400 }
401
HeaderMatcher(ScopedVector<const HeaderMatchTest> * tests)402 HeaderMatcher::HeaderMatcher(ScopedVector<const HeaderMatchTest>* tests)
403 : tests_(tests->Pass()) {}
404
405 // HeaderMatcher::StringMatchTest implementation.
406
407 // static
408 scoped_ptr<HeaderMatcher::StringMatchTest>
Create(const base::Value * data,MatchType type,bool case_sensitive)409 HeaderMatcher::StringMatchTest::Create(const base::Value* data,
410 MatchType type,
411 bool case_sensitive) {
412 std::string str;
413 CHECK(data->GetAsString(&str));
414 return scoped_ptr<StringMatchTest>(
415 new StringMatchTest(str, type, case_sensitive));
416 }
417
~StringMatchTest()418 HeaderMatcher::StringMatchTest::~StringMatchTest() {}
419
Matches(const std::string & str) const420 bool HeaderMatcher::StringMatchTest::Matches(
421 const std::string& str) const {
422 switch (type_) {
423 case kPrefix:
424 return StartsWithASCII(str, data_, case_sensitive_);
425 case kSuffix:
426 return EndsWith(str, data_, case_sensitive_);
427 case kEquals:
428 return str.size() == data_.size() &&
429 StartsWithASCII(str, data_, case_sensitive_);
430 case kContains:
431 if (!case_sensitive_) {
432 return std::search(str.begin(), str.end(), data_.begin(), data_.end(),
433 CaseInsensitiveCompareASCII<char>()) != str.end();
434 } else {
435 return str.find(data_) != std::string::npos;
436 }
437 }
438 // We never get past the "switch", but the compiler worries about no return.
439 NOTREACHED();
440 return false;
441 }
442
StringMatchTest(const std::string & data,MatchType type,bool case_sensitive)443 HeaderMatcher::StringMatchTest::StringMatchTest(const std::string& data,
444 MatchType type,
445 bool case_sensitive)
446 : data_(data),
447 type_(type),
448 case_sensitive_(case_sensitive) {}
449
450 // HeaderMatcher::HeaderMatchTest implementation.
451
HeaderMatchTest(ScopedVector<const StringMatchTest> * name_match,ScopedVector<const StringMatchTest> * value_match)452 HeaderMatcher::HeaderMatchTest::HeaderMatchTest(
453 ScopedVector<const StringMatchTest>* name_match,
454 ScopedVector<const StringMatchTest>* value_match)
455 : name_match_(name_match->Pass()),
456 value_match_(value_match->Pass()) {}
457
~HeaderMatchTest()458 HeaderMatcher::HeaderMatchTest::~HeaderMatchTest() {}
459
460 // static
461 scoped_ptr<const HeaderMatcher::HeaderMatchTest>
Create(const base::DictionaryValue * tests)462 HeaderMatcher::HeaderMatchTest::Create(const base::DictionaryValue* tests) {
463 ScopedVector<const StringMatchTest> name_match;
464 ScopedVector<const StringMatchTest> value_match;
465
466 for (base::DictionaryValue::Iterator it(*tests);
467 !it.IsAtEnd(); it.Advance()) {
468 bool is_name = false; // Is this test for header name?
469 StringMatchTest::MatchType match_type;
470 if (it.key() == keys::kNamePrefixKey) {
471 is_name = true;
472 match_type = StringMatchTest::kPrefix;
473 } else if (it.key() == keys::kNameSuffixKey) {
474 is_name = true;
475 match_type = StringMatchTest::kSuffix;
476 } else if (it.key() == keys::kNameContainsKey) {
477 is_name = true;
478 match_type = StringMatchTest::kContains;
479 } else if (it.key() == keys::kNameEqualsKey) {
480 is_name = true;
481 match_type = StringMatchTest::kEquals;
482 } else if (it.key() == keys::kValuePrefixKey) {
483 match_type = StringMatchTest::kPrefix;
484 } else if (it.key() == keys::kValueSuffixKey) {
485 match_type = StringMatchTest::kSuffix;
486 } else if (it.key() == keys::kValueContainsKey) {
487 match_type = StringMatchTest::kContains;
488 } else if (it.key() == keys::kValueEqualsKey) {
489 match_type = StringMatchTest::kEquals;
490 } else {
491 NOTREACHED(); // JSON schema type checking should prevent this.
492 return scoped_ptr<const HeaderMatchTest>();
493 }
494 const base::Value* content = &it.value();
495
496 ScopedVector<const StringMatchTest>* tests =
497 is_name ? &name_match : &value_match;
498 switch (content->GetType()) {
499 case base::Value::TYPE_LIST: {
500 const base::ListValue* list = NULL;
501 CHECK(content->GetAsList(&list));
502 for (base::ListValue::const_iterator it = list->begin();
503 it != list->end(); ++it) {
504 tests->push_back(
505 StringMatchTest::Create(*it, match_type, !is_name).release());
506 }
507 break;
508 }
509 case base::Value::TYPE_STRING: {
510 tests->push_back(
511 StringMatchTest::Create(content, match_type, !is_name).release());
512 break;
513 }
514 default: {
515 NOTREACHED(); // JSON schema type checking should prevent this.
516 return scoped_ptr<const HeaderMatchTest>();
517 }
518 }
519 }
520
521 return scoped_ptr<const HeaderMatchTest>(
522 new HeaderMatchTest(&name_match, &value_match));
523 }
524
Matches(const std::string & name,const std::string & value) const525 bool HeaderMatcher::HeaderMatchTest::Matches(const std::string& name,
526 const std::string& value) const {
527 for (size_t i = 0; i < name_match_.size(); ++i) {
528 if (!name_match_[i]->Matches(name))
529 return false;
530 }
531
532 for (size_t i = 0; i < value_match_.size(); ++i) {
533 if (!value_match_[i]->Matches(value))
534 return false;
535 }
536
537 return true;
538 }
539
540 //
541 // WebRequestConditionAttributeRequestHeaders
542 //
543
544 WebRequestConditionAttributeRequestHeaders::
WebRequestConditionAttributeRequestHeaders(scoped_ptr<const HeaderMatcher> header_matcher,bool positive)545 WebRequestConditionAttributeRequestHeaders(
546 scoped_ptr<const HeaderMatcher> header_matcher,
547 bool positive)
548 : header_matcher_(header_matcher.Pass()),
549 positive_(positive) {}
550
551 WebRequestConditionAttributeRequestHeaders::
~WebRequestConditionAttributeRequestHeaders()552 ~WebRequestConditionAttributeRequestHeaders() {}
553
554 namespace {
555
PrepareHeaderMatcher(const std::string & name,const base::Value * value,std::string * error)556 scoped_ptr<const HeaderMatcher> PrepareHeaderMatcher(
557 const std::string& name,
558 const base::Value* value,
559 std::string* error) {
560 const base::ListValue* value_as_list = NULL;
561 if (!value->GetAsList(&value_as_list)) {
562 *error = ErrorUtils::FormatErrorMessage(kInvalidValue, name);
563 return scoped_ptr<const HeaderMatcher>();
564 }
565
566 scoped_ptr<const HeaderMatcher> header_matcher(
567 HeaderMatcher::Create(value_as_list));
568 if (header_matcher.get() == NULL)
569 *error = ErrorUtils::FormatErrorMessage(kInvalidValue, name);
570 return header_matcher.Pass();
571 }
572
573 } // namespace
574
575 // static
576 scoped_refptr<const WebRequestConditionAttribute>
Create(const std::string & name,const base::Value * value,std::string * error,bool * bad_message)577 WebRequestConditionAttributeRequestHeaders::Create(
578 const std::string& name,
579 const base::Value* value,
580 std::string* error,
581 bool* bad_message) {
582 DCHECK(name == keys::kRequestHeadersKey ||
583 name == keys::kExcludeRequestHeadersKey);
584
585 scoped_ptr<const HeaderMatcher> header_matcher(
586 PrepareHeaderMatcher(name, value, error));
587 if (header_matcher.get() == NULL)
588 return scoped_refptr<const WebRequestConditionAttribute>(NULL);
589
590 return scoped_refptr<const WebRequestConditionAttribute>(
591 new WebRequestConditionAttributeRequestHeaders(
592 header_matcher.Pass(), name == keys::kRequestHeadersKey));
593 }
594
GetStages() const595 int WebRequestConditionAttributeRequestHeaders::GetStages() const {
596 // Currently we only allow matching against headers in the before-send-headers
597 // stage. The headers are accessible in other stages as well, but before
598 // allowing to match against them in further stages, we should consider
599 // caching the match result.
600 return ON_BEFORE_SEND_HEADERS;
601 }
602
IsFulfilled(const WebRequestData & request_data) const603 bool WebRequestConditionAttributeRequestHeaders::IsFulfilled(
604 const WebRequestData& request_data) const {
605 if (!(request_data.stage & GetStages()))
606 return false;
607
608 const net::HttpRequestHeaders& headers =
609 request_data.request->extra_request_headers();
610
611 bool passed = false; // Did some header pass TestNameValue?
612 net::HttpRequestHeaders::Iterator it(headers);
613 while (!passed && it.GetNext())
614 passed |= header_matcher_->TestNameValue(it.name(), it.value());
615
616 return (positive_ ? passed : !passed);
617 }
618
619 WebRequestConditionAttribute::Type
GetType() const620 WebRequestConditionAttributeRequestHeaders::GetType() const {
621 return CONDITION_REQUEST_HEADERS;
622 }
623
GetName() const624 std::string WebRequestConditionAttributeRequestHeaders::GetName() const {
625 return (positive_ ? keys::kRequestHeadersKey
626 : keys::kExcludeRequestHeadersKey);
627 }
628
Equals(const WebRequestConditionAttribute * other) const629 bool WebRequestConditionAttributeRequestHeaders::Equals(
630 const WebRequestConditionAttribute* other) const {
631 // Comparing headers is too heavy, so we skip it entirely.
632 return false;
633 }
634
635 //
636 // WebRequestConditionAttributeResponseHeaders
637 //
638
639 WebRequestConditionAttributeResponseHeaders::
WebRequestConditionAttributeResponseHeaders(scoped_ptr<const HeaderMatcher> header_matcher,bool positive)640 WebRequestConditionAttributeResponseHeaders(
641 scoped_ptr<const HeaderMatcher> header_matcher,
642 bool positive)
643 : header_matcher_(header_matcher.Pass()),
644 positive_(positive) {}
645
646 WebRequestConditionAttributeResponseHeaders::
~WebRequestConditionAttributeResponseHeaders()647 ~WebRequestConditionAttributeResponseHeaders() {}
648
649 // static
650 scoped_refptr<const WebRequestConditionAttribute>
Create(const std::string & name,const base::Value * value,std::string * error,bool * bad_message)651 WebRequestConditionAttributeResponseHeaders::Create(
652 const std::string& name,
653 const base::Value* value,
654 std::string* error,
655 bool* bad_message) {
656 DCHECK(name == keys::kResponseHeadersKey ||
657 name == keys::kExcludeResponseHeadersKey);
658
659 scoped_ptr<const HeaderMatcher> header_matcher(
660 PrepareHeaderMatcher(name, value, error));
661 if (header_matcher.get() == NULL)
662 return scoped_refptr<const WebRequestConditionAttribute>(NULL);
663
664 return scoped_refptr<const WebRequestConditionAttribute>(
665 new WebRequestConditionAttributeResponseHeaders(
666 header_matcher.Pass(), name == keys::kResponseHeadersKey));
667 }
668
GetStages() const669 int WebRequestConditionAttributeResponseHeaders::GetStages() const {
670 return ON_HEADERS_RECEIVED;
671 }
672
IsFulfilled(const WebRequestData & request_data) const673 bool WebRequestConditionAttributeResponseHeaders::IsFulfilled(
674 const WebRequestData& request_data) const {
675 if (!(request_data.stage & GetStages()))
676 return false;
677
678 const net::HttpResponseHeaders* headers =
679 request_data.original_response_headers;
680 if (headers == NULL) {
681 // Each header of an empty set satisfies (the negation of) everything;
682 // OTOH, there is no header to satisfy even the most permissive test.
683 return !positive_;
684 }
685
686 bool passed = false; // Did some header pass TestNameValue?
687 std::string name;
688 std::string value;
689 void* iter = NULL;
690 while (!passed && headers->EnumerateHeaderLines(&iter, &name, &value)) {
691 passed |= header_matcher_->TestNameValue(name, value);
692 }
693
694 return (positive_ ? passed : !passed);
695 }
696
697 WebRequestConditionAttribute::Type
GetType() const698 WebRequestConditionAttributeResponseHeaders::GetType() const {
699 return CONDITION_RESPONSE_HEADERS;
700 }
701
GetName() const702 std::string WebRequestConditionAttributeResponseHeaders::GetName() const {
703 return (positive_ ? keys::kResponseHeadersKey
704 : keys::kExcludeResponseHeadersKey);
705 }
706
Equals(const WebRequestConditionAttribute * other) const707 bool WebRequestConditionAttributeResponseHeaders::Equals(
708 const WebRequestConditionAttribute* other) const {
709 return false;
710 }
711
712 //
713 // WebRequestConditionAttributeThirdParty
714 //
715
716 WebRequestConditionAttributeThirdParty::
WebRequestConditionAttributeThirdParty(bool match_third_party)717 WebRequestConditionAttributeThirdParty(bool match_third_party)
718 : match_third_party_(match_third_party) {}
719
720 WebRequestConditionAttributeThirdParty::
~WebRequestConditionAttributeThirdParty()721 ~WebRequestConditionAttributeThirdParty() {}
722
723 // static
724 scoped_refptr<const WebRequestConditionAttribute>
Create(const std::string & name,const base::Value * value,std::string * error,bool * bad_message)725 WebRequestConditionAttributeThirdParty::Create(
726 const std::string& name,
727 const base::Value* value,
728 std::string* error,
729 bool* bad_message) {
730 DCHECK(name == keys::kThirdPartyKey);
731
732 bool third_party = false; // Dummy value, gets overwritten.
733 if (!value->GetAsBoolean(&third_party)) {
734 *error = ErrorUtils::FormatErrorMessage(kInvalidValue,
735 keys::kThirdPartyKey);
736 return scoped_refptr<const WebRequestConditionAttribute>(NULL);
737 }
738
739 return scoped_refptr<const WebRequestConditionAttribute>(
740 new WebRequestConditionAttributeThirdParty(third_party));
741 }
742
GetStages() const743 int WebRequestConditionAttributeThirdParty::GetStages() const {
744 return ON_BEFORE_REQUEST | ON_BEFORE_SEND_HEADERS | ON_SEND_HEADERS |
745 ON_HEADERS_RECEIVED | ON_AUTH_REQUIRED | ON_BEFORE_REDIRECT |
746 ON_RESPONSE_STARTED | ON_COMPLETED | ON_ERROR;
747 }
748
IsFulfilled(const WebRequestData & request_data) const749 bool WebRequestConditionAttributeThirdParty::IsFulfilled(
750 const WebRequestData& request_data) const {
751 if (!(request_data.stage & GetStages()))
752 return false;
753
754 // Request is "1st party" if it gets cookies under 3rd party-blocking policy.
755 const net::StaticCookiePolicy block_third_party_policy(
756 net::StaticCookiePolicy::BLOCK_ALL_THIRD_PARTY_COOKIES);
757 const int can_get_cookies = block_third_party_policy.CanGetCookies(
758 request_data.request->url(),
759 request_data.request->first_party_for_cookies());
760 const bool is_first_party = (can_get_cookies == net::OK);
761
762 return match_third_party_ ? !is_first_party : is_first_party;
763 }
764
765 WebRequestConditionAttribute::Type
GetType() const766 WebRequestConditionAttributeThirdParty::GetType() const {
767 return CONDITION_THIRD_PARTY;
768 }
769
GetName() const770 std::string WebRequestConditionAttributeThirdParty::GetName() const {
771 return keys::kThirdPartyKey;
772 }
773
Equals(const WebRequestConditionAttribute * other) const774 bool WebRequestConditionAttributeThirdParty::Equals(
775 const WebRequestConditionAttribute* other) const {
776 if (!WebRequestConditionAttribute::Equals(other))
777 return false;
778 const WebRequestConditionAttributeThirdParty* casted_other =
779 static_cast<const WebRequestConditionAttributeThirdParty*>(other);
780 return match_third_party_ == casted_other->match_third_party_;
781 }
782
783 //
784 // WebRequestConditionAttributeStages
785 //
786
787 WebRequestConditionAttributeStages::
WebRequestConditionAttributeStages(int allowed_stages)788 WebRequestConditionAttributeStages(int allowed_stages)
789 : allowed_stages_(allowed_stages) {}
790
791 WebRequestConditionAttributeStages::
~WebRequestConditionAttributeStages()792 ~WebRequestConditionAttributeStages() {}
793
794 namespace {
795
796 // Reads strings stored in |value|, which is expected to be a ListValue, and
797 // sets corresponding bits (see RequestStage) in |out_stages|. Returns true on
798 // success, false otherwise.
ParseListOfStages(const base::Value & value,int * out_stages)799 bool ParseListOfStages(const base::Value& value, int* out_stages) {
800 const base::ListValue* list = NULL;
801 if (!value.GetAsList(&list))
802 return false;
803
804 int stages = 0;
805 std::string stage_name;
806 for (base::ListValue::const_iterator it = list->begin();
807 it != list->end(); ++it) {
808 if (!((*it)->GetAsString(&stage_name)))
809 return false;
810 if (stage_name == keys::kOnBeforeRequestEnum) {
811 stages |= ON_BEFORE_REQUEST;
812 } else if (stage_name == keys::kOnBeforeSendHeadersEnum) {
813 stages |= ON_BEFORE_SEND_HEADERS;
814 } else if (stage_name == keys::kOnHeadersReceivedEnum) {
815 stages |= ON_HEADERS_RECEIVED;
816 } else if (stage_name == keys::kOnAuthRequiredEnum) {
817 stages |= ON_AUTH_REQUIRED;
818 } else {
819 NOTREACHED(); // JSON schema checks prevent getting here.
820 return false;
821 }
822 }
823
824 *out_stages = stages;
825 return true;
826 }
827
828 } // namespace
829
830 // static
831 scoped_refptr<const WebRequestConditionAttribute>
Create(const std::string & name,const base::Value * value,std::string * error,bool * bad_message)832 WebRequestConditionAttributeStages::Create(const std::string& name,
833 const base::Value* value,
834 std::string* error,
835 bool* bad_message) {
836 DCHECK(name == keys::kStagesKey);
837
838 int allowed_stages = 0;
839 if (!ParseListOfStages(*value, &allowed_stages)) {
840 *error = ErrorUtils::FormatErrorMessage(kInvalidValue,
841 keys::kStagesKey);
842 return scoped_refptr<const WebRequestConditionAttribute>(NULL);
843 }
844
845 return scoped_refptr<const WebRequestConditionAttribute>(
846 new WebRequestConditionAttributeStages(allowed_stages));
847 }
848
GetStages() const849 int WebRequestConditionAttributeStages::GetStages() const {
850 return allowed_stages_;
851 }
852
IsFulfilled(const WebRequestData & request_data) const853 bool WebRequestConditionAttributeStages::IsFulfilled(
854 const WebRequestData& request_data) const {
855 // Note: removing '!=' triggers warning C4800 on the VS compiler.
856 return (request_data.stage & GetStages()) != 0;
857 }
858
859 WebRequestConditionAttribute::Type
GetType() const860 WebRequestConditionAttributeStages::GetType() const {
861 return CONDITION_STAGES;
862 }
863
GetName() const864 std::string WebRequestConditionAttributeStages::GetName() const {
865 return keys::kStagesKey;
866 }
867
Equals(const WebRequestConditionAttribute * other) const868 bool WebRequestConditionAttributeStages::Equals(
869 const WebRequestConditionAttribute* other) const {
870 if (!WebRequestConditionAttribute::Equals(other))
871 return false;
872 const WebRequestConditionAttributeStages* casted_other =
873 static_cast<const WebRequestConditionAttributeStages*>(other);
874 return allowed_stages_ == casted_other->allowed_stages_;
875 }
876
877 } // namespace extensions
878