1use std::{any::Any, io, str, sync::Arc};
2
3#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
4use aws_lc_rs::aead;
5use bytes::BytesMut;
6#[cfg(feature = "ring")]
7use ring::aead;
8pub use rustls::Error;
9#[cfg(feature = "__rustls-post-quantum-test")]
10use rustls::NamedGroup;
11use rustls::{
12 self, CipherSuite,
13 client::danger::ServerCertVerifier,
14 pki_types::{CertificateDer, PrivateKeyDer, ServerName},
15 quic::{Connection, HeaderProtectionKey, KeyChange, PacketKey, Secrets, Suite, Version},
16};
17#[cfg(feature = "platform-verifier")]
18use rustls_platform_verifier::BuilderVerifierExt;
19
20use crate::{
21 ConnectError, ConnectionId, Side, TransportError, TransportErrorCode,
22 crypto::{
23 self, CryptoError, ExportKeyingMaterialError, HeaderKey, KeyPair, Keys, UnsupportedVersion,
24 },
25 transport_parameters::TransportParameters,
26};
27
28impl From<Side> for rustls::Side {
29 fn from(s: Side) -> Self {
30 match s {
31 Side::Client => Self::Client,
32 Side::Server => Self::Server,
33 }
34 }
35}
36
37pub struct TlsSession {
39 version: Version,
40 got_handshake_data: bool,
41 next_secrets: Option<Secrets>,
42 inner: Connection,
43 suite: Suite,
44}
45
46impl TlsSession {
47 fn side(&self) -> Side {
48 match self.inner {
49 Connection::Client(_) => Side::Client,
50 Connection::Server(_) => Side::Server,
51 }
52 }
53}
54
55impl crypto::Session for TlsSession {
56 fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys {
57 initial_keys(self.version, *dst_cid, side, &self.suite)
58 }
59
60 fn handshake_data(&self) -> Option<Box<dyn Any>> {
61 if !self.got_handshake_data {
62 return None;
63 }
64 Some(Box::new(HandshakeData {
65 protocol: self.inner.alpn_protocol().map(|x| x.into()),
66 server_name: match self.inner {
67 Connection::Client(_) => None,
68 Connection::Server(ref session) => session.server_name().map(|x| x.into()),
69 },
70 #[cfg(feature = "__rustls-post-quantum-test")]
71 negotiated_key_exchange_group: self
72 .inner
73 .negotiated_key_exchange_group()
74 .expect("key exchange group is negotiated")
75 .name(),
76 }))
77 }
78
79 fn peer_identity(&self) -> Option<Box<dyn Any>> {
81 self.inner.peer_certificates().map(|v| -> Box<dyn Any> {
82 Box::new(
83 v.iter()
84 .map(|v| v.clone().into_owned())
85 .collect::<Vec<CertificateDer<'static>>>(),
86 )
87 })
88 }
89
90 fn early_crypto(&self) -> Option<(Box<dyn HeaderKey>, Box<dyn crypto::PacketKey>)> {
91 let keys = self.inner.zero_rtt_keys()?;
92 Some((Box::new(keys.header), Box::new(keys.packet)))
93 }
94
95 fn early_data_accepted(&self) -> Option<bool> {
96 match self.inner {
97 Connection::Client(ref session) => Some(session.is_early_data_accepted()),
98 _ => None,
99 }
100 }
101
102 fn is_handshaking(&self) -> bool {
103 self.inner.is_handshaking()
104 }
105
106 fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, TransportError> {
107 self.inner.read_hs(buf).map_err(|e| {
108 if let Some(alert) = self.inner.alert() {
109 TransportError {
110 code: TransportErrorCode::crypto(alert.into()),
111 frame: None,
112 reason: e.to_string(),
113 }
114 } else {
115 TransportError::PROTOCOL_VIOLATION(format!("TLS error: {e}"))
116 }
117 })?;
118 if !self.got_handshake_data {
119 let have_server_name = match self.inner {
123 Connection::Client(_) => false,
124 Connection::Server(ref session) => session.server_name().is_some(),
125 };
126 if self.inner.alpn_protocol().is_some() || have_server_name || !self.is_handshaking() {
127 self.got_handshake_data = true;
128 return Ok(true);
129 }
130 }
131 Ok(false)
132 }
133
134 fn transport_parameters(&self) -> Result<Option<TransportParameters>, TransportError> {
135 match self.inner.quic_transport_parameters() {
136 None => Ok(None),
137 Some(buf) => match TransportParameters::read(self.side(), &mut io::Cursor::new(buf)) {
138 Ok(params) => Ok(Some(params)),
139 Err(e) => Err(e.into()),
140 },
141 }
142 }
143
144 fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<Keys> {
145 let keys = match self.inner.write_hs(buf)? {
146 KeyChange::Handshake { keys } => keys,
147 KeyChange::OneRtt { keys, next } => {
148 self.next_secrets = Some(next);
149 keys
150 }
151 };
152
153 Some(Keys {
154 header: KeyPair {
155 local: Box::new(keys.local.header),
156 remote: Box::new(keys.remote.header),
157 },
158 packet: KeyPair {
159 local: Box::new(keys.local.packet),
160 remote: Box::new(keys.remote.packet),
161 },
162 })
163 }
164
165 fn next_1rtt_keys(&mut self) -> Option<KeyPair<Box<dyn crypto::PacketKey>>> {
166 let secrets = self.next_secrets.as_mut()?;
167 let keys = secrets.next_packet_keys();
168 Some(KeyPair {
169 local: Box::new(keys.local),
170 remote: Box::new(keys.remote),
171 })
172 }
173
174 fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool {
175 let tag_start = match payload.len().checked_sub(16) {
176 Some(x) => x,
177 None => return false,
178 };
179
180 let mut pseudo_packet =
181 Vec::with_capacity(header.len() + payload.len() + orig_dst_cid.len() + 1);
182 pseudo_packet.push(orig_dst_cid.len() as u8);
183 pseudo_packet.extend_from_slice(orig_dst_cid);
184 pseudo_packet.extend_from_slice(header);
185 let tag_start = tag_start + pseudo_packet.len();
186 pseudo_packet.extend_from_slice(payload);
187
188 let (nonce, key) = match self.version {
189 Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
190 Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
191 _ => unreachable!(),
192 };
193
194 let nonce = aead::Nonce::assume_unique_for_key(nonce);
195 let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap());
196
197 let (aad, tag) = pseudo_packet.split_at_mut(tag_start);
198 key.open_in_place(nonce, aead::Aad::from(aad), tag).is_ok()
199 }
200
201 fn export_keying_material(
202 &self,
203 output: &mut [u8],
204 label: &[u8],
205 context: &[u8],
206 ) -> Result<(), ExportKeyingMaterialError> {
207 self.inner
208 .export_keying_material(output, label, Some(context))
209 .map_err(|_| ExportKeyingMaterialError)?;
210 Ok(())
211 }
212}
213
214const RETRY_INTEGRITY_KEY_DRAFT: [u8; 16] = [
215 0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1,
216];
217const RETRY_INTEGRITY_NONCE_DRAFT: [u8; 12] = [
218 0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c,
219];
220
221const RETRY_INTEGRITY_KEY_V1: [u8; 16] = [
222 0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e,
223];
224const RETRY_INTEGRITY_NONCE_V1: [u8; 12] = [
225 0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb,
226];
227
228impl crypto::HeaderKey for Box<dyn HeaderProtectionKey> {
229 fn decrypt(&self, pn_offset: usize, packet: &mut [u8]) {
230 let (header, sample) = packet.split_at_mut(pn_offset + 4);
231 let (first, rest) = header.split_at_mut(1);
232 let pn_end = Ord::min(pn_offset + 3, rest.len());
233 self.decrypt_in_place(
234 &sample[..self.sample_size()],
235 &mut first[0],
236 &mut rest[pn_offset - 1..pn_end],
237 )
238 .unwrap();
239 }
240
241 fn encrypt(&self, pn_offset: usize, packet: &mut [u8]) {
242 let (header, sample) = packet.split_at_mut(pn_offset + 4);
243 let (first, rest) = header.split_at_mut(1);
244 let pn_end = Ord::min(pn_offset + 3, rest.len());
245 self.encrypt_in_place(
246 &sample[..self.sample_size()],
247 &mut first[0],
248 &mut rest[pn_offset - 1..pn_end],
249 )
250 .unwrap();
251 }
252
253 fn sample_size(&self) -> usize {
254 self.sample_len()
255 }
256}
257
258pub struct HandshakeData {
260 pub protocol: Option<Vec<u8>>,
264 pub server_name: Option<String>,
268 #[cfg(feature = "__rustls-post-quantum-test")]
270 pub negotiated_key_exchange_group: NamedGroup,
271}
272
273pub struct QuicClientConfig {
292 pub(crate) inner: Arc<rustls::ClientConfig>,
293 initial: Suite,
294}
295
296impl QuicClientConfig {
297 #[cfg(feature = "platform-verifier")]
298 pub(crate) fn with_platform_verifier() -> Result<Self, Error> {
299 let mut inner = rustls::ClientConfig::builder_with_provider(configured_provider())
301 .with_protocol_versions(&[&rustls::version::TLS13])
302 .unwrap() .with_platform_verifier()?
304 .with_no_client_auth();
305
306 inner.enable_early_data = true;
307 Ok(Self {
308 initial: initial_suite_from_provider(inner.crypto_provider())
310 .expect("no initial cipher suite found"),
311 inner: Arc::new(inner),
312 })
313 }
314
315 pub(crate) fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
320 let inner = Self::inner(verifier);
321 Self {
322 initial: initial_suite_from_provider(inner.crypto_provider())
324 .expect("no initial cipher suite found"),
325 inner: Arc::new(inner),
326 }
327 }
328
329 pub fn with_initial(
333 inner: Arc<rustls::ClientConfig>,
334 initial: Suite,
335 ) -> Result<Self, NoInitialCipherSuite> {
336 match initial.suite.common.suite {
337 CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self { inner, initial }),
338 _ => Err(NoInitialCipherSuite { specific: true }),
339 }
340 }
341
342 pub(crate) fn inner(verifier: Arc<dyn ServerCertVerifier>) -> rustls::ClientConfig {
343 let mut config = rustls::ClientConfig::builder_with_provider(configured_provider())
345 .with_protocol_versions(&[&rustls::version::TLS13])
346 .unwrap() .dangerous()
348 .with_custom_certificate_verifier(verifier)
349 .with_no_client_auth();
350
351 config.enable_early_data = true;
352 config
353 }
354}
355
356impl crypto::ClientConfig for QuicClientConfig {
357 fn start_session(
358 self: Arc<Self>,
359 version: u32,
360 server_name: &str,
361 params: &TransportParameters,
362 ) -> Result<Box<dyn crypto::Session>, ConnectError> {
363 let version = interpret_version(version)?;
364 Ok(Box::new(TlsSession {
365 version,
366 got_handshake_data: false,
367 next_secrets: None,
368 inner: rustls::quic::Connection::Client(
369 rustls::quic::ClientConnection::new(
370 self.inner.clone(),
371 version,
372 ServerName::try_from(server_name)
373 .map_err(|_| ConnectError::InvalidServerName(server_name.into()))?
374 .to_owned(),
375 to_vec(params),
376 )
377 .unwrap(),
378 ),
379 suite: self.initial,
380 }))
381 }
382}
383
384impl TryFrom<rustls::ClientConfig> for QuicClientConfig {
385 type Error = NoInitialCipherSuite;
386
387 fn try_from(inner: rustls::ClientConfig) -> Result<Self, Self::Error> {
388 Arc::new(inner).try_into()
389 }
390}
391
392impl TryFrom<Arc<rustls::ClientConfig>> for QuicClientConfig {
393 type Error = NoInitialCipherSuite;
394
395 fn try_from(inner: Arc<rustls::ClientConfig>) -> Result<Self, Self::Error> {
396 Ok(Self {
397 initial: initial_suite_from_provider(inner.crypto_provider())
398 .ok_or(NoInitialCipherSuite { specific: false })?,
399 inner,
400 })
401 }
402}
403
404#[derive(Clone, Debug)]
412pub struct NoInitialCipherSuite {
413 specific: bool,
415}
416
417impl std::fmt::Display for NoInitialCipherSuite {
418 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
419 f.write_str(match self.specific {
420 true => "invalid cipher suite specified",
421 false => "no initial cipher suite found",
422 })
423 }
424}
425
426impl std::error::Error for NoInitialCipherSuite {}
427
428pub struct QuicServerConfig {
441 inner: Arc<rustls::ServerConfig>,
442 initial: Suite,
443}
444
445impl QuicServerConfig {
446 pub(crate) fn new(
447 cert_chain: Vec<CertificateDer<'static>>,
448 key: PrivateKeyDer<'static>,
449 ) -> Result<Self, rustls::Error> {
450 let inner = Self::inner(cert_chain, key)?;
451 Ok(Self {
452 initial: initial_suite_from_provider(inner.crypto_provider())
454 .expect("no initial cipher suite found"),
455 inner: Arc::new(inner),
456 })
457 }
458
459 pub fn with_initial(
463 inner: Arc<rustls::ServerConfig>,
464 initial: Suite,
465 ) -> Result<Self, NoInitialCipherSuite> {
466 match initial.suite.common.suite {
467 CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self { inner, initial }),
468 _ => Err(NoInitialCipherSuite { specific: true }),
469 }
470 }
471
472 pub(crate) fn inner(
478 cert_chain: Vec<CertificateDer<'static>>,
479 key: PrivateKeyDer<'static>,
480 ) -> Result<rustls::ServerConfig, rustls::Error> {
481 let mut inner = rustls::ServerConfig::builder_with_provider(configured_provider())
482 .with_protocol_versions(&[&rustls::version::TLS13])
483 .unwrap() .with_no_client_auth()
485 .with_single_cert(cert_chain, key)?;
486
487 inner.max_early_data_size = u32::MAX;
488 Ok(inner)
489 }
490}
491
492impl TryFrom<rustls::ServerConfig> for QuicServerConfig {
493 type Error = NoInitialCipherSuite;
494
495 fn try_from(inner: rustls::ServerConfig) -> Result<Self, Self::Error> {
496 Arc::new(inner).try_into()
497 }
498}
499
500impl TryFrom<Arc<rustls::ServerConfig>> for QuicServerConfig {
501 type Error = NoInitialCipherSuite;
502
503 fn try_from(inner: Arc<rustls::ServerConfig>) -> Result<Self, Self::Error> {
504 Ok(Self {
505 initial: initial_suite_from_provider(inner.crypto_provider())
506 .ok_or(NoInitialCipherSuite { specific: false })?,
507 inner,
508 })
509 }
510}
511
512impl crypto::ServerConfig for QuicServerConfig {
513 fn start_session(
514 self: Arc<Self>,
515 version: u32,
516 params: &TransportParameters,
517 ) -> Box<dyn crypto::Session> {
518 let version = interpret_version(version).unwrap();
520 Box::new(TlsSession {
521 version,
522 got_handshake_data: false,
523 next_secrets: None,
524 inner: rustls::quic::Connection::Server(
525 rustls::quic::ServerConnection::new(self.inner.clone(), version, to_vec(params))
526 .unwrap(),
527 ),
528 suite: self.initial,
529 })
530 }
531
532 fn initial_keys(
533 &self,
534 version: u32,
535 dst_cid: &ConnectionId,
536 ) -> Result<Keys, UnsupportedVersion> {
537 let version = interpret_version(version)?;
538 Ok(initial_keys(version, *dst_cid, Side::Server, &self.initial))
539 }
540
541 fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16] {
542 let version = interpret_version(version).unwrap();
544 let (nonce, key) = match version {
545 Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
546 Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
547 _ => unreachable!(),
548 };
549
550 let mut pseudo_packet = Vec::with_capacity(packet.len() + orig_dst_cid.len() + 1);
551 pseudo_packet.push(orig_dst_cid.len() as u8);
552 pseudo_packet.extend_from_slice(orig_dst_cid);
553 pseudo_packet.extend_from_slice(packet);
554
555 let nonce = aead::Nonce::assume_unique_for_key(nonce);
556 let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap());
557
558 let tag = key
559 .seal_in_place_separate_tag(nonce, aead::Aad::from(pseudo_packet), &mut [])
560 .unwrap();
561 let mut result = [0; 16];
562 result.copy_from_slice(tag.as_ref());
563 result
564 }
565}
566
567pub(crate) fn initial_suite_from_provider(
568 provider: &Arc<rustls::crypto::CryptoProvider>,
569) -> Option<Suite> {
570 provider
571 .cipher_suites
572 .iter()
573 .find_map(|cs| match (cs.suite(), cs.tls13()) {
574 (rustls::CipherSuite::TLS13_AES_128_GCM_SHA256, Some(suite)) => {
575 Some(suite.quic_suite())
576 }
577 _ => None,
578 })
579 .flatten()
580}
581
582pub(crate) fn configured_provider() -> Arc<rustls::crypto::CryptoProvider> {
583 #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
584 let provider = rustls::crypto::aws_lc_rs::default_provider();
585 #[cfg(feature = "rustls-ring")]
586 let provider = rustls::crypto::ring::default_provider();
587 Arc::new(provider)
588}
589
590fn to_vec(params: &TransportParameters) -> Vec<u8> {
591 let mut bytes = Vec::new();
592 params.write(&mut bytes);
593 bytes
594}
595
596pub(crate) fn initial_keys(
597 version: Version,
598 dst_cid: ConnectionId,
599 side: Side,
600 suite: &Suite,
601) -> Keys {
602 let keys = suite.keys(&dst_cid, side.into(), version);
603 Keys {
604 header: KeyPair {
605 local: Box::new(keys.local.header),
606 remote: Box::new(keys.remote.header),
607 },
608 packet: KeyPair {
609 local: Box::new(keys.local.packet),
610 remote: Box::new(keys.remote.packet),
611 },
612 }
613}
614
615impl crypto::PacketKey for Box<dyn PacketKey> {
616 fn encrypt(&self, packet: u64, buf: &mut [u8], header_len: usize) {
617 let (header, payload_tag) = buf.split_at_mut(header_len);
618 let (payload, tag_storage) = payload_tag.split_at_mut(payload_tag.len() - self.tag_len());
619 let tag = self.encrypt_in_place(packet, &*header, payload).unwrap();
620 tag_storage.copy_from_slice(tag.as_ref());
621 }
622
623 fn decrypt(
624 &self,
625 packet: u64,
626 header: &[u8],
627 payload: &mut BytesMut,
628 ) -> Result<(), CryptoError> {
629 let plain = self
630 .decrypt_in_place(packet, header, payload.as_mut())
631 .map_err(|_| CryptoError)?;
632 let plain_len = plain.len();
633 payload.truncate(plain_len);
634 Ok(())
635 }
636
637 fn tag_len(&self) -> usize {
638 (**self).tag_len()
639 }
640
641 fn confidentiality_limit(&self) -> u64 {
642 (**self).confidentiality_limit()
643 }
644
645 fn integrity_limit(&self) -> u64 {
646 (**self).integrity_limit()
647 }
648}
649
650fn interpret_version(version: u32) -> Result<Version, UnsupportedVersion> {
651 match version {
652 0xff00_001d..=0xff00_0020 => Ok(Version::V1Draft),
653 0x0000_0001 | 0xff00_0021..=0xff00_0022 => Ok(Version::V1),
654 _ => Err(UnsupportedVersion),
655 }
656}