1 // Copyright 2012 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "polo/pairing/pairingsession.h"
16
17 #include <glog/logging.h>
18 #include "polo/encoding/hexadecimalencoder.h"
19 #include "polo/util/poloutil.h"
20
21 namespace polo {
22 namespace pairing {
23
PairingSession(wire::PoloWireAdapter * wire,PairingContext * context,PoloChallengeResponse * challenge)24 PairingSession::PairingSession(wire::PoloWireAdapter* wire,
25 PairingContext* context,
26 PoloChallengeResponse* challenge)
27 : state_(kUninitialized),
28 wire_(wire),
29 context_(context),
30 challenge_(challenge),
31 configuration_(NULL),
32 encoder_(NULL),
33 nonce_(NULL),
34 secret_(NULL) {
35 wire_->set_listener(this);
36
37 local_options_.set_protocol_role_preference(context->is_server() ?
38 message::OptionsMessage::kDisplayDevice
39 : message::OptionsMessage::kInputDevice);
40 }
41
~PairingSession()42 PairingSession::~PairingSession() {
43 if (configuration_) {
44 delete configuration_;
45 }
46
47 if (encoder_) {
48 delete encoder_;
49 }
50
51 if (nonce_) {
52 delete nonce_;
53 }
54
55 if (secret_) {
56 delete secret_;
57 }
58 }
59
AddInputEncoding(const encoding::EncodingOption & encoding)60 void PairingSession::AddInputEncoding(
61 const encoding::EncodingOption& encoding) {
62 if (state_ != kUninitialized) {
63 LOG(ERROR) << "Attempt to add input encoding to active session";
64 return;
65 }
66
67 if (!IsValidEncodingOption(encoding)) {
68 LOG(ERROR) << "Invalid input encoding: " << encoding.ToString();
69 return;
70 }
71
72 local_options_.AddInputEncoding(encoding);
73 }
74
AddOutputEncoding(const encoding::EncodingOption & encoding)75 void PairingSession::AddOutputEncoding(
76 const encoding::EncodingOption& encoding) {
77 if (state_ != kUninitialized) {
78 LOG(ERROR) << "Attempt to add output encoding to active session";
79 return;
80 }
81
82 if (!IsValidEncodingOption(encoding)) {
83 LOG(ERROR) << "Invalid output encoding: " << encoding.ToString();
84 return;
85 }
86
87 local_options_.AddOutputEncoding(encoding);
88 }
89
SetSecret(const Gamma & secret)90 bool PairingSession::SetSecret(const Gamma& secret) {
91 secret_ = new Gamma(secret);
92
93 if (!IsInputDevice() || state_ != kWaitingForSecret) {
94 LOG(ERROR) << "Invalid state: unexpected secret";
95 return false;
96 }
97
98 if (!challenge().CheckGamma(secret)) {
99 LOG(ERROR) << "Secret failed local check";
100 return false;
101 }
102
103 nonce_ = challenge().ExtractNonce(secret);
104 if (!nonce_) {
105 LOG(ERROR) << "Failed to extract nonce";
106 return false;
107 }
108
109 const Alpha* gen_alpha = challenge().GetAlpha(*nonce_);
110 if (!gen_alpha) {
111 LOG(ERROR) << "Failed to get alpha";
112 return false;
113 }
114
115 message::SecretMessage secret_message(*gen_alpha);
116 delete gen_alpha;
117
118 wire_->SendSecretMessage(secret_message);
119
120 LOG(INFO) << "Waiting for SecretAck...";
121 wire_->GetNextMessage();
122
123 return true;
124 }
125
DoPair(PairingListener * listener)126 void PairingSession::DoPair(PairingListener *listener) {
127 listener_ = listener;
128 listener_->OnSessionCreated();
129
130 if (context_->is_server()) {
131 LOG(INFO) << "Pairing started (SERVER mode)";
132 } else {
133 LOG(INFO) << "Pairing started (CLIENT mode)";
134 }
135 LOG(INFO) << "Local options: " << local_options_.ToString();
136
137 set_state(kInitializing);
138 DoInitializationPhase();
139 }
140
DoPairingPhase()141 void PairingSession::DoPairingPhase() {
142 if (IsInputDevice()) {
143 DoInputPairing();
144 } else {
145 DoOutputPairing();
146 }
147 }
148
DoInputPairing()149 void PairingSession::DoInputPairing() {
150 set_state(kWaitingForSecret);
151 listener_->OnPerformInputDeviceRole();
152 }
153
DoOutputPairing()154 void PairingSession::DoOutputPairing() {
155 size_t nonce_length = configuration_->encoding().symbol_length() / 2;
156 size_t bytes_needed = nonce_length / encoder_->symbols_per_byte();
157
158 uint8_t* random = util::PoloUtil::GenerateRandomBytes(bytes_needed);
159 nonce_ = new Nonce(random, random + bytes_needed);
160 delete[] random;
161
162 const Gamma* gamma = challenge().GetGamma(*nonce_);
163 if (!gamma) {
164 LOG(ERROR) << "Failed to get gamma";
165 wire()->SendErrorMessage(kErrorProtocol);
166 listener()->OnError(kErrorProtocol);
167 return;
168 }
169
170 listener_->OnPerformOutputDeviceRole(*gamma);
171 delete gamma;
172
173 set_state(kWaitingForSecret);
174
175 LOG(INFO) << "Waiting for Secret...";
176 wire_->GetNextMessage();
177 }
178
set_state(ProtocolState state)179 void PairingSession::set_state(ProtocolState state) {
180 LOG(INFO) << "New state: " << state;
181 state_ = state;
182 }
183
SetConfiguration(const message::ConfigurationMessage & message)184 bool PairingSession::SetConfiguration(
185 const message::ConfigurationMessage& message) {
186 const encoding::EncodingOption& encoding = message.encoding();
187
188 if (!IsValidEncodingOption(encoding)) {
189 LOG(ERROR) << "Invalid configuration: " << encoding.ToString();
190 return false;
191 }
192
193 if (encoder_) {
194 delete encoder_;
195 encoder_ = NULL;
196 }
197
198 switch (encoding.encoding_type()) {
199 case encoding::EncodingOption::kHexadecimal:
200 encoder_ = new encoding::HexadecimalEncoder();
201 break;
202 default:
203 LOG(ERROR) << "Unsupported encoding type: "
204 << encoding.encoding_type();
205 return false;
206 }
207
208 if (configuration_) {
209 delete configuration_;
210 }
211 configuration_ = new message::ConfigurationMessage(message.encoding(),
212 message.client_role());
213 return true;
214 }
215
OnSecretMessage(const message::SecretMessage & message)216 void PairingSession::OnSecretMessage(const message::SecretMessage& message) {
217 if (state() != kWaitingForSecret) {
218 LOG(ERROR) << "Invalid state: unexpected secret message";
219 wire()->SendErrorMessage(kErrorProtocol);
220 listener()->OnError(kErrorProtocol);
221 return;
222 }
223
224 if (!VerifySecret(message.secret())) {
225 wire()->SendErrorMessage(kErrorInvalidChallengeResponse);
226 listener_->OnError(kErrorInvalidChallengeResponse);
227 return;
228 }
229
230 const Alpha* alpha = challenge().GetAlpha(*nonce_);
231 if (!alpha) {
232 LOG(ERROR) << "Failed to get alpha";
233 wire()->SendErrorMessage(kErrorProtocol);
234 listener()->OnError(kErrorProtocol);
235 return;
236 }
237
238 message::SecretAckMessage ack(*alpha);
239 delete alpha;
240
241 wire_->SendSecretAckMessage(ack);
242
243 listener_->OnPairingSuccess();
244 }
245
OnSecretAckMessage(const message::SecretAckMessage & message)246 void PairingSession::OnSecretAckMessage(
247 const message::SecretAckMessage& message) {
248 if (kVerifySecretAck && !VerifySecret(message.secret())) {
249 wire()->SendErrorMessage(kErrorInvalidChallengeResponse);
250 listener_->OnError(kErrorInvalidChallengeResponse);
251 return;
252 }
253
254 listener_->OnPairingSuccess();
255 }
256
OnError(pairing::PoloError error)257 void PairingSession::OnError(pairing::PoloError error) {
258 listener_->OnError(error);
259 }
260
VerifySecret(const Alpha & secret) const261 bool PairingSession::VerifySecret(const Alpha& secret) const {
262 if (!nonce_) {
263 LOG(ERROR) << "Nonce not set";
264 return false;
265 }
266
267 const Alpha* gen_alpha = challenge().GetAlpha(*nonce_);
268 if (!gen_alpha) {
269 LOG(ERROR) << "Failed to get alpha";
270 return false;
271 }
272
273 bool valid = (secret == *gen_alpha);
274
275 if (!valid) {
276 LOG(ERROR) << "Inband secret did not match. Expected ["
277 << util::PoloUtil::BytesToHexString(&(*gen_alpha)[0], gen_alpha->size())
278 << "], got ["
279 << util::PoloUtil::BytesToHexString(&secret[0], secret.size())
280 << "]";
281 }
282
283 delete gen_alpha;
284 return valid;
285 }
286
GetLocalRole() const287 message::OptionsMessage::ProtocolRole PairingSession::GetLocalRole() const {
288 if (!configuration_) {
289 return message::OptionsMessage::kUnknown;
290 }
291
292 if (context_->is_client()) {
293 return configuration_->client_role();
294 } else {
295 return configuration_->client_role() ==
296 message::OptionsMessage::kDisplayDevice ?
297 message::OptionsMessage::kInputDevice
298 : message::OptionsMessage::kDisplayDevice;
299 }
300 }
301
IsInputDevice() const302 bool PairingSession::IsInputDevice() const {
303 return GetLocalRole() == message::OptionsMessage::kInputDevice;
304 }
305
IsValidEncodingOption(const encoding::EncodingOption & option) const306 bool PairingSession::IsValidEncodingOption(
307 const encoding::EncodingOption& option) const {
308 // Legal values of GAMMALEN must be an even number of at least 2 bytes.
309 return option.encoding_type() != encoding::EncodingOption::kUnknown
310 && (option.symbol_length() % 2 == 0)
311 && (option.symbol_length() >= 2);
312 }
313
314 } // namespace pairing
315 } // namespace polo
316