"""Model loading and generation module with baseline and optimized modes."""
import torch
import torch.nn as nn
import platform
from transformers import AutoTokenizer, AutoModelForCausalLM
from src.config import MODEL_PATH, MODEL_NAME

# Global cache for models
_model_cache = {}
_tokenizer = None


def load_tokenizer_model():
    """Load tokenizer and model in FP32 on CPU.
    
    Tries to load from local cache first (models/cache/), then falls back
    to downloading from HuggingFace if local files are not available.
    
    Returns:
        tuple: (tokenizer, model) both loaded in FP32 on CPU
    """
    global _tokenizer
    
    if _tokenizer is None:
        # Try to load tokenizer from local cache first, then download if needed
        try:
            _tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True)
        except (OSError, ValueError):
            # Fallback: download from HuggingFace if local files are missing or invalid
            # OSError covers FileNotFoundError and other file-related errors
            _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        
        # Set pad token if not already set
        if _tokenizer.pad_token is None:
            _tokenizer.pad_token = _tokenizer.eos_token
    
    # Load model in FP32 on CPU
    # Try local files first, then download if needed
    try:
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_PATH,
            dtype=torch.float32,
            low_cpu_mem_usage=False,
            local_files_only=True
        )
    except (OSError, ValueError):
        # Fallback: download from HuggingFace if local files are missing or invalid
        # OSError covers FileNotFoundError and other file-related errors
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            dtype=torch.float32,
            low_cpu_mem_usage=False
        )
    
    model = model.to('cpu')
    model.eval()
    
    return _tokenizer, model


def _quantize_dynamic_preserve_head(model: nn.Module) -> nn.Module:
    """Apply dynamic INT8 quantization to Linear layers while preserving lm_head in FP32.

    This targets the main transformer blocks for speedups but keeps the final
    projection (`lm_head`) in full precision to protect generation quality.
    """
    # Preserve sensitive output head in FP32
    has_head = hasattr(model, "lm_head") and isinstance(getattr(model, "lm_head"), nn.Module)
    float_head = model.lm_head if has_head else None

    # Use torch.ao.quantization API as recommended
    try:
        from torch.ao.quantization import quantize_dynamic as ao_quantize_dynamic
    except Exception:
        # Fallback to legacy namespace if necessary
        from torch.quantization import quantize_dynamic as ao_quantize_dynamic  # type: ignore

    def _quantize_linears_selective(module: nn.Module, prefix: str = ""):
        """Recursively quantize only MLP submodules' Linear layers, skip attention and heads."""
        for child_name, child in module.named_children():
            full_name = f"{prefix}.{child_name}" if prefix else child_name
            # Skip output head explicitly
            if child is getattr(model, "lm_head", None):
                continue
            # Quantize only MLP blocks (common pattern in GPTNeoX/Pythia)
            if "mlp" in full_name.lower():
                try:
                    qchild = ao_quantize_dynamic(child, {nn.Linear}, dtype=torch.qint8)
                    setattr(module, child_name, qchild)
                except Exception:
                    # If selective quantization fails on this child, leave it untouched
                    pass
            else:
                _quantize_linears_selective(child, full_name)

    # First try selective quantization focused on MLPs
    quantized = model
    try:
        _quantize_linears_selective(quantized)
    except Exception:
        # If the model structure is unexpected, fall back to broad quantization
        quantized = ao_quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)

    # Restore output head in FP32 to reduce degradation
    if has_head and float_head is not None:
        quantized.lm_head = float_head

    quantized.eval()
    return quantized


def get_model(mode):
    """Get model for the specified mode (baseline or optimized).
    
    Args:
        mode (str): Either "baseline" or "optimized"
        
    Returns:
        tuple: (tokenizer, model) configured for the specified mode
    """
    global _model_cache, _tokenizer
    
    # Check if model is already cached
    if mode in _model_cache:
        return _tokenizer, _model_cache[mode]
    
    # Load tokenizer and base model
    tokenizer, model = load_tokenizer_model()
    
    if mode == "baseline":
        # Baseline: FP32 strict, no optimizations
        _model_cache[mode] = model
    elif mode == "optimized":
        # Optimized: dynamic INT8 quantization with preserved FP32 head
        model_optimized = _quantize_dynamic_preserve_head(model)

        # Optionally apply torch.compile if it is stable and available
        # Avoid on Windows where it may be slower/unstable for CPU quantized graphs
        use_compile = hasattr(torch, "compile") and platform.system() != "Windows"
        if use_compile:
            try:
                model_optimized = torch.compile(model_optimized, mode="reduce-overhead")
            except Exception:
                # If compilation fails or is unsupported, continue without it
                pass
        
        _model_cache[mode] = model_optimized
    else:
        raise ValueError(f"Invalid mode: {mode}. Must be 'baseline' or 'optimized'.")
    
    return tokenizer, _model_cache[mode]


def clear_cache():
    """Clear the model cache (useful for testing)."""
    global _model_cache
    _model_cache = {}
