From 46aadf245408cb8e5ddba7e4fda32148082bf650 Mon Sep 17 00:00:00 2001 From: Jake LoRocco Date: Thu, 19 Mar 2026 11:59:59 -0400 Subject: [PATCH] feat: add tests for dependency / package management --- test/package/_check_dep_isolation.py | 321 ++++++++++++++++++++++ test/package/test_dependency_isolation.py | 289 +++++++++++++++++++ 2 files changed, 610 insertions(+) create mode 100644 test/package/_check_dep_isolation.py create mode 100644 test/package/test_dependency_isolation.py diff --git a/test/package/_check_dep_isolation.py b/test/package/_check_dep_isolation.py new file mode 100644 index 000000000..bef629f16 --- /dev/null +++ b/test/package/_check_dep_isolation.py @@ -0,0 +1,321 @@ +"""Subprocess helper for dependency isolation tests. + +Usage: python _check_dep_isolation.py [module2 ...] + +Exits 0 if all imports are from declared dependencies, 1 if violations found. +""" + +import importlib +import importlib.metadata +import re +import sys +import tomllib +from pathlib import Path + +# Packages that are part of the Python standard library or otherwise +# should never be flagged as undeclared dependencies. +STDLIB_AND_INFRASTRUCTURE = { + # Build/install infrastructure that leaks into sys.modules + "_distutils_hack", + "pkg_resources", + "setuptools", + "pip", + "wheel", + "distutils", +} + +# Packages that third-party libraries opportunistically import via +# `try/except ImportError` when installed. These are extras of core +# networking and serialization libraries — not declared by mellea, but +# they appear in sys.modules when present in the environment. +OPPORTUNISTIC_IMPORTS = { + # urllib3 / httpx extras (compression & protocol upgrades) + "brotli", + "brotlicffi", + "zstandard", + "h2", + "hpack", + "hyperframe", + "socksio", + # Widely used utility imported opportunistically by many packages + "packaging", + # Fast JSON — used by pydantic/fastapi when available + "orjson", +} + + +def parse_dep_name(dep_spec: str) -> str | None: + """Extract the distribution name from a dependency specifier. + + Strips version constraints, extras, and environment markers. + Returns None for self-references like 'mellea[hooks]'. + """ + # Remove environment markers (e.g., "; sys_platform != 'darwin'") + dep_spec = dep_spec.split(";")[0].strip() + # Extract just the package name (before any version/extras specifiers) + match = re.match(r"^([A-Za-z0-9]([A-Za-z0-9._-]*[A-Za-z0-9])?)", dep_spec) + if not match: + return None + name = match.group(1).lower() + # Skip self-references (handled separately by extract_self_ref_groups) + if name == "mellea": + return None + return name + + +def extract_self_ref_groups(dep_spec: str) -> list[str]: + """Extract optional-dependency group names from self-references. + + e.g. 'mellea[hooks]' → ['hooks'], 'mellea[watsonx,hf,vllm]' → ['watsonx', 'hf', 'vllm'] + Returns an empty list for non-self-references. + """ + dep_spec = dep_spec.split(";")[0].strip() + match = re.match(r"^mellea\[([^\]]+)\]", dep_spec, re.IGNORECASE) + if not match: + return [] + return [g.strip() for g in match.group(1).split(",")] + + +def get_top_level_names(dist_name: str) -> set[str]: + """Get the importable top-level module names for a distribution.""" + try: + dist = importlib.metadata.distribution(dist_name) + except importlib.metadata.PackageNotFoundError: + return set() + + # Try top_level.txt first + top_level = dist.read_text("top_level.txt") + if top_level: + return {line.strip() for line in top_level.splitlines() if line.strip()} + + # Fall back to packages listed in RECORD + names = set() + if dist.files: + for f in dist.files: + parts = str(f).split("/") + if len(parts) > 1 and not parts[0].endswith(".dist-info"): + name = parts[0].replace(".py", "") + if name and not name.startswith("_") and name != "__pycache__": + names.add(name) + if names: + return names + + # Last resort: normalize the dist name itself + return {dist_name.replace("-", "_").lower()} + + +def get_transitive_deps(dist_name: str, seen: set[str] | None = None) -> set[str]: + """Recursively resolve all transitive dependencies of a distribution. + + Returns a set of normalized distribution names. + """ + if seen is None: + seen = set() + + normalized = dist_name.lower().replace("-", "_") + if normalized in seen: + return set() + seen.add(normalized) + + result = {normalized} + try: + dist = importlib.metadata.distribution(dist_name) + except importlib.metadata.PackageNotFoundError: + return result + + reqs = dist.requires + if not reqs: + return result + + for req in reqs: + # For extras-only requirements (e.g., 'brotli ; extra == "brotli"'), + # include them if actually installed. These are legitimate transitive + # deps of declared packages — e.g., urllib3[brotli] pulls in brotli, + # datasets[s3] pulls in boto3, transformers[torch] pulls in torchvision. + if "extra ==" in req: + dep = parse_dep_name(req) + if dep: + try: + importlib.metadata.distribution(dep) + result |= get_transitive_deps(dep, seen) + except importlib.metadata.PackageNotFoundError: + pass + continue + dep = parse_dep_name(req) + if dep: + result |= get_transitive_deps(dep, seen) + + return result + + +def build_allowed_set( + group_name: str, also_allow_groups: list[str] | None = None +) -> set[str]: + """Build the set of allowed top-level import names for a dependency group. + + Args: + group_name: The optional-dependency group (or "core" for base only). + also_allow_groups: Extra optional-dependency groups whose packages + should also be allowed. Use this for groups that are imported + opportunistically via ``try/except ImportError`` guards — the + code works without them, but they *will* appear in + ``sys.modules`` when installed. + """ + # Parse pyproject.toml + pyproject_path = Path(__file__).resolve().parent.parent.parent / "pyproject.toml" + + with open(pyproject_path, "rb") as f: + pyproject = tomllib.load(f) + + # Collect declared distribution names: core + specified group + core_deps = pyproject.get("project", {}).get("dependencies", []) + optional_deps = pyproject.get("project", {}).get("optional-dependencies", {}) + # "core" is a special pseudo-group meaning core deps only + group_deps = [] if group_name == "core" else optional_deps.get(group_name, []) + + # Include deps from additionally-allowed groups + extra_deps: list[str] = [] + for g in also_allow_groups or []: + extra_deps.extend(optional_deps.get(g, [])) + + # Expand self-references like 'mellea[hooks]' into that group's deps + all_dep_specs = core_deps + group_deps + extra_deps + expanded: list[str] = [] + seen_groups: set[str] = set() + queue = list(all_dep_specs) + while queue: + spec = queue.pop(0) + refs = extract_self_ref_groups(spec) + if refs: + for ref in refs: + if ref not in seen_groups: + seen_groups.add(ref) + queue.extend(optional_deps.get(ref, [])) + else: + expanded.append(spec) + + declared_dists: set[str] = set() + for dep_spec in expanded: + name = parse_dep_name(dep_spec) + if name: + declared_dists.add(name) + + # Resolve transitive dependencies + all_allowed_dists: set[str] = set() + for dist_name in declared_dists: + all_allowed_dists |= get_transitive_deps(dist_name) + + # Map all allowed distributions to their importable top-level names + allowed_imports: set[str] = set() + for dist_name in all_allowed_dists: + allowed_imports |= get_top_level_names(dist_name) + + # Also add the normalized dist names themselves (common pattern) + for dist_name in all_allowed_dists: + allowed_imports.add(dist_name.replace("-", "_").lower()) + + return allowed_imports + + +def is_third_party(module_name: str) -> bool: + """Check if a module name appears to be third-party (not stdlib, not local).""" + top = module_name.split(".")[0] + + if top in STDLIB_AND_INFRASTRUCTURE or top in OPPORTUNISTIC_IMPORTS: + return False + + # Skip internal/private modules + if top.startswith("_"): + return False + + # Skip mellea and cli (our own packages) + if top in ("mellea", "cli", "test"): + return False + + # Check if it's a known distribution + try: + importlib.metadata.distribution(top) + return True + except importlib.metadata.PackageNotFoundError: + pass + + # Try with hyphens replaced + try: + importlib.metadata.distribution(top.replace("_", "-")) + return True + except importlib.metadata.PackageNotFoundError: + pass + + # Not a known distribution — likely stdlib + return False + + +def main() -> int: + # Parse --allow-group flags before positional args + also_allow: list[str] = [] + positional: list[str] = [] + args = sys.argv[1:] + while args: + # Iterates over all the args until the list is empty. + if args[0] == "--allow-group" and len(args) >= 2: + # Grabs from ["--allow-group", "", ...] + also_allow.append(args[1]) + args = args[2:] + else: + positional.append(args[0]) + args = args[1:] + + if len(positional) < 2: + print( + f"Usage: {sys.argv[0]} [--allow-group GROUP ...] [module2 ...]", + file=sys.stderr, + ) + return 2 + + group_name = positional[0] + target_modules = positional[1:] + + # Build allowed set + allowed = build_allowed_set(group_name, also_allow_groups=also_allow) + + # Snapshot modules before import + before = set(sys.modules.keys()) + + # Import target modules + for mod in target_modules: + try: + importlib.import_module(mod) + except ImportError as e: + print(f"IMPORT_ERROR: Could not import {mod}: {e}", file=sys.stderr) + return 2 + + # Find new third-party modules + after = set(sys.modules.keys()) + new_modules = after - before + + violations: list[str] = [] + for mod in sorted(new_modules): + top = mod.split(".")[0] + if not is_third_party(top): + # It's a standard python package. + continue + if top.lower() in allowed or top.replace("-", "_").lower() in allowed: + # It's allowed by the current group or an explicitly allowed group. + continue + violations.append(top) + + # Deduplicate + violations = sorted(set(violations)) + + if violations: + print(f"VIOLATIONS for group '{group_name}':") + for v in violations: + print(f" - {v}") + return 1 + + print(f"OK: group '{group_name}' imports only declared dependencies") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/test/package/test_dependency_isolation.py b/test/package/test_dependency_isolation.py new file mode 100644 index 000000000..5e4ac4960 --- /dev/null +++ b/test/package/test_dependency_isolation.py @@ -0,0 +1,289 @@ +"""Dependency isolation tests. + +Verify that: +1. Core mellea modules work with only base dependencies (no extras). +2. Each optional dependency group only imports packages from + core dependencies + that group's declared optional dependencies. +3. Every optional-dependency group in pyproject.toml has a test mapping. + +Each test spawns a fresh subprocess to get a clean sys.modules snapshot. +""" + +import json +import subprocess +import sys +import tomllib +from pathlib import Path + +import pytest + +# Core modules that must work with only the base dependencies declared in +# pyproject.toml [project.dependencies] — no extras installed. +# +# Notably excluded (these eagerly import optional-extra packages): +# - mellea.backends.huggingface / vllm / litellm / watsonx (backend extras) +# - mellea.telemetry (telemetry extra) +# - mellea.stdlib.tools (sandbox extra — __init__ imports interpreter) +# - mellea.stdlib.components.docs.richdocument (docling extra) +# - mellea.formatters.granite.retrievers (granite_retriever — __init__ imports elasticsearch + numpy) +# - mellea.plugins.hooks (hooks extra) +# - cli.serve.app (server extra) +# - cli.m (imports cli.serve.app) +CORE_MODULES: list[str] = [ + # Top-level package + "mellea", + # Core abstractions + "mellea.core", + # Backends (core-only — openai, ollama, bedrock, tools, adapters, etc.) + "mellea.backends", + "mellea.backends.backend", + "mellea.backends.bedrock", + "mellea.backends.cache", + "mellea.backends.dummy", + "mellea.backends.model_ids", + "mellea.backends.model_options", + "mellea.backends.ollama", + "mellea.backends.openai", + "mellea.backends.tools", + "mellea.backends.utils", + "mellea.backends.adapters", + # Formatters (core-only — no retrievers/__init__ which pulls elasticsearch) + "mellea.formatters", + "mellea.formatters.chat_formatter", + "mellea.formatters.template_formatter", + "mellea.formatters.granite", + # Helpers + "mellea.helpers", + # Plugin system (core infra, not the hooks extra) + "mellea.plugins", + # Standard library (core components, sessions, sampling) + "mellea.stdlib", + "mellea.stdlib.components", + "mellea.stdlib.components.docs", + "mellea.stdlib.context", + "mellea.stdlib.functional", + "mellea.stdlib.session", + "mellea.stdlib.sampling", + # CLI (excluding serve/app and m which depend on server extra) + "cli", + "cli.alora.commands", + "cli.decompose", + "cli.eval.commands", +] + +# Map each pyproject optional-dependency group to the mellea modules it covers. +GROUP_MODULES: dict[str, list[str]] = { + "hf": ["mellea.backends.huggingface"], + "vllm": ["mellea.backends.vllm"], + "litellm": ["mellea.backends.litellm"], + "watsonx": ["mellea.backends.watsonx"], + "tools": ["mellea.backends.tools"], + "telemetry": ["mellea.telemetry"], + "docling": ["mellea.stdlib.components.docs.richdocument"], + "granite_retriever": ["mellea.formatters.granite.retrievers.elasticsearch"], + "server": ["cli.serve.app"], + "sandbox": ["mellea.stdlib.tools.interpreter"], + "hooks": ["mellea.plugins"], +} + +# Aggregate/meta groups that just combine other groups — no modules of their own. +# These don't need isolation tests; they're tested via their constituent groups. +META_GROUPS: set[str] = {"all", "backends"} + +# Optional-dependency groups whose packages are imported opportunistically by +# core code via `try/except ImportError` guards. The code works without them, +# but when they're installed they *will* appear in sys.modules. Every test +# allows these so we don't flag guarded imports as violations. +# +# Currently only "hooks": mellea.plugins.{manager,registry,base} guard-import cpex. +GUARDED_GROUPS: list[str] = ["hooks"] + +CHECKER_SCRIPT = Path(__file__).parent / "_check_dep_isolation.py" + +PYPROJECT_PATH = Path(__file__).resolve().parent.parent.parent / "pyproject.toml" + +# Maximum allowed wall-clock time for `import mellea` in a fresh interpreter. +# Current baseline is ~140-565ms depending on hardware. 750ms catches +# heavy-dep regressions (torch ~2s, transformers ~800ms+) without flaking. +IMPORT_TIME_LIMIT_MS = 750 + + +def _read_pyproject_groups() -> set[str]: + """Return all optional-dependency group names from pyproject.toml.""" + with open(PYPROJECT_PATH, "rb") as f: + pyproject = tomllib.load(f) + return set(pyproject.get("project", {}).get("optional-dependencies", {}).keys()) + + +def _run_checker(group: str, modules: list[str]) -> subprocess.CompletedProcess[str]: + """Spawn the checker script in a fresh subprocess. + + Runs ``_check_dep_isolation.py`` with the given group and modules, + automatically adding ``--allow-group`` flags for each entry in + GUARDED_GROUPS (skipping the group being tested to avoid redundancy). + + Return codes from the checker: + 0 — all imports are within declared dependencies. + 1 — undeclared dependency violations found (details in stdout). + 2 — one or more target modules could not be imported (details in stderr). + """ + allow_flags: list[str] = [] + for g in GUARDED_GROUPS: + if g != group: # Don't redundantly allow the group being tested + allow_flags.extend(["--allow-group", g]) + return subprocess.run( + [sys.executable, str(CHECKER_SCRIPT), *allow_flags, group, *modules], + capture_output=True, + text=True, + timeout=120, + ) + + +def _find_untested_groups( + group_modules: dict[str, list[str]], meta_groups: set[str] +) -> set[str]: + """Return pyproject optional-dependency groups that lack a test mapping. + + Compares the groups declared in pyproject.toml against those covered by + ``group_modules`` and ``meta_groups``, returning any that are missing. + """ + pyproject_groups = _read_pyproject_groups() + tested_groups = set(group_modules.keys()) | meta_groups + return pyproject_groups - tested_groups + + +# --------------------------------------------------------------------------- +# Core import tests +# --------------------------------------------------------------------------- + + +def test_core_modules_only_use_declared_dependencies() -> None: + """Core modules must import successfully and only use declared base dependencies.""" + result = _run_checker("core", CORE_MODULES) + + if result.returncode == 2: + pytest.fail( + f"Core module import failed (no extras should be needed):\n{result.stderr.strip()}" + ) + + if result.returncode != 0: + violations = result.stdout.strip() + pytest.fail( + f"Core modules pull in undeclared packages " + f"(these should be added to [project.dependencies] " + f"or the import should be made lazy):\n{violations}" + ) + + +# --------------------------------------------------------------------------- +# Per-group isolation tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("group", sorted(GROUP_MODULES.keys())) +def test_dependency_isolation(group: str) -> None: + """Each optional group should only import its declared dependencies.""" + modules = GROUP_MODULES[group] + result = _run_checker(group, modules) + + # Import errors mean the group's deps aren't installed — skip + if result.returncode == 2: + pytest.skip( + f"Could not import modules for group '{group}': {result.stderr.strip()}" + ) + + if result.returncode != 0: + violations = result.stdout.strip() + pytest.fail(f"Undeclared dependency imports in group '{group}':\n{violations}") + + +# --------------------------------------------------------------------------- +# Guard: all pyproject groups must be mapped +# --------------------------------------------------------------------------- + + +def test_all_groups_have_isolation_tests() -> None: + """Every optional-dependency group in pyproject.toml must have a test mapping. + + If this test fails, a new group was added to [project.optional-dependencies] + without a corresponding entry in GROUP_MODULES above. To fix: + 1. Add the group -> module list mapping to GROUP_MODULES. + 2. If it's a meta/aggregate group (like 'all' or 'backends'), add it to META_GROUPS instead. + """ + untested = _find_untested_groups(GROUP_MODULES, META_GROUPS) + + if untested: + pytest.fail( + f"New optional-dependency group(s) in pyproject.toml missing isolation tests: " + f"{sorted(untested)}. Add them to GROUP_MODULES in test_dependency_isolation.py " + f"(or to META_GROUPS if they are aggregate groups)." + ) + + +def test_isolation_detects_undeclared_import() -> None: + """Verify the checker flags a module imported under the wrong group. + + Imports mellea.backends.watsonx under the hf group — watsonx's dep + (ibm-watsonx-ai) is not declared in hf and has no transitive overlap + via extras chains, so the checker must report violations. + """ + result = _run_checker("hf", ["mellea.backends.watsonx"]) + + if result.returncode == 2: + pytest.skip(f"watsonx extras not installed: {result.stderr.strip()}") + + assert result.returncode == 1, ( + f"Expected checker to flag undeclared deps (exit 1) but got exit {result.returncode}. " + f"stdout: {result.stdout.strip()}" + ) + assert "VIOLATIONS" in result.stdout + + +def test_guard_detects_missing_group() -> None: + """Verify the guard actually catches unmapped groups. + + Simulates a new pyproject group by removing 'hooks' from the tested set. + The guard logic should flag it as untested. + """ + incomplete_modules = {k: v for k, v in GROUP_MODULES.items() if k != "hooks"} + untested = _find_untested_groups(incomplete_modules, META_GROUPS) + assert "hooks" in untested, ( + "Guard failed to detect that 'hooks' was missing from GROUP_MODULES" + ) + + +# --------------------------------------------------------------------------- +# Import time budget +# --------------------------------------------------------------------------- + + +# TODO: Test is marked as qualitative to prevent false regressions. Once we are confident, +# we can have this run on nightlies / github actions. +@pytest.mark.qualitative +def test_import_mellea_time() -> None: + """``import mellea`` must complete within the time budget. + + Uses a single fresh subprocess rather than averaging multiple runs, + because the OS page cache warms after the first invocation and would + make subsequent runs artificially fast. The threshold provides enough + headroom over the baseline to absorb normal single-run variance. + """ + timing_script = ( + "import json, time; " + "s = time.perf_counter(); " + "import mellea; " + "print(json.dumps({'ms': (time.perf_counter() - s) * 1000}))" + ) + result = subprocess.run( + [sys.executable, "-c", timing_script], + capture_output=True, + text=True, + timeout=30, + ) + assert result.returncode == 0, f"Import crashed: {result.stderr}" + elapsed_ms = json.loads(result.stdout)["ms"] + + assert elapsed_ms < IMPORT_TIME_LIMIT_MS, ( + f"import mellea took {elapsed_ms:.0f}ms (limit: {IMPORT_TIME_LIMIT_MS}ms). " + f"A heavy dependency may have been added to the import chain." + )