• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chrome/browser/extensions/api/declarative_webrequest/webrequest_condition_attribute.h"
6 
7 #include <algorithm>
8 
9 #include "base/lazy_instance.h"
10 #include "base/logging.h"
11 #include "base/strings/string_util.h"
12 #include "base/strings/stringprintf.h"
13 #include "base/values.h"
14 #include "chrome/browser/extensions/api/declarative/deduping_factory.h"
15 #include "chrome/browser/extensions/api/declarative_webrequest/request_stage.h"
16 #include "chrome/browser/extensions/api/declarative_webrequest/webrequest_condition.h"
17 #include "chrome/browser/extensions/api/declarative_webrequest/webrequest_constants.h"
18 #include "chrome/browser/extensions/api/web_request/web_request_api_helpers.h"
19 #include "content/public/browser/resource_request_info.h"
20 #include "extensions/common/error_utils.h"
21 #include "net/base/net_errors.h"
22 #include "net/base/registry_controlled_domains/registry_controlled_domain.h"
23 #include "net/base/static_cookie_policy.h"
24 #include "net/http/http_request_headers.h"
25 #include "net/http/http_util.h"
26 #include "net/url_request/url_request.h"
27 
28 using base::CaseInsensitiveCompareASCII;
29 using base::DictionaryValue;
30 using base::ListValue;
31 using base::StringValue;
32 using base::Value;
33 
34 namespace helpers = extension_web_request_api_helpers;
35 namespace keys = extensions::declarative_webrequest_constants;
36 
37 namespace extensions {
38 
39 namespace {
40 // Error messages.
41 const char kInvalidValue[] = "Condition '*' has an invalid value";
42 
43 struct WebRequestConditionAttributeFactory {
44   DedupingFactory<WebRequestConditionAttribute> factory;
45 
WebRequestConditionAttributeFactoryextensions::__anon2a2920e40111::WebRequestConditionAttributeFactory46   WebRequestConditionAttributeFactory() : factory(5) {
47     factory.RegisterFactoryMethod(
48         keys::kResourceTypeKey,
49         DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
50         &WebRequestConditionAttributeResourceType::Create);
51 
52     factory.RegisterFactoryMethod(
53         keys::kContentTypeKey,
54         DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
55         &WebRequestConditionAttributeContentType::Create);
56     factory.RegisterFactoryMethod(
57         keys::kExcludeContentTypeKey,
58         DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
59         &WebRequestConditionAttributeContentType::Create);
60 
61     factory.RegisterFactoryMethod(
62         keys::kRequestHeadersKey,
63         DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
64         &WebRequestConditionAttributeRequestHeaders::Create);
65     factory.RegisterFactoryMethod(
66         keys::kExcludeRequestHeadersKey,
67         DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
68         &WebRequestConditionAttributeRequestHeaders::Create);
69 
70     factory.RegisterFactoryMethod(
71         keys::kResponseHeadersKey,
72         DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
73         &WebRequestConditionAttributeResponseHeaders::Create);
74     factory.RegisterFactoryMethod(
75         keys::kExcludeResponseHeadersKey,
76         DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
77         &WebRequestConditionAttributeResponseHeaders::Create);
78 
79     factory.RegisterFactoryMethod(
80         keys::kThirdPartyKey,
81         DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
82         &WebRequestConditionAttributeThirdParty::Create);
83 
84     factory.RegisterFactoryMethod(
85         keys::kStagesKey,
86         DedupingFactory<WebRequestConditionAttribute>::IS_PARAMETERIZED,
87         &WebRequestConditionAttributeStages::Create);
88   }
89 };
90 
91 base::LazyInstance<WebRequestConditionAttributeFactory>::Leaky
92     g_web_request_condition_attribute_factory = LAZY_INSTANCE_INITIALIZER;
93 
94 }  // namespace
95 
96 //
97 // WebRequestConditionAttribute
98 //
99 
WebRequestConditionAttribute()100 WebRequestConditionAttribute::WebRequestConditionAttribute() {}
101 
~WebRequestConditionAttribute()102 WebRequestConditionAttribute::~WebRequestConditionAttribute() {}
103 
Equals(const WebRequestConditionAttribute * other) const104 bool WebRequestConditionAttribute::Equals(
105     const WebRequestConditionAttribute* other) const {
106   return GetType() == other->GetType();
107 }
108 
109 // static
110 scoped_refptr<const WebRequestConditionAttribute>
Create(const std::string & name,const base::Value * value,std::string * error)111 WebRequestConditionAttribute::Create(
112     const std::string& name,
113     const base::Value* value,
114     std::string* error) {
115   CHECK(value != NULL && error != NULL);
116   bool bad_message = false;
117   return g_web_request_condition_attribute_factory.Get().factory.Instantiate(
118       name, value, error, &bad_message);
119 }
120 
121 //
122 // WebRequestConditionAttributeResourceType
123 //
124 
125 WebRequestConditionAttributeResourceType::
WebRequestConditionAttributeResourceType(const std::vector<ResourceType::Type> & types)126 WebRequestConditionAttributeResourceType(
127     const std::vector<ResourceType::Type>& types)
128     : types_(types) {}
129 
130 WebRequestConditionAttributeResourceType::
~WebRequestConditionAttributeResourceType()131 ~WebRequestConditionAttributeResourceType() {}
132 
133 // static
134 scoped_refptr<const WebRequestConditionAttribute>
Create(const std::string & instance_type,const base::Value * value,std::string * error,bool * bad_message)135 WebRequestConditionAttributeResourceType::Create(
136     const std::string& instance_type,
137     const base::Value* value,
138     std::string* error,
139     bool* bad_message) {
140   DCHECK(instance_type == keys::kResourceTypeKey);
141   const base::ListValue* value_as_list = NULL;
142   if (!value->GetAsList(&value_as_list)) {
143     *error = ErrorUtils::FormatErrorMessage(kInvalidValue,
144                                             keys::kResourceTypeKey);
145     return scoped_refptr<const WebRequestConditionAttribute>(NULL);
146   }
147 
148   size_t number_types = value_as_list->GetSize();
149 
150   std::vector<ResourceType::Type> passed_types;
151   passed_types.reserve(number_types);
152   for (size_t i = 0; i < number_types; ++i) {
153     std::string resource_type_string;
154     ResourceType::Type type = ResourceType::LAST_TYPE;
155     if (!value_as_list->GetString(i, &resource_type_string) ||
156         !helpers::ParseResourceType(resource_type_string, &type)) {
157       *error = ErrorUtils::FormatErrorMessage(kInvalidValue,
158                                               keys::kResourceTypeKey);
159       return scoped_refptr<const WebRequestConditionAttribute>(NULL);
160     }
161     passed_types.push_back(type);
162   }
163 
164   return scoped_refptr<const WebRequestConditionAttribute>(
165       new WebRequestConditionAttributeResourceType(passed_types));
166 }
167 
GetStages() const168 int WebRequestConditionAttributeResourceType::GetStages() const {
169   return ON_BEFORE_REQUEST | ON_BEFORE_SEND_HEADERS | ON_SEND_HEADERS |
170       ON_HEADERS_RECEIVED | ON_AUTH_REQUIRED | ON_BEFORE_REDIRECT |
171       ON_RESPONSE_STARTED | ON_COMPLETED | ON_ERROR;
172 }
173 
IsFulfilled(const WebRequestData & request_data) const174 bool WebRequestConditionAttributeResourceType::IsFulfilled(
175     const WebRequestData& request_data) const {
176   if (!(request_data.stage & GetStages()))
177     return false;
178   const content::ResourceRequestInfo* info =
179       content::ResourceRequestInfo::ForRequest(request_data.request);
180   if (!info)
181     return false;
182   return std::find(types_.begin(), types_.end(), info->GetResourceType()) !=
183       types_.end();
184 }
185 
186 WebRequestConditionAttribute::Type
GetType() const187 WebRequestConditionAttributeResourceType::GetType() const {
188   return CONDITION_RESOURCE_TYPE;
189 }
190 
GetName() const191 std::string WebRequestConditionAttributeResourceType::GetName() const {
192   return keys::kResourceTypeKey;
193 }
194 
Equals(const WebRequestConditionAttribute * other) const195 bool WebRequestConditionAttributeResourceType::Equals(
196     const WebRequestConditionAttribute* other) const {
197   if (!WebRequestConditionAttribute::Equals(other))
198     return false;
199   const WebRequestConditionAttributeResourceType* casted_other =
200       static_cast<const WebRequestConditionAttributeResourceType*>(other);
201   return types_ == casted_other->types_;
202 }
203 
204 //
205 // WebRequestConditionAttributeContentType
206 //
207 
208 WebRequestConditionAttributeContentType::
WebRequestConditionAttributeContentType(const std::vector<std::string> & content_types,bool inclusive)209 WebRequestConditionAttributeContentType(
210     const std::vector<std::string>& content_types,
211     bool inclusive)
212     : content_types_(content_types),
213       inclusive_(inclusive) {}
214 
215 WebRequestConditionAttributeContentType::
~WebRequestConditionAttributeContentType()216 ~WebRequestConditionAttributeContentType() {}
217 
218 // static
219 scoped_refptr<const WebRequestConditionAttribute>
Create(const std::string & name,const base::Value * value,std::string * error,bool * bad_message)220 WebRequestConditionAttributeContentType::Create(
221       const std::string& name,
222       const base::Value* value,
223       std::string* error,
224       bool* bad_message) {
225   DCHECK(name == keys::kContentTypeKey || name == keys::kExcludeContentTypeKey);
226 
227   const base::ListValue* value_as_list = NULL;
228   if (!value->GetAsList(&value_as_list)) {
229     *error = ErrorUtils::FormatErrorMessage(kInvalidValue, name);
230     return scoped_refptr<const WebRequestConditionAttribute>(NULL);
231   }
232   std::vector<std::string> content_types;
233   for (base::ListValue::const_iterator it = value_as_list->begin();
234        it != value_as_list->end(); ++it) {
235     std::string content_type;
236     if (!(*it)->GetAsString(&content_type)) {
237       *error = ErrorUtils::FormatErrorMessage(kInvalidValue, name);
238       return scoped_refptr<const WebRequestConditionAttribute>(NULL);
239     }
240     content_types.push_back(content_type);
241   }
242 
243   return scoped_refptr<const WebRequestConditionAttribute>(
244       new WebRequestConditionAttributeContentType(
245           content_types, name == keys::kContentTypeKey));
246 }
247 
GetStages() const248 int WebRequestConditionAttributeContentType::GetStages() const {
249   return ON_HEADERS_RECEIVED;
250 }
251 
IsFulfilled(const WebRequestData & request_data) const252 bool WebRequestConditionAttributeContentType::IsFulfilled(
253     const WebRequestData& request_data) const {
254   if (!(request_data.stage & GetStages()))
255     return false;
256   std::string content_type;
257   request_data.original_response_headers->GetNormalizedHeader(
258       net::HttpRequestHeaders::kContentType, &content_type);
259   std::string mime_type;
260   std::string charset;
261   bool had_charset = false;
262   net::HttpUtil::ParseContentType(
263       content_type, &mime_type, &charset, &had_charset, NULL);
264 
265   if (inclusive_) {
266     return std::find(content_types_.begin(), content_types_.end(),
267                      mime_type) != content_types_.end();
268   } else {
269     return std::find(content_types_.begin(), content_types_.end(),
270                      mime_type) == content_types_.end();
271   }
272 }
273 
274 WebRequestConditionAttribute::Type
GetType() const275 WebRequestConditionAttributeContentType::GetType() const {
276   return CONDITION_CONTENT_TYPE;
277 }
278 
GetName() const279 std::string WebRequestConditionAttributeContentType::GetName() const {
280   return (inclusive_ ? keys::kContentTypeKey : keys::kExcludeContentTypeKey);
281 }
282 
Equals(const WebRequestConditionAttribute * other) const283 bool WebRequestConditionAttributeContentType::Equals(
284     const WebRequestConditionAttribute* other) const {
285   if (!WebRequestConditionAttribute::Equals(other))
286     return false;
287   const WebRequestConditionAttributeContentType* casted_other =
288       static_cast<const WebRequestConditionAttributeContentType*>(other);
289   return content_types_ == casted_other->content_types_ &&
290          inclusive_ == casted_other->inclusive_;
291 }
292 
293 // Manages a set of tests to be applied to name-value pairs representing
294 // headers. This is a helper class to header-related condition attributes.
295 // It contains a set of test groups. A name-value pair satisfies the whole
296 // set of test groups iff it passes at least one test group.
297 class HeaderMatcher {
298  public:
299   ~HeaderMatcher();
300 
301   // Creates an instance based on a list |tests| of test groups, encoded as
302   // dictionaries of the type declarativeWebRequest.HeaderFilter (see
303   // declarative_web_request.json).
304   static scoped_ptr<const HeaderMatcher> Create(const base::ListValue* tests);
305 
306   // Does |this| match the header "|name|: |value|"?
307   bool TestNameValue(const std::string& name, const std::string& value) const;
308 
309  private:
310   // Represents a single string-matching test.
311   class StringMatchTest {
312    public:
313     enum MatchType { kPrefix, kSuffix, kEquals, kContains };
314 
315     // |data| is the pattern to be matched in the position given by |type|.
316     // Note that |data| must point to a StringValue object.
317     static scoped_ptr<StringMatchTest> Create(const base::Value* data,
318                                               MatchType type,
319                                               bool case_sensitive);
320     ~StringMatchTest();
321 
322     // Does |str| pass |this| StringMatchTest?
323     bool Matches(const std::string& str) const;
324 
325    private:
326     StringMatchTest(const std::string& data,
327                     MatchType type,
328                     bool case_sensitive);
329 
330     const std::string data_;
331     const MatchType type_;
332     const bool case_sensitive_;
333     DISALLOW_COPY_AND_ASSIGN(StringMatchTest);
334   };
335 
336   // Represents a test group -- a set of string matching tests to be applied to
337   // both the header name and value.
338   class HeaderMatchTest {
339    public:
340     ~HeaderMatchTest();
341 
342     // Gets the test group description in |tests| and creates the corresponding
343     // HeaderMatchTest. On failure returns NULL.
344     static scoped_ptr<const HeaderMatchTest> Create(
345         const base::DictionaryValue* tests);
346 
347     // Does the header "|name|: |value|" match all tests in |this|?
348     bool Matches(const std::string& name, const std::string& value) const;
349 
350    private:
351     // Takes ownership of the content of both |name_match| and |value_match|.
352     HeaderMatchTest(ScopedVector<const StringMatchTest>* name_match,
353                     ScopedVector<const StringMatchTest>* value_match);
354 
355     // Tests to be passed by a header's name.
356     const ScopedVector<const StringMatchTest> name_match_;
357     // Tests to be passed by a header's value.
358     const ScopedVector<const StringMatchTest> value_match_;
359     DISALLOW_COPY_AND_ASSIGN(HeaderMatchTest);
360   };
361 
362   explicit HeaderMatcher(ScopedVector<const HeaderMatchTest>* tests);
363 
364   const ScopedVector<const HeaderMatchTest> tests_;
365 
366   DISALLOW_COPY_AND_ASSIGN(HeaderMatcher);
367 };
368 
369 // HeaderMatcher implementation.
370 
~HeaderMatcher()371 HeaderMatcher::~HeaderMatcher() {}
372 
373 // static
Create(const base::ListValue * tests)374 scoped_ptr<const HeaderMatcher> HeaderMatcher::Create(
375     const base::ListValue* tests) {
376   ScopedVector<const HeaderMatchTest> header_tests;
377   for (base::ListValue::const_iterator it = tests->begin();
378        it != tests->end(); ++it) {
379     const base::DictionaryValue* tests = NULL;
380     if (!(*it)->GetAsDictionary(&tests))
381       return scoped_ptr<const HeaderMatcher>();
382 
383     scoped_ptr<const HeaderMatchTest> header_test(
384         HeaderMatchTest::Create(tests));
385     if (header_test.get() == NULL)
386       return scoped_ptr<const HeaderMatcher>();
387     header_tests.push_back(header_test.release());
388   }
389 
390   return scoped_ptr<const HeaderMatcher>(new HeaderMatcher(&header_tests));
391 }
392 
TestNameValue(const std::string & name,const std::string & value) const393 bool HeaderMatcher::TestNameValue(const std::string& name,
394                                   const std::string& value) const {
395   for (size_t i = 0; i < tests_.size(); ++i) {
396     if (tests_[i]->Matches(name, value))
397       return true;
398   }
399   return false;
400 }
401 
HeaderMatcher(ScopedVector<const HeaderMatchTest> * tests)402 HeaderMatcher::HeaderMatcher(ScopedVector<const HeaderMatchTest>* tests)
403   : tests_(tests->Pass()) {}
404 
405 // HeaderMatcher::StringMatchTest implementation.
406 
407 // static
408 scoped_ptr<HeaderMatcher::StringMatchTest>
Create(const base::Value * data,MatchType type,bool case_sensitive)409 HeaderMatcher::StringMatchTest::Create(const base::Value* data,
410                                        MatchType type,
411                                        bool case_sensitive) {
412   std::string str;
413   CHECK(data->GetAsString(&str));
414   return scoped_ptr<StringMatchTest>(
415       new StringMatchTest(str, type, case_sensitive));
416 }
417 
~StringMatchTest()418 HeaderMatcher::StringMatchTest::~StringMatchTest() {}
419 
Matches(const std::string & str) const420 bool HeaderMatcher::StringMatchTest::Matches(
421     const std::string& str) const {
422   switch (type_) {
423     case kPrefix:
424       return StartsWithASCII(str, data_, case_sensitive_);
425     case kSuffix:
426       return EndsWith(str, data_, case_sensitive_);
427     case kEquals:
428       return str.size() == data_.size() &&
429              StartsWithASCII(str, data_, case_sensitive_);
430     case kContains:
431       if (!case_sensitive_) {
432         return std::search(str.begin(), str.end(), data_.begin(), data_.end(),
433                            CaseInsensitiveCompareASCII<char>()) != str.end();
434       } else {
435         return str.find(data_) != std::string::npos;
436       }
437   }
438   // We never get past the "switch", but the compiler worries about no return.
439   NOTREACHED();
440   return false;
441 }
442 
StringMatchTest(const std::string & data,MatchType type,bool case_sensitive)443 HeaderMatcher::StringMatchTest::StringMatchTest(const std::string& data,
444                                                 MatchType type,
445                                                 bool case_sensitive)
446     : data_(data),
447       type_(type),
448       case_sensitive_(case_sensitive) {}
449 
450 // HeaderMatcher::HeaderMatchTest implementation.
451 
HeaderMatchTest(ScopedVector<const StringMatchTest> * name_match,ScopedVector<const StringMatchTest> * value_match)452 HeaderMatcher::HeaderMatchTest::HeaderMatchTest(
453     ScopedVector<const StringMatchTest>* name_match,
454     ScopedVector<const StringMatchTest>* value_match)
455     : name_match_(name_match->Pass()),
456       value_match_(value_match->Pass()) {}
457 
~HeaderMatchTest()458 HeaderMatcher::HeaderMatchTest::~HeaderMatchTest() {}
459 
460 // static
461 scoped_ptr<const HeaderMatcher::HeaderMatchTest>
Create(const base::DictionaryValue * tests)462 HeaderMatcher::HeaderMatchTest::Create(const base::DictionaryValue* tests) {
463   ScopedVector<const StringMatchTest> name_match;
464   ScopedVector<const StringMatchTest> value_match;
465 
466   for (base::DictionaryValue::Iterator it(*tests);
467        !it.IsAtEnd(); it.Advance()) {
468     bool is_name = false;  // Is this test for header name?
469     StringMatchTest::MatchType match_type;
470     if (it.key() == keys::kNamePrefixKey) {
471       is_name = true;
472       match_type = StringMatchTest::kPrefix;
473     } else if (it.key() == keys::kNameSuffixKey) {
474       is_name = true;
475       match_type = StringMatchTest::kSuffix;
476     } else if (it.key() == keys::kNameContainsKey) {
477       is_name = true;
478       match_type = StringMatchTest::kContains;
479     } else if (it.key() == keys::kNameEqualsKey) {
480       is_name = true;
481       match_type = StringMatchTest::kEquals;
482     } else if (it.key() == keys::kValuePrefixKey) {
483       match_type = StringMatchTest::kPrefix;
484     } else if (it.key() == keys::kValueSuffixKey) {
485       match_type = StringMatchTest::kSuffix;
486     } else if (it.key() == keys::kValueContainsKey) {
487       match_type = StringMatchTest::kContains;
488     } else if (it.key() == keys::kValueEqualsKey) {
489       match_type = StringMatchTest::kEquals;
490     } else {
491       NOTREACHED();  // JSON schema type checking should prevent this.
492       return scoped_ptr<const HeaderMatchTest>();
493     }
494     const base::Value* content = &it.value();
495 
496     ScopedVector<const StringMatchTest>* tests =
497         is_name ? &name_match : &value_match;
498     switch (content->GetType()) {
499       case base::Value::TYPE_LIST: {
500         const base::ListValue* list = NULL;
501         CHECK(content->GetAsList(&list));
502         for (base::ListValue::const_iterator it = list->begin();
503              it != list->end(); ++it) {
504           tests->push_back(
505               StringMatchTest::Create(*it, match_type, !is_name).release());
506         }
507         break;
508       }
509       case base::Value::TYPE_STRING: {
510         tests->push_back(
511             StringMatchTest::Create(content, match_type, !is_name).release());
512         break;
513       }
514       default: {
515         NOTREACHED();  // JSON schema type checking should prevent this.
516         return scoped_ptr<const HeaderMatchTest>();
517       }
518     }
519   }
520 
521   return scoped_ptr<const HeaderMatchTest>(
522       new HeaderMatchTest(&name_match, &value_match));
523 }
524 
Matches(const std::string & name,const std::string & value) const525 bool HeaderMatcher::HeaderMatchTest::Matches(const std::string& name,
526                                              const std::string& value) const {
527   for (size_t i = 0; i < name_match_.size(); ++i) {
528     if (!name_match_[i]->Matches(name))
529       return false;
530   }
531 
532   for (size_t i = 0; i < value_match_.size(); ++i) {
533     if (!value_match_[i]->Matches(value))
534       return false;
535   }
536 
537   return true;
538 }
539 
540 //
541 // WebRequestConditionAttributeRequestHeaders
542 //
543 
544 WebRequestConditionAttributeRequestHeaders::
WebRequestConditionAttributeRequestHeaders(scoped_ptr<const HeaderMatcher> header_matcher,bool positive)545 WebRequestConditionAttributeRequestHeaders(
546     scoped_ptr<const HeaderMatcher> header_matcher,
547     bool positive)
548     : header_matcher_(header_matcher.Pass()),
549       positive_(positive) {}
550 
551 WebRequestConditionAttributeRequestHeaders::
~WebRequestConditionAttributeRequestHeaders()552 ~WebRequestConditionAttributeRequestHeaders() {}
553 
554 namespace {
555 
PrepareHeaderMatcher(const std::string & name,const base::Value * value,std::string * error)556 scoped_ptr<const HeaderMatcher> PrepareHeaderMatcher(
557     const std::string& name,
558     const base::Value* value,
559     std::string* error) {
560   const base::ListValue* value_as_list = NULL;
561   if (!value->GetAsList(&value_as_list)) {
562     *error = ErrorUtils::FormatErrorMessage(kInvalidValue, name);
563     return scoped_ptr<const HeaderMatcher>();
564   }
565 
566   scoped_ptr<const HeaderMatcher> header_matcher(
567       HeaderMatcher::Create(value_as_list));
568   if (header_matcher.get() == NULL)
569     *error = ErrorUtils::FormatErrorMessage(kInvalidValue, name);
570   return header_matcher.Pass();
571 }
572 
573 }  // namespace
574 
575 // static
576 scoped_refptr<const WebRequestConditionAttribute>
Create(const std::string & name,const base::Value * value,std::string * error,bool * bad_message)577 WebRequestConditionAttributeRequestHeaders::Create(
578     const std::string& name,
579     const base::Value* value,
580     std::string* error,
581     bool* bad_message) {
582   DCHECK(name == keys::kRequestHeadersKey ||
583          name == keys::kExcludeRequestHeadersKey);
584 
585   scoped_ptr<const HeaderMatcher> header_matcher(
586       PrepareHeaderMatcher(name, value, error));
587   if (header_matcher.get() == NULL)
588     return scoped_refptr<const WebRequestConditionAttribute>(NULL);
589 
590   return scoped_refptr<const WebRequestConditionAttribute>(
591       new WebRequestConditionAttributeRequestHeaders(
592           header_matcher.Pass(), name == keys::kRequestHeadersKey));
593 }
594 
GetStages() const595 int WebRequestConditionAttributeRequestHeaders::GetStages() const {
596   // Currently we only allow matching against headers in the before-send-headers
597   // stage. The headers are accessible in other stages as well, but before
598   // allowing to match against them in further stages, we should consider
599   // caching the match result.
600   return ON_BEFORE_SEND_HEADERS;
601 }
602 
IsFulfilled(const WebRequestData & request_data) const603 bool WebRequestConditionAttributeRequestHeaders::IsFulfilled(
604     const WebRequestData& request_data) const {
605   if (!(request_data.stage & GetStages()))
606     return false;
607 
608   const net::HttpRequestHeaders& headers =
609       request_data.request->extra_request_headers();
610 
611   bool passed = false;  // Did some header pass TestNameValue?
612   net::HttpRequestHeaders::Iterator it(headers);
613   while (!passed && it.GetNext())
614     passed |= header_matcher_->TestNameValue(it.name(), it.value());
615 
616   return (positive_ ? passed : !passed);
617 }
618 
619 WebRequestConditionAttribute::Type
GetType() const620 WebRequestConditionAttributeRequestHeaders::GetType() const {
621   return CONDITION_REQUEST_HEADERS;
622 }
623 
GetName() const624 std::string WebRequestConditionAttributeRequestHeaders::GetName() const {
625   return (positive_ ? keys::kRequestHeadersKey
626                     : keys::kExcludeRequestHeadersKey);
627 }
628 
Equals(const WebRequestConditionAttribute * other) const629 bool WebRequestConditionAttributeRequestHeaders::Equals(
630     const WebRequestConditionAttribute* other) const {
631   // Comparing headers is too heavy, so we skip it entirely.
632   return false;
633 }
634 
635 //
636 // WebRequestConditionAttributeResponseHeaders
637 //
638 
639 WebRequestConditionAttributeResponseHeaders::
WebRequestConditionAttributeResponseHeaders(scoped_ptr<const HeaderMatcher> header_matcher,bool positive)640 WebRequestConditionAttributeResponseHeaders(
641     scoped_ptr<const HeaderMatcher> header_matcher,
642     bool positive)
643     : header_matcher_(header_matcher.Pass()),
644       positive_(positive) {}
645 
646 WebRequestConditionAttributeResponseHeaders::
~WebRequestConditionAttributeResponseHeaders()647 ~WebRequestConditionAttributeResponseHeaders() {}
648 
649 // static
650 scoped_refptr<const WebRequestConditionAttribute>
Create(const std::string & name,const base::Value * value,std::string * error,bool * bad_message)651 WebRequestConditionAttributeResponseHeaders::Create(
652     const std::string& name,
653     const base::Value* value,
654     std::string* error,
655     bool* bad_message) {
656   DCHECK(name == keys::kResponseHeadersKey ||
657          name == keys::kExcludeResponseHeadersKey);
658 
659   scoped_ptr<const HeaderMatcher> header_matcher(
660       PrepareHeaderMatcher(name, value, error));
661   if (header_matcher.get() == NULL)
662     return scoped_refptr<const WebRequestConditionAttribute>(NULL);
663 
664   return scoped_refptr<const WebRequestConditionAttribute>(
665       new WebRequestConditionAttributeResponseHeaders(
666           header_matcher.Pass(), name == keys::kResponseHeadersKey));
667 }
668 
GetStages() const669 int WebRequestConditionAttributeResponseHeaders::GetStages() const {
670   return ON_HEADERS_RECEIVED;
671 }
672 
IsFulfilled(const WebRequestData & request_data) const673 bool WebRequestConditionAttributeResponseHeaders::IsFulfilled(
674     const WebRequestData& request_data) const {
675   if (!(request_data.stage & GetStages()))
676     return false;
677 
678   const net::HttpResponseHeaders* headers =
679       request_data.original_response_headers;
680   if (headers == NULL) {
681     // Each header of an empty set satisfies (the negation of) everything;
682     // OTOH, there is no header to satisfy even the most permissive test.
683     return !positive_;
684   }
685 
686   bool passed = false;  // Did some header pass TestNameValue?
687   std::string name;
688   std::string value;
689   void* iter = NULL;
690   while (!passed && headers->EnumerateHeaderLines(&iter, &name, &value)) {
691     passed |= header_matcher_->TestNameValue(name, value);
692   }
693 
694   return (positive_ ? passed : !passed);
695 }
696 
697 WebRequestConditionAttribute::Type
GetType() const698 WebRequestConditionAttributeResponseHeaders::GetType() const {
699   return CONDITION_RESPONSE_HEADERS;
700 }
701 
GetName() const702 std::string WebRequestConditionAttributeResponseHeaders::GetName() const {
703   return (positive_ ? keys::kResponseHeadersKey
704                     : keys::kExcludeResponseHeadersKey);
705 }
706 
Equals(const WebRequestConditionAttribute * other) const707 bool WebRequestConditionAttributeResponseHeaders::Equals(
708     const WebRequestConditionAttribute* other) const {
709   return false;
710 }
711 
712 //
713 // WebRequestConditionAttributeThirdParty
714 //
715 
716 WebRequestConditionAttributeThirdParty::
WebRequestConditionAttributeThirdParty(bool match_third_party)717 WebRequestConditionAttributeThirdParty(bool match_third_party)
718     : match_third_party_(match_third_party) {}
719 
720 WebRequestConditionAttributeThirdParty::
~WebRequestConditionAttributeThirdParty()721 ~WebRequestConditionAttributeThirdParty() {}
722 
723 // static
724 scoped_refptr<const WebRequestConditionAttribute>
Create(const std::string & name,const base::Value * value,std::string * error,bool * bad_message)725 WebRequestConditionAttributeThirdParty::Create(
726     const std::string& name,
727     const base::Value* value,
728     std::string* error,
729     bool* bad_message) {
730   DCHECK(name == keys::kThirdPartyKey);
731 
732   bool third_party = false;  // Dummy value, gets overwritten.
733   if (!value->GetAsBoolean(&third_party)) {
734     *error = ErrorUtils::FormatErrorMessage(kInvalidValue,
735                                                      keys::kThirdPartyKey);
736     return scoped_refptr<const WebRequestConditionAttribute>(NULL);
737   }
738 
739   return scoped_refptr<const WebRequestConditionAttribute>(
740       new WebRequestConditionAttributeThirdParty(third_party));
741 }
742 
GetStages() const743 int WebRequestConditionAttributeThirdParty::GetStages() const {
744   return ON_BEFORE_REQUEST | ON_BEFORE_SEND_HEADERS | ON_SEND_HEADERS |
745       ON_HEADERS_RECEIVED | ON_AUTH_REQUIRED | ON_BEFORE_REDIRECT |
746       ON_RESPONSE_STARTED | ON_COMPLETED | ON_ERROR;
747 }
748 
IsFulfilled(const WebRequestData & request_data) const749 bool WebRequestConditionAttributeThirdParty::IsFulfilled(
750     const WebRequestData& request_data) const {
751   if (!(request_data.stage & GetStages()))
752     return false;
753 
754   // Request is "1st party" if it gets cookies under 3rd party-blocking policy.
755   const net::StaticCookiePolicy block_third_party_policy(
756       net::StaticCookiePolicy::BLOCK_ALL_THIRD_PARTY_COOKIES);
757   const int can_get_cookies = block_third_party_policy.CanGetCookies(
758           request_data.request->url(),
759           request_data.request->first_party_for_cookies());
760   const bool is_first_party = (can_get_cookies == net::OK);
761 
762   return match_third_party_ ? !is_first_party : is_first_party;
763 }
764 
765 WebRequestConditionAttribute::Type
GetType() const766 WebRequestConditionAttributeThirdParty::GetType() const {
767   return CONDITION_THIRD_PARTY;
768 }
769 
GetName() const770 std::string WebRequestConditionAttributeThirdParty::GetName() const {
771   return keys::kThirdPartyKey;
772 }
773 
Equals(const WebRequestConditionAttribute * other) const774 bool WebRequestConditionAttributeThirdParty::Equals(
775     const WebRequestConditionAttribute* other) const {
776   if (!WebRequestConditionAttribute::Equals(other))
777     return false;
778   const WebRequestConditionAttributeThirdParty* casted_other =
779       static_cast<const WebRequestConditionAttributeThirdParty*>(other);
780   return match_third_party_ == casted_other->match_third_party_;
781 }
782 
783 //
784 // WebRequestConditionAttributeStages
785 //
786 
787 WebRequestConditionAttributeStages::
WebRequestConditionAttributeStages(int allowed_stages)788 WebRequestConditionAttributeStages(int allowed_stages)
789     : allowed_stages_(allowed_stages) {}
790 
791 WebRequestConditionAttributeStages::
~WebRequestConditionAttributeStages()792 ~WebRequestConditionAttributeStages() {}
793 
794 namespace {
795 
796 // Reads strings stored in |value|, which is expected to be a ListValue, and
797 // sets corresponding bits (see RequestStage) in |out_stages|. Returns true on
798 // success, false otherwise.
ParseListOfStages(const base::Value & value,int * out_stages)799 bool ParseListOfStages(const base::Value& value, int* out_stages) {
800   const base::ListValue* list = NULL;
801   if (!value.GetAsList(&list))
802     return false;
803 
804   int stages = 0;
805   std::string stage_name;
806   for (base::ListValue::const_iterator it = list->begin();
807        it != list->end(); ++it) {
808     if (!((*it)->GetAsString(&stage_name)))
809       return false;
810     if (stage_name == keys::kOnBeforeRequestEnum) {
811       stages |= ON_BEFORE_REQUEST;
812     } else if (stage_name == keys::kOnBeforeSendHeadersEnum) {
813       stages |= ON_BEFORE_SEND_HEADERS;
814     } else if (stage_name == keys::kOnHeadersReceivedEnum) {
815       stages |= ON_HEADERS_RECEIVED;
816     } else if (stage_name == keys::kOnAuthRequiredEnum) {
817       stages |= ON_AUTH_REQUIRED;
818     } else {
819       NOTREACHED();  // JSON schema checks prevent getting here.
820       return false;
821     }
822   }
823 
824   *out_stages = stages;
825   return true;
826 }
827 
828 }  // namespace
829 
830 // static
831 scoped_refptr<const WebRequestConditionAttribute>
Create(const std::string & name,const base::Value * value,std::string * error,bool * bad_message)832 WebRequestConditionAttributeStages::Create(const std::string& name,
833                                            const base::Value* value,
834                                            std::string* error,
835                                            bool* bad_message) {
836   DCHECK(name == keys::kStagesKey);
837 
838   int allowed_stages = 0;
839   if (!ParseListOfStages(*value, &allowed_stages)) {
840     *error = ErrorUtils::FormatErrorMessage(kInvalidValue,
841                                                      keys::kStagesKey);
842     return scoped_refptr<const WebRequestConditionAttribute>(NULL);
843   }
844 
845   return scoped_refptr<const WebRequestConditionAttribute>(
846       new WebRequestConditionAttributeStages(allowed_stages));
847 }
848 
GetStages() const849 int WebRequestConditionAttributeStages::GetStages() const {
850   return allowed_stages_;
851 }
852 
IsFulfilled(const WebRequestData & request_data) const853 bool WebRequestConditionAttributeStages::IsFulfilled(
854     const WebRequestData& request_data) const {
855   // Note: removing '!=' triggers warning C4800 on the VS compiler.
856   return (request_data.stage & GetStages()) != 0;
857 }
858 
859 WebRequestConditionAttribute::Type
GetType() const860 WebRequestConditionAttributeStages::GetType() const {
861   return CONDITION_STAGES;
862 }
863 
GetName() const864 std::string WebRequestConditionAttributeStages::GetName() const {
865   return keys::kStagesKey;
866 }
867 
Equals(const WebRequestConditionAttribute * other) const868 bool WebRequestConditionAttributeStages::Equals(
869     const WebRequestConditionAttribute* other) const {
870   if (!WebRequestConditionAttribute::Equals(other))
871     return false;
872   const WebRequestConditionAttributeStages* casted_other =
873       static_cast<const WebRequestConditionAttributeStages*>(other);
874   return allowed_stages_ == casted_other->allowed_stages_;
875 }
876 
877 }  // namespace extensions
878