import re, csv, sys, pptx, fitz                   # fitz=PyMuPDF
import mysql.connector, os, json
from docx import Document
import pandas as pd
from openai import OpenAI
from dotenv import load_dotenv
import chromadb
from datetime import datetime
import nltk
from nltk.tokenize import sent_tokenize

# Assicura risorse per nltk
nltk.download('punkt')

# Carica variabili d'ambiente
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
openai = OpenAI(api_key=OPENAI_API_KEY)

# MySQL config
MYSQL_HOST = os.getenv("MYSQL_HOST")
MYSQL_USER = os.getenv("MYSQL_USER")
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD")
MYSQL_DATABASE = os.getenv("MYSQL_DATABASE")

cfg = dict(
    host=MYSQL_HOST,
    port=3306,
    user=MYSQL_USER,
    password=MYSQL_PASSWORD,
    database=MYSQL_DATABASE
)

USE_NLTK = False

# Setup ChromaDB (persistente)
#chroma_client = chromadb.Client(path="/var/www/html/ai_worker/chroma", settings=chromadb.Settings(allow_reset=True))
chroma_client = chromadb.PersistentClient(path="/var/www/html/ai_worker/chroma", settings=chromadb.Settings(allow_reset=True))
collection = chroma_client.get_or_create_collection("case_studies")

# Estrai frasi usando nltk

def extract_sentences(text):
    #parts = re.split(r'(?<=[\.\!\?])\s+', text)
    #return [p.strip() for p in parts if len(p.strip()) > 30]
    if USE_NLTK:
        return [s.strip() for s in sent_tokenize(text, language='italian') if len(s.strip()) > 30]
    else:
        # Fallback semplice senza nltk
        parts = re.split(r'(?<=[\.\!\?])\s+', text)
        return [s.strip() for s in parts if len(s.strip()) > 30]

# Recupera KPI e mapping da DB

def get_kpi_labels():
    with mysql.connector.connect(**cfg) as c:
        cur = c.cursor()
        cur.execute("SELECT label FROM business_case_factors WHERE is_active = 1")
        return [r[0] for r in cur.fetchall()]

def get_factor_id_map():
    with mysql.connector.connect(**cfg) as c:
        cur = c.cursor()
        cur.execute("SELECT id, label FROM business_case_factors WHERE is_active = 1")
        return {label.lower(): fid for fid, label in cur.fetchall()}

# Invia a ChatGPT per dedurre settore e KPI per tutto il testo

def infer_sector_and_kpis(text, valid_kpi_labels, log_path=None):
    prompt = f"""
Analizza il seguente testo aziendale. Indica per ogni caso studio distinto:
1. Il settore aziendale principale (una sola parola, es: logistica, sanità, tecnologia, produzione, ecc.).
2. Gli indicatori chiave di performance (KPI) presenti, scegliendoli solo da questa lista:
{valid_kpi_labels}
Per ogni KPI presente, indica anche il valore numerico se c'è (es: +12%).

Testo:
{text}

Rispondi in JSON nel formato:
[
  {{"sector": "settore", "kpis": [{{"kpi": "Produttività", "valore": "+12%"}}, ...]}},
  ...
]
"""
    response = openai.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{"role": "user", "content": prompt}],
        temperature=0
    )
    result = response.choices[0].message.content

    if log_path:
        os.makedirs(os.path.dirname(log_path), exist_ok=True)
        with open(log_path, "a", encoding="utf-8") as f:
            f.write("\n===== PROMPT =====\n")
            f.write(prompt)
            f.write("\n===== RISPOSTA GPT =====\n")
            f.write(result)
            f.write("\n========================\n")

    return result

# Parsing file intero

def parse_full_text(path):
    ext = os.path.splitext(path)[1].lower()
    if ext == ".docx":
        doc = Document(path)
        return "\n".join(p.text.strip() for p in doc.paragraphs if p.text.strip())
    elif ext == ".pdf":
        doc = fitz.open(path)
        return "\n".join(page.get_text() for page in doc)
    elif ext == ".pptx":
        prs = pptx.Presentation(path)
        texts = []
        for slide in prs.slides:
            for shp in slide.shapes:
                if shp.has_text_frame:
                    texts.append(shp.text_frame.text.strip())
        return "\n".join(texts)
    else:
        return ""

# Import in MySQL

def load_into_db(rows):
    with mysql.connector.connect(**cfg) as c:
        cur = c.cursor()
        for row in rows:
            cur.execute("""
              INSERT INTO case_study_kpi (sector, factor_id, best_value)
              VALUES (%s,%s,%s)
              ON DUPLICATE KEY UPDATE best_value = VALUES(best_value)
            """, row)
    print("📥 Import terminato")

def populate_multi_sector():
    with mysql.connector.connect(**cfg) as cnx:
        cur = cnx.cursor()

        # Calcola la media per ogni factor_id ignorando 'multi' e 'unknown'
        cur.execute("""
            SELECT factor_id, AVG(best_value)
            FROM case_study_kpi
            WHERE sector NOT IN ('multi', 'unknown')
            GROUP BY factor_id
        """)
        rows = cur.fetchall()

        for factor_id, avg_value in rows:
            cur.execute("""
                INSERT INTO case_study_kpi (sector, factor_id, best_value)
                VALUES ('multi', %s, %s)
                ON DUPLICATE KEY UPDATE best_value = VALUES(best_value)
            """, (factor_id, avg_value))

    print("✅ Valori 'multi' aggiornati con medie settoriali.")


# Main

def main():
    path = sys.argv[1]
    log_file = f"log/gpt_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
    err_log = log_file.replace("gpt_log_", "gpt_kpi_invalid_")

    text = parse_full_text(path)
    if not text:
        print("❌ Nessun testo estratto dal file")
        return

    kpi_labels = get_kpi_labels()
    fid_map = get_factor_id_map()

    print("🧠 Analisi GPT su documento completo...")
    json_response = infer_sector_and_kpis(text, kpi_labels, log_path=log_file)
    cleaned = json_response.strip()
    if cleaned.startswith("```json"):
        cleaned = cleaned.removeprefix("```json").strip()
    if cleaned.endswith("```"):
        cleaned = cleaned.removesuffix("```").strip()
    last = cleaned.rfind("]")
    if last != -1:
        cleaned = cleaned[:last + 1]

    try:
        parsed_list = json.loads(cleaned)
        if isinstance(parsed_list, dict):
            parsed_list = [parsed_list]
    except json.JSONDecodeError as e:
        print("⚠️ Errore nel parsing JSON:", e)
        print("Risposta GPT grezza:\n", repr(json_response))
        return

    all_rows = []
    skipped_kpis = []
    simboli_ignorati = {"+", "-", "n.d.", "n/a", "", None}

    for parsed in parsed_list:
        sector = parsed.get("sector", "multi")
        for kpi in parsed.get("kpis", []):
            kpiname = kpi.get("kpi", "").lower()
            raw = kpi.get("valore")
            if not isinstance(raw, str) or raw.strip().lower() in simboli_ignorati:
                skipped_kpis.append((sector, kpiname, raw, "valore assente o simbolico"))
                continue
            m = re.search(r"(-?\d+(\.\d+)?)(?=%|€)?", raw)
            if m and kpiname in fid_map:
                value = float(m.group(1))
                factor_id = fid_map[kpiname]
                all_rows.append((sector, factor_id, value))
            else:
                skipped_kpis.append((sector, kpiname, raw, "regex non valida o kpi ignoto"))


    if skipped_kpis:
        with open(err_log, "w", encoding="utf-8") as log:
            log.write("SETTORE\tKPI\tVALORE\tERRORE\n")
            for s in skipped_kpis:
                log.write("\t".join(map(str, s)) + "\n")

    load_into_db(all_rows)

    sentences = extract_sentences(text)
    total = len(sentences)
    for i, sent in enumerate(sentences, 1):
        percent = (i / total) * 100
        print(f"\r🔄 Embedding frasi: {i}/{total} ({percent:.1f}%)", end="")
        emb = openai.embeddings.create(
            model="text-embedding-3-small",
            input=sent
        ).data[0].embedding

        matched_factor_id = None
        for label, fid in fid_map.items():
            if label in sent.lower():
                matched_factor_id = fid
                break

        collection.add(
            documents=[sent],
            embeddings=[emb],
            metadatas=[{
                "source": os.path.basename(path),
                "sector": "multi",
                "factor_id": matched_factor_id if matched_factor_id is not None else "unclassified"
            }],
            ids=[f"{os.path.basename(path)}_{hash(sent)%100000}"]
        )

    print("\n✅ Dati importati su MySQL e ChromaDB.")

    populate_multi_sector()
    print("📊 Popolamento 'multi' completato.")

if __name__ == "__main__":
    main()
