#!/usr/bin/env python3
"""
Location Batch Processor - Process offline GPS location syncs
Handles deduplication, validation, and cleanup of GPS location data

Run: python3 location_batch_processor.py --mode batch
     Runs every 5 minutes for batch processing

     python3 location_batch_processor.py --mode archive
     Daily cleanup of old location data
"""

import sys
import argparse
import time
from datetime import datetime, timedelta
from math import radians, sin, cos, sqrt, atan2
import mysql.connector

from db_config import load_db_config
from logging_config import get_logger

logger = get_logger(__name__, '/var/log/location_batch_processor.log')


class LocationBatchProcessor:
    """Process GPS location batches"""

    def __init__(self, db_config):
        """Initialize processor"""
        self.db_config = db_config
        self.conn = None
        self.connect()

        # Batch processing parameters
        self.dedup_distance_m = 5  # meters
        self.dedup_time_s = 10  # seconds
        self.archive_age_days = 30  # Keep 30 days of history
        self.offline_gap_threshold_min = 5  # Gap > 5 min = offline period

    def connect(self):
        """Establish database connection"""
        try:
            self.conn = mysql.connector.connect(**self.db_config)
            logger.info("Database connection established")
        except mysql.connector.Error as e:
            logger.error(f"Database connection failed: {e}")
            raise

    def disconnect(self):
        """Close database connection"""
        if self.conn and self.conn.is_connected():
            self.conn.close()
            logger.info("Database connection closed")

    def calculate_distance(self, lat1, lon1, lat2, lon2):
        """Calculate distance between coordinates in meters"""
        R = 6371000  # Earth radius in meters

        lat1_rad = radians(lat1)
        lat2_rad = radians(lat2)
        delta_lat = radians(lat2 - lat1)
        delta_lon = radians(lon2 - lon1)

        a = sin(delta_lat/2)**2 + cos(lat1_rad) * cos(lat2_rad) * sin(delta_lon/2)**2
        c = 2 * atan2(sqrt(a), sqrt(1-a))
        distance = R * c

        return distance

    def get_pending_syncs(self, limit=1000):
        """Get offline sync batches to process"""
        try:
            cursor = self.conn.cursor(dictionary=True)

            query = """
                SELECT location_id, route_id, user_id, latitude, longitude,
                       accuracy_meters, timestamp
                FROM technician_location_updates
                WHERE processed = FALSE
                AND is_offline_sync = TRUE
                ORDER BY route_id, timestamp ASC
                LIMIT %s
            """

            cursor.execute(query, (limit,))
            locations = cursor.fetchall()
            cursor.close()

            return locations
        except mysql.connector.Error as e:
            logger.error(f"Error fetching pending syncs: {e}")
            return []

    def validate_location(self, latitude, longitude, accuracy_meters):
        """Validate GPS location data"""
        if latitude < -90 or latitude > 90:
            return False, "Invalid latitude"

        if longitude < -180 or longitude > 180:
            return False, "Invalid longitude"

        if accuracy_meters and accuracy_meters > 100:
            return False, "Accuracy too low (>100m)"

        return True, None

    def dedup_locations(self, locations):
        """
        Remove duplicate/nearby locations
        Merges points within 5m and 10 seconds
        """
        if not locations:
            return []

        # Sort by timestamp
        sorted_locs = sorted(locations, key=lambda x: x['timestamp'])
        deduplicated = []

        for loc in sorted_locs:
            if not deduplicated:
                deduplicated.append(loc)
                continue

            # Compare with last kept location
            last = deduplicated[-1]

            # Calculate time delta
            if isinstance(loc['timestamp'], str):
                curr_time = datetime.fromisoformat(loc['timestamp'].replace('Z', '+00:00'))
                last_time = datetime.fromisoformat(last['timestamp'].replace('Z', '+00:00'))
            else:
                curr_time = loc['timestamp']
                last_time = last['timestamp']

            time_delta = (curr_time - last_time).total_seconds()

            # Calculate distance
            distance = self.calculate_distance(
                float(last['latitude']), float(last['longitude']),
                float(loc['latitude']), float(loc['longitude'])
            )

            # Keep if far enough or time gap is large
            if distance > self.dedup_distance_m or time_delta > self.dedup_time_s:
                deduplicated.append(loc)

        logger.info(f"Deduplicated {len(locations)} to {len(deduplicated)} locations")
        return deduplicated

    def process_batch(self, batch):
        """Process a batch of offline locations"""
        if not batch:
            return 0, 0

        processed = 0
        failed = 0

        for loc in batch:
            try:
                # Validate
                is_valid, error = self.validate_location(
                    float(loc['latitude']),
                    float(loc['longitude']),
                    loc['accuracy_meters']
                )

                if not is_valid:
                    logger.warning(f"Invalid location {loc['location_id']}: {error}")
                    self.mark_location_processed(loc['location_id'])
                    failed += 1
                    continue

                # Mark as processed
                if self.mark_location_processed(loc['location_id']):
                    processed += 1
                else:
                    failed += 1

            except Exception as e:
                logger.error(f"Error processing location {loc['location_id']}: {e}")
                failed += 1
                continue

        logger.info(f"Batch processing: {processed} processed, {failed} failed")
        return processed, failed

    def mark_location_processed(self, location_id):
        """Mark location as processed"""
        try:
            cursor = self.conn.cursor()

            query = "UPDATE technician_location_updates SET processed = TRUE WHERE location_id = %s"
            cursor.execute(query, (location_id,))
            self.conn.commit()
            cursor.close()

            return True
        except mysql.connector.Error as e:
            logger.error(f"Error marking location {location_id} as processed: {e}")
            return False

    def detect_offline_period(self, locations):
        """Detect offline periods (gaps > 5 minutes)"""
        if not locations or len(locations) < 2:
            return []

        offline_periods = []

        for i in range(len(locations) - 1):
            curr_time = datetime.fromisoformat(
                locations[i]['timestamp'].replace('Z', '+00:00')
            )
            next_time = datetime.fromisoformat(
                locations[i+1]['timestamp'].replace('Z', '+00:00')
            )

            gap_minutes = (next_time - curr_time).total_seconds() / 60

            if gap_minutes > self.offline_gap_threshold_min:
                offline_periods.append({
                    'start': curr_time,
                    'end': next_time,
                    'gap_minutes': gap_minutes
                })

                if gap_minutes > 30:
                    logger.warning(f"Long offline period: {gap_minutes:.0f} minutes")

        return offline_periods

    def cleanup_old_data(self, days_to_keep=30):
        """Archive/delete old location data"""
        try:
            cutoff_date = datetime.now() - timedelta(days=days_to_keep)

            cursor = self.conn.cursor()

            # Count records to delete
            count_query = """
                SELECT COUNT(*) as count FROM technician_location_updates
                WHERE timestamp < %s
            """
            cursor.execute(count_query, (cutoff_date,))
            count_result = cursor.fetchone()
            record_count = count_result[0] if count_result else 0

            # Delete old records
            delete_query = """
                DELETE FROM technician_location_updates
                WHERE timestamp < %s
            """
            cursor.execute(delete_query, (cutoff_date,))
            self.conn.commit()

            logger.info(f"Cleaned up {record_count} records older than {days_to_keep} days")
            cursor.close()

            return record_count
        except mysql.connector.Error as e:
            logger.error(f"Error cleaning up old data: {e}")
            return 0

    def process_all_pending(self):
        """Process all pending offline syncs"""
        pending = self.get_pending_syncs()

        if not pending:
            logger.debug("No pending syncs to process")
            return 0, 0

        logger.info(f"Processing {len(pending)} pending sync locations")

        # Deduplicate
        deduplicated = self.dedup_locations(pending)

        # Process batch
        processed, failed = self.process_batch(deduplicated)

        return processed, failed

    def run_batch_job(self):
        """Run batch processing job"""
        logger.info("Location Batch Processor started (batch mode)")

        # Process pending syncs
        processed, failed = self.process_all_pending()
        logger.info(f"Batch processing complete: {processed} processed, {failed} failed")

        self.disconnect()

    def run_archive_job(self):
        """Run archive/cleanup job"""
        logger.info("Location Archive job started")

        # Clean up old data
        cleaned = self.cleanup_old_data(self.archive_age_days)
        logger.info(f"Archive complete: {cleaned} records cleaned")

        self.disconnect()

    def run_continuous(self, interval=300):
        """Run processor continuously"""
        logger.info(f"Location Batch Processor started (interval: {interval}s)")

        try:
            while True:
                processed, failed = self.process_all_pending()
                time.sleep(interval)
        except KeyboardInterrupt:
            logger.info("Location Batch Processor stopped by user")
        except Exception as e:
            logger.error(f"Location Batch Processor error: {e}")
        finally:
            self.disconnect()


def main():
    parser = argparse.ArgumentParser(description='Location Batch Processor')
    parser.add_argument('--mode', choices=['batch', 'archive', 'realtime'], default='batch',
                        help='Mode: batch (single run), archive (cleanup), realtime (continuous)')
    parser.add_argument('--interval', type=int, default=300,
                        help='Interval for realtime mode (seconds)')
    parser.add_argument('--config', default='.my.admin.cnf',
                        help='Path to database config file')
    parser.add_argument('--keep-days', type=int, default=30,
                        help='Days of location data to keep (archive mode)')

    args = parser.parse_args()

    try:
        # Load database config
        db_config = load_db_config(args.config)

        # Create processor
        processor = LocationBatchProcessor(db_config)
        processor.archive_age_days = args.keep_days

        if args.mode == 'batch':
            processor.run_batch_job()
        elif args.mode == 'archive':
            processor.run_archive_job()
        elif args.mode == 'realtime':
            processor.run_continuous(interval=args.interval)

    except Exception as e:
        logger.error(f"Fatal error: {e}")
        sys.exit(1)


if __name__ == '__main__':
    main()
