#!/usr/bin/env bash

# Test MCP protocol interactions in more detail
export MISE_EXPERIMENTAL=1

# Create a test environment with tools, tasks and env vars
cat >mise.toml <<EOF
[tools]
node = "20.11.0"
python = "3.12"

[env]
TEST_MCP_VAR = "test_value"
PROJECT_ENV = "development"
API_KEY = "secret123"

[tasks.test-task]
run = "echo 'Running test task'"
description = "A test task for MCP"

[tasks.echo-args]
run = 'echo "args: $@"'
description = "Echo arguments back"

[tasks.build]
run = "echo 'Building...'"
description = "Build the project"
depends = ["test-task"]
alias = ["b"]
dir = "."
quiet = true

[tasks.lint]
run = "echo 'Linting code'"
description = "Run linters"
usage = "lint [files...]"
EOF

echo "Testing MCP server startup and error handling..."

# Test that MCP server rejects invalid JSON
output=$(echo "not json" | mise mcp 2>&1 || true)
if [[ $output == *"Failed to create service"* ]]; then
	echo "SUCCESS: MCP server correctly rejected invalid JSON"
else
	echo "ERROR: Expected 'Failed to create service' in output, got: $output"
	exit 1
fi

echo -e "\nTesting MCP protocol with actual requests..."

# Create a simplified test that works around the rmcp 0.3 connection issue
cat >test_mcp.py <<'EOF'
#!/usr/bin/env python3
import json
import os
import select
import subprocess
import sys


def find_mise():
    paths = [
        os.path.join(os.path.dirname(__file__), "..", "..", "..", "target", "debug", "mise"),
        "mise",
    ]
    for path in paths:
        if os.path.exists(path) or path == "mise":
            return path
    return None


def send(proc, msg):
    proc.stdin.write((json.dumps(msg) + "\n").encode())
    proc.stdin.flush()


def recv(proc, timeout=5.0):
    ready, _, _ = select.select([proc.stdout], [], [], timeout)
    if not ready:
        return None
    line = proc.stdout.readline().decode().strip()
    if not line:
        return None
    return json.loads(line)


def handshake(proc):
    """Perform initialize + initialized handshake, return initialize result."""
    send(proc, {
        "jsonrpc": "2.0", "id": 1,
        "method": "initialize",
        "params": {
            "protocolVersion": "2025-03-26",
            "capabilities": {},
            "clientInfo": {"name": "test-client", "version": "1.0.0"},
        },
    })
    resp = recv(proc)
    if not resp or "result" not in resp:
        print(f"ERROR: Bad initialize response: {resp}")
        return None
    # Send initialized notification (no id = notification)
    send(proc, {"jsonrpc": "2.0", "method": "notifications/initialized"})
    return resp["result"]


def test_initialize(proc):
    result = handshake(proc)
    if not result:
        return 1

    assert result["protocolVersion"] == "2025-03-26", "wrong protocol version"
    print("✓ Protocol version correct")

    caps = result["capabilities"]
    assert "resources" in caps, "resources capability missing"
    print("✓ Resources capability present")
    assert "tools" in caps, "tools capability missing"
    print("✓ Tools capability present")

    assert result["serverInfo"]["name"] == "rmcp", "wrong server name"
    print("✓ Server info correct")
    return 0


def test_tools_list(proc):
    send(proc, {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}})
    resp = recv(proc)
    if not resp or "result" not in resp:
        print(f"ERROR: Bad tools/list response: {resp}")
        return 1

    tools = resp["result"]["tools"]
    tool_names = {t["name"] for t in tools}
    assert "run_task" in tool_names, f"run_task not in tools: {tool_names}"
    assert "install_tool" in tool_names, f"install_tool not in tools: {tool_names}"
    print(f"✓ tools/list returned {len(tools)} tools: {tool_names}")

    # Verify run_task has proper input schema
    run_task = next(t for t in tools if t["name"] == "run_task")
    schema_props = run_task["inputSchema"]["properties"]
    assert "task" in schema_props, "run_task missing 'task' param"
    assert "args" in schema_props, "run_task missing 'args' param"
    print("✓ run_task schema has task and args parameters")
    return 0


def test_run_task_success(proc):
    send(proc, {
        "jsonrpc": "2.0", "id": 3,
        "method": "tools/call",
        "params": {
            "name": "run_task",
            "arguments": {"task": "test-task"},
        },
    })
    resp = recv(proc, timeout=30.0)
    if not resp or "result" not in resp:
        print(f"ERROR: Bad run_task response: {resp}")
        return 1

    result = resp["result"]
    assert not result.get("isError", False), f"run_task returned error: {result}"
    text = result["content"][0]["text"]
    assert "Running test task" in text, f"unexpected output: {text}"
    print(f"✓ run_task success: {text.strip()}")
    return 0


def test_run_task_with_args(proc):
    send(proc, {
        "jsonrpc": "2.0", "id": 4,
        "method": "tools/call",
        "params": {
            "name": "run_task",
            "arguments": {"task": "echo-args", "args": ["hello", "world"]},
        },
    })
    resp = recv(proc, timeout=30.0)
    if not resp or "result" not in resp:
        print(f"ERROR: Bad run_task response: {resp}")
        return 1

    result = resp["result"]
    assert not result.get("isError", False), f"run_task returned error: {result}"
    text = result["content"][0]["text"]
    assert "hello" in text and "world" in text, f"args not forwarded: {text}"
    print(f"✓ run_task with args forwarded correctly: {text.strip()}")
    return 0


def test_run_task_unknown(proc):
    send(proc, {
        "jsonrpc": "2.0", "id": 5,
        "method": "tools/call",
        "params": {
            "name": "run_task",
            "arguments": {"task": "nonexistent-task-xyz"},
        },
    })
    resp = recv(proc, timeout=30.0)
    if not resp or "result" not in resp:
        print(f"ERROR: Bad run_task response: {resp}")
        return 1

    result = resp["result"]
    assert result.get("isError", False), "expected error for unknown task"
    print("✓ run_task with unknown task returns error")
    return 0


def test_install_tool_stub(proc):
    send(proc, {
        "jsonrpc": "2.0", "id": 6,
        "method": "tools/call",
        "params": {
            "name": "install_tool",
            "arguments": {"tool": "node"},
        },
    })
    resp = recv(proc, timeout=30.0)
    if not resp or "result" not in resp:
        print(f"ERROR: Bad install_tool response: {resp}")
        return 1

    result = resp["result"]
    assert result.get("isError", False), "expected error from stub"
    text = result["content"][0]["text"]
    assert "not yet implemented" in text, f"unexpected stub message: {text}"
    print("✓ install_tool stub returns not-yet-implemented error")
    return 0


def main():
    mise_path = find_mise()
    if not mise_path:
        print("ERROR: Could not find mise binary")
        return 1
    print(f"Using mise binary: {mise_path}")

    proc = subprocess.Popen(
        [mise_path, "mcp"],
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        env={**dict(os.environ), "MISE_EXPERIMENTAL": "1"},
    )

    try:
        tests = [
            ("initialize", test_initialize),
            ("tools/list", test_tools_list),
            ("run_task success", test_run_task_success),
            ("run_task with args", test_run_task_with_args),
            ("run_task unknown task", test_run_task_unknown),
            ("install_tool stub", test_install_tool_stub),
        ]
        for name, test_fn in tests:
            print(f"\n--- {name} ---")
            rc = test_fn(proc)
            if rc != 0:
                print(f"FAIL: {name}")
                return 1
        print("\n✓ All MCP protocol tests passed!")
        return 0
    except Exception as e:
        print(f"ERROR: {e}")
        return 1
    finally:
        proc.terminate()
        proc.wait()


if __name__ == "__main__":
    sys.exit(main())
EOF

# Run the Python test script
python3 test_mcp.py || exit 1

# Clean up
rm -f test_mcp.py

echo "All tests completed successfully!"
