#!/usr/bin/env python3
"""
GPS Data Processor - Real-time GPS validation and processing
Processes location updates from mobile app, validates coordinates,
detects arrivals/departures, and queues alerts

Run: python3 gps_processor.py --mode realtime
     Runs continuously, processing GPS data every 30 seconds
"""

import sys
import argparse
import time
from datetime import datetime, timedelta
import mysql.connector
from math import radians, sin, cos, sqrt, atan2

from db_config import load_db_config
from logging_config import get_logger

logger = get_logger(__name__, '/var/log/gps_processor.log')


class GPSProcessor:
    """Process GPS data from technicians"""

    def __init__(self, db_config):
        """Initialize processor with database connection"""
        self.db_config = db_config
        self.conn = None
        self.connect()

        # GPS validation parameters
        self.min_accuracy = 5  # meters
        self.max_accuracy = 100  # meters (unreliable GPS)
        self.max_speed = 80  # km/h for speeding alert
        self.stop_speed_threshold = 2  # km/h
        self.stop_duration = 120  # seconds before classifying as stopped

    def connect(self):
        """Establish database connection"""
        try:
            self.conn = mysql.connector.connect(**self.db_config)
            logger.info("Database connection established")
        except mysql.connector.Error as e:
            logger.error(f"Database connection failed: {e}")
            raise

    def disconnect(self):
        """Close database connection"""
        if self.conn and self.conn.is_connected():
            self.conn.close()
            logger.info("Database connection closed")

    def calculate_distance(self, lat1, lon1, lat2, lon2):
        """
        Calculate distance between two coordinates in meters (Haversine formula)
        """
        R = 6371000  # Earth radius in meters

        lat1_rad = radians(lat1)
        lat2_rad = radians(lat2)
        delta_lat = radians(lat2 - lat1)
        delta_lon = radians(lon2 - lon1)

        a = sin(delta_lat/2)**2 + cos(lat1_rad) * cos(lat2_rad) * sin(delta_lon/2)**2
        c = 2 * atan2(sqrt(a), sqrt(1-a))
        distance = R * c

        return distance

    def calculate_speed(self, prev_lat, prev_lon, prev_time, curr_lat, curr_lon, curr_time):
        """
        Calculate speed in km/h from two GPS points
        """
        distance_m = self.calculate_distance(prev_lat, prev_lon, curr_lat, curr_lon)

        # Calculate time delta in seconds
        prev_dt = datetime.fromisoformat(prev_time.replace('Z', '+00:00'))
        curr_dt = datetime.fromisoformat(curr_time.replace('Z', '+00:00'))
        time_delta = (curr_dt - prev_dt).total_seconds()

        if time_delta <= 0:
            return 0

        # Convert m/s to km/h
        speed_ms = distance_m / time_delta
        speed_kmh = speed_ms * 3.6

        return speed_kmh

    def validate_coordinates(self, latitude, longitude, accuracy_meters=None):
        """
        Validate GPS coordinates
        Returns: (is_valid, error_message)
        """
        if latitude < -90 or latitude > 90:
            return False, f"Invalid latitude: {latitude}"

        if longitude < -180 or longitude > 180:
            return False, f"Invalid longitude: {longitude}"

        if accuracy_meters and accuracy_meters > self.max_accuracy:
            return False, f"Accuracy too low: {accuracy_meters}m > {self.max_accuracy}m"

        return True, None

    def get_unprocessed_locations(self, limit=50):
        """Get unprocessed location records from database"""
        try:
            cursor = self.conn.cursor(dictionary=True)

            query = """
                SELECT location_id, route_id, user_id, latitude, longitude,
                       accuracy_meters, speed_kmh, timestamp
                FROM technician_location_updates
                WHERE processed = FALSE AND source = 'mobile_app'
                ORDER BY timestamp ASC
                LIMIT %s
            """

            cursor.execute(query, (limit,))
            locations = cursor.fetchall()
            cursor.close()

            return locations
        except mysql.connector.Error as e:
            logger.error(f"Error fetching unprocessed locations: {e}")
            return []

    def mark_location_processed(self, location_id):
        """Mark a location record as processed"""
        try:
            cursor = self.conn.cursor()

            query = "UPDATE technician_location_updates SET processed = TRUE WHERE location_id = %s"
            cursor.execute(query, (location_id,))
            self.conn.commit()
            cursor.close()

            return True
        except mysql.connector.Error as e:
            logger.error(f"Error marking location {location_id} as processed: {e}")
            return False

    def queue_alert(self, route_id, alert_type, severity, message, data=None):
        """Queue an alert to the alerts table"""
        try:
            cursor = self.conn.cursor()
            import json

            data_json = json.dumps(data) if data else None

            query = """
                INSERT INTO alerts (route_id, alert_type, severity, message, data, created_at)
                VALUES (%s, %s, %s, %s, %s, NOW())
            """

            cursor.execute(query, (route_id, alert_type, severity, message, data_json))
            self.conn.commit()
            cursor.close()

            logger.info(f"Alert queued: {alert_type} for route {route_id}")
            return True
        except mysql.connector.Error as e:
            logger.error(f"Error queueing alert: {e}")
            return False

    def detect_speeding(self, route_id, speed_kmh):
        """Check if technician is speeding"""
        if speed_kmh and speed_kmh > self.max_speed:
            self.queue_alert(
                route_id,
                'speeding',
                'warning',
                f"Technician exceeding speed limit: {speed_kmh:.1f} km/h",
                {'speed_kmh': speed_kmh, 'limit_kmh': self.max_speed}
            )
            return True
        return False

    def detect_arrival(self, route_id, latitude, longitude, stop_radius=50):
        """Check if technician arrived at a stop (within 50m)"""
        try:
            cursor = self.conn.cursor(dictionary=True)

            query = """
                SELECT rs.stop_id, rs.customer_id, c.first_name, c.last_name,
                       rs.latitude, rs.longitude
                FROM route_stops rs
                JOIN customers c ON rs.customer_id = c.customer_id
                WHERE rs.route_id = %s AND rs.status IN ('pending', 'in_progress')
                LIMIT 5
            """

            cursor.execute(query, (route_id,))
            stops = cursor.fetchall()
            cursor.close()

            for stop in stops:
                if stop['latitude'] and stop['longitude']:
                    distance = self.calculate_distance(
                        latitude, longitude,
                        float(stop['latitude']), float(stop['longitude'])
                    )

                    if distance <= stop_radius:
                        logger.info(f"Arrival detected at stop {stop['stop_id']}: {distance:.1f}m")
                        return True, stop

            return False, None
        except mysql.connector.Error as e:
            logger.error(f"Error detecting arrival: {e}")
            return False, None

    def process_locations_batch(self):
        """Process all unprocessed GPS locations"""
        locations = self.get_unprocessed_locations(limit=100)

        if not locations:
            return 0

        processed_count = 0

        for loc in locations:
            try:
                # Validate coordinates
                is_valid, error = self.validate_coordinates(
                    loc['latitude'], loc['longitude'], loc['accuracy_meters']
                )

                if not is_valid:
                    logger.warning(f"Invalid coordinates for location {loc['location_id']}: {error}")
                    self.mark_location_processed(loc['location_id'])
                    processed_count += 1
                    continue

                # Check for speeding
                if loc['speed_kmh']:
                    self.detect_speeding(loc['route_id'], loc['speed_kmh'])

                # Check for arrival
                has_arrival, stop_info = self.detect_arrival(
                    loc['route_id'], loc['latitude'], loc['longitude']
                )

                if has_arrival and stop_info:
                    logger.info(f"Stop arrival detected for route {loc['route_id']}")

                # Mark as processed
                self.mark_location_processed(loc['location_id'])
                processed_count += 1

            except Exception as e:
                logger.error(f"Error processing location {loc['location_id']}: {e}")
                continue

        logger.info(f"Processed {processed_count} GPS locations")
        return processed_count

    def run_continuous(self, interval=30):
        """Run processor continuously"""
        logger.info(f"GPS Processor started (interval: {interval}s)")

        try:
            while True:
                self.process_locations_batch()
                time.sleep(interval)
        except KeyboardInterrupt:
            logger.info("GPS Processor stopped by user")
        except Exception as e:
            logger.error(f"GPS Processor error: {e}")
        finally:
            self.disconnect()


def main():
    parser = argparse.ArgumentParser(description='GPS Data Processor')
    parser.add_argument('--mode', choices=['realtime', 'batch'], default='realtime',
                        help='Execution mode: realtime (continuous) or batch (single run)')
    parser.add_argument('--interval', type=int, default=30,
                        help='Interval between processing runs (seconds)')
    parser.add_argument('--config', default='.my.admin.cnf',
                        help='Path to database config file')

    args = parser.parse_args()

    try:
        # Load database config
        db_config = load_db_config(args.config)

        # Create and run processor
        processor = GPSProcessor(db_config)

        if args.mode == 'realtime':
            processor.run_continuous(interval=args.interval)
        else:
            # Single batch run
            count = processor.process_locations_batch()
            logger.info(f"Batch processing complete: {count} locations processed")
            processor.disconnect()

    except Exception as e:
        logger.error(f"Fatal error: {e}")
        sys.exit(1)


if __name__ == '__main__':
    main()
