From 64ba9efa349bb54a86ca01f497329393c0ba11a6 Mon Sep 17 00:00:00 2001 From: profitroll Date: Sun, 26 May 2024 21:39:55 +0200 Subject: [PATCH] Replaced hasattr in dumps with supports_argument --- src/libbot/__main__.py | 3 ++- src/libbot/_utils.py | 22 ++++++++++++++++++++++ src/libbot/sync/__main__.py | 3 ++- tests/test_utils.py | 24 ++++++++++++++++++++++++ 4 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 src/libbot/_utils.py create mode 100644 tests/test_utils.py diff --git a/src/libbot/__main__.py b/src/libbot/__main__.py index b7f03b3..eb696e1 100644 --- a/src/libbot/__main__.py +++ b/src/libbot/__main__.py @@ -8,6 +8,7 @@ try: except ImportError: from json import dumps, loads +from ._utils import supports_argument from .sync._nested import nested_delete, nested_set @@ -36,7 +37,7 @@ async def json_write(data: Any, path: Union[str, Path]) -> None: async with aiofiles.open(str(path), mode="w", encoding="utf-8") as f: await f.write( dumps(data, ensure_ascii=False, escape_forward_slashes=False, indent=4) - if hasattr(dumps, "escape_forward_slashes") + if supports_argument(dumps, "escape_forward_slashes") else dumps(data, ensure_ascii=False, indent=4) ) diff --git a/src/libbot/_utils.py b/src/libbot/_utils.py new file mode 100644 index 0000000..122d750 --- /dev/null +++ b/src/libbot/_utils.py @@ -0,0 +1,22 @@ +import inspect +from typing import Callable + + +def supports_argument(func: Callable, arg_name: str) -> bool: + """Check whether a function has a specific argument + + ### Args: + * func (`Callable`): Function to be inspected + * arg_name (`str`): Argument to be checked + + ### Returns: + * `bool`: `True` if argument is supported and `False` if not + """ + if hasattr(func, "__code__"): + return arg_name in inspect.signature(func).parameters + elif hasattr(func, "__doc__"): + if doc := func.__doc__: + first_line = doc.splitlines()[0] + return arg_name in first_line + + return False diff --git a/src/libbot/sync/__main__.py b/src/libbot/sync/__main__.py index 4745ab8..7ece6cb 100644 --- a/src/libbot/sync/__main__.py +++ b/src/libbot/sync/__main__.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import Any, Union +from .._utils import supports_argument from ._nested import nested_delete, nested_set try: @@ -34,7 +35,7 @@ def json_write(data: Any, path: Union[str, Path]) -> None: with open(str(path), mode="w", encoding="utf-8") as f: f.write( dumps(data, ensure_ascii=False, escape_forward_slashes=False, indent=4) - if hasattr(dumps, "escape_forward_slashes") + if supports_argument(dumps, "escape_forward_slashes") else dumps(data, ensure_ascii=False, indent=4) ) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..c0aef75 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,24 @@ +from typing import Callable + +import pytest + +from libbot._utils import supports_argument + + +def func1(foo: str, bar: str): + pass + + +def func2(foo: str): + pass + + +@pytest.mark.parametrize( + "func, arg_name, result", + [ + (func1, "foo", True), + (func2, "bar", False), + ], +) +def test_supports_argument(func: Callable, arg_name: str, result: bool): + assert supports_argument(func, arg_name) == result