diff --git a/minijinja-py/src/typeconv.rs b/minijinja-py/src/typeconv.rs index 8b823606..4078ce07 100644 --- a/minijinja-py/src/typeconv.rs +++ b/minijinja-py/src/typeconv.rs @@ -18,7 +18,11 @@ static AUTO_ESCAPE_CACHE: Mutex> = Mutex::new(BTree static MARK_SAFE: GILOnceCell> = GILOnceCell::new(); fn is_safe_attr(name: &str) -> bool { - !name.starts_with('_') + if matches!(name, "__add__" | "__sub__") { + true + } else { + !name.starts_with('_') + } } fn is_dictish(val: &Bound<'_, PyAny>) -> bool { diff --git a/minijinja-py/tests/test_basic.py b/minijinja-py/tests/test_basic.py index 885eb7ea..2ca8dee7 100644 --- a/minijinja-py/tests/test_basic.py +++ b/minijinja-py/tests/test_basic.py @@ -504,3 +504,18 @@ def test_striptags(): env = Environment() assert env.eval_expr("'foo'|striptags") == "foo" assert env.eval_expr("'ä'|striptags") == "รค" + + +def test_addition(): + class MyObj: + def __add__(self, other): + assert other == 22 + return 42 + + def __sub__(self, other): + assert other == 11 + return 23 + + env = Environment() + assert env.eval_expr("x + 22", x=MyObj()) == 42 + assert env.eval_expr("x - 11", x=MyObj()) == 23 diff --git a/minijinja/src/compiler/ast.rs b/minijinja/src/compiler/ast.rs index b55e4ba5..8f35d41d 100644 --- a/minijinja/src/compiler/ast.rs +++ b/minijinja/src/compiler/ast.rs @@ -217,13 +217,13 @@ impl Expr<'_> { return None; }; match c.op { - BinOpKind::Add => ops::add(&left, &right).ok(), - BinOpKind::Sub => ops::sub(&left, &right).ok(), - BinOpKind::Mul => ops::mul(&left, &right).ok(), - BinOpKind::Div => ops::div(&left, &right).ok(), - BinOpKind::FloorDiv => ops::int_div(&left, &right).ok(), - BinOpKind::Rem => ops::rem(&left, &right).ok(), - BinOpKind::Pow => ops::pow(&left, &right).ok(), + BinOpKind::Add => ops::add(None, &left, &right).ok(), + BinOpKind::Sub => ops::sub(None, &left, &right).ok(), + BinOpKind::Mul => ops::mul(None, &left, &right).ok(), + BinOpKind::Div => ops::div(None, &left, &right).ok(), + BinOpKind::FloorDiv => ops::int_div(None, &left, &right).ok(), + BinOpKind::Rem => ops::rem(None, &left, &right).ok(), + BinOpKind::Pow => ops::pow(None, &left, &right).ok(), BinOpKind::Concat => Some(ops::string_concat(left, &right)), BinOpKind::Eq => Some(Value::from(left == right)), BinOpKind::Ne => Some(Value::from(left != right)), diff --git a/minijinja/src/filters.rs b/minijinja/src/filters.rs index e2614b16..19d7d195 100644 --- a/minijinja/src/filters.rs +++ b/minijinja/src/filters.rs @@ -594,7 +594,7 @@ mod builtins { format!("can only sum numbers, got {}", value.kind()), )); } - rv = ok!(ops::add(&rv, &value)); + rv = ok!(ops::add(Some(state), &rv, &value)); } Ok(rv) diff --git a/minijinja/src/value/ops.rs b/minijinja/src/value/ops.rs index 537d6cb2..b2b94cd9 100644 --- a/minijinja/src/value/ops.rs +++ b/minijinja/src/value/ops.rs @@ -1,5 +1,8 @@ +use core::slice; + use crate::error::{Error, ErrorKind}; use crate::value::{DynObject, ObjectRepr, Value, ValueKind, ValueRepr}; +use crate::vm::State; const MIN_I128_AS_POS_U128: u128 = 170141183460469231731687303715884105728; @@ -178,9 +181,28 @@ fn failed_op(op: &str, lhs: &Value, rhs: &Value) -> Error { ) } +macro_rules! special_method_impl { + ($method:expr, $state:expr, $lhs:expr, $rhs:expr) => { + if let Some(state) = $state { + if let Some(obj) = $lhs.as_object() { + match obj.call_method(state, $method, slice::from_ref($rhs)) { + Ok(rv) => return Ok(rv), + Err(err) => { + if err.kind() != ErrorKind::UnknownMethod { + return Err(err); + } + } + } + } + } + }; +} + macro_rules! math_binop { - ($name:ident, $int:ident, $float:tt) => { - pub fn $name(lhs: &Value, rhs: &Value) -> Result { + ($name:ident, $int:ident, $float:tt, $method:expr) => { + pub fn $name(state: Option<&State>, lhs: &Value, rhs: &Value) -> Result { + special_method_impl!($method, state, lhs, rhs); + match coerce(lhs, rhs, true) { Some(CoerceResult::I128(a, b)) => match a.$int(b) { Some(val) => Ok(int_as_value(val)), @@ -193,7 +215,9 @@ macro_rules! math_binop { } } -pub fn add(lhs: &Value, rhs: &Value) -> Result { +pub fn add(state: Option<&State>, lhs: &Value, rhs: &Value) -> Result { + special_method_impl!("__add__", state, lhs, rhs); + if matches!(lhs.kind(), ValueKind::Seq | ValueKind::Iterable) && matches!(rhs.kind(), ValueKind::Seq | ValueKind::Iterable) { @@ -220,10 +244,12 @@ pub fn add(lhs: &Value, rhs: &Value) -> Result { } } -math_binop!(sub, checked_sub, -); -math_binop!(rem, checked_rem_euclid, %); +math_binop!(sub, checked_sub, -, "__sub__"); +math_binop!(rem, checked_rem_euclid, %, "__mod__"); + +pub fn mul(state: Option<&State>, lhs: &Value, rhs: &Value) -> Result { + special_method_impl!("__mul__", state, lhs, rhs); -pub fn mul(lhs: &Value, rhs: &Value) -> Result { if let Some((s, n)) = lhs .as_str() .map(|s| (s, rhs)) @@ -308,7 +334,8 @@ fn repeat_iterable(n: &Value, seq: &DynObject) -> Result { })) } -pub fn div(lhs: &Value, rhs: &Value) -> Result { +pub fn div(state: Option<&State>, lhs: &Value, rhs: &Value) -> Result { + special_method_impl!("__truediv__", state, lhs, rhs); fn do_it(lhs: &Value, rhs: &Value) -> Option { let a = some!(as_f64(lhs, true)); let b = some!(as_f64(rhs, true)); @@ -317,7 +344,8 @@ pub fn div(lhs: &Value, rhs: &Value) -> Result { do_it(lhs, rhs).ok_or_else(|| impossible_op("/", lhs, rhs)) } -pub fn int_div(lhs: &Value, rhs: &Value) -> Result { +pub fn int_div(state: Option<&State>, lhs: &Value, rhs: &Value) -> Result { + special_method_impl!("__floordiv__", state, lhs, rhs); match coerce(lhs, rhs, true) { Some(CoerceResult::I128(a, b)) => { if b != 0 { @@ -334,7 +362,8 @@ pub fn int_div(lhs: &Value, rhs: &Value) -> Result { } /// Implements a binary `pow` operation on values. -pub fn pow(lhs: &Value, rhs: &Value) -> Result { +pub fn pow(state: Option<&State>, lhs: &Value, rhs: &Value) -> Result { + special_method_impl!("__pow__", state, lhs, rhs); match coerce(lhs, rhs, true) { Some(CoerceResult::I128(a, b)) => { match TryFrom::try_from(b).ok().and_then(|b| a.checked_pow(b)) { @@ -421,22 +450,22 @@ mod tests { #[test] fn test_adding() { - let err = add(&Value::from("a"), &Value::from(42)).unwrap_err(); + let err = add(None, &Value::from("a"), &Value::from(42)).unwrap_err(); assert_eq!( err.to_string(), "invalid operation: tried to use + operator on unsupported types string and number" ); assert_eq!( - add(&Value::from(1), &Value::from(2)).unwrap(), + add(None, &Value::from(1), &Value::from(2)).unwrap(), Value::from(3) ); assert_eq!( - add(&Value::from("foo"), &Value::from("bar")).unwrap(), + add(None, &Value::from("foo"), &Value::from("bar")).unwrap(), Value::from("foobar") ); - let err = add(&Value::from(i128::MAX), &Value::from(1)).unwrap_err(); + let err = add(None, &Value::from(i128::MAX), &Value::from(1)).unwrap_err(); assert_eq!( err.to_string(), "invalid operation: unable to calculate 170141183460469231731687303715884105727 + 1" @@ -445,44 +474,44 @@ mod tests { #[test] fn test_subtracting() { - let err = sub(&Value::from("a"), &Value::from(42)).unwrap_err(); + let err = sub(None, &Value::from("a"), &Value::from(42)).unwrap_err(); assert_eq!( err.to_string(), "invalid operation: tried to use - operator on unsupported types string and number" ); - let err = sub(&Value::from("foo"), &Value::from("bar")).unwrap_err(); + let err = sub(None, &Value::from("foo"), &Value::from("bar")).unwrap_err(); assert_eq!( err.to_string(), "invalid operation: tried to use - operator on unsupported types string and string" ); assert_eq!( - sub(&Value::from(2), &Value::from(1)).unwrap(), + sub(None, &Value::from(2), &Value::from(1)).unwrap(), Value::from(1) ); } #[test] fn test_dividing() { - let err = div(&Value::from("a"), &Value::from(42)).unwrap_err(); + let err = div(None, &Value::from("a"), &Value::from(42)).unwrap_err(); assert_eq!( err.to_string(), "invalid operation: tried to use / operator on unsupported types string and number" ); - let err = div(&Value::from("foo"), &Value::from("bar")).unwrap_err(); + let err = div(None, &Value::from("foo"), &Value::from("bar")).unwrap_err(); assert_eq!( err.to_string(), "invalid operation: tried to use / operator on unsupported types string and string" ); assert_eq!( - div(&Value::from(100), &Value::from(2)).unwrap(), + div(None, &Value::from(100), &Value::from(2)).unwrap(), Value::from(50.0) ); - let err = int_div(&Value::from(i128::MIN), &Value::from(-1i128)).unwrap_err(); + let err = int_div(None, &Value::from(i128::MIN), &Value::from(-1i128)).unwrap_err(); assert_eq!( err.to_string(), "invalid operation: unable to calculate -170141183460469231731687303715884105728 // -1" diff --git a/minijinja/src/vm/mod.rs b/minijinja/src/vm/mod.rs index 59c88275..9f4588ee 100644 --- a/minijinja/src/vm/mod.rs +++ b/minijinja/src/vm/mod.rs @@ -253,7 +253,7 @@ impl<'env> Vm<'env> { ($method:ident) => {{ b = stack.pop(); a = stack.pop(); - stack.push(ctx_ok!(ops::$method(&a, &b))); + stack.push(ctx_ok!(ops::$method(Some(state), &a, &b))); }}; }