• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include "src/core/tsi/alts/frame_protector/frame_handler.h"
20 
21 #include <grpc/support/alloc.h>
22 #include <grpc/support/port_platform.h>
23 #include <limits.h>
24 #include <stdint.h>
25 #include <string.h>
26 
27 #include <algorithm>
28 
29 #include "absl/log/log.h"
30 #include "src/core/util/crash.h"
31 #include "src/core/util/memory.h"
32 
33 // Use little endian to interpret a string of bytes as uint32_t.
load_32_le(const unsigned char * buffer)34 static uint32_t load_32_le(const unsigned char* buffer) {
35   return (static_cast<uint32_t>(buffer[3]) << 24) |
36          (static_cast<uint32_t>(buffer[2]) << 16) |
37          (static_cast<uint32_t>(buffer[1]) << 8) |
38          static_cast<uint32_t>(buffer[0]);
39 }
40 
41 // Store uint32_t as a string of little endian bytes.
store_32_le(uint32_t value,unsigned char * buffer)42 static void store_32_le(uint32_t value, unsigned char* buffer) {
43   buffer[3] = static_cast<unsigned char>(value >> 24) & 0xFF;
44   buffer[2] = static_cast<unsigned char>(value >> 16) & 0xFF;
45   buffer[1] = static_cast<unsigned char>(value >> 8) & 0xFF;
46   buffer[0] = static_cast<unsigned char>(value) & 0xFF;
47 }
48 
49 // Frame writer implementation.
alts_create_frame_writer()50 alts_frame_writer* alts_create_frame_writer() {
51   return grpc_core::Zalloc<alts_frame_writer>();
52 }
53 
alts_reset_frame_writer(alts_frame_writer * writer,const unsigned char * buffer,size_t length)54 bool alts_reset_frame_writer(alts_frame_writer* writer,
55                              const unsigned char* buffer, size_t length) {
56   if (buffer == nullptr) return false;
57   size_t max_input_size = SIZE_MAX - kFrameLengthFieldSize;
58   if (length > max_input_size) {
59     LOG(ERROR) << "length must be at most " << max_input_size;
60     return false;
61   }
62   writer->input_buffer = buffer;
63   writer->input_size = length;
64   writer->input_bytes_written = 0;
65   writer->header_bytes_written = 0;
66   store_32_le(
67       static_cast<uint32_t>(writer->input_size + kFrameMessageTypeFieldSize),
68       writer->header_buffer);
69   store_32_le(kFrameMessageType, writer->header_buffer + kFrameLengthFieldSize);
70   return true;
71 }
72 
alts_write_frame_bytes(alts_frame_writer * writer,unsigned char * output,size_t * bytes_size)73 bool alts_write_frame_bytes(alts_frame_writer* writer, unsigned char* output,
74                             size_t* bytes_size) {
75   if (bytes_size == nullptr || output == nullptr) return false;
76   if (alts_is_frame_writer_done(writer)) {
77     *bytes_size = 0;
78     return true;
79   }
80   size_t bytes_written = 0;
81   // Write some header bytes, if needed.
82   if (writer->header_bytes_written != sizeof(writer->header_buffer)) {
83     size_t bytes_to_write =
84         std::min(*bytes_size,
85                  sizeof(writer->header_buffer) - writer->header_bytes_written);
86     memcpy(output, writer->header_buffer + writer->header_bytes_written,
87            bytes_to_write);
88     bytes_written += bytes_to_write;
89     *bytes_size -= bytes_to_write;
90     writer->header_bytes_written += bytes_to_write;
91     output += bytes_to_write;
92     if (writer->header_bytes_written != sizeof(writer->header_buffer)) {
93       *bytes_size = bytes_written;
94       return true;
95     }
96   }
97   // Write some non-header bytes.
98   size_t bytes_to_write =
99       std::min(writer->input_size - writer->input_bytes_written, *bytes_size);
100   memcpy(output, writer->input_buffer, bytes_to_write);
101   writer->input_buffer += bytes_to_write;
102   bytes_written += bytes_to_write;
103   writer->input_bytes_written += bytes_to_write;
104   *bytes_size = bytes_written;
105   return true;
106 }
107 
alts_is_frame_writer_done(alts_frame_writer * writer)108 bool alts_is_frame_writer_done(alts_frame_writer* writer) {
109   return writer->input_buffer == nullptr ||
110          writer->input_size == writer->input_bytes_written;
111 }
112 
alts_get_num_writer_bytes_remaining(alts_frame_writer * writer)113 size_t alts_get_num_writer_bytes_remaining(alts_frame_writer* writer) {
114   return (sizeof(writer->header_buffer) - writer->header_bytes_written) +
115          (writer->input_size - writer->input_bytes_written);
116 }
117 
alts_destroy_frame_writer(alts_frame_writer * writer)118 void alts_destroy_frame_writer(alts_frame_writer* writer) { gpr_free(writer); }
119 
120 // Frame reader implementation.
alts_create_frame_reader()121 alts_frame_reader* alts_create_frame_reader() {
122   alts_frame_reader* reader = grpc_core::Zalloc<alts_frame_reader>();
123   return reader;
124 }
125 
alts_is_frame_reader_done(alts_frame_reader * reader)126 bool alts_is_frame_reader_done(alts_frame_reader* reader) {
127   return reader->output_buffer == nullptr ||
128          (reader->header_bytes_read == sizeof(reader->header_buffer) &&
129           reader->bytes_remaining == 0);
130 }
131 
alts_has_read_frame_length(alts_frame_reader * reader)132 bool alts_has_read_frame_length(alts_frame_reader* reader) {
133   return sizeof(reader->header_buffer) == reader->header_bytes_read;
134 }
135 
alts_get_reader_bytes_remaining(alts_frame_reader * reader)136 size_t alts_get_reader_bytes_remaining(alts_frame_reader* reader) {
137   return alts_has_read_frame_length(reader) ? reader->bytes_remaining : 0;
138 }
139 
alts_reset_reader_output_buffer(alts_frame_reader * reader,unsigned char * buffer)140 void alts_reset_reader_output_buffer(alts_frame_reader* reader,
141                                      unsigned char* buffer) {
142   reader->output_buffer = buffer;
143 }
144 
alts_reset_frame_reader(alts_frame_reader * reader,unsigned char * buffer)145 bool alts_reset_frame_reader(alts_frame_reader* reader, unsigned char* buffer) {
146   if (buffer == nullptr) return false;
147   reader->output_buffer = buffer;
148   reader->bytes_remaining = 0;
149   reader->header_bytes_read = 0;
150   reader->output_bytes_read = 0;
151   return true;
152 }
153 
alts_read_frame_bytes(alts_frame_reader * reader,const unsigned char * bytes,size_t * bytes_size)154 bool alts_read_frame_bytes(alts_frame_reader* reader,
155                            const unsigned char* bytes, size_t* bytes_size) {
156   if (bytes_size == nullptr) return false;
157   if (bytes == nullptr) {
158     *bytes_size = 0;
159     return false;
160   }
161   if (alts_is_frame_reader_done(reader)) {
162     *bytes_size = 0;
163     return true;
164   }
165   size_t bytes_processed = 0;
166   // Process the header, if needed.
167   if (reader->header_bytes_read != sizeof(reader->header_buffer)) {
168     size_t bytes_to_write = std::min(
169         *bytes_size, sizeof(reader->header_buffer) - reader->header_bytes_read);
170     memcpy(reader->header_buffer + reader->header_bytes_read, bytes,
171            bytes_to_write);
172     reader->header_bytes_read += bytes_to_write;
173     bytes_processed += bytes_to_write;
174     bytes += bytes_to_write;
175     *bytes_size -= bytes_to_write;
176     if (reader->header_bytes_read != sizeof(reader->header_buffer)) {
177       *bytes_size = bytes_processed;
178       return true;
179     }
180     size_t frame_length = load_32_le(reader->header_buffer);
181     if (frame_length < kFrameMessageTypeFieldSize ||
182         frame_length > kFrameMaxSize) {
183       LOG(ERROR) << "Bad frame length (should be at least "
184                  << kFrameMessageTypeFieldSize << ", and at most "
185                  << kFrameMaxSize << ")";
186       *bytes_size = 0;
187       return false;
188     }
189     size_t message_type =
190         load_32_le(reader->header_buffer + kFrameLengthFieldSize);
191     if (message_type != kFrameMessageType) {
192       LOG(ERROR) << "Unsupported message type " << message_type
193                  << " (should be " << kFrameMessageType << ")";
194       *bytes_size = 0;
195       return false;
196     }
197     reader->bytes_remaining = frame_length - kFrameMessageTypeFieldSize;
198   }
199   // Process the non-header bytes.
200   size_t bytes_to_write = std::min(*bytes_size, reader->bytes_remaining);
201   memcpy(reader->output_buffer, bytes, bytes_to_write);
202   reader->output_buffer += bytes_to_write;
203   bytes_processed += bytes_to_write;
204   reader->bytes_remaining -= bytes_to_write;
205   reader->output_bytes_read += bytes_to_write;
206   *bytes_size = bytes_processed;
207   return true;
208 }
209 
alts_get_output_bytes_read(alts_frame_reader * reader)210 size_t alts_get_output_bytes_read(alts_frame_reader* reader) {
211   return reader->output_bytes_read;
212 }
213 
alts_get_output_buffer(alts_frame_reader * reader)214 unsigned char* alts_get_output_buffer(alts_frame_reader* reader) {
215   return reader->output_buffer;
216 }
217 
alts_destroy_frame_reader(alts_frame_reader * reader)218 void alts_destroy_frame_reader(alts_frame_reader* reader) { gpr_free(reader); }
219