regorus/
utils.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![allow(
5    clippy::arithmetic_side_effects,
6    clippy::shadow_unrelated,
7    clippy::pattern_type_mismatch,
8    clippy::as_conversions
9)] // small arithmetic checks are intentional
10
11pub mod limits;
12
13use crate::ast::*;
14use crate::builtins::*;
15use crate::lexer::*;
16use crate::*;
17
18use alloc::collections::BTreeMap;
19
20use anyhow::{bail, Result};
21pub fn get_path_string(refr: &Expr, document: Option<&str>) -> Result<String> {
22    let mut comps: Vec<&str> = vec![];
23    let mut expr = Some(refr);
24    while expr.is_some() {
25        match expr {
26            Some(Expr::RefDot { refr, field, .. }) => {
27                comps.push(field.0.text());
28                expr = Some(refr);
29            }
30            Some(Expr::RefBrack { refr, index, .. }) => {
31                if let Expr::String { span: s, .. } = index.as_ref() {
32                    comps.push(s.text());
33                }
34                expr = Some(refr);
35            }
36            Some(Expr::Var { span: v, .. }) => {
37                comps.push(v.text());
38                expr = None;
39            }
40            _ => bail!("internal error: not a simple ref {expr:?}"),
41        }
42    }
43    if let Some(d) = document {
44        comps.push(d);
45    };
46    comps.reverse();
47    Ok(comps.join("."))
48}
49
50pub type FunctionTable = BTreeMap<String, (Vec<Ref<Rule>>, u8, Ref<Module>)>;
51
52fn get_extra_arg_impl(
53    expr: &Expr,
54    module: Option<&str>,
55    functions: &FunctionTable,
56) -> Result<Option<Ref<Expr>>> {
57    if let Expr::Call { fcn, params, .. } = expr {
58        let full_path = get_path_string(fcn, module)?;
59        let n_args = if let Some((_, n_args, _)) = functions.get(&full_path) {
60            *n_args
61        } else {
62            let path = get_path_string(fcn, None)?;
63            if let Some((_, n_args, _)) = functions.get(&path) {
64                *n_args
65            } else if let Some((_, n_args)) = BUILTINS.get(path.as_str()) {
66                *n_args
67            } else {
68                return Ok(None);
69            }
70        };
71        if (n_args as usize) + 1 == params.len() {
72            return Ok(params.last().cloned());
73        }
74    }
75    Ok(None)
76}
77
78pub fn get_extra_arg(
79    expr: &Expr,
80    module: Option<&str>,
81    functions: &FunctionTable,
82) -> Option<Ref<Expr>> {
83    get_extra_arg_impl(expr, module, functions).unwrap_or_default()
84}
85
86pub fn gather_functions(modules: &[Ref<Module>]) -> Result<FunctionTable> {
87    let mut table = FunctionTable::new();
88
89    for module in modules {
90        let module_path = get_path_string(&module.package.refr, Some("data"))?;
91        for rule in &module.policy {
92            if let Rule::Spec {
93                span,
94                head: RuleHead::Func { refr, args, .. },
95                ..
96            } = rule.as_ref()
97            {
98                let full_path = get_path_string(refr, Some(module_path.as_str()))?;
99
100                if let Some((functions, arity, _)) = table.get_mut(&full_path) {
101                    if args.len() as u8 != *arity {
102                        bail!(span.error(
103                            format!("{full_path} was previously defined with {arity} arguments.")
104                                .as_str()
105                        ));
106                    }
107                    functions.push(rule.clone());
108                } else {
109                    table.insert(
110                        full_path,
111                        (vec![rule.clone()], args.len() as u8, module.clone()),
112                    );
113                }
114            }
115        }
116    }
117    Ok(table)
118}
119
120pub fn get_root_var(mut expr: &Expr) -> Result<SourceStr> {
121    let empty = expr.span().source_str().clone_empty();
122    loop {
123        match expr {
124            Expr::Var { span: v, .. } => return Ok(v.source_str()),
125            Expr::RefDot { refr, .. } | Expr::RefBrack { refr, .. } => expr = refr,
126            _ => return Ok(empty),
127        }
128    }
129}