"""
Calcul de consommation des articles par projet
Gestion des réservations de stock
"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from sqlalchemy import func
from typing import List
from math import ceil
from datetime import datetime

from app.database import get_db
from app.models.projet import Projet, Niveau, Piece
from app.models.article import Article, MouvementStock
from app.models.notification import Notification
from app.models.reservation import ReservationStock

router = APIRouter(prefix="/api/consommation", tags=["consommation"])


def calculer_consommation_piece_exacte(piece, articles) -> dict:
    """
    Calcule la consommation exacte en décimales pour une pièce.
    Retourne {article_id: quantite_decimale}
    """
    consommation = {}

    # ── BECQUETS ──
    if piece.becquet_type and piece.becquet_longueur:
        try:
            art_id = int(piece.becquet_type)
            article = next((a for a in articles if a.id == art_id), None)
            if article and article.longueur_unitaire:
                qte = (piece.becquet_longueur * piece.quantite) / article.longueur_unitaire
                consommation[article.id] = consommation.get(article.id, 0) + qte
            elif article:
                consommation[article.id] = consommation.get(article.id, 0) + piece.quantite
        except (ValueError, TypeError):
            pass

    # ── STARTS ──
    if piece.config_start and isinstance(piece.config_start, dict):
        art_id = piece.config_start.get('article_id')
        total_ml = piece.config_start.get('total_ml', 0)
        if art_id:
            try:
                article = next((a for a in articles if a.id == int(art_id)), None)
                if article and article.longueur_unitaire and total_ml:
                    total_cm = total_ml * 100 * piece.quantite
                    qte = total_cm / article.longueur_unitaire
                    consommation[article.id] = consommation.get(article.id, 0) + qte
                elif article:
                    nb = piece.config_start.get('nb_boites', 0)
                    consommation[article.id] = consommation.get(article.id, 0) + (nb * piece.quantite)
            except (ValueError, TypeError):
                pass

    # ── TREILLIS ──
    if piece.config_treillis and isinstance(piece.config_treillis, dict):
        # Structure multi-nappes : {nappes: [{article_id, nb_panneaux, nappe}]}
        if 'nappes' in piece.config_treillis:
            for nappe in piece.config_treillis['nappes']:
                art_id = nappe.get('article_id')
                nb_panneaux = nappe.get('nb_panneaux', 0)
                if art_id and nb_panneaux:
                    try:
                        article = next((a for a in articles if a.id == int(art_id)), None)
                        if article:
                            qte = nb_panneaux * piece.quantite
                            consommation[article.id] = consommation.get(article.id, 0) + qte
                    except (ValueError, TypeError):
                        pass
        # Structure simple : {article_id, nb_panneaux, couches}
        else:
            art_id = piece.config_treillis.get('article_id')
            nb_panneaux = piece.config_treillis.get('nb_panneaux', 0)
            couches = piece.config_treillis.get('couches', 1)
            if art_id and nb_panneaux:
                try:
                    article = next((a for a in articles if a.id == int(art_id)), None)
                    if article:
                        qte = nb_panneaux * couches * piece.quantite
                        consommation[article.id] = consommation.get(article.id, 0) + qte
                except (ValueError, TypeError):
                    pass

    # ── INSERTS ──
    if piece.config_inserts and isinstance(piece.config_inserts, dict):
        art_id = piece.config_inserts.get('article_id')
        quantite_ml = piece.config_inserts.get('quantite_ml', 0)
        if art_id and quantite_ml:
            try:
                article = next((a for a in articles if a.id == int(art_id)), None)
                if article and article.longueur_unitaire:
                    total_cm = quantite_ml * 100 * piece.quantite
                    qte = total_cm / article.longueur_unitaire
                    consommation[article.id] = consommation.get(article.id, 0) + qte
                elif article:
                    consommation[article.id] = consommation.get(article.id, 0) + piece.quantite
            except (ValueError, TypeError):
                pass

    return consommation


def calculer_consommation_par_niveau(projet_id: int, db) -> dict:
    """
    Calcule la consommation totale par niveau puis arrondit au supérieur.
    Applique le taux de perte après arrondi.
    Retourne {article_id: quantite_finale}
    """
    from app.models.projet import Niveau
    articles = db.query(Article).all()
    articles_map = {a.id: a for a in articles}
    niveaux = db.query(Niveau).filter(Niveau.projet_id == projet_id).all()
    consommation_totale = {}

    for niveau in niveaux:
        pieces = db.query(Piece).filter(Piece.niveau_id == niveau.id).all()
        conso_niveau = {}
        for piece in pieces:
            conso = calculer_consommation_piece_exacte(piece, articles)
            for art_id, qte in conso.items():
                conso_niveau[art_id] = conso_niveau.get(art_id, 0) + qte

        for art_id, qte_decimale in conso_niveau.items():
            article = articles_map.get(art_id)
            if article:
                qte_arrondie = ceil(qte_decimale)
                taux = (article.taux_perte or 0) / 100
                qte_finale = ceil(qte_arrondie * (1 + taux))
                consommation_totale[art_id] = consommation_totale.get(art_id, 0) + qte_finale

    return consommation_totale


def calculer_consommation_piece(piece, articles) -> dict:
    """Compatibilité avec imputer_stock_piece — arrondi immédiat."""
    return {k: ceil(v) for k, v in calculer_consommation_piece_exacte(piece, articles).items()}


def get_stock_disponible(article_id, db, exclure_projet_id=None):
    article = db.query(Article).filter(Article.id == article_id).first()
    if not article:
        return 0
    query = db.query(func.sum(ReservationStock.quantite))\
              .filter(ReservationStock.article_id == article_id)
    if exclure_projet_id:
        query = query.filter(ReservationStock.projet_id != exclure_projet_id)
    total_reserve = query.scalar() or 0
    return article.stock_actuel - total_reserve


@router.get("/projet/{projet_id}")
def get_consommation_projet(projet_id: int, db: Session = Depends(get_db)):
    projet = db.query(Projet).filter(Projet.id == projet_id).first()
    if not projet:
        raise HTTPException(404, "Projet introuvable")

    articles = db.query(Article).all()
    consommation_totale = calculer_consommation_par_niveau(projet_id, db)

    resultat = []
    for art_id, qte_necessaire in consommation_totale.items():
        article = next((a for a in articles if a.id == art_id), None)
        if article:
            stock_dispo = get_stock_disponible(art_id, db, exclure_projet_id=projet_id)
            stock_dispo_affiche = max(0, stock_dispo)
            resultat.append({
                "article_id": art_id,
                "code": article.code,
                "designation": article.designation,
                "unite": article.unite,
                "qte_necessaire": qte_necessaire,
                "stock_physique": article.stock_actuel,
                "stock_reserve": article.stock_actuel - stock_dispo,
                "stock_disponible": stock_dispo_affiche,
                "suffisant": stock_dispo >= qte_necessaire,
                "manque": max(0, qte_necessaire - stock_dispo)
            })

    return {
        "projet_id": projet_id,
        "projet_nom": projet.nom,
        "articles": resultat,
        "alerte": any(not a["suffisant"] for a in resultat)
    }


@router.post("/projet/{projet_id}/verifier-et-notifier")
def verifier_et_notifier(projet_id: int, db: Session = Depends(get_db)):
    projet = db.query(Projet).filter(Projet.id == projet_id).first()
    if not projet:
        raise HTTPException(404, "Projet introuvable")

    db.query(ReservationStock).filter(ReservationStock.projet_id == projet_id).delete()

    consommation_totale = calculer_consommation_par_niveau(projet_id, db)
    articles = db.query(Article).all()

    alertes = []
    for art_id, qte_necessaire in consommation_totale.items():
        article = next((a for a in articles if a.id == art_id), None)
        if not article:
            continue

        stock_dispo = get_stock_disponible(art_id, db, exclure_projet_id=projet_id)

        reservation = ReservationStock(
            article_id=art_id,
            projet_id=projet_id,
            quantite=qte_necessaire,
            date_creation=datetime.utcnow().isoformat()
        )
        db.add(reservation)

        if stock_dispo < qte_necessaire:
            manque = qte_necessaire - stock_dispo
            notif = Notification(
                type="alerte_stock",
                module="etudes",
                titre=f"Stock insuffisant - {article.designation}",
                message=f"Projet {projet.nom} : besoin {qte_necessaire} {article.unite}, disponible {stock_dispo} {article.unite} (physique {article.stock_actuel}, reserve {article.stock_actuel - stock_dispo}), manque {manque} {article.unite}.",
                projet_id=projet_id,
                lu=False
            )
            db.add(notif)
            alertes.append(article.designation)

    db.commit()
    return {"alertes_creees": len(alertes), "articles_en_alerte": alertes}


@router.delete("/projet/{projet_id}/reservations")
def liberer_reservations(projet_id: int, db: Session = Depends(get_db)):
    db.query(ReservationStock).filter(ReservationStock.projet_id == projet_id).delete()
    db.commit()
    return {"ok": True}


@router.post("/piece/{piece_id}/imputer-stock")
def imputer_stock_piece(piece_id: int, db: Session = Depends(get_db)):
    piece = db.query(Piece).filter(Piece.id == piece_id).first()
    if not piece:
        raise HTTPException(404, "Piece introuvable")

    articles = db.query(Article).all()
    consommation = calculer_consommation_piece(piece, articles)

    imputations = []
    for art_id, qte in consommation.items():
        article = db.query(Article).filter(Article.id == art_id).first()
        if article:
            article.stock_actuel = max(0, article.stock_actuel - qte)

            if article.stock_actuel == 0:
                article.statut = 'stock_faible'
            elif article.seuil_alerte > 0 and article.stock_actuel <= article.seuil_alerte:
                article.statut = 'stock_faible'
            else:
                article.statut = 'actif'

            mouvement = MouvementStock(
                article_id=art_id,
                quantite=-qte,
                type_mouvement='fabrication',
                reference=f"Piece {piece.reference} - Projet {piece.projet_id}",
                commentaire=f"Imputation automatique fabrication piece {piece.reference}"
            )
            db.add(mouvement)

            reservations = db.query(ReservationStock)\
                             .filter(ReservationStock.article_id == art_id,
                                     ReservationStock.projet_id == piece.projet_id)\
                             .all()
            if reservations:
                r = reservations[0]
                r.quantite = max(0, r.quantite - qte)
                if r.quantite == 0:
                    db.delete(r)

            imputations.append({
                "article": article.designation,
                "qte_imputee": qte,
                "stock_restant": article.stock_actuel
            })

    db.commit()
    return {"imputations": imputations, "nb_articles": len(imputations)}


@router.get("/article/{article_id}/stock-resume")
def get_stock_resume_article(article_id: int, db: Session = Depends(get_db)):
    article = db.query(Article).filter(Article.id == article_id).first()
    if not article:
        raise HTTPException(404, "Article introuvable")

    total_reserve = db.query(func.sum(ReservationStock.quantite))\
                      .filter(ReservationStock.article_id == article_id)\
                      .scalar() or 0

    return {
        "article_id": article_id,
        "stock_physique": article.stock_actuel,
        "stock_reserve": total_reserve,
        "stock_disponible": max(0, article.stock_actuel - total_reserve)
    }
