1use alloc::{string::String, vec::Vec};
17use core::char;
18use core::fmt::Write;
19use core::marker::PhantomData;
20
21const BASE: u32 = 36;
23const T_MIN: u32 = 1;
24const T_MAX: u32 = 26;
25const SKEW: u32 = 38;
26const DAMP: u32 = 700;
27const INITIAL_BIAS: u32 = 72;
28const INITIAL_N: u32 = 0x80;
29
30#[inline]
31fn adapt(mut delta: u32, num_points: u32, first_time: bool) -> u32 {
32    delta /= if first_time { DAMP } else { 2 };
33    delta += delta / num_points;
34    let mut k = 0;
35    while delta > ((BASE - T_MIN) * T_MAX) / 2 {
36        delta /= BASE - T_MIN;
37        k += BASE;
38    }
39    k + (((BASE - T_MIN + 1) * delta) / (delta + SKEW))
40}
41
42#[inline]
48pub fn decode_to_string(input: &str) -> Option<String> {
49    Some(
50        Decoder::default()
51            .decode::<u8, ExternalCaller>(input.as_bytes())
52            .ok()?
53            .collect(),
54    )
55}
56
57pub fn decode(input: &str) -> Option<Vec<char>> {
63    Some(
64        Decoder::default()
65            .decode::<u8, ExternalCaller>(input.as_bytes())
66            .ok()?
67            .collect(),
68    )
69}
70
71pub(crate) trait PunycodeCaller {
85    const EXTERNAL_CALLER: bool;
86}
87
88pub(crate) struct InternalCaller;
89
90impl PunycodeCaller for InternalCaller {
91    const EXTERNAL_CALLER: bool = false;
92}
93
94struct ExternalCaller;
95
96impl PunycodeCaller for ExternalCaller {
97    const EXTERNAL_CALLER: bool = true;
98}
99
100pub(crate) trait PunycodeCodeUnit {
101    fn is_delimiter(&self) -> bool;
102    fn is_ascii(&self) -> bool;
103    fn digit(&self) -> Option<u32>;
104    fn char(&self) -> char;
105    fn char_ascii_lower_case(&self) -> char;
106}
107
108impl PunycodeCodeUnit for u8 {
109    fn is_delimiter(&self) -> bool {
110        *self == b'-'
111    }
112    fn is_ascii(&self) -> bool {
113        *self < 0x80
114    }
115    fn digit(&self) -> Option<u32> {
116        let byte = *self;
117        Some(match byte {
118            byte @ b'0'..=b'9' => byte - b'0' + 26,
119            byte @ b'A'..=b'Z' => byte - b'A',
120            byte @ b'a'..=b'z' => byte - b'a',
121            _ => return None,
122        } as u32)
123    }
124    fn char(&self) -> char {
125        char::from(*self)
126    }
127    fn char_ascii_lower_case(&self) -> char {
128        char::from(self.to_ascii_lowercase())
129    }
130}
131
132impl PunycodeCodeUnit for char {
133    fn is_delimiter(&self) -> bool {
134        *self == '-'
135    }
136    fn is_ascii(&self) -> bool {
137        debug_assert!(false); true
139    }
140    fn digit(&self) -> Option<u32> {
141        let byte = *self;
142        Some(match byte {
143            byte @ '0'..='9' => u32::from(byte) - u32::from('0') + 26,
144            byte @ 'a'..='z' => u32::from(byte) - u32::from('a'),
146            _ => return None,
147        })
148    }
149    fn char(&self) -> char {
150        debug_assert!(false); *self
152    }
153    fn char_ascii_lower_case(&self) -> char {
154        *self
156    }
157}
158
159#[derive(Default)]
160pub(crate) struct Decoder {
161    insertions: smallvec::SmallVec<[(usize, char); 59]>,
162}
163
164impl Decoder {
165    pub(crate) fn decode<'a, T: PunycodeCodeUnit + Copy, C: PunycodeCaller>(
167        &'a mut self,
168        input: &'a [T],
169    ) -> Result<Decode<'a, T, C>, ()> {
170        self.insertions.clear();
171        let (base, input) = if let Some(position) = input.iter().rposition(|c| c.is_delimiter()) {
174            (
175                &input[..position],
176                if position > 0 {
177                    &input[position + 1..]
178                } else {
179                    input
180                },
181            )
182        } else {
183            (&input[..0], input)
184        };
185
186        if C::EXTERNAL_CALLER && !base.iter().all(|c| c.is_ascii()) {
187            return Err(());
188        }
189
190        let base_len = base.len();
191        let mut length = base_len as u32;
192        let mut code_point = INITIAL_N;
193        let mut bias = INITIAL_BIAS;
194        let mut i = 0u32;
195        let mut iter = input.iter();
196        loop {
197            let previous_i = i;
198            let mut weight = 1;
199            let mut k = BASE;
200            let mut byte = match iter.next() {
201                None => break,
202                Some(byte) => byte,
203            };
204
205            loop {
208                let digit = if let Some(digit) = byte.digit() {
209                    digit
210                } else {
211                    return Err(());
212                };
213                let product = digit.checked_mul(weight).ok_or(())?;
214                i = i.checked_add(product).ok_or(())?;
215                let t = if k <= bias {
216                    T_MIN
217                } else if k >= bias + T_MAX {
218                    T_MAX
219                } else {
220                    k - bias
221                };
222                if digit < t {
223                    break;
224                }
225                weight = weight.checked_mul(BASE - t).ok_or(())?;
226                k += BASE;
227                byte = match iter.next() {
228                    None => return Err(()), Some(byte) => byte,
230                };
231            }
232
233            bias = adapt(i - previous_i, length + 1, previous_i == 0);
234
235            code_point = code_point.checked_add(i / (length + 1)).ok_or(())?;
238            i %= length + 1;
239            let c = match char::from_u32(code_point) {
240                Some(c) => c,
241                None => return Err(()),
242            };
243
244            for (idx, _) in &mut self.insertions {
246                if *idx >= i as usize {
247                    *idx += 1;
248                }
249            }
250            self.insertions.push((i as usize, c));
251            length += 1;
252            i += 1;
253        }
254
255        self.insertions.sort_by_key(|(i, _)| *i);
256        Ok(Decode {
257            base: base.iter(),
258            insertions: &self.insertions,
259            inserted: 0,
260            position: 0,
261            len: base_len + self.insertions.len(),
262            phantom: PhantomData::<C>,
263        })
264    }
265}
266
267pub(crate) struct Decode<'a, T, C>
268where
269    T: PunycodeCodeUnit + Copy,
270    C: PunycodeCaller,
271{
272    base: core::slice::Iter<'a, T>,
273    pub(crate) insertions: &'a [(usize, char)],
274    inserted: usize,
275    position: usize,
276    len: usize,
277    phantom: PhantomData<C>,
278}
279
280impl<T: PunycodeCodeUnit + Copy, C: PunycodeCaller> Iterator for Decode<'_, T, C> {
281    type Item = char;
282
283    fn next(&mut self) -> Option<Self::Item> {
284        loop {
285            match self.insertions.get(self.inserted) {
286                Some((pos, c)) if *pos == self.position => {
287                    self.inserted += 1;
288                    self.position += 1;
289                    return Some(*c);
290                }
291                _ => {}
292            }
293            if let Some(c) = self.base.next() {
294                self.position += 1;
295                return Some(if C::EXTERNAL_CALLER {
296                    c.char()
297                } else {
298                    c.char_ascii_lower_case()
299                });
300            } else if self.inserted >= self.insertions.len() {
301                return None;
302            }
303        }
304    }
305
306    fn size_hint(&self) -> (usize, Option<usize>) {
307        let len = self.len - self.position;
308        (len, Some(len))
309    }
310}
311
312impl<T: PunycodeCodeUnit + Copy, C: PunycodeCaller> ExactSizeIterator for Decode<'_, T, C> {
313    fn len(&self) -> usize {
314        self.len - self.position
315    }
316}
317
318#[inline]
322pub fn encode_str(input: &str) -> Option<String> {
323    if input.len() > u32::MAX as usize {
324        return None;
325    }
326    let mut buf = String::with_capacity(input.len());
327    encode_into::<_, _, ExternalCaller>(input.chars(), &mut buf)
328        .ok()
329        .map(|()| buf)
330}
331
332pub fn encode(input: &[char]) -> Option<String> {
337    if input.len() > u32::MAX as usize {
338        return None;
339    }
340    let mut buf = String::with_capacity(input.len());
341    encode_into::<_, _, ExternalCaller>(input.iter().copied(), &mut buf)
342        .ok()
343        .map(|()| buf)
344}
345
346pub(crate) enum PunycodeEncodeError {
347    Overflow,
348    Sink,
349}
350
351impl From<core::fmt::Error> for PunycodeEncodeError {
352    fn from(_: core::fmt::Error) -> Self {
353        Self::Sink
354    }
355}
356
357pub(crate) fn encode_into<I, W, C>(input: I, output: &mut W) -> Result<(), PunycodeEncodeError>
358where
359    I: Iterator<Item = char> + Clone,
360    W: Write + ?Sized,
361    C: PunycodeCaller,
362{
363    let (mut input_length, mut basic_length) = (0u32, 0);
365    for c in input.clone() {
366        input_length = input_length
367            .checked_add(1)
368            .ok_or(PunycodeEncodeError::Overflow)?;
369        if c.is_ascii() {
370            output.write_char(c)?;
371            basic_length += 1;
372        }
373    }
374
375    if !C::EXTERNAL_CALLER {
376        let len_plus_one = input_length
381            .checked_add(1)
382            .ok_or(PunycodeEncodeError::Overflow)?;
383        len_plus_one
384            .checked_mul(u32::from(char::MAX) - INITIAL_N)
385            .ok_or(PunycodeEncodeError::Overflow)?;
386    }
387
388    if basic_length > 0 {
389        output.write_char('-')?;
390    }
391    let mut code_point = INITIAL_N;
392    let mut delta = 0u32;
393    let mut bias = INITIAL_BIAS;
394    let mut processed = basic_length;
395    while processed < input_length {
396        let min_code_point = input
399            .clone()
400            .map(|c| c as u32)
401            .filter(|&c| c >= code_point)
402            .min()
403            .unwrap();
404        if C::EXTERNAL_CALLER {
406            let product = (min_code_point - code_point)
407                .checked_mul(processed + 1)
408                .ok_or(PunycodeEncodeError::Overflow)?;
409            delta = delta
410                .checked_add(product)
411                .ok_or(PunycodeEncodeError::Overflow)?;
412        } else {
413            delta += (min_code_point - code_point) * (processed + 1);
414        }
415        code_point = min_code_point;
416        for c in input.clone() {
417            let c = c as u32;
418            if c < code_point {
419                if C::EXTERNAL_CALLER {
420                    delta = delta.checked_add(1).ok_or(PunycodeEncodeError::Overflow)?;
421                } else {
422                    delta += 1;
423                }
424            }
425            if c == code_point {
426                let mut q = delta;
428                let mut k = BASE;
429                loop {
430                    let t = if k <= bias {
431                        T_MIN
432                    } else if k >= bias + T_MAX {
433                        T_MAX
434                    } else {
435                        k - bias
436                    };
437                    if q < t {
438                        break;
439                    }
440                    let value = t + ((q - t) % (BASE - t));
441                    output.write_char(value_to_digit(value))?;
442                    q = (q - t) / (BASE - t);
443                    k += BASE;
444                }
445                output.write_char(value_to_digit(q))?;
446                bias = adapt(delta, processed + 1, processed == basic_length);
447                delta = 0;
448                processed += 1;
449            }
450        }
451        delta += 1;
452        code_point += 1;
453    }
454    Ok(())
455}
456
457#[inline]
458fn value_to_digit(value: u32) -> char {
459    match value {
460        0..=25 => (value as u8 + b'a') as char,       26..=35 => (value as u8 - 26 + b'0') as char, _ => panic!(),
463    }
464}
465
466#[test]
467#[ignore = "slow"]
468#[cfg(target_pointer_width = "64")]
469fn huge_encode() {
470    let mut buf = String::new();
471    assert!(encode_into::<_, _, ExternalCaller>(
472        core::iter::repeat('ß').take(u32::MAX as usize + 1),
473        &mut buf
474    )
475    .is_err());
476    assert_eq!(buf.len(), 0);
477}