#!/usr/bin/env python3
"""
Route Analyzer - Analyze completed routes for efficiency metrics
Compares actual GPS path to planned route, calculates efficiency,
and updates route analytics

Run: python3 route_analyzer.py --mode batch
     Processes completed routes every 5 minutes
"""

import sys
import argparse
import time
from datetime import datetime, timedelta
from math import radians, sin, cos, sqrt, atan2
import mysql.connector
import json

from db_config import load_db_config
from logging_config import get_logger

logger = get_logger(__name__, '/var/log/route_analyzer.log')


class RouteAnalyzer:
    """Analyze route efficiency from GPS data"""

    def __init__(self, db_config):
        """Initialize analyzer with database connection"""
        self.db_config = db_config
        self.conn = None
        self.connect()

    def connect(self):
        """Establish database connection"""
        try:
            self.conn = mysql.connector.connect(**self.db_config)
            logger.info("Database connection established")
        except mysql.connector.Error as e:
            logger.error(f"Database connection failed: {e}")
            raise

    def disconnect(self):
        """Close database connection"""
        if self.conn and self.conn.is_connected():
            self.conn.close()
            logger.info("Database connection closed")

    def calculate_distance(self, lat1, lon1, lat2, lon2):
        """Calculate distance between two coordinates in meters"""
        R = 6371000  # Earth radius in meters

        lat1_rad = radians(lat1)
        lat2_rad = radians(lat2)
        delta_lat = radians(lat2 - lat1)
        delta_lon = radians(lon2 - lon1)

        a = sin(delta_lat/2)**2 + cos(lat1_rad) * cos(lat2_rad) * sin(delta_lon/2)**2
        c = 2 * atan2(sqrt(a), sqrt(1-a))
        distance = R * c

        return distance

    def calculate_actual_distance(self, gps_points):
        """Calculate total distance traveled from GPS points"""
        if not gps_points or len(gps_points) < 2:
            return 0

        total_distance = 0
        for i in range(len(gps_points) - 1):
            dist = self.calculate_distance(
                float(gps_points[i]['latitude']),
                float(gps_points[i]['longitude']),
                float(gps_points[i+1]['latitude']),
                float(gps_points[i+1]['longitude'])
            )
            total_distance += dist

        # Convert meters to miles
        return total_distance / 1609.34

    def get_route_gps_data(self, route_id):
        """Get all GPS points for a route"""
        try:
            cursor = self.conn.cursor(dictionary=True)

            query = """
                SELECT location_id, latitude, longitude, speed_kmh, timestamp
                FROM technician_location_updates
                WHERE route_id = %s
                ORDER BY timestamp ASC
            """

            cursor.execute(query, (route_id,))
            locations = cursor.fetchall()
            cursor.close()

            return locations
        except mysql.connector.Error as e:
            logger.error(f"Error fetching GPS data for route {route_id}: {e}")
            return []

    def get_route_stops(self, route_id):
        """Get all stops for a route in order"""
        try:
            cursor = self.conn.cursor(dictionary=True)

            query = """
                SELECT rs.stop_id, rs.stop_sequence, rs.latitude, rs.longitude,
                       rs.estimated_arrival_time, rs.actual_arrival_time,
                       rs.estimated_completion_time, rs.actual_completion_time
                FROM route_stops rs
                WHERE rs.route_id = %s
                ORDER BY rs.stop_sequence ASC
            """

            cursor.execute(query, (route_id,))
            stops = cursor.fetchall()
            cursor.close()

            return stops
        except mysql.connector.Error as e:
            logger.error(f"Error fetching stops for route {route_id}: {e}")
            return []

    def analyze_stop_visit(self, stop_gps_points, estimated_time):
        """Analyze time spent at a stop"""
        if not stop_gps_points or len(stop_gps_points) < 2:
            return 0

        first_time = datetime.fromisoformat(
            stop_gps_points[0]['timestamp'].replace('Z', '+00:00')
        )
        last_time = datetime.fromisoformat(
            stop_gps_points[-1]['timestamp'].replace('Z', '+00:00')
        )

        actual_time_minutes = (last_time - first_time).total_seconds() / 60

        return actual_time_minutes

    def calculate_route_efficiency(self, actual_distance, estimated_distance,
                                   actual_time, estimated_time):
        """Calculate efficiency percentage"""
        if estimated_distance <= 0 or estimated_time <= 0:
            return 0

        # Distance efficiency (actual vs estimated)
        distance_ratio = actual_distance / estimated_distance if estimated_distance > 0 else 1

        # Time efficiency (actual vs estimated)
        time_ratio = actual_time / estimated_time if estimated_time > 0 else 1

        # Combined efficiency (average of both ratios, inverted so 100% = perfect)
        efficiency = ((2 - (distance_ratio + time_ratio) / 2) * 100)

        # Cap between 0 and 200
        efficiency = max(0, min(200, efficiency))

        return round(efficiency, 1)

    def update_route_analytics(self, route_id):
        """Analyze a completed route and update database"""
        try:
            # Get route info
            cursor = self.conn.cursor(dictionary=True)
            cursor.execute(
                "SELECT * FROM routes WHERE route_id = %s",
                (route_id,)
            )
            route = cursor.fetchone()
            cursor.close()

            if not route:
                logger.warning(f"Route {route_id} not found")
                return False

            # Get GPS data
            gps_points = self.get_route_gps_data(route_id)
            if not gps_points:
                logger.info(f"No GPS data for route {route_id}")
                return False

            # Calculate actual distance
            actual_distance_gps = self.calculate_actual_distance(gps_points)

            # Get estimated distance
            estimated_distance = route['total_distance_miles'] or 0
            estimated_time = route['total_duration_minutes'] or 0

            # Calculate actual time (from start_time to end_time)
            if route['start_time'] and route['end_time']:
                start = datetime.fromisoformat(str(route['start_time']))
                end = datetime.fromisoformat(str(route['end_time']))
                actual_time = (end - start).total_seconds() / 60
            else:
                # Use GPS data timestamps
                if len(gps_points) >= 2:
                    start_dt = datetime.fromisoformat(
                        gps_points[0]['timestamp'].replace('Z', '+00:00')
                    )
                    end_dt = datetime.fromisoformat(
                        gps_points[-1]['timestamp'].replace('Z', '+00:00')
                    )
                    actual_time = (end_dt - start_dt).total_seconds() / 60
                else:
                    actual_time = estimated_time

            # Calculate efficiency
            gps_efficiency = self.calculate_route_efficiency(
                actual_distance_gps, estimated_distance,
                actual_time, estimated_time
            )

            # Calculate max speed
            max_speed = 0
            for point in gps_points:
                if point['speed_kmh'] and point['speed_kmh'] > max_speed:
                    max_speed = point['speed_kmh']

            # Update route record
            cursor = self.conn.cursor()

            update_query = """
                UPDATE routes
                SET actual_distance_gps = %s,
                    gps_efficiency_percent = %s,
                    max_speed_kmh = %s
                WHERE route_id = %s
            """

            cursor.execute(update_query, (
                actual_distance_gps,
                gps_efficiency,
                max_speed if max_speed > 0 else None,
                route_id
            ))
            self.conn.commit()
            cursor.close()

            logger.info(
                f"Route {route_id} analyzed: "
                f"distance={actual_distance_gps:.1f}mi, "
                f"efficiency={gps_efficiency:.1f}%, "
                f"max_speed={max_speed:.1f}km/h"
            )

            return True

        except Exception as e:
            logger.error(f"Error analyzing route {route_id}: {e}")
            return False

    def get_completed_routes_to_analyze(self):
        """Get completed routes that haven't been analyzed yet"""
        try:
            cursor = self.conn.cursor(dictionary=True)

            query = """
                SELECT route_id, route_date, status
                FROM routes
                WHERE status = 'completed'
                AND (gps_efficiency_percent IS NULL
                     OR actual_distance_gps IS NULL)
                AND end_time IS NOT NULL
                ORDER BY end_time DESC
                LIMIT 20
            """

            cursor.execute(query)
            routes = cursor.fetchall()
            cursor.close()

            return routes
        except mysql.connector.Error as e:
            logger.error(f"Error fetching routes to analyze: {e}")
            return []

    def analyze_batch(self):
        """Analyze a batch of completed routes"""
        routes = self.get_completed_routes_to_analyze()

        if not routes:
            logger.debug("No routes to analyze")
            return 0

        analyzed_count = 0

        for route in routes:
            if self.update_route_analytics(route['route_id']):
                analyzed_count += 1

        logger.info(f"Analyzed {analyzed_count} routes")
        return analyzed_count

    def run_continuous(self, interval=300):
        """Run analyzer continuously"""
        logger.info(f"Route Analyzer started (interval: {interval}s)")

        try:
            while True:
                self.analyze_batch()
                time.sleep(interval)
        except KeyboardInterrupt:
            logger.info("Route Analyzer stopped by user")
        except Exception as e:
            logger.error(f"Route Analyzer error: {e}")
        finally:
            self.disconnect()


def main():
    parser = argparse.ArgumentParser(description='Route Analyzer')
    parser.add_argument('--mode', choices=['realtime', 'batch'], default='batch',
                        help='Execution mode: realtime (continuous) or batch (single run)')
    parser.add_argument('--interval', type=int, default=300,
                        help='Interval between analysis runs (seconds)')
    parser.add_argument('--config', default='.my.admin.cnf',
                        help='Path to database config file')

    args = parser.parse_args()

    try:
        # Load database config
        db_config = load_db_config(args.config)

        # Create and run analyzer
        analyzer = RouteAnalyzer(db_config)

        if args.mode == 'realtime':
            analyzer.run_continuous(interval=args.interval)
        else:
            # Single batch run
            count = analyzer.analyze_batch()
            logger.info(f"Batch analysis complete: {count} routes analyzed")
            analyzer.disconnect()

    except Exception as e:
        logger.error(f"Fatal error: {e}")
        sys.exit(1)


if __name__ == '__main__':
    main()
