"""RDKit-TM (xyz2mol_tm) molecular graph construction for coordination complexes."""

from __future__ import annotations

import logging

import networkx as nx
import numpy as np
from rdkit import Chem

from .data_loader import BOHR_TO_ANGSTROM, DATA

logger = logging.getLogger(__name__)


def build_graph_rdkit_tm(
    xyz_file: str | list[tuple[str, tuple[float, float, float]]],
    charge: int = 0,
    bohr_units: bool = False,
) -> nx.Graph:
    """Build molecular graph using xyz2mol_tm for coordination complexes.

    Combines XYZ coordinates, connectivity from ``xyz2mol_tm``
    (specialised for metal coordination), and graph matching to align
    RDKit atom ordering with the original XYZ ordering.

    Parameters
    ----------
    xyz_file : str or list of (symbol, (x, y, z))
        Path to an XYZ file or coordinates.
    charge : int
        Total molecular charge.
    bohr_units : bool
        Whether input coordinates are in Bohr.

    Returns
    -------
    nx.Graph
        Molecular graph with nodes containing ``symbol``,
        ``atomic_number``, ``position``, ``formal_charge``,
        ``valence``, and ``charges`` (empty dict).
    """
    import signal
    import tempfile

    from networkx.algorithms import isomorphism

    # Import xyz2mol_tm
    try:
        from xyz2mol_tm import xyz2mol_tmc
    except ImportError:
        raise ImportError(
            "xyz2mol_tm not found. Install via:\npip install git+https://github.com/jensengroup/xyz2mol_tm.git"
        ) from None

    # ===== STEP 1: Parse XYZ coordinates =====
    if isinstance(xyz_file, str):
        with open(xyz_file) as f:
            lines = f.readlines()
        nat = int(lines[0].strip())
        atoms: list[tuple[str, tuple[float, float, float]]] = []
        for line in lines[2 : 2 + nat]:
            parts = line.split()
            sym = parts[0]
            x, y, z = map(float, parts[1:4])
            if bohr_units:
                x, y, z = x * BOHR_TO_ANGSTROM, y * BOHR_TO_ANGSTROM, z * BOHR_TO_ANGSTROM
            atoms.append((sym, (x, y, z)))
    elif isinstance(xyz_file, list):
        atoms = xyz_file
        if bohr_units:
            atoms = [(s, (x * BOHR_TO_ANGSTROM, y * BOHR_TO_ANGSTROM, z * BOHR_TO_ANGSTROM)) for s, (x, y, z) in atoms]
    else:
        raise TypeError("xyz_file must be a path or list of (symbol, position) tuples")

    heavy_idx = [i for i, (s, _) in enumerate(atoms) if s != "H"]

    # ===== STEP 2: Get connectivity from xyz2mol_tm =====
    xyz_lines = [str(len(atoms)), "Generated by build_graph_rdkit_tm"]
    xyz_lines += [f"{s} {x:.6f} {y:.6f} {z:.6f}" for s, (x, y, z) in atoms]
    xyz_block = "\n".join(xyz_lines) + "\n"
    with tempfile.NamedTemporaryFile(mode="w+", suffix=".xyz", delete=False) as tmp:
        tmp.write(xyz_block)
        tmp.flush()

        def handler(signum, frame):
            raise TimeoutError("xyz2mol_tmc took too long")

        signal.signal(signal.SIGALRM, handler)
        signal.alarm(5)

        try:
            mol = xyz2mol_tmc.get_tmc_mol(tmp.name, overall_charge=charge)
        except TimeoutError:
            logger.warning("xyz2mol_tmc timed out for %s. Skipping RDKit-TM graph.", xyz_file)
            mol = None
        except Exception as e:
            logger.warning("xyz2mol_tmc failed for %s: %s", xyz_file, e)
            mol = None
        finally:
            signal.alarm(0)

    if mol is None:
        G = nx.Graph()
        G.graph["metadata"] = {
            "source": "rdkit_tm",
            "note": "xyz2mol_tmc failed or timed out",
        }
        return G

    # Build RDKit connectivity graph (element + bonds only)
    G_rdkit = nx.Graph()
    for i in range(mol.GetNumAtoms()):
        G_rdkit.add_node(i, symbol=mol.GetAtomWithIdx(i).GetSymbol())
    for bond in mol.GetBonds():
        G_rdkit.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())

    # ===== STEP 3: Build XYZ heavy-atom graph =====
    from .graph_builders import build_graph

    G_xyz_heavy = build_graph([atoms[i] for i in heavy_idx], charge=charge, quick=True)
    mapping_to_original = dict(enumerate(heavy_idx))
    G_xyz_relabeled = nx.relabel_nodes(G_xyz_heavy, mapping_to_original)

    # ===== STEP 3a: Filter XYZ edges to match RDKit connectivity =====
    allowed_pairs: set[frozenset[str]] = set()
    for bond in mol.GetBonds():
        sym_i = mol.GetAtomWithIdx(bond.GetBeginAtomIdx()).GetSymbol()
        sym_j = mol.GetAtomWithIdx(bond.GetEndAtomIdx()).GetSymbol()
        allowed_pairs.add(frozenset([sym_i, sym_j]))

    edges_to_keep = [
        (i, j)
        for i, j in G_xyz_relabeled.edges()
        if frozenset([G_xyz_relabeled.nodes[i]["symbol"], G_xyz_relabeled.nodes[j]["symbol"]]) in allowed_pairs
    ]

    G_xyz_simple = nx.Graph()
    for n in G_xyz_relabeled.nodes():
        G_xyz_simple.add_node(n, symbol=G_xyz_relabeled.nodes[n]["symbol"])
    G_xyz_simple.add_edges_from(edges_to_keep)

    # ===== STEP 4: Match graphs (try perfect first, fall back to partial) =====
    nm = isomorphism.categorical_node_match("symbol", "")
    GM = isomorphism.GraphMatcher(G_rdkit, G_xyz_simple, node_match=nm)

    if GM.is_isomorphic():
        rdkit_to_xyz = GM.mapping
        logger.debug("Indexed against xyzgraph by perfect isomorphism.")
    else:
        logger.debug(
            "Graphs not perfectly isomorphic. RDKit: %d nodes/%d edges, XYZ: %d nodes/%d edges",
            G_rdkit.number_of_nodes(),
            G_rdkit.number_of_edges(),
            G_xyz_simple.number_of_nodes(),
            G_xyz_simple.number_of_edges(),
        )
        rdkit_to_xyz = _partial_graph_matching(G_rdkit, G_xyz_simple)

        # Validate mapping quality
        mapped_edges = sum(
            1
            for i, j in G_rdkit.edges()
            if rdkit_to_xyz.get(i) and rdkit_to_xyz.get(j) and G_xyz_simple.has_edge(rdkit_to_xyz[i], rdkit_to_xyz[j])
        )
        total_rdkit_edges = G_rdkit.number_of_edges()
        overlap = mapped_edges / total_rdkit_edges if total_rdkit_edges > 0 else 0
        logger.debug("Mapping quality: %d/%d edges match (%.1f%%)", mapped_edges, total_rdkit_edges, overlap * 100)

        if overlap < 0.75:
            raise ValueError(
                f"Insufficient graph overlap ({overlap * 100:.1f}%). "
                f"xyz2mol_tm and geometric methods disagree too much on connectivity."
            )

    # ===== STEP 5: Build final graph with XYZ ordering =====
    G = nx.Graph()
    for idx, (sym, pos) in enumerate(atoms):
        G.add_node(
            idx,
            symbol=sym,
            atomic_number=Chem.GetPeriodicTable().GetAtomicNumber(sym),
            position=pos,
            formal_charge=0,
        )

    # Add heavy-heavy edges from RDKit, mapped to XYZ indices
    for bond in mol.GetBonds():
        i_xyz = rdkit_to_xyz[bond.GetBeginAtomIdx()]
        j_xyz = rdkit_to_xyz[bond.GetEndAtomIdx()]

        bt = bond.GetBondType()
        bo = {
            Chem.BondType.SINGLE: 1.0,
            Chem.BondType.DOUBLE: 2.0,
            Chem.BondType.TRIPLE: 3.0,
            Chem.BondType.AROMATIC: 1.5,
        }.get(bt, 1.0)

        pos_i = np.array(G.nodes[i_xyz]["position"])
        pos_j = np.array(G.nodes[j_xyz]["position"])

        G.add_edge(
            i_xyz,
            j_xyz,
            bond_order=bo,
            distance=float(np.linalg.norm(pos_i - pos_j)),
            bond_type=(G.nodes[i_xyz]["symbol"], G.nodes[j_xyz]["symbol"]),
            metal_coord=(G.nodes[i_xyz]["symbol"] in DATA.metals or G.nodes[j_xyz]["symbol"] in DATA.metals),
        )

    # Connect hydrogens to nearest heavy atom
    for idx, (sym, pos) in enumerate(atoms):
        if sym == "H":
            pos_arr = np.array(pos)
            dists = [np.linalg.norm(pos_arr - np.array(G.nodes[i]["position"])) for i in heavy_idx]
            nearest = heavy_idx[int(np.argmin(dists))]
            G.add_edge(
                idx,
                nearest,
                bond_order=1.0,
                distance=float(min(dists)),
                bond_type=("H", G.nodes[nearest]["symbol"]),
                metal_coord=(G.nodes[nearest]["symbol"] in DATA.metals),
            )

    # Update valences and formal charges
    for node in G.nodes():
        # Split valence: organic (excludes metal bonds) and metal (coordination bonds)
        organic_val = sum(
            G.edges[node, nbr]["bond_order"] for nbr in G.neighbors(node) if G.nodes[nbr]["symbol"] not in DATA.metals
        )
        metal_val = sum(
            G.edges[node, nbr]["bond_order"] for nbr in G.neighbors(node) if G.nodes[nbr]["symbol"] in DATA.metals
        )
        G.nodes[node]["valence"] = organic_val
        G.nodes[node]["metal_valence"] = metal_val

    for rdkit_idx, xyz_idx in rdkit_to_xyz.items():
        G.nodes[xyz_idx]["formal_charge"] = mol.GetAtomWithIdx(rdkit_idx).GetFormalCharge()

    # Metadata
    from . import __citation__, __version__

    G.graph["metadata"] = {
        "version": __version__,
        "citation": __citation__,
        "source": "rdkit_tm",
    }
    G.graph["total_charge"] = charge
    G.graph["method"] = "rdkit_tm"

    return G


def _partial_graph_matching(G_rdkit: nx.Graph, G_xyz: nx.Graph) -> dict:
    """Graph-distance + neighbour-symbol partial matching for non-isomorphic graphs.

    Parameters
    ----------
    G_rdkit : nx.Graph
        RDKit molecular graph (nodes with ``symbol``).
    G_xyz : nx.Graph
        XYZ-based molecular graph (nodes with ``symbol``).

    Returns
    -------
    dict
        Mapping ``{rdkit_node: xyz_node}``.
    """
    from collections import defaultdict

    try:
        from scipy.optimize import linear_sum_assignment
    except ImportError:
        raise ImportError("scipy not found. Install via:\npip install scipy") from None

    logger.debug("Starting graph-distance + neighbor-symbol partial matching...")

    # Group nodes by element
    rdkit_by_elem: dict[str, list] = defaultdict(list)
    xyz_by_elem: dict[str, list] = defaultdict(list)
    for n in G_rdkit.nodes():
        rdkit_by_elem[G_rdkit.nodes[n]["symbol"]].append(n)
    for n in G_xyz.nodes():
        xyz_by_elem[G_xyz.nodes[n]["symbol"]].append(n)

    for elem, rdkit_nodes_for_elem in rdkit_by_elem.items():
        rdkit_count = len(rdkit_nodes_for_elem)
        xyz_count = len(xyz_by_elem.get(elem, []))
        if rdkit_count != xyz_count:
            raise ValueError(
                f"Cannot perform partial matching: element '{elem}' count mismatch. RDKit has {rdkit_count}, "
                f"XYZ has {xyz_count}. This could be bimetallic and not handled by xyz2mol_tm."
            )

    # Shortest-path distance matrices
    D_rdkit = np.asarray(nx.floyd_warshall_numpy(G_rdkit))
    D_xyz = np.asarray(nx.floyd_warshall_numpy(G_xyz))

    rdkit_nodes = list(G_rdkit.nodes())
    xyz_nodes = list(G_xyz.nodes())
    rdkit_index = {n: i for i, n in enumerate(rdkit_nodes)}
    xyz_index = {n: i for i, n in enumerate(xyz_nodes)}

    rdkit_to_xyz: dict = {}

    for elem, rdkit_list in rdkit_by_elem.items():
        if elem not in xyz_by_elem:
            raise ValueError(f"Element {elem} in RDKit but not in XYZ")
        xyz_list = xyz_by_elem[elem]

        n_r, n_x = len(rdkit_list), len(xyz_list)
        min_count = min(n_r, n_x)

        scores = np.zeros((n_r, n_x))
        for i, r_node in enumerate(rdkit_list):
            d_r = D_rdkit[rdkit_index[r_node], :]
            r_symbols = {G_rdkit.nodes[n]["symbol"] for n in G_rdkit.neighbors(r_node)}

            for j, x_node in enumerate(xyz_list):
                d_x = D_xyz[xyz_index[x_node], :]
                x_symbols = {G_xyz.nodes[n]["symbol"] for n in G_xyz.neighbors(x_node)}

                dist_diff = np.sum(np.abs(d_r - d_x))
                score = -dist_diff
                score += len(r_symbols & x_symbols) * 5
                scores[i, j] = score

        row_ind, col_ind = linear_sum_assignment(-scores)
        for i, j in zip(row_ind[:min_count], col_ind[:min_count]):
            rdkit_to_xyz[rdkit_list[i]] = xyz_list[j]
            logger.debug("Matched %d -> %d (score=%.2f)", rdkit_list[i], xyz_list[j], scores[i, j])

    logger.debug("Finished partial matching. %d atoms mapped.", len(rdkit_to_xyz))
    return rdkit_to_xyz
