regorus/
lexer.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::*;
5use core::cmp;
6use core::fmt::{self, Debug, Formatter};
7use core::iter::Peekable;
8use core::str::CharIndices;
9
10use crate::Value;
11
12use anyhow::{anyhow, bail, Result};
13
14#[derive(Clone)]
15#[cfg_attr(feature = "ast", derive(serde::Serialize))]
16struct SourceInternal {
17    pub file: String,
18    pub contents: String,
19    #[cfg_attr(feature = "ast", serde(skip_serializing))]
20    pub lines: Vec<(u32, u32)>,
21}
22
23/// A policy file.
24#[derive(Clone)]
25#[cfg_attr(feature = "ast", derive(serde::Serialize))]
26pub struct Source {
27    #[cfg_attr(feature = "ast", serde(flatten))]
28    src: Rc<SourceInternal>,
29}
30
31impl Source {
32    /// The path associated with the policy file.
33    pub fn get_path(&self) -> &String {
34        &self.src.file
35    }
36
37    /// The contents of the policy file.
38    pub fn get_contents(&self) -> &String {
39        &self.src.contents
40    }
41}
42
43impl cmp::Ord for Source {
44    fn cmp(&self, other: &Source) -> cmp::Ordering {
45        Rc::as_ptr(&self.src).cmp(&Rc::as_ptr(&other.src))
46    }
47}
48
49impl cmp::PartialOrd for Source {
50    fn partial_cmp(&self, other: &Source) -> Option<cmp::Ordering> {
51        Some(self.cmp(other))
52    }
53}
54
55impl cmp::PartialEq for Source {
56    fn eq(&self, other: &Source) -> bool {
57        Rc::as_ptr(&self.src) == Rc::as_ptr(&other.src)
58    }
59}
60
61impl cmp::Eq for Source {}
62
63#[cfg(feature = "std")]
64impl std::hash::Hash for Source {
65    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
66        Rc::as_ptr(&self.src).hash(state)
67    }
68}
69
70impl Debug for Source {
71    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
72        self.src.file.fmt(f)
73    }
74}
75
76#[derive(Clone)]
77pub struct SourceStr {
78    source: Source,
79    start: u32,
80    end: u32,
81}
82
83impl Debug for SourceStr {
84    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
85        self.text().fmt(f)
86    }
87}
88
89impl fmt::Display for SourceStr {
90    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
91        fmt::Display::fmt(&self.text(), f)
92    }
93}
94
95impl SourceStr {
96    pub fn new(source: Source, start: u32, end: u32) -> Self {
97        Self { source, start, end }
98    }
99
100    pub fn text(&self) -> &str {
101        &self.source.contents()[self.start as usize..self.end as usize]
102    }
103
104    pub fn clone_empty(&self) -> SourceStr {
105        Self {
106            source: self.source.clone(),
107            start: 0,
108            end: 0,
109        }
110    }
111}
112
113impl cmp::PartialEq for SourceStr {
114    fn eq(&self, other: &Self) -> bool {
115        self.text().eq(other.text())
116    }
117}
118
119impl cmp::Eq for SourceStr {}
120
121impl cmp::PartialOrd for SourceStr {
122    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
123        Some(self.cmp(other))
124    }
125}
126
127impl cmp::Ord for SourceStr {
128    fn cmp(&self, other: &Self) -> cmp::Ordering {
129        self.text().cmp(other.text())
130    }
131}
132
133impl Source {
134    pub fn from_contents(file: String, contents: String) -> Result<Source> {
135        let max_size = u32::MAX as usize - 2; // Account for rows, cols possibly starting at 1, EOF etc.
136        if contents.len() > max_size {
137            bail!("{file} exceeds maximum allowed policy file size {max_size}");
138        }
139        let mut lines = vec![];
140        let mut prev_ch = ' ';
141        let mut prev_pos = 0u32;
142        let mut start = 0u32;
143        for (i, ch) in contents.char_indices() {
144            if ch == '\n' {
145                let end = match prev_ch {
146                    '\r' => prev_pos,
147                    _ => i as u32,
148                };
149                lines.push((start, end));
150                start = i as u32 + 1;
151            }
152            prev_ch = ch;
153            prev_pos = i as u32;
154        }
155
156        if (start as usize) < contents.len() {
157            lines.push((start, contents.len() as u32));
158        } else if contents.is_empty() {
159            lines.push((0, 0));
160        } else {
161            let s = (contents.len() - 1) as u32;
162            lines.push((s, s));
163        }
164        Ok(Self {
165            src: Rc::new(SourceInternal {
166                file,
167                contents,
168                lines,
169            }),
170        })
171    }
172
173    #[cfg(feature = "std")]
174    pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Source> {
175        let contents = match std::fs::read_to_string(&path) {
176            Ok(c) => c,
177            Err(e) => bail!("Failed to read {}. {e}", path.as_ref().display()),
178        };
179        // TODO: retain path instead of converting to string
180        Self::from_contents(path.as_ref().to_string_lossy().to_string(), contents)
181    }
182
183    pub fn file(&self) -> &String {
184        &self.src.file
185    }
186    pub fn contents(&self) -> &String {
187        &self.src.contents
188    }
189    pub fn line(&self, idx: u32) -> &str {
190        let idx = idx as usize;
191        if idx < self.src.lines.len() {
192            let (start, end) = self.src.lines[idx];
193            &self.src.contents[start as usize..end as usize]
194        } else {
195            ""
196        }
197    }
198
199    pub fn message(&self, line: u32, col: u32, kind: &str, msg: &str) -> String {
200        if line as usize > self.src.lines.len() {
201            return format!("{}: invalid line {} specified", self.src.file, line);
202        }
203
204        let line_str = format!("{line}");
205        let line_num_width = line_str.len() + 1;
206        let col_spaces = col as usize - 1;
207
208        format!(
209            "\n--> {}:{}:{}\n{:<line_num_width$}|\n\
210		{:<line_num_width$}| {}\n\
211		{:<line_num_width$}| {:<col_spaces$}^\n\
212		{}: {}",
213            self.src.file,
214            line,
215            col,
216            "",
217            line,
218            self.line(line - 1),
219            "",
220            "",
221            kind,
222            msg
223        )
224    }
225
226    pub fn error(&self, line: u32, col: u32, msg: &str) -> anyhow::Error {
227        anyhow!(self.message(line, col, "error", msg))
228    }
229}
230
231#[derive(Clone)]
232#[cfg_attr(feature = "ast", derive(serde::Serialize))]
233pub struct Span {
234    #[cfg_attr(feature = "ast", serde(skip_serializing))]
235    pub source: Source,
236    pub line: u32,
237    pub col: u32,
238    pub start: u32,
239    pub end: u32,
240}
241
242impl Span {
243    pub fn text(&self) -> &str {
244        &self.source.contents()[self.start as usize..self.end as usize]
245    }
246
247    pub fn source_str(&self) -> SourceStr {
248        SourceStr::new(self.source.clone(), self.start, self.end)
249    }
250
251    pub fn message(&self, kind: &str, msg: &str) -> String {
252        self.source.message(self.line, self.col, kind, msg)
253    }
254
255    pub fn error(&self, msg: &str) -> anyhow::Error {
256        self.source.error(self.line, self.col, msg)
257    }
258}
259
260impl Debug for Span {
261    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
262        let t = self.text().escape_debug().to_string();
263        let max = 32;
264        let (txt, trailer) = if t.len() > max {
265            (&t[0..max], "...")
266        } else {
267            (t.as_str(), "")
268        };
269
270        f.write_fmt(format_args!(
271            "{}:{}:{}:{}, \"{}{}\"",
272            self.line, self.col, self.start, self.end, txt, trailer
273        ))
274    }
275}
276
277#[derive(Debug, PartialEq, Eq, Clone)]
278pub enum TokenKind {
279    Symbol,
280    String,
281    RawString,
282    Number,
283    Ident,
284    Eof,
285}
286
287#[derive(Debug, Clone)]
288pub struct Token(pub TokenKind, pub Span);
289
290#[derive(Clone)]
291pub struct Lexer<'source> {
292    source: Source,
293    iter: Peekable<CharIndices<'source>>,
294    line: u32,
295    col: u32,
296    unknown_char_is_symbol: bool,
297    allow_slash_star_escape: bool,
298    comment_starts_with_double_slash: bool,
299    double_colon_token: bool,
300}
301
302impl<'source> Lexer<'source> {
303    pub fn new(source: &'source Source) -> Self {
304        Self {
305            source: source.clone(),
306            iter: source.contents().char_indices().peekable(),
307            line: 1,
308            col: 1,
309            unknown_char_is_symbol: false,
310            allow_slash_star_escape: false,
311            comment_starts_with_double_slash: false,
312            double_colon_token: false,
313        }
314    }
315
316    pub fn set_unknown_char_is_symbol(&mut self, b: bool) {
317        self.unknown_char_is_symbol = b;
318    }
319
320    pub fn set_allow_slash_star_escape(&mut self, b: bool) {
321        self.allow_slash_star_escape = b;
322    }
323
324    pub fn set_comment_starts_with_double_slash(&mut self, b: bool) {
325        self.comment_starts_with_double_slash = b;
326    }
327
328    pub fn set_double_colon_token(&mut self, b: bool) {
329        self.double_colon_token = b;
330    }
331
332    fn peek(&mut self) -> (usize, char) {
333        match self.iter.peek() {
334            Some((index, chr)) => (*index, *chr),
335            _ => (self.source.contents().len(), '\x00'),
336        }
337    }
338
339    fn peekahead(&mut self, n: usize) -> (usize, char) {
340        match self.iter.clone().nth(n) {
341            Some((index, chr)) => (index, chr),
342            _ => (self.source.contents().len(), '\x00'),
343        }
344    }
345
346    fn read_ident(&mut self) -> Result<Token> {
347        let start = self.peek().0;
348        let col = self.col;
349        loop {
350            let ch = self.peek().1;
351            if ch.is_ascii_alphanumeric() || ch == '_' {
352                self.iter.next();
353            } else {
354                break;
355            }
356        }
357        let end = self.peek().0;
358        self.col += (end - start) as u32;
359        Ok(Token(
360            TokenKind::Ident,
361            Span {
362                source: self.source.clone(),
363                line: self.line,
364                col,
365                start: start as u32,
366                end: end as u32,
367            },
368        ))
369    }
370
371    fn read_digits(&mut self) {
372        while self.peek().1.is_ascii_digit() {
373            self.iter.next();
374        }
375    }
376
377    // See https://www.json.org/json-en.html for number's grammar
378    fn read_number(&mut self) -> Result<Token> {
379        let (start, chr) = self.peek();
380        let col = self.col;
381        self.iter.next();
382
383        // Read integer part.
384        if chr != '0' {
385            // Starts with 1.. or 9. Read digits.
386            self.read_digits();
387        }
388
389        // Read fraction part
390        // . must be followed by at least 1 digit.
391        if self.peek().1 == '.' && self.peekahead(1).1.is_ascii_digit() {
392            self.iter.next(); // .
393            self.read_digits();
394        }
395
396        // Read exponent part
397        let ch = self.peek().1;
398        if ch == 'e' || ch == 'E' {
399            self.iter.next();
400            // e must be followed by an optional sign and digits
401            if matches!(self.peek().1, '+' | '-') {
402                self.iter.next();
403            }
404            // Read digits. Absence of digit will be validated by serde later.
405            self.read_digits();
406        }
407
408        let end = self.peek().0;
409        self.col += (end - start) as u32;
410
411        // Check for invalid number.Valid number cannot be followed by
412        // these characters:
413        let ch = self.peek().1;
414        if ch == '_' || ch == '.' || ch.is_ascii_alphanumeric() {
415            return Err(self.source.error(self.line, self.col, "invalid number"));
416        }
417
418        // Ensure that the number is parsable in Rust.
419        match serde_json::from_str::<Value>(&self.source.contents()[start..end]) {
420            Ok(_) => (),
421            Err(e) => {
422                let serde_msg = &e.to_string();
423                let msg = match &serde_msg {
424                    m if m.contains("out of range") => "out of range",
425                    m if m.contains("invalid number") => "invalid number",
426                    m if m.contains("expected value") => "expected value",
427                    m if m.contains("trailing characters") => "trailing characters",
428                    m => m.to_owned(),
429                };
430
431                bail!(
432                    "{} {}",
433                    self.source.error(
434                        self.line,
435                        col,
436                        "invalid number. serde_json cannot parse number:"
437                    ),
438                    msg
439                )
440            }
441        }
442
443        Ok(Token(
444            TokenKind::Number,
445            Span {
446                source: self.source.clone(),
447                line: self.line,
448                col,
449                start: start as u32,
450                end: end as u32,
451            },
452        ))
453    }
454
455    fn read_raw_string(&mut self) -> Result<Token> {
456        self.iter.next();
457        self.col += 1;
458        let (start, _) = self.peek();
459        let (line, col) = (self.line, self.col);
460        loop {
461            let (_, ch) = self.peek();
462            self.iter.next();
463            match ch {
464                '`' => {
465                    self.col += 1;
466                    break;
467                }
468                '\x00' => {
469                    return Err(self.source.error(line, col, "unmatched `"));
470                }
471                '\t' => self.col += 4,
472                '\n' => {
473                    self.line += 1;
474                    self.col = 1;
475                }
476                _ => self.col += 1,
477            }
478        }
479        let end = self.peek().0;
480        Ok(Token(
481            TokenKind::RawString,
482            Span {
483                source: self.source.clone(),
484                line,
485                col,
486                start: start as u32,
487                end: end as u32 - 1,
488            },
489        ))
490    }
491
492    fn read_string(&mut self) -> Result<Token> {
493        let (line, col) = (self.line, self.col);
494        self.iter.next();
495        self.col += 1;
496        let (start, _) = self.peek();
497        loop {
498            let (offset, ch) = self.peek();
499            let col = self.col + (offset - start) as u32;
500            match ch {
501                '"' | '\x00' => {
502                    break;
503                }
504                '\\' => {
505                    self.iter.next();
506                    let (_, ch) = self.peek();
507                    self.iter.next();
508                    match ch {
509                        // json escape sequence
510                        '"' | '\\' | '/' | 'b' | 'f' | 'n' | 'r' | 't' => (),
511                        '*' if self.allow_slash_star_escape => (),
512                        'u' => {
513                            for _i in 0..4 {
514                                let (offset, ch) = self.peek();
515                                let col = self.col + (offset - start) as u32;
516                                if !ch.is_ascii_hexdigit() {
517                                    return Err(self.source.error(
518                                        line,
519                                        col,
520                                        "invalid hex escape sequence",
521                                    ));
522                                }
523                                self.iter.next();
524                            }
525                        }
526                        _ => return Err(self.source.error(line, col, "invalid escape sequence")),
527                    }
528                }
529                _ => {
530                    // check for valid json chars
531                    let col = self.col + (offset - start) as u32;
532                    if !('\u{0020}'..='\u{10FFFF}').contains(&ch) {
533                        return Err(self.source.error(line, col, "invalid character in string"));
534                    }
535                    self.iter.next();
536                }
537            }
538        }
539
540        if self.peek().1 != '"' {
541            return Err(self.source.error(line, col, "unmatched \""));
542        }
543
544        self.iter.next();
545        let end = self.peek().0;
546        self.col += (end - start) as u32;
547
548        // Ensure that the string is parsable in Rust.
549        match serde_json::from_str::<String>(&self.source.contents()[start - 1..end]) {
550            Ok(_) => (),
551            Err(e) => {
552                let serde_msg = &e.to_string();
553                let msg = serde_msg;
554                bail!(
555                    "{} {}",
556                    self.source
557                        .error(self.line, col, "serde_json cannot parse string:"),
558                    msg
559                )
560            }
561        }
562
563        Ok(Token(
564            TokenKind::String,
565            Span {
566                source: self.source.clone(),
567                line,
568                col: col + 1,
569                start: start as u32,
570                end: end as u32 - 1,
571            },
572        ))
573    }
574
575    #[inline]
576    fn skip_past_newline(&mut self) -> Result<()> {
577        self.iter.next();
578        loop {
579            match self.peek().1 {
580                '\n' | '\x00' => break,
581                _ => self.iter.next(),
582            };
583        }
584        Ok(())
585    }
586
587    fn skip_ws(&mut self) -> Result<()> {
588        // Only the 4 json whitespace characters are recognized.
589        // https://www.crockford.com/mckeeman.html.
590        // Additionally, comments are also skipped.
591        // A tab is considered 4 space characters.
592        loop {
593            match self.peek().1 {
594                ' ' => self.col += 1,
595                '\t' => self.col += 4,
596                '\r' => {
597                    if self.peekahead(1).1 != '\n' {
598                        return Err(self.source.error(
599                            self.line,
600                            self.col,
601                            "\\r must be followed by \\n",
602                        ));
603                    }
604                }
605                '\n' => {
606                    self.col = 1;
607                    self.line += 1;
608                }
609                '#' if !self.comment_starts_with_double_slash => {
610                    self.skip_past_newline()?;
611                    continue;
612                }
613                '/' if self.comment_starts_with_double_slash && self.peekahead(1).1 == '/' => {
614                    self.skip_past_newline()?;
615                    continue;
616                }
617                _ => break,
618            }
619            self.iter.next();
620        }
621        Ok(())
622    }
623
624    pub fn next_token(&mut self) -> Result<Token> {
625        self.skip_ws()?;
626
627        let (start, chr) = self.peek();
628        let col = self.col;
629
630        match chr {
631	    // Special case for - followed by digit which is a
632	    // negative json number.
633	    // . followed by digit is invalid number.
634	    '-' | '.' if self.peekahead(1).1.is_ascii_digit() => {
635		self.read_number()
636	    }
637	    // grouping characters
638	    '{' | '}' | '[' | ']' | '(' | ')' |
639	    // arith operator
640	    '+' | '-' | '*' | '/' | '%' |
641	    // bin operator
642	    '&' | '|' |
643	    // separators
644	    ',' | ';' | '.' => {
645		self.col += 1;
646		self.iter.next();
647		Ok(Token(TokenKind::Symbol, Span {
648		    source: self.source.clone(),
649		    line: self.line,
650		    col,
651		    start: start as u32,
652		    end: start as u32 + 1,
653		}))
654	    }
655	    ':' => {
656		self.col += 1;
657		self.iter.next();
658		let mut end = start as u32 + 1;
659		if self.peek().1 == '=' || (self.peek().1 == ':' && self.double_colon_token) {
660		    self.col += 1;
661		    self.iter.next();
662		    end += 1;
663		}
664		Ok(Token(TokenKind::Symbol, Span {
665		    source: self.source.clone(),
666		    line: self.line,
667		    col,
668		    start: start as u32,
669		    end
670		}))
671	    }
672	    // < <= > >= = ==
673	    '<' | '>' | '=' => {
674		self.col += 1;
675		self.iter.next();
676		if self.peek().1 == '=' {
677		    self.col += 1;
678		    self.iter.next();
679		};
680		Ok(Token(TokenKind::Symbol, Span {
681		    source: self.source.clone(),
682		    line: self.line,
683		    col,
684		    start: start as u32,
685		    end: self.peek().0 as u32,
686		}))
687	    }
688	    '!' if self.peekahead(1).1 == '=' => {
689		self.col += 2;
690		self.iter.next();
691		self.iter.next();
692		Ok(Token(TokenKind::Symbol, Span {
693		    source: self.source.clone(),
694		    line: self.line,
695		    col,
696		    start: start as u32,
697		    end: self.peek().0 as u32,
698		}))
699	    }
700	    '"' => self.read_string(),
701	    '`' => self.read_raw_string(),
702	    '\x00' => Ok(Token(TokenKind::Eof, Span {
703		source: self.source.clone(),
704		line:self.line,
705		col,
706		start: start as u32,
707		end: start as u32
708	    })),
709	    _ if chr.is_ascii_digit() => self.read_number(),
710	    _ if chr.is_ascii_alphabetic() || chr == '_' => {
711		let mut ident = self.read_ident()?;
712		if ident.1.text() == "set" && self.peek().1 == '(' {
713		    // set immediately followed by ( is treated as set( if
714		    // the next token is ).
715		    let state = (self.iter.clone(), self.line, self.col);
716		    self.iter.next();
717
718		    // Check it next token is ).
719		    let next_tok = self.next_token()?;
720		    let is_setp = next_tok.1.text() == ")";
721
722		    // Restore state
723		    (self.iter, self.line, self.col) = state;
724
725		    if is_setp {
726			self.iter.next();
727			self.col += 1;
728			ident.1.end += 1;
729		    }
730		}
731		Ok(ident)
732	    }
733	    _ if self.unknown_char_is_symbol => {
734		self.col += 1;
735		self.iter.next();
736		Ok(Token(TokenKind::Symbol, Span {
737		    source: self.source.clone(),
738		    line: self.line,
739		    col,
740		    start: start as u32,
741		    end: start as u32 + 1,
742		}))
743	    }
744	    _ => Err(self.source.error(self.line, self.col, "invalid character"))
745	}
746    }
747}