1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
use crate::grammar::visit::*;
use crate::grammar::*;
use crate::{InternedStr, Span};
use std::collections::{HashMap, HashSet};
pub fn collect_var_names(expr: &RcExpr) -> HashSet<InternedStr> {
let mut collector = VarNameCollector::default();
collector.visit_expr(expr);
collector.vars
}
#[derive(Default)]
struct VarNameCollector {
vars: HashSet<InternedStr>,
}
impl<'a> StmtVisitor<'a> for VarNameCollector {
fn visit_var(&mut self, var: &'a InternedStr, _span: Span) {
self.vars.insert(*var);
}
}
pub fn collect_var_asgns(program: &StmtList) -> HashMap<InternedStr, Vec<&Assignment>> {
let mut collector = VarAsgnsCollector::default();
collector.visit_stmt_list(program);
collector.defs
}
#[derive(Default)]
struct VarAsgnsCollector<'a> {
defs: HashMap<InternedStr, Vec<&'a Assignment>>,
}
impl<'a> StmtVisitor<'a> for VarAsgnsCollector<'a> {
fn visit_asgn(&mut self, asgn: &'a Assignment) {
if let Some(var) = asgn.lhs.get_var() {
self.defs.entry(var).or_default().push(asgn);
}
}
}
pub fn collect_pat_names(expr: &RcExprPat) -> HashSet<&str> {
let mut collector = PatternCollector::default();
collector.visit_expr_pat(expr);
collector.pats
}
#[derive(Default)]
struct PatternCollector<'a> {
pats: HashSet<&'a str>,
}
impl<'a> ExprPatVisitor<'a> for PatternCollector<'a> {
fn visit_var_pat(&mut self, var_pat: &'a str, _span: Span) {
self.pats.insert(var_pat);
}
fn visit_const_pat(&mut self, const_pat: &'a str, _span: Span) {
self.pats.insert(const_pat);
}
fn visit_any_pat(&mut self, any_pat: &'a str, _span: Span) {
self.pats.insert(any_pat);
}
}
#[cfg(test)]
mod test {
use crate::{parse_expr, parse_expression_pattern, scan};
#[test]
fn collect_var_names() {
let parsed = parse_expr!("a + b + c + a + d / b ^ e");
let vars = super::collect_var_names(&parsed);
let mut pats: Vec<_> = vars.into_iter().map(|v| v.to_string()).collect();
pats.sort();
assert_eq!(pats, vec!["a", "b", "c", "d", "e"]);
}
#[test]
fn collect_pat_names() {
let parsed = parse_expression_pattern(scan("$a + _b * (#c - [$d]) / $a").tokens).program;
let pats = super::collect_pat_names(&parsed);
let mut pats: Vec<_> = pats.into_iter().collect();
pats.sort_by(|a, b| a.as_bytes()[1].cmp(&b.as_bytes()[1]));
assert_eq!(pats, vec!["$a", "_b", "#c", "$d"]);
}
}