• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 // See "SSPI Sample Application" at
6 // http://msdn.microsoft.com/en-us/library/aa918273.aspx
7 
8 #include "net/http/http_auth_sspi_win.h"
9 
10 #include "base/base64.h"
11 #include "base/logging.h"
12 #include "base/string_util.h"
13 #include "net/base/net_errors.h"
14 #include "net/base/net_util.h"
15 #include "net/http/http_auth.h"
16 
17 namespace net {
18 
HttpAuthSSPI(const std::string & scheme,SEC_WCHAR * security_package)19 HttpAuthSSPI::HttpAuthSSPI(const std::string& scheme,
20                            SEC_WCHAR* security_package)
21     : scheme_(scheme),
22       security_package_(security_package),
23       max_token_length_(0) {
24   SecInvalidateHandle(&cred_);
25   SecInvalidateHandle(&ctxt_);
26 }
27 
~HttpAuthSSPI()28 HttpAuthSSPI::~HttpAuthSSPI() {
29   ResetSecurityContext();
30   if (SecIsValidHandle(&cred_)) {
31     FreeCredentialsHandle(&cred_);
32     SecInvalidateHandle(&cred_);
33   }
34 }
35 
NeedsIdentity() const36 bool HttpAuthSSPI::NeedsIdentity() const {
37   return decoded_server_auth_token_.empty();
38 }
39 
IsFinalRound() const40 bool HttpAuthSSPI::IsFinalRound() const {
41   return !decoded_server_auth_token_.empty();
42 }
43 
ResetSecurityContext()44 void HttpAuthSSPI::ResetSecurityContext() {
45   if (SecIsValidHandle(&ctxt_)) {
46     DeleteSecurityContext(&ctxt_);
47     SecInvalidateHandle(&ctxt_);
48   }
49 }
50 
ParseChallenge(std::string::const_iterator challenge_begin,std::string::const_iterator challenge_end)51 bool HttpAuthSSPI::ParseChallenge(std::string::const_iterator challenge_begin,
52                                   std::string::const_iterator challenge_end) {
53   // Verify the challenge's auth-scheme.
54   HttpAuth::ChallengeTokenizer challenge_tok(challenge_begin, challenge_end);
55   if (!challenge_tok.valid() ||
56       !LowerCaseEqualsASCII(challenge_tok.scheme(),
57                             StringToLowerASCII(scheme_).c_str()))
58     return false;
59   // Extract the auth-data.  We can't use challenge_tok.GetNext() because
60   // auth-data is base64-encoded and may contain '=' padding at the end,
61   // which would be mistaken for a name=value pair.
62   challenge_begin += scheme_.length();  // Skip over scheme name.
63   HttpUtil::TrimLWS(&challenge_begin, &challenge_end);
64   std::string encoded_auth_token(challenge_begin, challenge_end);
65   int encoded_length = encoded_auth_token.length();
66   // Strip off any padding.
67   // (See https://bugzilla.mozilla.org/show_bug.cgi?id=230351.)
68   //
69   // Our base64 decoder requires that the length be a multiple of 4.
70   while (encoded_length > 0 && encoded_length % 4 != 0 &&
71          encoded_auth_token[encoded_length - 1] == '=')
72     encoded_length--;
73   encoded_auth_token.erase(encoded_length);
74 
75   std::string decoded_auth_token;
76   bool rv = base::Base64Decode(encoded_auth_token, &decoded_auth_token);
77   if (rv) {
78     decoded_server_auth_token_ = decoded_auth_token;
79   }
80   return rv;
81 }
82 
GenerateCredentials(const std::wstring & username,const std::wstring & password,const GURL & origin,const HttpRequestInfo * request,const ProxyInfo * proxy,std::string * out_credentials)83 int HttpAuthSSPI::GenerateCredentials(const std::wstring& username,
84                                       const std::wstring& password,
85                                       const GURL& origin,
86                                       const HttpRequestInfo* request,
87                                       const ProxyInfo* proxy,
88                                       std::string* out_credentials) {
89   // |username| may be in the form "DOMAIN\user".  Parse it into the two
90   // components.
91   std::wstring domain;
92   std::wstring user;
93   SplitDomainAndUser(username, &domain, &user);
94 
95   // Initial challenge.
96   if (!IsFinalRound()) {
97     int rv = OnFirstRound(domain, user, password);
98     if (rv != OK)
99       return rv;
100   }
101 
102   void* out_buf;
103   int out_buf_len;
104   int rv = GetNextSecurityToken(
105       origin,
106       static_cast<void *>(const_cast<char *>(
107           decoded_server_auth_token_.c_str())),
108       decoded_server_auth_token_.length(),
109       &out_buf,
110       &out_buf_len);
111   if (rv != OK)
112     return rv;
113 
114   // Base64 encode data in output buffer and prepend the scheme.
115   std::string encode_input(static_cast<char*>(out_buf), out_buf_len);
116   std::string encode_output;
117   bool ok = base::Base64Encode(encode_input, &encode_output);
118   // OK, we are done with |out_buf|
119   free(out_buf);
120   if (!ok)
121     return rv;
122   *out_credentials = scheme_ + " " + encode_output;
123   return OK;
124 }
125 
OnFirstRound(const std::wstring & domain,const std::wstring & user,const std::wstring & password)126 int HttpAuthSSPI::OnFirstRound(const std::wstring& domain,
127                                const std::wstring& user,
128                                const std::wstring& password) {
129   int rv = DetermineMaxTokenLength(security_package_, &max_token_length_);
130   if (rv != OK) {
131     return rv;
132   }
133   rv = AcquireCredentials(security_package_, domain, user, password, &cred_);
134   return rv;
135 }
136 
GetNextSecurityToken(const GURL & origin,const void * in_token,int in_token_len,void ** out_token,int * out_token_len)137 int HttpAuthSSPI::GetNextSecurityToken(
138     const GURL& origin,
139     const void * in_token,
140     int in_token_len,
141     void** out_token,
142     int* out_token_len) {
143   SECURITY_STATUS status;
144   TimeStamp expiry;
145 
146   DWORD ctxt_attr;
147   CtxtHandle* ctxt_ptr;
148   SecBufferDesc in_buffer_desc, out_buffer_desc;
149   SecBufferDesc* in_buffer_desc_ptr;
150   SecBuffer in_buffer, out_buffer;
151 
152   if (in_token_len > 0) {
153     // Prepare input buffer.
154     in_buffer_desc.ulVersion = SECBUFFER_VERSION;
155     in_buffer_desc.cBuffers = 1;
156     in_buffer_desc.pBuffers = &in_buffer;
157     in_buffer.BufferType = SECBUFFER_TOKEN;
158     in_buffer.cbBuffer = in_token_len;
159     in_buffer.pvBuffer = const_cast<void*>(in_token);
160     ctxt_ptr = &ctxt_;
161     in_buffer_desc_ptr = &in_buffer_desc;
162   } else {
163     // If there is no input token, then we are starting a new authentication
164     // sequence.  If we have already initialized our security context, then
165     // we're incorrectly reusing the auth handler for a new sequence.
166     if (SecIsValidHandle(&ctxt_)) {
167       LOG(ERROR) << "Cannot restart authentication sequence";
168       return ERR_UNEXPECTED;
169     }
170     ctxt_ptr = NULL;
171     in_buffer_desc_ptr = NULL;
172   }
173 
174   // Prepare output buffer.
175   out_buffer_desc.ulVersion = SECBUFFER_VERSION;
176   out_buffer_desc.cBuffers = 1;
177   out_buffer_desc.pBuffers = &out_buffer;
178   out_buffer.BufferType = SECBUFFER_TOKEN;
179   out_buffer.cbBuffer = max_token_length_;
180   out_buffer.pvBuffer = malloc(out_buffer.cbBuffer);
181   if (!out_buffer.pvBuffer)
182     return ERR_OUT_OF_MEMORY;
183 
184   // The service principal name of the destination server.  See
185   // http://msdn.microsoft.com/en-us/library/ms677949%28VS.85%29.aspx
186   std::wstring target(L"HTTP/");
187   target.append(ASCIIToWide(GetHostAndPort(origin)));
188   wchar_t* target_name = const_cast<wchar_t*>(target.c_str());
189 
190   // This returns a token that is passed to the remote server.
191   status = InitializeSecurityContext(&cred_,  // phCredential
192                                      ctxt_ptr,  // phContext
193                                      target_name,  // pszTargetName
194                                      0,  // fContextReq
195                                      0,  // Reserved1 (must be 0)
196                                      SECURITY_NATIVE_DREP,  // TargetDataRep
197                                      in_buffer_desc_ptr,  // pInput
198                                      0,  // Reserved2 (must be 0)
199                                      &ctxt_,  // phNewContext
200                                      &out_buffer_desc,  // pOutput
201                                      &ctxt_attr,  // pfContextAttr
202                                      &expiry);  // ptsExpiry
203   // On success, the function returns SEC_I_CONTINUE_NEEDED on the first call
204   // and SEC_E_OK on the second call.  On failure, the function returns an
205   // error code.
206   if (status != SEC_I_CONTINUE_NEEDED && status != SEC_E_OK) {
207     LOG(ERROR) << "InitializeSecurityContext failed: " << status;
208     ResetSecurityContext();
209     free(out_buffer.pvBuffer);
210     return ERR_UNEXPECTED;  // TODO(wtc): map error code.
211   }
212   if (!out_buffer.cbBuffer) {
213     free(out_buffer.pvBuffer);
214     out_buffer.pvBuffer = NULL;
215   }
216   *out_token = out_buffer.pvBuffer;
217   *out_token_len = out_buffer.cbBuffer;
218   return OK;
219 }
220 
SplitDomainAndUser(const std::wstring & combined,std::wstring * domain,std::wstring * user)221 void SplitDomainAndUser(const std::wstring& combined,
222                         std::wstring* domain,
223                         std::wstring* user) {
224   size_t backslash_idx = combined.find(L'\\');
225   if (backslash_idx == std::wstring::npos) {
226     domain->clear();
227     *user = combined;
228   } else {
229     *domain = combined.substr(0, backslash_idx);
230     *user = combined.substr(backslash_idx + 1);
231   }
232 }
233 
DetermineMaxTokenLength(const std::wstring & package,ULONG * max_token_length)234 int DetermineMaxTokenLength(const std::wstring& package,
235                             ULONG* max_token_length) {
236   PSecPkgInfo pkg_info;
237   SECURITY_STATUS status = QuerySecurityPackageInfo(
238       const_cast<wchar_t *>(package.c_str()), &pkg_info);
239   if (status != SEC_E_OK) {
240     LOG(ERROR) << "Security package " << package << " not found";
241     return ERR_UNEXPECTED;
242   }
243   *max_token_length = pkg_info->cbMaxToken;
244   FreeContextBuffer(pkg_info);
245   return OK;
246 }
247 
AcquireCredentials(const SEC_WCHAR * package,const std::wstring & domain,const std::wstring & user,const std::wstring & password,CredHandle * cred)248 int AcquireCredentials(const SEC_WCHAR* package,
249                        const std::wstring& domain,
250                        const std::wstring& user,
251                        const std::wstring& password,
252                        CredHandle* cred) {
253   SEC_WINNT_AUTH_IDENTITY identity;
254   identity.Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE;
255   identity.User =
256       reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(user.c_str()));
257   identity.UserLength = user.size();
258   identity.Domain =
259       reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(domain.c_str()));
260   identity.DomainLength = domain.size();
261   identity.Password =
262       reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(password.c_str()));
263   identity.PasswordLength = password.size();
264 
265   TimeStamp expiry;
266 
267   // Pass the username/password to get the credentials handle.
268   // Note: If the 5th argument is NULL, it uses the default cached credentials
269   // for the logged in user, which can be used for single sign-on.
270   SECURITY_STATUS status = AcquireCredentialsHandle(
271       NULL,  // pszPrincipal
272       const_cast<SEC_WCHAR*>(package),  // pszPackage
273       SECPKG_CRED_OUTBOUND,  // fCredentialUse
274       NULL,  // pvLogonID
275       &identity,  // pAuthData
276       NULL,  // pGetKeyFn (not used)
277       NULL,  // pvGetKeyArgument (not used)
278       cred,  // phCredential
279       &expiry);  // ptsExpiry
280 
281   if (status != SEC_E_OK)
282     return ERR_UNEXPECTED;
283   return OK;
284 }
285 
286 }  // namespace net
287