1 // Copyright 2017 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/ntlm/ntlm_buffer_writer.h"
6
7 #include <string.h>
8
9 #include <limits>
10
11 #include "base/check_op.h"
12 #include "base/strings/utf_string_conversions.h"
13 #include "build/build_config.h"
14
15 namespace net::ntlm {
16
NtlmBufferWriter(size_t buffer_len)17 NtlmBufferWriter::NtlmBufferWriter(size_t buffer_len)
18 : buffer_(buffer_len, 0) {}
19
20 NtlmBufferWriter::~NtlmBufferWriter() = default;
21
CanWrite(size_t len) const22 bool NtlmBufferWriter::CanWrite(size_t len) const {
23 if (len == 0)
24 return true;
25
26 if (!GetBufferPtr())
27 return false;
28
29 DCHECK_LE(GetCursor(), GetLength());
30
31 return (len <= GetLength()) && (GetCursor() <= GetLength() - len);
32 }
33
WriteUInt16(uint16_t value)34 bool NtlmBufferWriter::WriteUInt16(uint16_t value) {
35 return WriteUInt<uint16_t>(value);
36 }
37
WriteUInt32(uint32_t value)38 bool NtlmBufferWriter::WriteUInt32(uint32_t value) {
39 return WriteUInt<uint32_t>(value);
40 }
41
WriteUInt64(uint64_t value)42 bool NtlmBufferWriter::WriteUInt64(uint64_t value) {
43 return WriteUInt<uint64_t>(value);
44 }
45
WriteFlags(NegotiateFlags flags)46 bool NtlmBufferWriter::WriteFlags(NegotiateFlags flags) {
47 return WriteUInt32(static_cast<uint32_t>(flags));
48 }
49
WriteBytes(base::span<const uint8_t> bytes)50 bool NtlmBufferWriter::WriteBytes(base::span<const uint8_t> bytes) {
51 if (bytes.size() == 0)
52 return true;
53
54 if (!CanWrite(bytes.size()))
55 return false;
56
57 memcpy(GetBufferPtrAtCursor(), bytes.data(), bytes.size());
58 AdvanceCursor(bytes.size());
59 return true;
60 }
61
WriteZeros(size_t count)62 bool NtlmBufferWriter::WriteZeros(size_t count) {
63 if (count == 0)
64 return true;
65
66 if (!CanWrite(count))
67 return false;
68
69 memset(GetBufferPtrAtCursor(), 0, count);
70 AdvanceCursor(count);
71 return true;
72 }
73
WriteSecurityBuffer(SecurityBuffer sec_buf)74 bool NtlmBufferWriter::WriteSecurityBuffer(SecurityBuffer sec_buf) {
75 return WriteUInt16(sec_buf.length) && WriteUInt16(sec_buf.length) &&
76 WriteUInt32(sec_buf.offset);
77 }
78
WriteAvPairHeader(TargetInfoAvId avid,uint16_t avlen)79 bool NtlmBufferWriter::WriteAvPairHeader(TargetInfoAvId avid, uint16_t avlen) {
80 if (!CanWrite(kAvPairHeaderLen))
81 return false;
82
83 bool result = WriteUInt16(static_cast<uint16_t>(avid)) && WriteUInt16(avlen);
84
85 DCHECK(result);
86 return result;
87 }
88
WriteAvPairTerminator()89 bool NtlmBufferWriter::WriteAvPairTerminator() {
90 return WriteAvPairHeader(TargetInfoAvId::kEol, 0);
91 }
92
WriteAvPair(const AvPair & pair)93 bool NtlmBufferWriter::WriteAvPair(const AvPair& pair) {
94 if (!WriteAvPairHeader(pair))
95 return false;
96
97 if (pair.avid == TargetInfoAvId::kFlags) {
98 if (pair.avlen != sizeof(uint32_t))
99 return false;
100 return WriteUInt32(static_cast<uint32_t>(pair.flags));
101 } else {
102 return WriteBytes(pair.buffer);
103 }
104 }
105
WriteUtf8String(const std::string & str)106 bool NtlmBufferWriter::WriteUtf8String(const std::string& str) {
107 return WriteBytes(base::as_bytes(base::make_span(str)));
108 }
109
WriteUtf16AsUtf8String(const std::u16string & str)110 bool NtlmBufferWriter::WriteUtf16AsUtf8String(const std::u16string& str) {
111 std::string utf8 = base::UTF16ToUTF8(str);
112 return WriteUtf8String(utf8);
113 }
114
WriteUtf8AsUtf16String(const std::string & str)115 bool NtlmBufferWriter::WriteUtf8AsUtf16String(const std::string& str) {
116 std::u16string unicode = base::UTF8ToUTF16(str);
117 return WriteUtf16String(unicode);
118 }
119
WriteUtf16String(const std::u16string & str)120 bool NtlmBufferWriter::WriteUtf16String(const std::u16string& str) {
121 if (str.size() > std::numeric_limits<size_t>::max() / 2)
122 return false;
123
124 size_t num_bytes = str.size() * 2;
125 if (num_bytes == 0)
126 return true;
127
128 if (!CanWrite(num_bytes))
129 return false;
130
131 #if defined(ARCH_CPU_BIG_ENDIAN)
132 uint8_t* ptr = reinterpret_cast<uint8_t*>(GetBufferPtrAtCursor());
133
134 for (int i = 0; i < num_bytes; i += 2) {
135 ptr[i] = str[i / 2] & 0xff;
136 ptr[i + 1] = str[i / 2] >> 8;
137 }
138 #else
139 memcpy(reinterpret_cast<void*>(GetBufferPtrAtCursor()), str.c_str(),
140 num_bytes);
141
142 #endif
143
144 AdvanceCursor(num_bytes);
145 return true;
146 }
147
WriteSignature()148 bool NtlmBufferWriter::WriteSignature() {
149 return WriteBytes(kSignature);
150 }
151
WriteMessageType(MessageType message_type)152 bool NtlmBufferWriter::WriteMessageType(MessageType message_type) {
153 return WriteUInt32(static_cast<uint32_t>(message_type));
154 }
155
WriteMessageHeader(MessageType message_type)156 bool NtlmBufferWriter::WriteMessageHeader(MessageType message_type) {
157 return WriteSignature() && WriteMessageType(message_type);
158 }
159
160 template <typename T>
WriteUInt(T value)161 bool NtlmBufferWriter::WriteUInt(T value) {
162 size_t int_size = sizeof(T);
163 if (!CanWrite(int_size))
164 return false;
165
166 for (size_t i = 0; i < int_size; i++) {
167 GetBufferPtrAtCursor()[i] = static_cast<uint8_t>(value & 0xff);
168 value >>= 8;
169 }
170
171 AdvanceCursor(int_size);
172 return true;
173 }
174
SetCursor(size_t cursor)175 void NtlmBufferWriter::SetCursor(size_t cursor) {
176 DCHECK(GetBufferPtr() && cursor <= GetLength());
177
178 cursor_ = cursor;
179 }
180
181 } // namespace net::ntlm
182