"""
Module for testing prerequisites a student's submission needs to fulfill.
Export functions for asserting properties of student code:
- has_signature -- is a function compatible with the supplied signature?
Might in the future test for static attributes or methods of classes
"""
from __future__ import annotations
__version__ = "0.3"
__all__ = ["has_signature", "writes", "writes_file"]
import importlib.util
import inspect
import io
import logging
import os
import pathlib
from collections.abc import Callable
from typing import Any
import pytest
logger = logging.getLogger()
[docs]
def writes_file(
spec: importlib.machinery.ModuleSpec,
*args: object,
name: str | None = None,
created: set[pathlib.Path] | None = None,
deleted: set[pathlib.Path] | None = None,
modified: set[pathlib.Path] | None = None,
) -> pytest.ExitCode | tuple[pytest.ExitCode, Any, Any]:
"""
Test file writes of module execution as ``name``.
Parameters
----------
spec : importlib.machinery.ModuleSpec
Module specification to be executed.
*args : tuple
Unused positional arguments.
name : str or None, optional
``__name__`` of module at execution time, by default None.
created : set or None, optional
Expected set of created file paths, by default None.
deleted : set or None, optional
Expected set of deleted file paths, by default None.
modified : set or None, optional
Expected set of modified file paths, by default None.
Returns
-------
Enum
``pytest.ExitCode.OK`` if file operations match expectations,
otherwise ``pytest.ExitCode.TESTS_FAILED``.
"""
def recursive_stats(path: pathlib.Path, return_dict: dict[pathlib.Path, os.stat_result] | None = None) -> dict[pathlib.Path, os.stat_result]:
"""
Recursively gather file paths and stats in subdirs.
Parameters
----------
path : pathlib.Path
Root path to start gathering stats from.
return_dict : dict or None, optional
Accumulator dict for recursive calls, by default None.
Returns
-------
dict
Mapping of file paths to their stat results.
"""
if return_dict is None:
return_dict = {}
if path.is_file():
return_dict[path] = path.stat()
else:
for child in path.iterdir():
return_dict = recursive_stats(child, return_dict)
return return_dict
result = None
module = importlib.util.module_from_spec(spec)
if name is not None:
logger.debug("changing name to %s", name)
spec_name, spec.name = spec.name, name
spec_loader_name, spec.loader.name = spec.loader.name, name
module.__name__ = name
pre_exec_stats = recursive_stats(pathlib.Path())
spec.loader.exec_module(module)
post_exec_stats = recursive_stats(pathlib.Path())
pre, post = set(pre_exec_stats.keys()), set(post_exec_stats.keys())
created_files, deleted_files, shared_files = (
post - pre,
pre - post,
pre & post,
)
modified_files = {file for file in shared_files if pre_exec_stats[file] != post_exec_stats[file]}
for mode, expected, actual in [
("created", created, created_files),
("deleted", deleted, deleted_files),
("modified", modified, modified_files),
]:
if expected is not None:
if expected != actual:
logger.warning(
"Test failed: module %s files (%s), but expected this exactly for files (%s)!",
mode,
", ".join(map(str, actual)),
", ".join(map(str, expected)),
)
result = (pytest.ExitCode.TESTS_FAILED, expected, actual)
else:
logger.debug("Test passed: module %s files %s as expected.", mode, expected)
if name is not None:
spec.name, spec.loader.name = spec_name, spec_loader_name
return result or pytest.ExitCode.OK
[docs]
def writes(
spec: importlib.machinery.ModuleSpec,
*args: object,
name: str | None = None,
out: str | None = None,
err: str | None = None,
**kwargs: object,
) -> pytest.ExitCode:
"""
Test stdout and stderr writes of module execution as ``name``.
Parameters
----------
spec : importlib.machinery.ModuleSpec
Module specification to be executed.
*args : tuple
Unused positional arguments.
name : str or None, optional
``__name__`` of module at execution time, by default None.
out : str or None, optional
Expected stdout output, skipped if None.
err : str or None, optional
Expected stderr output, skipped if None.
**kwargs : dict
Unused keyword arguments.
Returns
-------
Enum
``pytest.ExitCode.OK`` if stdout/stderr match expectations,
otherwise ``pytest.ExitCode.TESTS_FAILED``.
"""
from contextlib import ExitStack, redirect_stderr, redirect_stdout
def message(name: str, output: str, actual: str, expected: str) -> str:
"""
Format message for warning.
Parameters
----------
name : str
Module name.
output : str
Output stream name (stdout or stderr).
actual : str
Actual output written.
expected : str
Expected output.
Returns
-------
str
Formatted warning message.
"""
return (
f"Importing the module {name} wrote"
f" {repr(actual) if actual else 'nothing'} to {output},"
f" expected {repr(expected) if expected else 'nothing'}"
)
result = None
module = importlib.util.module_from_spec(spec)
if name is not None:
spec_name, spec.name = spec.name, name
spec_loader_name, spec.loader.name = spec.loader.name, name
module.__name__ = name
outputs = {
redirect_stdout: ("stdout", out, io.StringIO()),
redirect_stderr: ("stderr", err, io.StringIO()),
}
with ExitStack() as stack:
for redirect, (_output, expected, target) in outputs.items():
if expected is not None:
stack.enter_context(redirect(target))
spec.loader.exec_module(module)
for output, expected, actual in outputs.values():
if expected is None:
continue
actual = actual.getvalue()
if actual != expected:
logging.warning(message(spec.name, output, actual, expected))
result = pytest.ExitCode.TESTS_FAILED
else:
logging.debug(message(spec.name, output, actual, expected))
if name is not None:
spec.name, spec.loader.name = spec_name, spec_loader_name
return result or pytest.ExitCode.OK
[docs]
def has_signature(
function: Callable[..., Any],
ref_sig: inspect.Signature,
*strict_comparisons: str,
compare_names: Callable[[list[str], list[str]], bool] = list.__eq__,
**comparisons: Callable[[Any, Any], bool],
) -> pytest.ExitCode:
"""
Test if function is compatible with passed signature.
Parameters
----------
function : callable
The function whose signature is to be tested.
ref_sig : inspect.Signature
Reference signature to compare against.
*strict_comparisons : str
Parameter attributes to compare using strict equality.
compare_names : callable, optional
Function to compare parameter name lists, by default ``list.__eq__``.
**comparisons : callable
Mapping of parameter attributes to comparison functions.
Returns
-------
Enum
``pytest.ExitCode.OK`` if signature matches,
otherwise ``pytest.ExitCode.TESTS_FAILED``.
"""
def invalid_signature(expected: object, actual: object) -> str:
"""
Format message for warnings.
Parameters
----------
expected : object
Expected signature or parameter.
actual : object
Actual signature or parameter.
Returns
-------
str
Formatted warning message.
"""
return f"Function signature is not valid.\n{expected = },\n {actual = }."
def pretty_par(par: inspect.Parameter) -> str:
"""
Pretty formatting of function parameters.
Parameters
----------
par : inspect.Parameter
The parameter to format.
Returns
-------
str
Human-readable parameter description.
"""
string = f"{par.kind.name} parameter <{par.name}"
if par.annotation is not inspect.Parameter.empty:
string += f": {par.annotation}"
if par.default is not inspect.Parameter.empty:
string += f" = {par.default}"
string += ">"
return string
fun_sig = inspect.signature(function)
result = None
if not compare_names(list(fun_sig.parameters), list(ref_sig.parameters)):
logger.warning(invalid_signature(list(ref_sig.parameters), list(fun_sig.parameters)))
result = pytest.ExitCode.TESTS_FAILED
for name, fun_par in fun_sig.parameters.items():
ref_par = ref_sig.parameters.get(name)
if ref_par is not None:
for attr in strict_comparisons:
comparisons[attr] = type(getattr(fun_par, attr)).__eq__
for attr, comp in comparisons.items():
fun_value, ref_value = getattr(fun_par, attr), getattr(ref_par, attr)
if comp(fun_value, ref_value) is not True:
logger.warning(invalid_signature(pretty_par(ref_par), pretty_par(fun_par)))
result = pytest.ExitCode.TESTS_FAILED
if "annotation" in strict_comparisons:
comparisons["annotation"] = type(ref_sig.return_annotation).__eq__
if "annotation" in comparisons:
ref_return, fun_return = (
ref_sig.return_annotation,
fun_sig.return_annotation,
)
if comparisons["annotation"](ref_return, fun_return) is not True:
logger.warning("Return annotation of %s", invalid_signature(ref_return, fun_return))
result = pytest.ExitCode.TESTS_FAILED
return result or pytest.ExitCode.OK