1use std::{any::Any, io, str, sync::Arc};
2
3#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
4use aws_lc_rs::aead;
5use bytes::BytesMut;
6#[cfg(feature = "ring")]
7use ring::aead;
8pub use rustls::Error;
9use rustls::{
10 self, CipherSuite,
11 client::danger::ServerCertVerifier,
12 pki_types::{CertificateDer, PrivateKeyDer, ServerName},
13 quic::{Connection, HeaderProtectionKey, KeyChange, PacketKey, Secrets, Suite, Version},
14};
15#[cfg(feature = "platform-verifier")]
16use rustls_platform_verifier::BuilderVerifierExt;
17
18use crate::{
19 ConnectError, ConnectionId, Side, TransportError, TransportErrorCode,
20 crypto::{
21 self, CryptoError, ExportKeyingMaterialError, HeaderKey, KeyPair, Keys, UnsupportedVersion,
22 },
23 transport_parameters::TransportParameters,
24};
25
26impl From<Side> for rustls::Side {
27 fn from(s: Side) -> Self {
28 match s {
29 Side::Client => Self::Client,
30 Side::Server => Self::Server,
31 }
32 }
33}
34
35pub struct TlsSession {
37 version: Version,
38 got_handshake_data: bool,
39 next_secrets: Option<Secrets>,
40 inner: Connection,
41 suite: Suite,
42}
43
44impl TlsSession {
45 fn side(&self) -> Side {
46 match self.inner {
47 Connection::Client(_) => Side::Client,
48 Connection::Server(_) => Side::Server,
49 }
50 }
51}
52
53impl crypto::Session for TlsSession {
54 fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys {
55 initial_keys(self.version, *dst_cid, side, &self.suite)
56 }
57
58 fn handshake_data(&self) -> Option<Box<dyn Any>> {
59 if !self.got_handshake_data {
60 return None;
61 }
62 Some(Box::new(HandshakeData {
63 protocol: self.inner.alpn_protocol().map(|x| x.into()),
64 server_name: match self.inner {
65 Connection::Client(_) => None,
66 Connection::Server(ref session) => session.server_name().map(|x| x.into()),
67 },
68 }))
69 }
70
71 fn peer_identity(&self) -> Option<Box<dyn Any>> {
73 self.inner.peer_certificates().map(|v| -> Box<dyn Any> {
74 Box::new(
75 v.iter()
76 .map(|v| v.clone().into_owned())
77 .collect::<Vec<CertificateDer<'static>>>(),
78 )
79 })
80 }
81
82 fn early_crypto(&self) -> Option<(Box<dyn HeaderKey>, Box<dyn crypto::PacketKey>)> {
83 let keys = self.inner.zero_rtt_keys()?;
84 Some((Box::new(keys.header), Box::new(keys.packet)))
85 }
86
87 fn early_data_accepted(&self) -> Option<bool> {
88 match self.inner {
89 Connection::Client(ref session) => Some(session.is_early_data_accepted()),
90 _ => None,
91 }
92 }
93
94 fn is_handshaking(&self) -> bool {
95 self.inner.is_handshaking()
96 }
97
98 fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, TransportError> {
99 self.inner.read_hs(buf).map_err(|e| {
100 if let Some(alert) = self.inner.alert() {
101 TransportError {
102 code: TransportErrorCode::crypto(alert.into()),
103 frame: None,
104 reason: e.to_string(),
105 }
106 } else {
107 TransportError::PROTOCOL_VIOLATION(format!("TLS error: {e}"))
108 }
109 })?;
110 if !self.got_handshake_data {
111 let have_server_name = match self.inner {
115 Connection::Client(_) => false,
116 Connection::Server(ref session) => session.server_name().is_some(),
117 };
118 if self.inner.alpn_protocol().is_some() || have_server_name || !self.is_handshaking() {
119 self.got_handshake_data = true;
120 return Ok(true);
121 }
122 }
123 Ok(false)
124 }
125
126 fn transport_parameters(&self) -> Result<Option<TransportParameters>, TransportError> {
127 match self.inner.quic_transport_parameters() {
128 None => Ok(None),
129 Some(buf) => match TransportParameters::read(self.side(), &mut io::Cursor::new(buf)) {
130 Ok(params) => Ok(Some(params)),
131 Err(e) => Err(e.into()),
132 },
133 }
134 }
135
136 fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<Keys> {
137 let keys = match self.inner.write_hs(buf)? {
138 KeyChange::Handshake { keys } => keys,
139 KeyChange::OneRtt { keys, next } => {
140 self.next_secrets = Some(next);
141 keys
142 }
143 };
144
145 Some(Keys {
146 header: KeyPair {
147 local: Box::new(keys.local.header),
148 remote: Box::new(keys.remote.header),
149 },
150 packet: KeyPair {
151 local: Box::new(keys.local.packet),
152 remote: Box::new(keys.remote.packet),
153 },
154 })
155 }
156
157 fn next_1rtt_keys(&mut self) -> Option<KeyPair<Box<dyn crypto::PacketKey>>> {
158 let secrets = self.next_secrets.as_mut()?;
159 let keys = secrets.next_packet_keys();
160 Some(KeyPair {
161 local: Box::new(keys.local),
162 remote: Box::new(keys.remote),
163 })
164 }
165
166 fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool {
167 let tag_start = match payload.len().checked_sub(16) {
168 Some(x) => x,
169 None => return false,
170 };
171
172 let mut pseudo_packet =
173 Vec::with_capacity(header.len() + payload.len() + orig_dst_cid.len() + 1);
174 pseudo_packet.push(orig_dst_cid.len() as u8);
175 pseudo_packet.extend_from_slice(orig_dst_cid);
176 pseudo_packet.extend_from_slice(header);
177 let tag_start = tag_start + pseudo_packet.len();
178 pseudo_packet.extend_from_slice(payload);
179
180 let (nonce, key) = match self.version {
181 Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
182 Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
183 _ => unreachable!(),
184 };
185
186 let nonce = aead::Nonce::assume_unique_for_key(nonce);
187 let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap());
188
189 let (aad, tag) = pseudo_packet.split_at_mut(tag_start);
190 key.open_in_place(nonce, aead::Aad::from(aad), tag).is_ok()
191 }
192
193 fn export_keying_material(
194 &self,
195 output: &mut [u8],
196 label: &[u8],
197 context: &[u8],
198 ) -> Result<(), ExportKeyingMaterialError> {
199 self.inner
200 .export_keying_material(output, label, Some(context))
201 .map_err(|_| ExportKeyingMaterialError)?;
202 Ok(())
203 }
204}
205
206const RETRY_INTEGRITY_KEY_DRAFT: [u8; 16] = [
207 0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1,
208];
209const RETRY_INTEGRITY_NONCE_DRAFT: [u8; 12] = [
210 0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c,
211];
212
213const RETRY_INTEGRITY_KEY_V1: [u8; 16] = [
214 0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e,
215];
216const RETRY_INTEGRITY_NONCE_V1: [u8; 12] = [
217 0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb,
218];
219
220impl crypto::HeaderKey for Box<dyn HeaderProtectionKey> {
221 fn decrypt(&self, pn_offset: usize, packet: &mut [u8]) {
222 let (header, sample) = packet.split_at_mut(pn_offset + 4);
223 let (first, rest) = header.split_at_mut(1);
224 let pn_end = Ord::min(pn_offset + 3, rest.len());
225 self.decrypt_in_place(
226 &sample[..self.sample_size()],
227 &mut first[0],
228 &mut rest[pn_offset - 1..pn_end],
229 )
230 .unwrap();
231 }
232
233 fn encrypt(&self, pn_offset: usize, packet: &mut [u8]) {
234 let (header, sample) = packet.split_at_mut(pn_offset + 4);
235 let (first, rest) = header.split_at_mut(1);
236 let pn_end = Ord::min(pn_offset + 3, rest.len());
237 self.encrypt_in_place(
238 &sample[..self.sample_size()],
239 &mut first[0],
240 &mut rest[pn_offset - 1..pn_end],
241 )
242 .unwrap();
243 }
244
245 fn sample_size(&self) -> usize {
246 self.sample_len()
247 }
248}
249
250pub struct HandshakeData {
252 pub protocol: Option<Vec<u8>>,
256 pub server_name: Option<String>,
260}
261
262pub struct QuicClientConfig {
281 pub(crate) inner: Arc<rustls::ClientConfig>,
282 initial: Suite,
283}
284
285impl QuicClientConfig {
286 #[cfg(feature = "platform-verifier")]
287 pub(crate) fn with_platform_verifier() -> Result<Self, Error> {
288 let mut inner = rustls::ClientConfig::builder_with_provider(configured_provider())
290 .with_protocol_versions(&[&rustls::version::TLS13])
291 .unwrap() .with_platform_verifier()?
293 .with_no_client_auth();
294
295 inner.enable_early_data = true;
296 Ok(Self {
297 initial: initial_suite_from_provider(inner.crypto_provider())
299 .expect("no initial cipher suite found"),
300 inner: Arc::new(inner),
301 })
302 }
303
304 pub(crate) fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
309 let inner = Self::inner(verifier);
310 Self {
311 initial: initial_suite_from_provider(inner.crypto_provider())
313 .expect("no initial cipher suite found"),
314 inner: Arc::new(inner),
315 }
316 }
317
318 pub fn with_initial(
322 inner: Arc<rustls::ClientConfig>,
323 initial: Suite,
324 ) -> Result<Self, NoInitialCipherSuite> {
325 match initial.suite.common.suite {
326 CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self { inner, initial }),
327 _ => Err(NoInitialCipherSuite { specific: true }),
328 }
329 }
330
331 pub(crate) fn inner(verifier: Arc<dyn ServerCertVerifier>) -> rustls::ClientConfig {
332 let mut config = rustls::ClientConfig::builder_with_provider(configured_provider())
334 .with_protocol_versions(&[&rustls::version::TLS13])
335 .unwrap() .dangerous()
337 .with_custom_certificate_verifier(verifier)
338 .with_no_client_auth();
339
340 config.enable_early_data = true;
341 config
342 }
343}
344
345impl crypto::ClientConfig for QuicClientConfig {
346 fn start_session(
347 self: Arc<Self>,
348 version: u32,
349 server_name: &str,
350 params: &TransportParameters,
351 ) -> Result<Box<dyn crypto::Session>, ConnectError> {
352 let version = interpret_version(version)?;
353 Ok(Box::new(TlsSession {
354 version,
355 got_handshake_data: false,
356 next_secrets: None,
357 inner: rustls::quic::Connection::Client(
358 rustls::quic::ClientConnection::new(
359 self.inner.clone(),
360 version,
361 ServerName::try_from(server_name)
362 .map_err(|_| ConnectError::InvalidServerName(server_name.into()))?
363 .to_owned(),
364 to_vec(params),
365 )
366 .unwrap(),
367 ),
368 suite: self.initial,
369 }))
370 }
371}
372
373impl TryFrom<rustls::ClientConfig> for QuicClientConfig {
374 type Error = NoInitialCipherSuite;
375
376 fn try_from(inner: rustls::ClientConfig) -> Result<Self, Self::Error> {
377 Arc::new(inner).try_into()
378 }
379}
380
381impl TryFrom<Arc<rustls::ClientConfig>> for QuicClientConfig {
382 type Error = NoInitialCipherSuite;
383
384 fn try_from(inner: Arc<rustls::ClientConfig>) -> Result<Self, Self::Error> {
385 Ok(Self {
386 initial: initial_suite_from_provider(inner.crypto_provider())
387 .ok_or(NoInitialCipherSuite { specific: false })?,
388 inner,
389 })
390 }
391}
392
393#[derive(Clone, Debug)]
401pub struct NoInitialCipherSuite {
402 specific: bool,
404}
405
406impl std::fmt::Display for NoInitialCipherSuite {
407 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
408 f.write_str(match self.specific {
409 true => "invalid cipher suite specified",
410 false => "no initial cipher suite found",
411 })
412 }
413}
414
415impl std::error::Error for NoInitialCipherSuite {}
416
417pub struct QuicServerConfig {
430 inner: Arc<rustls::ServerConfig>,
431 initial: Suite,
432}
433
434impl QuicServerConfig {
435 pub(crate) fn new(
436 cert_chain: Vec<CertificateDer<'static>>,
437 key: PrivateKeyDer<'static>,
438 ) -> Result<Self, rustls::Error> {
439 let inner = Self::inner(cert_chain, key)?;
440 Ok(Self {
441 initial: initial_suite_from_provider(inner.crypto_provider())
443 .expect("no initial cipher suite found"),
444 inner: Arc::new(inner),
445 })
446 }
447
448 pub fn with_initial(
452 inner: Arc<rustls::ServerConfig>,
453 initial: Suite,
454 ) -> Result<Self, NoInitialCipherSuite> {
455 match initial.suite.common.suite {
456 CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self { inner, initial }),
457 _ => Err(NoInitialCipherSuite { specific: true }),
458 }
459 }
460
461 pub(crate) fn inner(
467 cert_chain: Vec<CertificateDer<'static>>,
468 key: PrivateKeyDer<'static>,
469 ) -> Result<rustls::ServerConfig, rustls::Error> {
470 let mut inner = rustls::ServerConfig::builder_with_provider(configured_provider())
471 .with_protocol_versions(&[&rustls::version::TLS13])
472 .unwrap() .with_no_client_auth()
474 .with_single_cert(cert_chain, key)?;
475
476 inner.max_early_data_size = u32::MAX;
477 Ok(inner)
478 }
479}
480
481impl TryFrom<rustls::ServerConfig> for QuicServerConfig {
482 type Error = NoInitialCipherSuite;
483
484 fn try_from(inner: rustls::ServerConfig) -> Result<Self, Self::Error> {
485 Arc::new(inner).try_into()
486 }
487}
488
489impl TryFrom<Arc<rustls::ServerConfig>> for QuicServerConfig {
490 type Error = NoInitialCipherSuite;
491
492 fn try_from(inner: Arc<rustls::ServerConfig>) -> Result<Self, Self::Error> {
493 Ok(Self {
494 initial: initial_suite_from_provider(inner.crypto_provider())
495 .ok_or(NoInitialCipherSuite { specific: false })?,
496 inner,
497 })
498 }
499}
500
501impl crypto::ServerConfig for QuicServerConfig {
502 fn start_session(
503 self: Arc<Self>,
504 version: u32,
505 params: &TransportParameters,
506 ) -> Box<dyn crypto::Session> {
507 let version = interpret_version(version).unwrap();
509 Box::new(TlsSession {
510 version,
511 got_handshake_data: false,
512 next_secrets: None,
513 inner: rustls::quic::Connection::Server(
514 rustls::quic::ServerConnection::new(self.inner.clone(), version, to_vec(params))
515 .unwrap(),
516 ),
517 suite: self.initial,
518 })
519 }
520
521 fn initial_keys(
522 &self,
523 version: u32,
524 dst_cid: &ConnectionId,
525 ) -> Result<Keys, UnsupportedVersion> {
526 let version = interpret_version(version)?;
527 Ok(initial_keys(version, *dst_cid, Side::Server, &self.initial))
528 }
529
530 fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16] {
531 let version = interpret_version(version).unwrap();
533 let (nonce, key) = match version {
534 Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
535 Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
536 _ => unreachable!(),
537 };
538
539 let mut pseudo_packet = Vec::with_capacity(packet.len() + orig_dst_cid.len() + 1);
540 pseudo_packet.push(orig_dst_cid.len() as u8);
541 pseudo_packet.extend_from_slice(orig_dst_cid);
542 pseudo_packet.extend_from_slice(packet);
543
544 let nonce = aead::Nonce::assume_unique_for_key(nonce);
545 let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap());
546
547 let tag = key
548 .seal_in_place_separate_tag(nonce, aead::Aad::from(pseudo_packet), &mut [])
549 .unwrap();
550 let mut result = [0; 16];
551 result.copy_from_slice(tag.as_ref());
552 result
553 }
554}
555
556pub(crate) fn initial_suite_from_provider(
557 provider: &Arc<rustls::crypto::CryptoProvider>,
558) -> Option<Suite> {
559 provider
560 .cipher_suites
561 .iter()
562 .find_map(|cs| match (cs.suite(), cs.tls13()) {
563 (rustls::CipherSuite::TLS13_AES_128_GCM_SHA256, Some(suite)) => {
564 Some(suite.quic_suite())
565 }
566 _ => None,
567 })
568 .flatten()
569}
570
571pub(crate) fn configured_provider() -> Arc<rustls::crypto::CryptoProvider> {
572 #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
573 let provider = rustls::crypto::aws_lc_rs::default_provider();
574 #[cfg(feature = "rustls-ring")]
575 let provider = rustls::crypto::ring::default_provider();
576 Arc::new(provider)
577}
578
579fn to_vec(params: &TransportParameters) -> Vec<u8> {
580 let mut bytes = Vec::new();
581 params.write(&mut bytes);
582 bytes
583}
584
585pub(crate) fn initial_keys(
586 version: Version,
587 dst_cid: ConnectionId,
588 side: Side,
589 suite: &Suite,
590) -> Keys {
591 let keys = suite.keys(&dst_cid, side.into(), version);
592 Keys {
593 header: KeyPair {
594 local: Box::new(keys.local.header),
595 remote: Box::new(keys.remote.header),
596 },
597 packet: KeyPair {
598 local: Box::new(keys.local.packet),
599 remote: Box::new(keys.remote.packet),
600 },
601 }
602}
603
604impl crypto::PacketKey for Box<dyn PacketKey> {
605 fn encrypt(&self, packet: u64, buf: &mut [u8], header_len: usize) {
606 let (header, payload_tag) = buf.split_at_mut(header_len);
607 let (payload, tag_storage) = payload_tag.split_at_mut(payload_tag.len() - self.tag_len());
608 let tag = self.encrypt_in_place(packet, &*header, payload).unwrap();
609 tag_storage.copy_from_slice(tag.as_ref());
610 }
611
612 fn decrypt(
613 &self,
614 packet: u64,
615 header: &[u8],
616 payload: &mut BytesMut,
617 ) -> Result<(), CryptoError> {
618 let plain = self
619 .decrypt_in_place(packet, header, payload.as_mut())
620 .map_err(|_| CryptoError)?;
621 let plain_len = plain.len();
622 payload.truncate(plain_len);
623 Ok(())
624 }
625
626 fn tag_len(&self) -> usize {
627 (**self).tag_len()
628 }
629
630 fn confidentiality_limit(&self) -> u64 {
631 (**self).confidentiality_limit()
632 }
633
634 fn integrity_limit(&self) -> u64 {
635 (**self).integrity_limit()
636 }
637}
638
639fn interpret_version(version: u32) -> Result<Version, UnsupportedVersion> {
640 match version {
641 0xff00_001d..=0xff00_0020 => Ok(Version::V1Draft),
642 0x0000_0001 | 0xff00_0021..=0xff00_0022 => Ok(Version::V1),
643 _ => Err(UnsupportedVersion),
644 }
645}