#!/usr/bin/env python3
"""Run compiler error tests and other validation checks."""

import glob
import os
import subprocess
import sys
import tempfile

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
ASY = os.path.join(SCRIPT_DIR, "asy")
ERRORTESTS_DIR = os.path.join(SCRIPT_DIR, "errortests")


def run_error_tests():
    """Run each .asy error test and compare stderr to the expected .errors file."""
    pattern = os.path.join(ERRORTESTS_DIR, "*.asy")
    test_files = sorted(glob.glob(pattern))

    # Exclude helper files that are not standalone tests.
    helpers = {
        "errortestNonTemplate.asy",
        "errortestBrokenTemplate.asy",
        "errortestTemplate.asy",
    }
    test_files = [f for f in test_files if os.path.basename(f) not in helpers]

    if not test_files:
        print("No error test files found in errortests/", file=sys.stderr)
        return False

    all_passed = True
    for test_file in test_files:
        name = os.path.splitext(os.path.basename(test_file))[0]
        errors_file = os.path.join(ERRORTESTS_DIR, name + ".errors")
        module_path = os.path.join("errortests", name)

        sys.stdout.write(f"Testing errors ({name})...")
        sys.stdout.flush()

        result = subprocess.run(
            [ASY, "-q", "-sysdir", "base", "-noautoplain", "-debug", module_path],
            cwd=SCRIPT_DIR,
            capture_output=True,
            text=True,
        )
        actual = result.stderr

        if not os.path.exists(errors_file):
            print(f" FAILED.\n  Missing expected errors file: {errors_file}")
            all_passed = False
            continue

        with open(errors_file, "r") as f:
            expected = f.read()

        if actual == expected:
            print(" PASSED.")
        else:
            print(" FAILED.")
            # Show a unified diff for diagnostics.
            import difflib

            diff = difflib.unified_diff(
                expected.splitlines(keepends=True),
                actual.splitlines(keepends=True),
                fromfile=errors_file,
                tofile="actual output",
            )
            sys.stdout.writelines(diff)
            all_passed = False

    return all_passed


def run_deconstruct_test():
    """Run the deconstruct test."""
    sys.stdout.write("Testing deconstruct...")
    sys.stdout.flush()

    expected_file = os.path.join(SCRIPT_DIR, "deconstruct")
    result = subprocess.run(
        [
            ASY,
            "-dvisvgmOptions=-v0",
            "-q",
            "-sysdir",
            "base",
            "-outpipe",
            "2",
            "-xasy",
            "-c",
            "draw(unitsquare); deconstruct()",
        ],
        cwd=SCRIPT_DIR,
        capture_output=True,
        text=True,
    )
    actual = result.stderr

    with open(expected_file, "r") as f:
        expected = f.read()

    # Clean up generated files.
    for svg in glob.glob(os.path.join(SCRIPT_DIR, "_*.svg")):
        os.remove(svg)

    if actual == expected:
        print(" PASSED.")
        return True
    else:
        print(" FAILED.")
        import difflib

        diff = difflib.unified_diff(
            expected.splitlines(keepends=True),
            actual.splitlines(keepends=True),
            fromfile=expected_file,
            tofile="actual output",
        )
        sys.stdout.writelines(diff)
        return False


def main():
    errors_ok = run_error_tests()
    deconstruct_ok = run_deconstruct_test()

    if not (errors_ok and deconstruct_ok):
        sys.exit(1)


if __name__ == "__main__":
    main()
