"""
Routes FastAPI pour le récapitulatif de saisie Xbat
"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import List
from datetime import datetime

from app.database import get_db
from app.models.recap_xbat import RecapXbat
from app.models.projet import Projet, Niveau, Piece
from app.models.article import Article
from app.routes.consommation import calculer_consommation_par_niveau

router = APIRouter(prefix="/api/recap-xbat", tags=["recap_xbat"])


@router.get("/projet/{projet_id}")
def get_recap_projet(projet_id: int, db: Session = Depends(get_db)):
    from app.routes.consommation import calculer_consommation_piece_exacte
    from math import ceil

    projet = db.query(Projet).filter(Projet.id == projet_id).first()
    if not projet:
        raise HTTPException(404, "Projet introuvable")

    niveaux = db.query(Niveau).filter(Niveau.projet_id == projet_id).all()
    articles = db.query(Article).all()
    articles_map = {a.id: a for a in articles}

    resultat_niveaux = []

    for niveau in niveaux:
        # Source 1 : pièces fabriquées
        pieces = db.query(Piece).filter(
            Piece.niveau_id == niveau.id,
            Piece.statut_realisation == 'fabriquee'
        ).all()

        conso_niveau = {}
        for piece in pieces:
            conso = calculer_consommation_piece_exacte(piece, articles)
            for art_id, qte in conso.items():
                conso_niveau[art_id] = conso_niveau.get(art_id, 0) + qte

        # Arrondi + taux de perte
        qtes_fabriquees = {}
        for art_id, qte_decimale in conso_niveau.items():
            article = articles_map.get(art_id)
            if article:
                qte_arrondie = ceil(qte_decimale)
                taux = (article.taux_perte or 0) / 100
                qtes_fabriquees[art_id] = ceil(qte_arrondie * (1 + taux))

        # Source 2 : sorties manuelles depuis le stock liées au projet/niveau
        from app.models.article import MouvementStock
        sorties_stock = db.query(MouvementStock).filter(
            MouvementStock.projet_id == projet_id,
            MouvementStock.niveau_id == niveau.id,
            MouvementStock.quantite < 0
        ).all()
        sorties_stock_map = {}
        for s in sorties_stock:
            sorties_stock_map[s.article_id] = sorties_stock_map.get(s.article_id, 0) + abs(s.quantite)

        # Source 3 : lignes déjà validées manuellement dans recap_xbat
        recap_valides = db.query(RecapXbat).filter(
            RecapXbat.projet_id == projet_id,
            RecapXbat.niveau_id == niveau.id
        ).all()
        recap_map = {r.article_id: r for r in recap_valides}

        # Fusionner toutes les sources
        tous_articles = set(
            list(qtes_fabriquees.keys()) +
            list(sorties_stock_map.keys()) +
            list(recap_map.keys())
        )

        lignes = []
        for art_id in tous_articles:
            article = articles_map.get(art_id)
            if not article:
                continue

            qte_fabriquee = qtes_fabriquees.get(art_id, 0)
            qte_sortie_stock = sorties_stock_map.get(art_id, 0)
            recap = recap_map.get(art_id)
            qte_saisie_manuelle = recap.quantite_saisie if recap else 0
            prix_unitaire = recap.prix_unitaire if recap else (article.prix_achat_ht or 0)

            # Quantité totale = pièces fabriquées + sorties stock
            qte_totale = qte_fabriquee + qte_sortie_stock
            # Attente = ce qui n'a pas encore été validé manuellement
            qte_attente = max(0, qte_totale - qte_saisie_manuelle)

            lignes.append({
                "article_id": art_id,
                "code": article.code,
                "designation": article.designation,
                "unite": article.unite,
                "qte_totale": qte_totale,
                "qte_saisie": qte_saisie_manuelle,
                "qte_attente": qte_attente,
                "prix_unitaire": prix_unitaire,
                "total_ht_attente": round(qte_attente * prix_unitaire, 2),
                "total_ht_saisi": round(qte_saisie_manuelle * prix_unitaire, 2),
                "entierement_saisi": qte_attente == 0
            })

        if lignes:
            resultat_niveaux.append({
                "niveau_id": niveau.id,
                "niveau_nom": niveau.nom,
                "lignes": sorted(lignes, key=lambda x: x["designation"])
            })

    toutes_lignes = [l for n in resultat_niveaux for l in n["lignes"]]
    tout_saisi = all(l["entierement_saisi"] for l in toutes_lignes) and len(toutes_lignes) > 0

    return {
        "projet_id": projet_id,
        "projet_nom": projet.nom,
        "niveaux": resultat_niveaux,
        "tout_saisi": tout_saisi
    }


@router.post("/projet/{projet_id}/valider")
def valider_saisie(projet_id: int, lignes: List[dict], db: Session = Depends(get_db)):
    from app.routes.consommation import calculer_consommation_piece_exacte
    from app.models.article import MouvementStock
    from math import ceil

    projet = db.query(Projet).filter(Projet.id == projet_id).first()
    if not projet:
        raise HTTPException(404, "Projet introuvable")

    articles = db.query(Article).all()
    articles_map = {a.id: a for a in articles}

    for ligne in lignes:
        art_id = ligne.get("article_id")
        niveau_id = ligne.get("niveau_id")
        prix_unitaire = ligne.get("prix_unitaire", 0)

        # Qté depuis pièces fabriquées
        pieces = db.query(Piece).filter(
            Piece.niveau_id == niveau_id,
            Piece.statut_realisation == 'fabriquee'
        ).all()
        conso_niveau = {}
        for piece in pieces:
            conso = calculer_consommation_piece_exacte(piece, articles)
            for aid, qte in conso.items():
                conso_niveau[aid] = conso_niveau.get(aid, 0) + qte

        article = articles_map.get(art_id)
        if not article:
            continue

        qte_fabriquee = 0
        if art_id in conso_niveau:
            qte_decimale = conso_niveau[art_id]
            qte_arrondie = ceil(qte_decimale)
            taux = (article.taux_perte or 0) / 100
            qte_fabriquee = ceil(qte_arrondie * (1 + taux))

        # Qté depuis sorties stock
        sorties = db.query(MouvementStock).filter(
            MouvementStock.projet_id == projet_id,
            MouvementStock.niveau_id == niveau_id,
            MouvementStock.article_id == art_id,
            MouvementStock.quantite < 0
        ).all()
        qte_sortie_stock = sum(abs(s.quantite) for s in sorties)

        qte_totale = qte_fabriquee + qte_sortie_stock

        # Récupérer ou créer la ligne recap
        recap = db.query(RecapXbat).filter(
            RecapXbat.projet_id == projet_id,
            RecapXbat.article_id == art_id,
            RecapXbat.niveau_id == niveau_id
        ).first()

        if not recap:
            recap = RecapXbat(
                projet_id=projet_id,
                article_id=art_id,
                niveau_id=niveau_id,
                quantite_totale=qte_totale,
                quantite_saisie=0,
                prix_unitaire=prix_unitaire
            )
            db.add(recap)
            db.flush()

        # Marquer comme entièrement saisi
        recap.quantite_saisie = qte_totale
        recap.quantite_totale = qte_totale
        recap.prix_unitaire = prix_unitaire
        recap.date_derniere_saisie = datetime.utcnow().isoformat()

    db.commit()
    return {"ok": True}


@router.get("/projet/{projet_id}/peut-archiver")
def peut_archiver(projet_id: int, db: Session = Depends(get_db)):
    """
    Vérifie si toutes les lignes du récapitulatif ont été saisies.
    """
    from app.routes.consommation import calculer_consommation_piece_exacte
    from math import ceil

    niveaux = db.query(Niveau).filter(Niveau.projet_id == projet_id).all()
    articles = db.query(Article).all()
    articles_map = {a.id: a for a in articles}

    for niveau in niveaux:
        pieces = db.query(Piece).filter(
            Piece.niveau_id == niveau.id,
            Piece.statut_realisation == 'fabriquee'
        ).all()
        conso_niveau = {}
        for piece in pieces:
            conso = calculer_consommation_piece_exacte(piece, articles)
            for art_id, qte in conso.items():
                conso_niveau[art_id] = conso_niveau.get(art_id, 0) + qte

        for art_id, qte_decimale in conso_niveau.items():
            article = articles_map.get(art_id)
            if not article:
                continue
            qte_arrondie = ceil(qte_decimale)
            taux = (article.taux_perte or 0) / 100
            qte_totale = ceil(qte_arrondie * (1 + taux))

            recap = db.query(RecapXbat).filter(
                RecapXbat.projet_id == projet_id,
                RecapXbat.article_id == art_id,
                RecapXbat.niveau_id == niveau.id
            ).first()

            qte_saisie = recap.quantite_saisie if recap else 0
            if qte_saisie < qte_totale:
                return {"peut_archiver": False, "raison": f"{article.designation} non entièrement saisi"}

    return {"peut_archiver": True}