#![allow(clippy::should_implement_trait)]
use crate::grammar::{BinaryExpr, BinaryOperator, Expr, RcExpr, RcExpression};
use crate::math::gcd;
use crate::partial_evaluator::flatten::flatten_expr;
use crate::utils::{get_flattened_binary_args, unflatten_binary_expr, UnflattenStrategy};
use core::cmp::max;
use std::collections::{HashMap, HashSet};
#[derive(Default, Clone, Eq, PartialEq, Debug)]
pub struct Poly {
pub vec: Vec<isize>,
}
impl From<Vec<isize>> for Poly {
fn from(v: Vec<isize>) -> Poly {
Self::new(v)
}
}
impl From<&Vec<isize>> for Poly {
fn from(v: &Vec<isize>) -> Poly {
Self::new(v.clone())
}
}
#[macro_export]
macro_rules! poly {
($($x:expr),+ $(,)?) => (
Poly::new(vec![$($x),+])
);
() => {
Poly::empty()
};
}
impl Poly {
pub fn new(vec: Vec<isize>) -> Self {
Self { vec }.truncate_zeros()
}
pub fn empty() -> Self {
Self::new(vec![]).truncate_zeros()
}
#[inline]
pub fn is_zero(&self) -> bool {
self.vec.is_empty() || self.vec.iter().all(|&n| n == 0)
}
#[inline]
pub fn is_one(&self) -> bool {
let is_one = self.vec.first() == Some(&1);
for coeff in self.vec.iter().skip(1) {
if coeff != &0 {
return false;
}
}
is_one
}
#[inline]
pub fn deg(&self) -> usize {
self.vec.len().saturating_sub(1)
}
pub fn primitive(self) -> Self {
if self.is_zero() {
return self;
}
let largest_gcd = self
.vec
.iter()
.fold(self.vec[0].abs() as usize, |largest_gcd, item| {
gcd(item.abs() as usize, largest_gcd)
}) as isize;
if largest_gcd == 1 {
self
} else {
Poly::new(self.vec.iter().map(|e| e / largest_gcd).collect())
}
}
fn add_term(mut self, coeff: isize, pow: usize) -> Self {
if coeff == 0 {
return self;
}
while self.vec.is_empty() || self.deg() < pow {
self.vec.push(0);
}
self.vec[pow as usize] += coeff;
self
}
fn mul_term(mut self, coeff: isize, pow: usize) -> Self {
if coeff == 0 || self.is_zero() {
return Poly::empty();
}
for term in self.vec.iter_mut() {
*term *= coeff;
}
for _ in 0..pow {
self.vec.insert(0, 0);
}
self
}
#[inline]
pub fn mul_scalar(mut self, c: isize) -> Self {
for term in self.vec.iter_mut() {
*term *= c;
}
self
}
pub fn div_scalar(self, c: isize) -> Result<Self, &'static str> {
if c == 0 {
Err("Cannot divide a polynomial by 0")
} else {
Ok(self.mul_scalar(1 / c))
}
}
fn sub(mut self, other: Self) -> Self {
while self.deg() < other.deg() {
self.vec.push(0);
}
for i in 0..other.vec.len() {
self.vec[i] -= other.vec[i];
}
self.truncate_zeros()
}
#[inline]
fn truncate_zeros(mut self) -> Self {
while self.vec.last() == Some(&0) {
self.vec.pop();
}
self
}
pub fn div(self, other: Poly) -> Result<(Self, Self), &'static str> {
let d_self = self.deg();
let d_other = other.deg();
if other.is_zero() {
return Err("Cannot divide by a 0 polynomial");
} else if d_self < d_other {
return Ok((poly![], Poly::new(self.vec)));
}
let lc_other = other.lc();
let mut rem_poly = self;
let mut d_rem = d_self;
let mut quo = poly![];
loop {
let lc_rem = rem_poly.lc();
if lc_rem % lc_other != 0 {
break;
}
let cur_term_coeff = lc_rem / lc_other;
quo = quo.add_term(cur_term_coeff, d_rem - d_other);
rem_poly = rem_poly.sub(
other.clone().mul_term(cur_term_coeff, d_rem - d_other),
);
let d_rem_old = d_rem;
d_rem = rem_poly.deg();
if d_rem < d_other {
break;
} else if d_rem >= d_rem_old {
return Err("Unexpected state: remainder degreee not lower after division");
}
}
Ok((quo, rem_poly))
}
pub fn max_norm(&self) -> usize {
let mut max_n = 0;
for i in &self.vec {
max_n = max(max_n, i.abs() as usize);
}
max_n
}
#[inline]
pub fn lc(&self) -> isize {
*self.vec.last().unwrap_or(&0)
}
#[inline]
pub fn eval(&self, x: isize) -> isize {
self.vec.iter().rev().fold(0, |mut res, &n| {
res *= x;
res + n
})
}
pub fn from_expr(
expr: RcExpr,
relative_to: Option<RcExpr>,
) -> Result<(Self, Option<RcExpr>), String> {
let expr = flatten_expr(expr);
let poly_parts = get_flattened_binary_args(expr, BinaryOperator::Plus);
let mut uniq_terms = HashSet::<RcExpr>::new();
let mut degree_coeffs = HashMap::<usize, isize>::new();
if let Some(ref term) = relative_to {
uniq_terms.insert(term.clone());
}
let mut konst_f64 = 0.;
for poly_part in poly_parts.iter() {
match poly_part.as_ref() {
Expr::Const(c) => konst_f64 += c,
Expr::BinaryExpr(BinaryExpr {
op: BinaryOperator::Mult,
lhs,
rhs,
}) if lhs.is_const() || rhs.is_const() => {
let (c_f64, term) = if lhs.is_const() {
(lhs.get_const().unwrap(), rhs)
} else {
(rhs.get_const().unwrap(), lhs)
};
let coeff = c_f64 as isize;
if (coeff as f64 - c_f64).abs() > std::f64::EPSILON {
return Err(format!("Expected an integer coefficient for {}", poly_part));
}
let (term, pow) = term_and_pow_from_expr(term.clone())?;
degree_coeffs.insert(pow, coeff);
uniq_terms.insert(term);
}
_ => {
let (term, pow) = term_and_pow_from_expr(poly_part.clone())?;
degree_coeffs.insert(pow, 1);
uniq_terms.insert(term);
}
}
if uniq_terms.len() > 1 {
return Err(format!(
"Expected a singular unique term, found {}: {:#?}",
uniq_terms.len(),
uniq_terms
));
}
}
let konst = konst_f64 as isize;
if (konst as f64 - konst_f64).abs() > std::f64::EPSILON {
return Err(format!("Expected an integer constant, got {}", konst_f64));
}
let degree = degree_coeffs.keys().max();
match degree {
None if konst != 0 => Ok((poly![konst], None)),
None => Ok((poly![], None)),
Some(degree) => {
let len = degree + 1;
let mut poly = vec![0; len];
poly[0] = konst;
for (pow, coeff) in degree_coeffs.into_iter() {
poly[pow] = coeff;
}
Ok((Self::new(poly), uniq_terms.into_iter().next()))
}
}
}
pub fn to_expr(&self, relative_to: RcExpr, span: crate::Span) -> RcExpr {
let mut terms = Vec::with_capacity(self.vec.len());
for (pow, coeff) in self.vec.iter().enumerate() {
let term = match coeff {
0 => {
continue;
}
1 => relative_to.clone(),
_ => rc_expr!(
Expr::BinaryExpr(BinaryExpr::mult(
rc_expr!(Expr::Const(*coeff as f64), span),
relative_to.clone()
)),
span
),
};
terms.push(match pow {
0 => rc_expr!(Expr::Const(*coeff as f64), span),
1 => term,
_ => rc_expr!(
Expr::BinaryExpr(BinaryExpr::exp(
term,
rc_expr!(Expr::Const(pow as f64), span),
)),
span
),
});
}
if terms.is_empty() {
return rc_expr!(Expr::Const(0.), span);
}
unflatten_binary_expr(
terms.as_ref(),
BinaryOperator::Plus,
UnflattenStrategy::Left,
)
}
pub fn to_string(&self, var: &str) -> String {
self.to_expr(
rc_expr!(Expr::Var(intern_str!(var)), crate::DUMMY_SP),
crate::DUMMY_SP,
)
.to_string()
}
}
fn term_and_pow_from_expr(expr: RcExpr) -> Result<(RcExpr, usize), String> {
match expr.as_ref() {
Expr::BinaryExpr(BinaryExpr {
op: BinaryOperator::Exp,
lhs,
rhs,
}) if rhs.is_const() => {
let (pow_f64, term) = (rhs.get_const().unwrap(), lhs);
let pow = pow_f64 as usize;
if (pow as f64 - pow_f64).abs() > std::f64::EPSILON {
return Err(format!("Expected a positive term degree for {}", expr));
}
Ok((term.clone(), pow))
}
_ => Ok((expr, 1)),
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::parse_expr;
#[test]
fn add_1() {
assert_eq!(poly![-1, 2].add_term(2, 4), poly![-1, 2, 0, 0, 2]);
}
#[test]
fn add_2() {
assert_eq!(poly![-1, 2].add_term(2, 1), poly![-1, 4]);
}
#[test]
fn add_3() {
assert_eq!(poly![5, 0, 3].add_term(2, 1), poly![5, 2, 3]);
}
#[test]
fn mul_1() {
assert_eq!(poly![5, 0, 3].mul_term(0, 2), poly![]);
}
#[test]
fn mul_2() {
assert_eq!(poly![5, 0, 3].mul_term(2, 2), poly![0, 0, 10, 0, 6]);
}
#[test]
fn sub_1() {
assert_eq!(poly![5, 0, 3].sub(poly![1, 0, 1]), poly![4, 0, 2]);
}
#[test]
fn sub_2() {
assert_eq!(poly![-1, 0, 1].sub(poly![-2, 1]), poly![1, -1, 1]);
}
#[test]
fn sub_3() {
assert_eq!(poly![-1, 0, 1].sub(poly![0, -1, 1]), poly![-1, 1]);
}
#[test]
fn sub_4() {
assert_eq!(poly![-1, 0, 1].sub(poly![3, 2]), poly![-4, -2, 1]);
}
#[test]
fn sub_5() {
assert_eq!(poly![3, 2].sub(poly![-1, 0, 1]), poly![4, 2, -1]);
}
#[test]
fn div_1() {
assert_eq!(
poly![-1, 0, 1].div(poly![-4, 2]).unwrap(),
(poly![], poly![-1, 0, 1])
);
}
#[test]
fn div_2() {
assert_eq!(
poly![-1, 0, 1].div(poly![-1, 1]).unwrap(),
(poly![1, 1], poly![])
)
}
#[test]
fn is_one() {
let cases = [
(poly![1], true),
(poly![1, 0, 0, 0, 0], true),
(poly![1, 0, 0, 1, 0], false),
(poly![0, 0, 0, 0, 0], false),
(poly![2], false),
(poly![], false),
];
for (poly, is_one) in cases.iter() {
assert_eq!(poly.is_one(), *is_one);
}
}
macro_rules! poly_from_expr_tests {
($($case:ident: $expr:expr => $expected:expr)*) => {
$(
poly_from_expr_tests!($case: $expr, None => $expected);
)*
};
($($case:ident: $expr:expr, $relative:expr => $expected:expr)*) => {
$(
#[test]
fn $case() {
let expr = parse_expr!($expr);
let relative: Option<&str> = $relative;
let has_relative = relative.is_some();
let rel = relative.map(|r: &str| parse_expr!(r)).unwrap_or(RcExpr::empty(crate::DUMMY_SP));
let rel_opt = if has_relative { Some(rel) } else { None };
let poly = Poly::from_expr(expr, rel_opt).ok().map(|(p, t)| (p, t.map(|expr| expr.to_string())));
assert_eq!(poly, $expected);
}
)*
};
}
poly_from_expr_tests! {
empty: "0" => Some((poly![], None))
konst: "1 + 2" => Some((poly![3], None))
single_deg: "x" => Some((poly![0,1], Some("x".to_string())))
multi_deg: "10 + x + x^2 + x^4 + x^8" => Some((poly![10, 1, 1, 0, 1, 0, 0, 0, 1], Some("x".to_string())))
with_coeff: "2x + x^3 + 10x^2 + 5x^4" => Some((poly![0, 2, 10, 1, 5], Some("x".to_string())))
complex_term: "2(x + y ^ z) + 5(x + y ^ z)^3" => Some((poly![0, 2, 0, 5], Some("x + y ^ z".to_string())))
multi_term: "10 + x + y^2" => None
}
poly_from_expr_tests! {
relative: "10 + x + x^2 + x^4", Some("x") => Some((poly![10, 1, 1, 0, 1], Some("x".to_string())))
relative_fails: "10 + x + x^2 + x^4", Some("y") => None
}
macro_rules! poly_to_expr_tests {
($($case:ident: $poly:expr, $relative:expr => $expected:expr)*) => {
$(
#[test]
fn $case() {
let rel = parse_expr!($relative);
let expr = $poly.to_expr(rel, crate::DUMMY_SP);
assert_eq!(expr.to_string(), $expected);
}
)*
};
}
poly_to_expr_tests! {
to_empty: poly![], "x" => "0"
to_empty_all_zeros: poly![0, 0, 0, 0], "x" => "0"
zero_coefficient: poly![10, 0], "x" => "10"
one_coefficient: poly![5, 1, 1, 0, 1], "x" => "5 + x + x ^ 2 + x ^ 4"
larger_coefficient: poly![1, 2, 3, 4, 5], "x" => "1 + 2 * x + (3 * x) ^ 2 + (4 * x) ^ 3 + (5 * x) ^ 4"
}
}