1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 // See "SSPI Sample Application" at
6 // http://msdn.microsoft.com/en-us/library/aa918273.aspx
7
8 #include "net/http/http_auth_sspi_win.h"
9
10 #include "base/base64.h"
11 #include "base/logging.h"
12 #include "base/string_util.h"
13 #include "net/base/net_errors.h"
14 #include "net/base/net_util.h"
15 #include "net/http/http_auth.h"
16
17 namespace net {
18
HttpAuthSSPI(const std::string & scheme,SEC_WCHAR * security_package)19 HttpAuthSSPI::HttpAuthSSPI(const std::string& scheme,
20 SEC_WCHAR* security_package)
21 : scheme_(scheme),
22 security_package_(security_package),
23 max_token_length_(0) {
24 SecInvalidateHandle(&cred_);
25 SecInvalidateHandle(&ctxt_);
26 }
27
~HttpAuthSSPI()28 HttpAuthSSPI::~HttpAuthSSPI() {
29 ResetSecurityContext();
30 if (SecIsValidHandle(&cred_)) {
31 FreeCredentialsHandle(&cred_);
32 SecInvalidateHandle(&cred_);
33 }
34 }
35
NeedsIdentity() const36 bool HttpAuthSSPI::NeedsIdentity() const {
37 return decoded_server_auth_token_.empty();
38 }
39
IsFinalRound() const40 bool HttpAuthSSPI::IsFinalRound() const {
41 return !decoded_server_auth_token_.empty();
42 }
43
ResetSecurityContext()44 void HttpAuthSSPI::ResetSecurityContext() {
45 if (SecIsValidHandle(&ctxt_)) {
46 DeleteSecurityContext(&ctxt_);
47 SecInvalidateHandle(&ctxt_);
48 }
49 }
50
ParseChallenge(std::string::const_iterator challenge_begin,std::string::const_iterator challenge_end)51 bool HttpAuthSSPI::ParseChallenge(std::string::const_iterator challenge_begin,
52 std::string::const_iterator challenge_end) {
53 // Verify the challenge's auth-scheme.
54 HttpAuth::ChallengeTokenizer challenge_tok(challenge_begin, challenge_end);
55 if (!challenge_tok.valid() ||
56 !LowerCaseEqualsASCII(challenge_tok.scheme(),
57 StringToLowerASCII(scheme_).c_str()))
58 return false;
59 // Extract the auth-data. We can't use challenge_tok.GetNext() because
60 // auth-data is base64-encoded and may contain '=' padding at the end,
61 // which would be mistaken for a name=value pair.
62 challenge_begin += scheme_.length(); // Skip over scheme name.
63 HttpUtil::TrimLWS(&challenge_begin, &challenge_end);
64 std::string encoded_auth_token(challenge_begin, challenge_end);
65 int encoded_length = encoded_auth_token.length();
66 // Strip off any padding.
67 // (See https://bugzilla.mozilla.org/show_bug.cgi?id=230351.)
68 //
69 // Our base64 decoder requires that the length be a multiple of 4.
70 while (encoded_length > 0 && encoded_length % 4 != 0 &&
71 encoded_auth_token[encoded_length - 1] == '=')
72 encoded_length--;
73 encoded_auth_token.erase(encoded_length);
74
75 std::string decoded_auth_token;
76 bool rv = base::Base64Decode(encoded_auth_token, &decoded_auth_token);
77 if (rv) {
78 decoded_server_auth_token_ = decoded_auth_token;
79 }
80 return rv;
81 }
82
GenerateCredentials(const std::wstring & username,const std::wstring & password,const GURL & origin,const HttpRequestInfo * request,const ProxyInfo * proxy,std::string * out_credentials)83 int HttpAuthSSPI::GenerateCredentials(const std::wstring& username,
84 const std::wstring& password,
85 const GURL& origin,
86 const HttpRequestInfo* request,
87 const ProxyInfo* proxy,
88 std::string* out_credentials) {
89 // |username| may be in the form "DOMAIN\user". Parse it into the two
90 // components.
91 std::wstring domain;
92 std::wstring user;
93 SplitDomainAndUser(username, &domain, &user);
94
95 // Initial challenge.
96 if (!IsFinalRound()) {
97 int rv = OnFirstRound(domain, user, password);
98 if (rv != OK)
99 return rv;
100 }
101
102 void* out_buf;
103 int out_buf_len;
104 int rv = GetNextSecurityToken(
105 origin,
106 static_cast<void *>(const_cast<char *>(
107 decoded_server_auth_token_.c_str())),
108 decoded_server_auth_token_.length(),
109 &out_buf,
110 &out_buf_len);
111 if (rv != OK)
112 return rv;
113
114 // Base64 encode data in output buffer and prepend the scheme.
115 std::string encode_input(static_cast<char*>(out_buf), out_buf_len);
116 std::string encode_output;
117 bool ok = base::Base64Encode(encode_input, &encode_output);
118 // OK, we are done with |out_buf|
119 free(out_buf);
120 if (!ok)
121 return rv;
122 *out_credentials = scheme_ + " " + encode_output;
123 return OK;
124 }
125
OnFirstRound(const std::wstring & domain,const std::wstring & user,const std::wstring & password)126 int HttpAuthSSPI::OnFirstRound(const std::wstring& domain,
127 const std::wstring& user,
128 const std::wstring& password) {
129 int rv = DetermineMaxTokenLength(security_package_, &max_token_length_);
130 if (rv != OK) {
131 return rv;
132 }
133 rv = AcquireCredentials(security_package_, domain, user, password, &cred_);
134 return rv;
135 }
136
GetNextSecurityToken(const GURL & origin,const void * in_token,int in_token_len,void ** out_token,int * out_token_len)137 int HttpAuthSSPI::GetNextSecurityToken(
138 const GURL& origin,
139 const void * in_token,
140 int in_token_len,
141 void** out_token,
142 int* out_token_len) {
143 SECURITY_STATUS status;
144 TimeStamp expiry;
145
146 DWORD ctxt_attr;
147 CtxtHandle* ctxt_ptr;
148 SecBufferDesc in_buffer_desc, out_buffer_desc;
149 SecBufferDesc* in_buffer_desc_ptr;
150 SecBuffer in_buffer, out_buffer;
151
152 if (in_token_len > 0) {
153 // Prepare input buffer.
154 in_buffer_desc.ulVersion = SECBUFFER_VERSION;
155 in_buffer_desc.cBuffers = 1;
156 in_buffer_desc.pBuffers = &in_buffer;
157 in_buffer.BufferType = SECBUFFER_TOKEN;
158 in_buffer.cbBuffer = in_token_len;
159 in_buffer.pvBuffer = const_cast<void*>(in_token);
160 ctxt_ptr = &ctxt_;
161 in_buffer_desc_ptr = &in_buffer_desc;
162 } else {
163 // If there is no input token, then we are starting a new authentication
164 // sequence. If we have already initialized our security context, then
165 // we're incorrectly reusing the auth handler for a new sequence.
166 if (SecIsValidHandle(&ctxt_)) {
167 LOG(ERROR) << "Cannot restart authentication sequence";
168 return ERR_UNEXPECTED;
169 }
170 ctxt_ptr = NULL;
171 in_buffer_desc_ptr = NULL;
172 }
173
174 // Prepare output buffer.
175 out_buffer_desc.ulVersion = SECBUFFER_VERSION;
176 out_buffer_desc.cBuffers = 1;
177 out_buffer_desc.pBuffers = &out_buffer;
178 out_buffer.BufferType = SECBUFFER_TOKEN;
179 out_buffer.cbBuffer = max_token_length_;
180 out_buffer.pvBuffer = malloc(out_buffer.cbBuffer);
181 if (!out_buffer.pvBuffer)
182 return ERR_OUT_OF_MEMORY;
183
184 // The service principal name of the destination server. See
185 // http://msdn.microsoft.com/en-us/library/ms677949%28VS.85%29.aspx
186 std::wstring target(L"HTTP/");
187 target.append(ASCIIToWide(GetHostAndPort(origin)));
188 wchar_t* target_name = const_cast<wchar_t*>(target.c_str());
189
190 // This returns a token that is passed to the remote server.
191 status = InitializeSecurityContext(&cred_, // phCredential
192 ctxt_ptr, // phContext
193 target_name, // pszTargetName
194 0, // fContextReq
195 0, // Reserved1 (must be 0)
196 SECURITY_NATIVE_DREP, // TargetDataRep
197 in_buffer_desc_ptr, // pInput
198 0, // Reserved2 (must be 0)
199 &ctxt_, // phNewContext
200 &out_buffer_desc, // pOutput
201 &ctxt_attr, // pfContextAttr
202 &expiry); // ptsExpiry
203 // On success, the function returns SEC_I_CONTINUE_NEEDED on the first call
204 // and SEC_E_OK on the second call. On failure, the function returns an
205 // error code.
206 if (status != SEC_I_CONTINUE_NEEDED && status != SEC_E_OK) {
207 LOG(ERROR) << "InitializeSecurityContext failed: " << status;
208 ResetSecurityContext();
209 free(out_buffer.pvBuffer);
210 return ERR_UNEXPECTED; // TODO(wtc): map error code.
211 }
212 if (!out_buffer.cbBuffer) {
213 free(out_buffer.pvBuffer);
214 out_buffer.pvBuffer = NULL;
215 }
216 *out_token = out_buffer.pvBuffer;
217 *out_token_len = out_buffer.cbBuffer;
218 return OK;
219 }
220
SplitDomainAndUser(const std::wstring & combined,std::wstring * domain,std::wstring * user)221 void SplitDomainAndUser(const std::wstring& combined,
222 std::wstring* domain,
223 std::wstring* user) {
224 size_t backslash_idx = combined.find(L'\\');
225 if (backslash_idx == std::wstring::npos) {
226 domain->clear();
227 *user = combined;
228 } else {
229 *domain = combined.substr(0, backslash_idx);
230 *user = combined.substr(backslash_idx + 1);
231 }
232 }
233
DetermineMaxTokenLength(const std::wstring & package,ULONG * max_token_length)234 int DetermineMaxTokenLength(const std::wstring& package,
235 ULONG* max_token_length) {
236 PSecPkgInfo pkg_info;
237 SECURITY_STATUS status = QuerySecurityPackageInfo(
238 const_cast<wchar_t *>(package.c_str()), &pkg_info);
239 if (status != SEC_E_OK) {
240 LOG(ERROR) << "Security package " << package << " not found";
241 return ERR_UNEXPECTED;
242 }
243 *max_token_length = pkg_info->cbMaxToken;
244 FreeContextBuffer(pkg_info);
245 return OK;
246 }
247
AcquireCredentials(const SEC_WCHAR * package,const std::wstring & domain,const std::wstring & user,const std::wstring & password,CredHandle * cred)248 int AcquireCredentials(const SEC_WCHAR* package,
249 const std::wstring& domain,
250 const std::wstring& user,
251 const std::wstring& password,
252 CredHandle* cred) {
253 SEC_WINNT_AUTH_IDENTITY identity;
254 identity.Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE;
255 identity.User =
256 reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(user.c_str()));
257 identity.UserLength = user.size();
258 identity.Domain =
259 reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(domain.c_str()));
260 identity.DomainLength = domain.size();
261 identity.Password =
262 reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(password.c_str()));
263 identity.PasswordLength = password.size();
264
265 TimeStamp expiry;
266
267 // Pass the username/password to get the credentials handle.
268 // Note: If the 5th argument is NULL, it uses the default cached credentials
269 // for the logged in user, which can be used for single sign-on.
270 SECURITY_STATUS status = AcquireCredentialsHandle(
271 NULL, // pszPrincipal
272 const_cast<SEC_WCHAR*>(package), // pszPackage
273 SECPKG_CRED_OUTBOUND, // fCredentialUse
274 NULL, // pvLogonID
275 &identity, // pAuthData
276 NULL, // pGetKeyFn (not used)
277 NULL, // pvGetKeyArgument (not used)
278 cred, // phCredential
279 &expiry); // ptsExpiry
280
281 if (status != SEC_E_OK)
282 return ERR_UNEXPECTED;
283 return OK;
284 }
285
286 } // namespace net
287