1 // Copyright 2017 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 #include "net/ntlm/ntlm_buffer_writer.h"
11
12 #include <string.h>
13
14 #include <limits>
15
16 #include "base/check_op.h"
17 #include "base/strings/utf_string_conversions.h"
18 #include "build/build_config.h"
19
20 namespace net::ntlm {
21
NtlmBufferWriter(size_t buffer_len)22 NtlmBufferWriter::NtlmBufferWriter(size_t buffer_len)
23 : buffer_(buffer_len, 0) {}
24
25 NtlmBufferWriter::~NtlmBufferWriter() = default;
26
CanWrite(size_t len) const27 bool NtlmBufferWriter::CanWrite(size_t len) const {
28 if (len == 0)
29 return true;
30
31 if (!GetBufferPtr())
32 return false;
33
34 DCHECK_LE(GetCursor(), GetLength());
35
36 return (len <= GetLength()) && (GetCursor() <= GetLength() - len);
37 }
38
WriteUInt16(uint16_t value)39 bool NtlmBufferWriter::WriteUInt16(uint16_t value) {
40 return WriteUInt<uint16_t>(value);
41 }
42
WriteUInt32(uint32_t value)43 bool NtlmBufferWriter::WriteUInt32(uint32_t value) {
44 return WriteUInt<uint32_t>(value);
45 }
46
WriteUInt64(uint64_t value)47 bool NtlmBufferWriter::WriteUInt64(uint64_t value) {
48 return WriteUInt<uint64_t>(value);
49 }
50
WriteFlags(NegotiateFlags flags)51 bool NtlmBufferWriter::WriteFlags(NegotiateFlags flags) {
52 return WriteUInt32(static_cast<uint32_t>(flags));
53 }
54
WriteBytes(base::span<const uint8_t> bytes)55 bool NtlmBufferWriter::WriteBytes(base::span<const uint8_t> bytes) {
56 if (bytes.size() == 0)
57 return true;
58
59 if (!CanWrite(bytes.size()))
60 return false;
61
62 memcpy(GetBufferPtrAtCursor(), bytes.data(), bytes.size());
63 AdvanceCursor(bytes.size());
64 return true;
65 }
66
WriteZeros(size_t count)67 bool NtlmBufferWriter::WriteZeros(size_t count) {
68 if (count == 0)
69 return true;
70
71 if (!CanWrite(count))
72 return false;
73
74 memset(GetBufferPtrAtCursor(), 0, count);
75 AdvanceCursor(count);
76 return true;
77 }
78
WriteSecurityBuffer(SecurityBuffer sec_buf)79 bool NtlmBufferWriter::WriteSecurityBuffer(SecurityBuffer sec_buf) {
80 return WriteUInt16(sec_buf.length) && WriteUInt16(sec_buf.length) &&
81 WriteUInt32(sec_buf.offset);
82 }
83
WriteAvPairHeader(TargetInfoAvId avid,uint16_t avlen)84 bool NtlmBufferWriter::WriteAvPairHeader(TargetInfoAvId avid, uint16_t avlen) {
85 if (!CanWrite(kAvPairHeaderLen))
86 return false;
87
88 bool result = WriteUInt16(static_cast<uint16_t>(avid)) && WriteUInt16(avlen);
89
90 DCHECK(result);
91 return result;
92 }
93
WriteAvPairTerminator()94 bool NtlmBufferWriter::WriteAvPairTerminator() {
95 return WriteAvPairHeader(TargetInfoAvId::kEol, 0);
96 }
97
WriteAvPair(const AvPair & pair)98 bool NtlmBufferWriter::WriteAvPair(const AvPair& pair) {
99 if (!WriteAvPairHeader(pair))
100 return false;
101
102 if (pair.avid == TargetInfoAvId::kFlags) {
103 if (pair.avlen != sizeof(uint32_t))
104 return false;
105 return WriteUInt32(static_cast<uint32_t>(pair.flags));
106 } else {
107 return WriteBytes(pair.buffer);
108 }
109 }
110
WriteUtf8String(const std::string & str)111 bool NtlmBufferWriter::WriteUtf8String(const std::string& str) {
112 return WriteBytes(base::as_byte_span(str));
113 }
114
WriteUtf16AsUtf8String(const std::u16string & str)115 bool NtlmBufferWriter::WriteUtf16AsUtf8String(const std::u16string& str) {
116 std::string utf8 = base::UTF16ToUTF8(str);
117 return WriteUtf8String(utf8);
118 }
119
WriteUtf8AsUtf16String(const std::string & str)120 bool NtlmBufferWriter::WriteUtf8AsUtf16String(const std::string& str) {
121 std::u16string unicode = base::UTF8ToUTF16(str);
122 return WriteUtf16String(unicode);
123 }
124
WriteUtf16String(const std::u16string & str)125 bool NtlmBufferWriter::WriteUtf16String(const std::u16string& str) {
126 if (str.size() > std::numeric_limits<size_t>::max() / 2)
127 return false;
128
129 size_t num_bytes = str.size() * 2;
130 if (num_bytes == 0)
131 return true;
132
133 if (!CanWrite(num_bytes))
134 return false;
135
136 #if defined(ARCH_CPU_BIG_ENDIAN)
137 uint8_t* ptr = reinterpret_cast<uint8_t*>(GetBufferPtrAtCursor());
138
139 for (int i = 0; i < num_bytes; i += 2) {
140 ptr[i] = str[i / 2] & 0xff;
141 ptr[i + 1] = str[i / 2] >> 8;
142 }
143 #else
144 memcpy(reinterpret_cast<void*>(GetBufferPtrAtCursor()), str.c_str(),
145 num_bytes);
146
147 #endif
148
149 AdvanceCursor(num_bytes);
150 return true;
151 }
152
WriteSignature()153 bool NtlmBufferWriter::WriteSignature() {
154 return WriteBytes(kSignature);
155 }
156
WriteMessageType(MessageType message_type)157 bool NtlmBufferWriter::WriteMessageType(MessageType message_type) {
158 return WriteUInt32(static_cast<uint32_t>(message_type));
159 }
160
WriteMessageHeader(MessageType message_type)161 bool NtlmBufferWriter::WriteMessageHeader(MessageType message_type) {
162 return WriteSignature() && WriteMessageType(message_type);
163 }
164
165 template <typename T>
WriteUInt(T value)166 bool NtlmBufferWriter::WriteUInt(T value) {
167 size_t int_size = sizeof(T);
168 if (!CanWrite(int_size))
169 return false;
170
171 for (size_t i = 0; i < int_size; i++) {
172 GetBufferPtrAtCursor()[i] = static_cast<uint8_t>(value & 0xff);
173 value >>= 8;
174 }
175
176 AdvanceCursor(int_size);
177 return true;
178 }
179
SetCursor(size_t cursor)180 void NtlmBufferWriter::SetCursor(size_t cursor) {
181 DCHECK(GetBufferPtr() && cursor <= GetLength());
182
183 cursor_ = cursor;
184 }
185
186 } // namespace net::ntlm
187