1use alloc::boxed::Box;
2use alloc::string::ToString;
3use core::fmt;
4
5use zeroize::Zeroize;
6
7use crate::enums::{ContentType, ProtocolVersion};
8use crate::error::Error;
9use crate::msgs::codec;
10pub use crate::msgs::message::{
11 BorrowedPayload, InboundOpaqueMessage, InboundPlainMessage, OutboundChunks,
12 OutboundOpaqueMessage, OutboundPlainMessage, PlainMessage, PrefixedPayload,
13};
14use crate::suites::ConnectionTrafficSecrets;
15
16pub trait Tls13AeadAlgorithm: Send + Sync {
18 fn encrypter(&self, key: AeadKey, iv: Iv) -> Box<dyn MessageEncrypter>;
20
21 fn decrypter(&self, key: AeadKey, iv: Iv) -> Box<dyn MessageDecrypter>;
23
24 fn key_len(&self) -> usize;
26
27 fn extract_keys(
32 &self,
33 key: AeadKey,
34 iv: Iv,
35 ) -> Result<ConnectionTrafficSecrets, UnsupportedOperationError>;
36
37 fn fips(&self) -> bool {
39 false
40 }
41}
42
43pub trait Tls12AeadAlgorithm: Send + Sync + 'static {
45 fn encrypter(&self, key: AeadKey, iv: &[u8], extra: &[u8]) -> Box<dyn MessageEncrypter>;
54
55 fn decrypter(&self, key: AeadKey, iv: &[u8]) -> Box<dyn MessageDecrypter>;
61
62 fn key_block_shape(&self) -> KeyBlockShape;
65
66 fn extract_keys(
77 &self,
78 key: AeadKey,
79 iv: &[u8],
80 explicit: &[u8],
81 ) -> Result<ConnectionTrafficSecrets, UnsupportedOperationError>;
82
83 fn fips(&self) -> bool {
85 false
86 }
87}
88
89#[derive(Debug, Eq, PartialEq, Clone, Copy)]
91pub struct UnsupportedOperationError;
92
93impl From<UnsupportedOperationError> for Error {
94 fn from(value: UnsupportedOperationError) -> Self {
95 Self::General(value.to_string())
96 }
97}
98
99impl fmt::Display for UnsupportedOperationError {
100 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101 write!(f, "operation not supported")
102 }
103}
104
105#[cfg(feature = "std")]
106impl std::error::Error for UnsupportedOperationError {}
107
108pub struct KeyBlockShape {
112 pub enc_key_len: usize,
118
119 pub fixed_iv_len: usize,
128
129 pub explicit_nonce_len: usize,
134}
135
136pub trait MessageDecrypter: Send + Sync {
138 fn decrypt<'a>(
141 &mut self,
142 msg: InboundOpaqueMessage<'a>,
143 seq: u64,
144 ) -> Result<InboundPlainMessage<'a>, Error>;
145}
146
147pub trait MessageEncrypter: Send + Sync {
149 fn encrypt(
152 &mut self,
153 msg: OutboundPlainMessage<'_>,
154 seq: u64,
155 ) -> Result<OutboundOpaqueMessage, Error>;
156
157 fn encrypted_payload_len(&self, payload_len: usize) -> usize;
160}
161
162impl dyn MessageEncrypter {
163 pub(crate) fn invalid() -> Box<dyn MessageEncrypter> {
164 Box::new(InvalidMessageEncrypter {})
165 }
166}
167
168impl dyn MessageDecrypter {
169 pub(crate) fn invalid() -> Box<dyn MessageDecrypter> {
170 Box::new(InvalidMessageDecrypter {})
171 }
172}
173
174#[derive(Default)]
176pub struct Iv([u8; NONCE_LEN]);
177
178impl Iv {
179 #[cfg(feature = "tls12")]
181 pub fn new(value: [u8; NONCE_LEN]) -> Self {
182 Self(value)
183 }
184
185 #[cfg(feature = "tls12")]
187 pub fn copy(value: &[u8]) -> Self {
188 debug_assert_eq!(value.len(), NONCE_LEN);
189 let mut iv = Self::new(Default::default());
190 iv.0.copy_from_slice(value);
191 iv
192 }
193}
194
195impl From<[u8; NONCE_LEN]> for Iv {
196 fn from(bytes: [u8; NONCE_LEN]) -> Self {
197 Self(bytes)
198 }
199}
200
201impl AsRef<[u8]> for Iv {
202 fn as_ref(&self) -> &[u8] {
203 self.0.as_ref()
204 }
205}
206
207pub struct Nonce(pub [u8; NONCE_LEN]);
209
210impl Nonce {
211 #[inline]
215 pub fn new(iv: &Iv, seq: u64) -> Self {
216 let mut seq_bytes = [0u8; NONCE_LEN];
217 codec::put_u64(seq, &mut seq_bytes[4..]);
218 Self::new_from_seq(iv, seq_bytes)
219 }
220
221 pub fn for_path(path_id: u32, iv: &Iv, pn: u64) -> Self {
226 let mut seq_bytes = [0u8; NONCE_LEN];
227 seq_bytes[0..4].copy_from_slice(&path_id.to_be_bytes());
228 codec::put_u64(pn, &mut seq_bytes[4..]);
229 Self::new_from_seq(iv, seq_bytes)
230 }
231
232 #[inline]
234 fn new_from_seq(iv: &Iv, mut seq: [u8; NONCE_LEN]) -> Self {
235 seq.iter_mut()
236 .zip(iv.0.iter())
237 .for_each(|(s, iv)| {
238 *s ^= *iv;
239 });
240
241 Self(seq)
242 }
243}
244
245pub const NONCE_LEN: usize = 12;
248
249#[inline]
253pub fn make_tls13_aad(payload_len: usize) -> [u8; 5] {
254 let version = ProtocolVersion::TLSv1_2.to_array();
255 [
256 ContentType::ApplicationData.into(),
257 version[0],
259 version[1],
260 (payload_len >> 8) as u8,
261 (payload_len & 0xff) as u8,
262 ]
263}
264
265#[inline]
269pub fn make_tls12_aad(
270 seq: u64,
271 typ: ContentType,
272 vers: ProtocolVersion,
273 len: usize,
274) -> [u8; TLS12_AAD_SIZE] {
275 let mut out = [0; TLS12_AAD_SIZE];
276 codec::put_u64(seq, &mut out[0..]);
277 out[8] = typ.into();
278 codec::put_u16(vers.into(), &mut out[9..]);
279 codec::put_u16(len as u16, &mut out[11..]);
280 out
281}
282
283const TLS12_AAD_SIZE: usize = 8 + 1 + 2 + 2;
284
285pub struct AeadKey {
289 buf: [u8; Self::MAX_LEN],
290 used: usize,
291}
292
293impl AeadKey {
294 #[cfg(feature = "tls12")]
295 pub(crate) fn new(buf: &[u8]) -> Self {
296 debug_assert!(buf.len() <= Self::MAX_LEN);
297 let mut key = Self::from([0u8; Self::MAX_LEN]);
298 key.buf[..buf.len()].copy_from_slice(buf);
299 key.used = buf.len();
300 key
301 }
302
303 pub(crate) fn with_length(self, len: usize) -> Self {
304 assert!(len <= self.used);
305 Self {
306 buf: self.buf,
307 used: len,
308 }
309 }
310
311 pub(crate) const MAX_LEN: usize = 32;
313}
314
315impl Drop for AeadKey {
316 fn drop(&mut self) {
317 self.buf.zeroize();
318 }
319}
320
321impl AsRef<[u8]> for AeadKey {
322 fn as_ref(&self) -> &[u8] {
323 &self.buf[..self.used]
324 }
325}
326
327impl From<[u8; Self::MAX_LEN]> for AeadKey {
328 fn from(bytes: [u8; Self::MAX_LEN]) -> Self {
329 Self {
330 buf: bytes,
331 used: Self::MAX_LEN,
332 }
333 }
334}
335
336struct InvalidMessageEncrypter {}
338
339impl MessageEncrypter for InvalidMessageEncrypter {
340 fn encrypt(
341 &mut self,
342 _m: OutboundPlainMessage<'_>,
343 _seq: u64,
344 ) -> Result<OutboundOpaqueMessage, Error> {
345 Err(Error::EncryptError)
346 }
347
348 fn encrypted_payload_len(&self, payload_len: usize) -> usize {
349 payload_len
350 }
351}
352
353struct InvalidMessageDecrypter {}
355
356impl MessageDecrypter for InvalidMessageDecrypter {
357 fn decrypt<'a>(
358 &mut self,
359 _m: InboundOpaqueMessage<'a>,
360 _seq: u64,
361 ) -> Result<InboundPlainMessage<'a>, Error> {
362 Err(Error::DecryptError)
363 }
364}