#!/usr/bin/env python3
"""
Télécharge le modèle HuggingFace EleutherAI/pythia-70m-deduped
et le sauvegarde dans le dossier `models/cache/EleutherAI/pythia-70m-deduped`.

Usage:
    python download_pythia_model.py
    python download_pythia_model.py --repo EleutherAI/pythia-70m-deduped --out models/cache

Notes:
- Exécutez ce script sur une machine avec accès internet.
- Pour pousser le modèle dans un repo GitHub, utilisez Git LFS (les fichiers peuvent dépasser la limite normale).
"""

import argparse
import os
import sys
import shutil
from pathlib import Path

try:
    from transformers import AutoModelForCausalLM, AutoTokenizer
except Exception as e:
    print("Le paquet 'transformers' est requis. Installez-le via `pip install transformers`.", file=sys.stderr)
    raise

def sanitize_repo_id(repo_id: str):
    # transforme "EleutherAI/pythia-70m-deduped" en ("EleutherAI", "pythia-70m-deduped")
    parts = repo_id.split("/")
    if len(parts) == 1:
        return ("", parts[0])
    return (parts[0], "/".join(parts[1:]))

def download_and_save(repo_id: str, out_base: str, trust_remote_code: bool = False, revision: str | None = None):
    owner, name = sanitize_repo_id(repo_id)
    # destination: {out_base}/{owner}/{name} or {out_base}/{name} if no owner
    if owner:
        dest = Path(out_base) / owner / name
    else:
        dest = Path(out_base) / name

    if dest.exists():
        print(f"[i] Le dossier {dest} existe déjà. Le contenu sera écrasé si download succeed.")
    else:
        dest.mkdir(parents=True, exist_ok=True)

    print(f"[i] Téléchargement du tokenizer depuis '{repo_id}'...")
    tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True, revision=revision, trust_remote_code=trust_remote_code)
    print(f"[i] Téléchargement du modèle (AutoModelForCausalLM) depuis '{repo_id}'...")
    model = AutoModelForCausalLM.from_pretrained(repo_id, revision=revision, trust_remote_code=trust_remote_code)

    # sauvegarde
    print(f"[i] Sauvegarde dans '{dest}' ...")
    tokenizer.save_pretrained(dest)
    model.save_pretrained(dest)

    # vérification rapide
    expected_files = list(dest.glob("*"))
    print(f"[i] Fichiers présents dans {dest}:")
    for p in expected_files:
        print("    -", p.name)

    print("\n[i] Terminé. Le modèle est maintenant disponible localement dans:")
    print("    ", dest.resolve())
    return dest

def main():
    p = argparse.ArgumentParser(description="Télécharge un modèle HF et le place dans models/cache/")
    p.add_argument("--repo", "-r", default="EleutherAI/pythia-70m-deduped", help="ID du repo HuggingFace (owner/name)")
    p.add_argument("--out", "-o", default="models/cache", help="Répertoire racine de destination (par défaut: models/cache)")
    p.add_argument("--revision", help="Révision / commit / tag à télécharger (optionnel)")
    p.add_argument("--trust-remote-code", action="store_true", help="Passer trust_remote_code=True (utile si le modèle a du code custom)")
    args = p.parse_args()

    try:
        dest = download_and_save(args.repo, args.out, trust_remote_code=args.trust_remote_code, revision=args.revision)
    except KeyboardInterrupt:
        print("\n[!] Interrompu par l'utilisateur.", file=sys.stderr)
        sys.exit(1)
    except Exception as e:
        print(f"[!] Erreur pendant le téléchargement: {e}", file=sys.stderr)
        sys.exit(2)

    # Conseils pour GitHub (affiché, sans actions automatiques)
    print("\nIMPORTANT:")
    print(" - Les fichiers du modèle peuvent être volumineux. Pour les suivre dans git, activez Git LFS:")
    print("     git lfs install")
    print("     git lfs track \"models/cache/**\"")
    print("     git add .gitattributes")
    print(" - Puis commit/push normalement (attention aux limites du repo).")
    print(" - Si vous préférez ne pas mettre le modèle dans le repo, conservez `models/cache` hors du dépôt et documentez le chemin dans la README.")

if __name__ == "__main__":
    main()