"""Metrics tracking module for energy, latency, and memory measurements."""
import time
import json
import os
from contextlib import contextmanager
from pathlib import Path
import psutil
from codecarbon import EmissionsTracker

# Ensure metrics directory exists
METRICS_DIR = Path("metrics")
METRICS_DIR.mkdir(exist_ok=True)
HISTORY_FILE = METRICS_DIR / "history.jsonl"


class MetricsTracker:
    """Context manager for tracking inference metrics."""
    
    def __init__(self):
        self.energy_Wh = 0.0
        self.latency_ms = 0.0
        self.memory_MiB = 0.0
        self.tracker = None
        self.start_time = None
        self.process = psutil.Process()
        
    def __enter__(self):
        """Start tracking metrics."""
        # Track memory before inference
        self.start_memory = self.process.memory_info().rss / (2**20)  # Convert to MiB
        
        # Start energy tracking
        self.tracker = EmissionsTracker(
            measure_power_secs=1,
            save_to_file=False,
            logging_logger=None,
            log_level="error"
        )
        self.tracker.start()
        
        # Start latency tracking
        self.start_time = time.perf_counter()
        
        return self
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Stop tracking and compute metrics."""
        # Stop latency tracking
        end_time = time.perf_counter()
        self.latency_ms = (end_time - self.start_time) * 1000.0
        
        # Stop energy tracking
        emissions = self.tracker.stop()
        
        # CodeCarbon returns emissions in kWh, convert to Wh
        # emissions is the CO2 emissions in kg, we need energy
        # We use the final_energy_consumed attribute
        if hasattr(self.tracker, '_total_energy'):
            energy_kwh = self.tracker._total_energy.kWh
        else:
            # Fallback: use emissions data
            energy_kwh = emissions if emissions else 0.0
        
        self.energy_Wh = energy_kwh * 1000.0  # Convert kWh to Wh
        
        # Track memory after inference
        end_memory = self.process.memory_info().rss / (2**20)  # Convert to MiB
        self.memory_MiB = end_memory - self.start_memory
        
        # Ensure non-negative memory
        if self.memory_MiB < 0:
            self.memory_MiB = 0.0
        
        return False
    
    def get_metrics(self):
        """Get metrics as a dictionary."""
        return {
            "energy_Wh": round(self.energy_Wh, 6),
            "latency_ms": round(self.latency_ms, 2),
            "memory_MiB": round(self.memory_MiB, 2)
        }


@contextmanager
def tracked_inference():
    """Context manager for tracked inference.
    
    Usage:
        with tracked_inference() as tracker:
            result = model.generate(...)
        metrics = tracker.get_metrics()
    """
    tracker = MetricsTracker()
    with tracker:
        yield tracker


def measure(inference_func):
    """Measure metrics for an inference function.
    
    Args:
        inference_func: Callable that performs inference
        
    Returns:
        tuple: (result, metrics_dict)
    """
    with tracked_inference() as tracker:
        result = inference_func()
    
    metrics = tracker.get_metrics()
    
    # Save to history
    save_metrics(metrics, result)
    
    return result, metrics


def save_metrics(metrics, summary=None):
    """Save metrics to history file in JSONL format.
    
    Args:
        metrics: Dictionary of metrics
        summary: Optional summary text
    """
    entry = {
        "timestamp": time.time(),
        "metrics": metrics
    }
    if summary:
        entry["summary_length"] = len(summary.split())
    
    # Append to JSONL file
    with open(HISTORY_FILE, "a", encoding="utf-8") as f:
        json.dump(entry, f, ensure_ascii=False)
        f.write("\n")
