Skip to content

Commit

Permalink
Merge pull request #2543 from sopel-irc/tools.calculation-typing
Browse files Browse the repository at this point in the history
tools, tools.calculation: docs/type-hint improvements, API fixes, better test coverage
  • Loading branch information
dgw committed Nov 3, 2023
2 parents d9b6a74 + 3c4e477 commit 3b91a76
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 25 deletions.
1 change: 1 addition & 0 deletions docs/source/package/tools/calculation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ sopel.tools.calculation

.. automodule:: sopel.tools.calculation
:members:
:ignore-module-all:
14 changes: 11 additions & 3 deletions sopel/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Useful miscellaneous tools and shortcuts for Sopel plugins
*Availability: 3+*
.. versionadded:: 3.0
"""

# tools.py - Sopel misc tools
Expand Down Expand Up @@ -68,8 +68,7 @@ def get_input(prompt):
.. deprecated:: 7.1
Use of this function will become a warning when Python 2 support is
dropped in Sopel 8.0. The function will be removed in Sopel 8.1.
This function will be removed in Sopel 8.1.
"""
return input(prompt)
Expand Down Expand Up @@ -116,6 +115,11 @@ class OutputRedirect:
"""Redirect the output to the terminal and a log file.
A simplified object used to write to both the terminal and a log file.
.. deprecated:: 8.0
Vestige of old logging system. Will be removed in Sopel 8.1.
"""

@deprecated(
Expand Down Expand Up @@ -200,6 +204,8 @@ def get_hostmask_regex(mask):
:param str mask: the hostmask that the pattern should match
:return: a compiled regex pattern matching the given ``mask``
:rtype: :ref:`re.Pattern <python:re-objects>`
.. versionadded:: 4.4
"""
mask = re.escape(mask)
mask = mask.replace(r'\*', '.*')
Expand Down Expand Up @@ -244,6 +250,8 @@ def chain_loaders(*lazy_loaders):
together into one. It's primarily a helper for lazy rule decorators such as
:func:`sopel.plugin.url_lazy`.
.. versionadded:: 7.1
.. important::
This function doesn't check the uniqueness of regexes generated by
Expand Down
59 changes: 37 additions & 22 deletions sopel/tools/calculation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
"""Tools to help safely do calculations from user input"""
"""Tools to help safely do calculations from user input
.. versionadded:: 5.3
.. note::
Most of this is internal machinery. :func:`eval_equation` is the "public"
part, used by Sopel's built-in ``calc`` plugin.
"""
from __future__ import annotations

import ast
import operator
import time
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Callable, Optional

__all__ = ['eval_equation']

Expand All @@ -21,16 +33,20 @@ class ExpressionEvaluator:

class Error(Exception):
"""Internal exception type for :class:`ExpressionEvaluator`\\s."""
pass

def __init__(self, bin_ops=None, unary_ops=None):
def __init__(
self,
bin_ops: Optional[dict[type[ast.operator], Callable]] = None,
unary_ops: Optional[dict[type[ast.unaryop], Callable]] = None
):
self.binary_ops = bin_ops or {}
self.unary_ops = unary_ops or {}

def __call__(self, expression_str, timeout=5.0):
def __call__(self, expression_str: str, timeout: float = 5.0):
"""Evaluate a Python expression and return the result.
:param str expression_str: the expression to evaluate
:param expression_str: the expression to evaluate
:param timeout: timeout for processing the expression, in seconds
:raise SyntaxError: if the given ``expression_str`` is not a valid
Python statement
:raise ExpressionEvaluator.Error: if the instance of
Expand All @@ -40,14 +56,12 @@ def __call__(self, expression_str, timeout=5.0):
ast_expression = ast.parse(expression_str, mode='eval')
return self._eval_node(ast_expression.body, time.time() + timeout)

def _eval_node(self, node, timeout):
def _eval_node(self, node: ast.AST, timeout: float):
"""Recursively evaluate the given :class:`ast.Node <ast.AST>`.
:param node: the AST node to evaluate
:type node: :class:`ast.AST`
:param timeout: how long the expression is allowed to process before
timing out and failing
:type timeout: int or float
timing out and failing, in seconds
:raise ExpressionEvaluator.Error: if it can't handle the ``node``
Uses :attr:`self.binary_ops` and :attr:`self.unary_ops` for the
Expand Down Expand Up @@ -102,13 +116,11 @@ def _eval_node(self, node, timeout):
)


def guarded_mul(left, right):
def guarded_mul(left: float, right: float):
"""Multiply two values, guarding against overly large inputs.
:param left: the left operand
:type left: int or float
:param right: the right operand
:type right: int or float
:raise ValueError: if the inputs are too large to handle safely
"""
# Only handle ints because floats will overflow anyway.
Expand All @@ -127,13 +139,11 @@ def guarded_mul(left, right):
return operator.mul(left, right)


def pow_complexity(num, exp):
def pow_complexity(num: int, exp: int):
"""Estimate the worst case time :func:`pow` takes to calculate.
:param num: base
:type num: int or float
:param exp: exponent
:type exp: int or float
This function is based on experimental data from the time it takes to
calculate ``num**exp`` in 32-bit CPython 2.7.6 on an Intel Core i7-2670QM
Expand Down Expand Up @@ -195,13 +205,11 @@ def pow_complexity(num, exp):
return exp ** 1.590 * num.bit_length() ** 1.73 / 36864057619.3


def guarded_pow(num, exp):
def guarded_pow(num: float, exp: float):
"""Raise a number to a power, guarding against overly large inputs.
:param num: base
:type num: int or float
:param exp: exponent
:type exp: int or float
:raise ValueError: if the inputs are too large to handle safely
"""
# Only handle ints because floats will overflow anyway.
Expand All @@ -218,7 +226,14 @@ def guarded_pow(num, exp):


class EquationEvaluator(ExpressionEvaluator):
__bin_ops = {
"""Specific subclass of :class:`ExpressionEvaluator` for simple math
This presets the allowed operators to safeguard against user input that
could try to do things that will adversely affect the running bot, while
still letting users pass arbitrary mathematical expressions using the
available (mostly arithmetic) operators.
"""
__bin_ops: dict[type[ast.operator], Callable] = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: guarded_mul,
Expand All @@ -228,7 +243,7 @@ class EquationEvaluator(ExpressionEvaluator):
ast.FloorDiv: operator.floordiv,
ast.BitXor: guarded_pow
}
__unary_ops = {
__unary_ops: dict[type[ast.unaryop], Callable] = {
ast.USub: operator.neg,
ast.UAdd: operator.pos,
}
Expand All @@ -240,8 +255,8 @@ def __init__(self):
unary_ops=self.__unary_ops
)

def __call__(self, expression_str):
result = ExpressionEvaluator.__call__(self, expression_str)
def __call__(self, expression_str: str, timeout: float = 5.0):
result = ExpressionEvaluator.__call__(self, expression_str, timeout)

# This wrapper is here so additional sanity checks could be done
# on the result of the eval, but currently none are done.
Expand Down
68 changes: 68 additions & 0 deletions test/tools/test_tools_calculation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Tests Sopel's calculation tools"""
from __future__ import annotations

import ast
import operator

import pytest

from sopel.tools.calculation import EquationEvaluator, ExpressionEvaluator


def test_expression_eval():
"""Ensure ExpressionEvaluator respects limited operator set."""
OPS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
}
evaluator = ExpressionEvaluator(bin_ops=OPS)

assert evaluator("1 + 1") == 2
assert evaluator("43 - 1") == 42
assert evaluator("1 + 1 - 2") == 0

with pytest.raises(ExpressionEvaluator.Error) as exc:
evaluator("2 * 2")
assert "Unsupported binary operator" in exc.value.args[0]

with pytest.raises(ExpressionEvaluator.Error) as exc:
evaluator("~2")
assert "Unsupported unary operator" in exc.value.args[0]


def test_equation_eval_invalid_constant():
"""Ensure unsupported constants are rejected."""
evaluator = EquationEvaluator()

with pytest.raises(ExpressionEvaluator.Error) as exc:
evaluator("2 + 'string'")
assert "values are not supported" in exc.value.args[0]


def test_equation_eval_timeout():
"""Ensure EquationEvaluator times out as expected."""
# timeout is added to the current time;
# negative means the timeout is "reached" before even starting
timeout = -1.0
evaluator = EquationEvaluator()

with pytest.raises(ExpressionEvaluator.Error) as exc:
evaluator("1000000**100", timeout)
assert "Time for evaluating" in exc.value.args[0]

with pytest.raises(ExpressionEvaluator.Error) as exc:
evaluator("+42", timeout)
assert "Time for evaluating" in exc.value.args[0]


def test_equation_eval():
"""Test that EquationEvaluator correctly parses input and calculates results."""
evaluator = EquationEvaluator()

assert evaluator("1 + 1") == 2
assert evaluator("43 - 1") == 42
assert evaluator("(((1 + 1 + 2) * 3 / 5) ** 8 - 13) // 21 % 35") == 16.0
assert evaluator("-42") == -42
assert evaluator("-(-42)") == 42
assert evaluator("+42") == 42
assert evaluator("3 ^ 2") == 9

0 comments on commit 3b91a76

Please sign in to comment.