"""
This script migrates custom content policies from one deployment to another.
It can be used in three modes:
1. Export mode: Fetches policy and dataset details from source and saves to files
2. Import mode: Loads policy and dataset details from files and creates in target (or updates existing policy if --target-policy-id is provided)
3. Update-trained mode: Updates an existing policy to trained state with provided model weights
"""

import argparse
import copy
import json
import os
import sys
from pathlib import Path

import dotenv
import requests

# Load environment variables from .env file
env_path = Path(__file__).parent / ".env"
if not env_path.exists():
    print("Error: .env file not found. Please copy .env.example to .env and fill in the values.")
    sys.exit(1)
dotenv.load_dotenv(env_path)

# Required for all modes
SOURCE_VPC_URL = os.getenv("SOURCE_VPC_URL")
SOURCE_API_KEY = os.getenv("SOURCE_API_KEY")
TARGET_VPC_URL = os.getenv("TARGET_VPC_URL")
TARGET_API_KEY = os.getenv("TARGET_API_KEY")

# Validate required environment variables
if not all([SOURCE_VPC_URL, SOURCE_API_KEY, TARGET_VPC_URL, TARGET_API_KEY]):
    print("Error: Missing required environment variables. Please check your .env file.")
    sys.exit(1)

REQUEST_TIMEOUT = 300  # seconds


def parse_arguments():
    parser = argparse.ArgumentParser(description="Migrate policy from source to target VPC.")
    parser.add_argument(
        "--policy-id",
        help="Policy ID: required for export mode (source policy) and update-trained mode (target policy).",
    )
    parser.add_argument(
        "--target-policy-id",
        help="Target policy ID for import mode: if provided, updates existing target policy (versioning); if omitted, creates new policy.",
    )
    parser.add_argument(
        "--model-weights", help="S3 path to the model weights (required for update-trained mode)"
    )
    parser.add_argument(
        "--mode",
        choices=["export", "import", "update-trained"],
        required=True,
        help="Mode: 'export' fetches and saves, 'import' loads and creates (or updates existing if --target-policy-id is provided), 'update-trained' updates existing policy to trained state",
    )
    parser.add_argument(
        "--import-model",
        action="store_true",
        default=False,
        help="Import model weights from the source environment",
    )
    parser.add_argument(
        "--is-encrypted",
        action="store_true",
        default=False,
        help="Mark the model as encrypted (update-trained mode only)",
    )
    parser.add_argument(
        "--default-policy",
        action="store_true",
        default=False,
        help="Use when migrating default/OOB policies. Preserves original methodParams (adapter_path) "
        "instead of using a dummy S3 path. Do NOT use with custom-trained policies.",
    )
    parser.add_argument(
        "--output-dir",
        default="./policy_data",
        help="Directory to save/load policy data (default: ./policy_data)",
    )
    return parser.parse_args()


def get_paginated_dataset(vpc_url, api_key, policy_id, page=1, limit=500):
    """
    Fetches a single page of the dataset with pagination.

    Args:
        vpc_url (str): The VPC URL
        api_key (str): The API key
        policy_id (str): The policy ID
        page (int): Page number (1-based)
        limit (int): Number of items per page

    Returns:
        tuple: (dataset_page_data, total_pages)
    """
    headers = {"Authorization": f"Bearer {api_key}"}
    params = {
        "populateCreatorEmails": "false",
        "page": page,
        "limit": limit,
        "includeDataset": "true",
    }

    response = requests.get(
        f"{vpc_url}/v1/guardrail-dataset/{policy_id}",
        headers=headers,
        params=params,
        timeout=REQUEST_TIMEOUT,
    )
    response.raise_for_status()
    data = response.json()

    total_pages = data.get("datasetPageDetails", {}).get("totalPages", 1)
    return data, total_pages


def get_complete_dataset(vpc_url, api_key, policy_id):
    """
    Fetches the complete dataset by handling pagination.
    """
    print("Fetching complete dataset...")

    # First get the first page and total pages
    dataset_data, total_pages = get_paginated_dataset(vpc_url, api_key, policy_id)

    # If only one page, return as is
    if total_pages <= 1:
        return dataset_data

    print(f"Total {total_pages} pages found. Fetching all pages...")

    # Get all other pages and merge datasets
    complete_dataset = dataset_data["dataset"]
    for page in range(2, total_pages + 1):
        print(f"Fetching page {page}/{total_pages}...")
        page_data, _ = get_paginated_dataset(vpc_url, api_key, policy_id, page=page)
        complete_dataset.extend(page_data["dataset"])

    # Update the dataset in the response with complete dataset
    dataset_data["dataset"] = complete_dataset
    return dataset_data


def get_finetuned_model(vpc_url, api_key, policy_id):
    """
    Fetches the latest finetuned model ID for the policy.

    Args:
        policy_id (str): The policy ID to look up

    Returns:
        str: The finetuned model ID
    """
    headers = {"Authorization": f"Bearer {api_key}"}
    finetuned_model_response = requests.get(
        f"{vpc_url}/v1/guardrail-dataset/{policy_id}/get-latest-finetuned-model",
        headers=headers,
        timeout=REQUEST_TIMEOUT,
    )
    if finetuned_model_response.status_code == 404:
        print(
            f"No finetuned model found for policy {policy_id} (404). Continuing without model data."
        )
        return {}
    finetuned_model_response.raise_for_status()
    finetuned_model = finetuned_model_response.json()
    return finetuned_model


def get_policy_details(vpc_url, api_key, policy_id):
    headers = {"Authorization": f"Bearer {api_key}"}

    # Get policy details
    policy_response = requests.get(
        f"{vpc_url}/v1/moderation/policy/{policy_id}",
        headers=headers,
        params={"expand": "true"},  # As we need to fetch the behavior details as well
        timeout=REQUEST_TIMEOUT,
    )
    policy_response.raise_for_status()
    policy_data = policy_response.json()

    # Get complete guardrail dataset
    dataset_data = get_complete_dataset(vpc_url, api_key, policy_id)

    # Get latest finetuned model
    finetuned_model = get_finetuned_model(vpc_url, api_key, policy_id)

    return policy_data, dataset_data, finetuned_model


def save_policy_data(policy_data, dataset_data, finetuned_model, output_dir):
    """
    Saves policy and dataset data to JSON files
    """
    os.makedirs(output_dir, exist_ok=True)

    # Save policy data
    policy_file = os.path.join(output_dir, "policy.json")
    with open(policy_file, "w", encoding="utf-8") as f:
        json.dump(policy_data, f, indent=2)

    # Save dataset data
    dataset_file = os.path.join(output_dir, "dataset.json")
    with open(dataset_file, "w", encoding="utf-8") as f:
        json.dump(dataset_data, f, indent=2)

    # Save finetuned model data
    finetuned_model_file = os.path.join(output_dir, "finetuned_model.json")
    with open(finetuned_model_file, "w", encoding="utf-8") as f:
        json.dump(finetuned_model, f, indent=2)

    print(f"Policy data saved to {policy_file}")
    print(f"Dataset data saved to {dataset_file}")
    print(f"Finetuned model data saved to {finetuned_model_file}")


def load_policy_data(output_dir):
    """
    Loads policy and dataset data from JSON files
    """
    # Load policy data
    policy_file = os.path.join(output_dir, "policy.json")
    with open(policy_file, "r", encoding="utf-8") as f:
        policy_data = json.load(f)

    # Load dataset data
    dataset_file = os.path.join(output_dir, "dataset.json")
    with open(dataset_file, "r", encoding="utf-8") as f:
        dataset_data = json.load(f)

    return policy_data, dataset_data


def create_policy_and_dataset(
    target_vpc, target_api_key, policy_data, dataset_data, preserve_method_params=False
):
    headers = {"Authorization": f"Bearer {target_api_key}"}

    # Create policy in target VPC
    policy_payload = {
        "name": policy_data["name"],
        "applicableTo": policy_data["applicableTo"],
        "method": policy_data["method"],
        "action": policy_data["action"],
        # For default/OOB policies (--default-policy flag), preserve original methodParams (adapter_path).
        # For custom policies (default behavior), use dummy path that will be replaced by update-trained mode.
        "methodParams": (
            policy_data.get("methodParams", {})
            if preserve_method_params
            else {"adapter_path": "s3://dummy"}
        ),
        "isDefault": policy_data.get("isDefault", False),
        "description": policy_data.get("description", ""),
        "allowedBehaviors": policy_data.get("allowedBehaviors", []),
        "disallowedBehaviors": policy_data.get("disallowedBehaviors", []),
    }

    policy_response = requests.post(
        f"{target_vpc}/v1/moderation/policy",
        json=policy_payload,
        headers=headers,
        timeout=REQUEST_TIMEOUT,
    )
    policy_response.raise_for_status()
    new_policy_id = policy_response.json()["id"]
    print(f"Created new policy with ID: {new_policy_id}")

    # Create guardrail dataset in target VPC
    dataset_endpoint = f"{target_vpc}/v1/guardrail-dataset/output/{new_policy_id}"
    if policy_data["applicableTo"] == "INPUT":
        dataset_endpoint = f"{target_vpc}/v1/guardrail-dataset/input/{new_policy_id}"

    dataset_payload = {
        # We don't need to provide allowed/disallowed behaviors as they're optional and already updated by create Policy step
        "manualCreation": True
    }

    dataset_response = requests.post(
        dataset_endpoint, json=dataset_payload, headers=headers, timeout=REQUEST_TIMEOUT
    )
    dataset_response.raise_for_status()
    print(f"Created guardrail dataset for policy ID: {new_policy_id}")

    # First update stage to approveReady
    stage_response = requests.patch(
        f"{target_vpc}/v1/guardrail-dataset/{new_policy_id}",
        json={
            "stage": "approveReady",  # explicitly marking it "Ready to train"
        },
        headers=headers,
        timeout=REQUEST_TIMEOUT,
    )
    stage_response.raise_for_status()
    print(f"Updated dataset stage to 'approveReady' for policy ID: {new_policy_id}")

    # Push dataset points in batches
    push_dataset_in_batches(target_vpc, target_api_key, new_policy_id, dataset_data["dataset"])
    return new_policy_id


def push_dataset_in_batches(vpc_url, api_key, policy_id, dataset, batch_size=500):
    """
    Pushes dataset points to a guardrail dataset in batches.

    Args:
        vpc_url (str): The VPC URL
        api_key (str): The API key
        policy_id (str): The policy ID
        dataset (list): List of dataset points to push
        batch_size (int, optional): Number of points to push in each batch. Defaults to 500.
    """
    headers = {"Authorization": f"Bearer {api_key}"}
    total_batches = (len(dataset) + batch_size - 1) // batch_size  # ceiling division

    print(f"\nPushing dataset in batches (total points: {len(dataset)})...")
    for i in range(0, len(dataset), batch_size):
        batch_num = (i // batch_size) + 1
        batch = dataset[i : i + batch_size]
        print(f"Pushing batch {batch_num}/{total_batches} ({len(batch)} points)...")

        batch_response = requests.post(
            f"{vpc_url}/v1/guardrail-dataset/{policy_id}/dataset",
            json={
                "dataset": batch,
            },
            headers=headers,
            timeout=REQUEST_TIMEOUT,
        )
        batch_response.raise_for_status()

    print(f"Successfully pushed all {len(dataset)} dataset points")


def get_policy_jobs(target_vpc, target_api_key, policy_id):
    """Fetch all jobs for a given policy"""
    headers = {"Authorization": f"Bearer {target_api_key}"}
    res = requests.get(
        f"{target_vpc}/v1/jobs/policy/{policy_id}/",
        params={"stats-only": "true"},
        headers=headers,
        timeout=REQUEST_TIMEOUT,
    )
    res.raise_for_status()
    return res.json()


def get_latest_job(jobs):
    """Find the latest job by queueEnterTime"""
    if not jobs:
        return None
    return max(jobs, key=lambda x: x.get("queueEnterTime", ""))


def update_job_status(target_vpc, target_api_key, job_id, status):
    """Update job status"""
    headers = {"Authorization": f"Bearer {target_api_key}"}
    res = requests.patch(
        f"{target_vpc}/v1/jobs/{job_id}",
        json={"status": status},
        headers=headers,
        timeout=REQUEST_TIMEOUT,
    )
    res.raise_for_status()
    print(f"Updated job {job_id} status to {status}")


def get_model_weights(output_dir):
    """Get model weights and encryption status from the output directory"""
    model_weights_file = os.path.join(output_dir, "finetuned_model.json")
    with open(model_weights_file, "r", encoding="utf-8") as f:
        model_weights = json.load(f)
    if "fineTunedModelZipS3Path" not in model_weights:
        raise ValueError("model weights not found in the exported data")

    weights_path = model_weights["fineTunedModelZipS3Path"]
    is_encrypted = model_weights.get("isModelEncrypted", False)

    return weights_path, is_encrypted


def update_to_trained(target_vpc, target_api_key, policy_id, weights_path, is_encrypted=False):
    headers = {"Authorization": f"Bearer {target_api_key}"}

    res = requests.post(
        f"{target_vpc}/v1/guardrail-dataset/{policy_id}/startTraining",
        json={"title": "Migrating policy", "description": "This policy version was migrated."},
        headers=headers,
        timeout=REQUEST_TIMEOUT,
    )
    res.raise_for_status()

    print("Started training.")

    # Now there would be a training job in the queue. We can cancel it before the training starts.
    response = requests.post(
        f"{target_vpc}/v1/guardrail-dataset/{policy_id}/cancelTraining",
        headers={"Authorization": f"Bearer {target_api_key}"},
        timeout=REQUEST_TIMEOUT,
    )
    response.raise_for_status()
    print(f"Successfully cancelled training for policy {policy_id}")

    finetuned_model = get_finetuned_model(target_vpc, target_api_key, policy_id)
    finetuned_model_id = finetuned_model["_id"]

    res = requests.patch(
        f"{target_vpc}/v1/guardrail-dataset/fineTunedModelInfo/{finetuned_model_id}",
        json={
            "status": "TRAINED",
            "isModelEncrypted": is_encrypted,
            "fineTunedModelZipS3Path": weights_path,
        },
        headers=headers,
        timeout=REQUEST_TIMEOUT,
    )
    res.raise_for_status()
    print(f"Patched finetunedModelInfo (isModelEncrypted={is_encrypted})")
    res = requests.post(
        f"{target_vpc}/v1/guardrail-dataset/{policy_id}/updateTrainingStatus",
        json={"isSuccess": True, "fineTunedModeId": finetuned_model_id},
        headers=headers,
        timeout=REQUEST_TIMEOUT,
    )
    res.raise_for_status()
    print("Patched training status")

    # Update latest job status to COMPLETED, to simulate indicating that the training is complete
    print("Fetching policy jobs...")
    jobs = get_policy_jobs(target_vpc, target_api_key, policy_id)
    latest_job = get_latest_job(jobs)

    # We are assuming the user cancels the latest training job after "import" mode is complete
    if latest_job and latest_job.get("status") == "CANCELLED":
        job_id = latest_job["id"]
        print(f"The latest job is cancelled, updating status of job {job_id} to COMPLETED...")

        # Update status following state machine transitions
        update_job_status(target_vpc, target_api_key, job_id, "QUEUED")
        update_job_status(target_vpc, target_api_key, job_id, "IN_PROGRESS")
        update_job_status(target_vpc, target_api_key, job_id, "COMPLETED")
        print("Successfully updated job status to COMPLETED")


def export_mode(args):
    """
    Export mode: Fetches policy and dataset details from source and saves to files
    """
    if not args.policy_id:
        print("Error: --policy-id is required in export mode")
        sys.exit(1)

    try:
        print(f"Fetching policy and dataset data from {SOURCE_VPC_URL}...")
        policy_data, dataset_data, finetuned_model = get_policy_details(
            SOURCE_VPC_URL, SOURCE_API_KEY, args.policy_id
        )

        print(f"Saving policy and dataset data to {args.output_dir}...")
        save_policy_data(policy_data, dataset_data, finetuned_model, args.output_dir)

        print(f"Successfully exported policy '{policy_data['name']}' from {SOURCE_VPC_URL}")
        print(f"Source policy ID: {args.policy_id}")

    except requests.exceptions.HTTPError as err:
        print(f"HTTP Error: {err}")
        print(f"Response: {err.response.text}")
        sys.exit(1)
    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)


def import_mode(args):
    """
    Import mode: Loads policy and dataset details from files and creates in target.
    If --target-policy-id is provided, updates an existing policy to create a new version instead.
    """
    try:
        print(f"Loading policy and dataset data from {args.output_dir}...")
        policy_data, dataset_data = load_policy_data(args.output_dir)

        target_policy_id = None

        if args.target_policy_id:
            # Versioning-import behavior: Update existing policy to create new version
            target_policy_id = args.target_policy_id
            print(f"Updating existing policy {target_policy_id} in {TARGET_VPC_URL}...")

            # Step 1: Update the guardrail definition
            print("Step 1: Updating guardrail config...")
            update_guardrail_config(TARGET_VPC_URL, TARGET_API_KEY, target_policy_id, policy_data)

            # Step 2: Incrementally sync datapoints (only delete/add what changed)
            print("Step 2: Syncing datapoints (skipping unchanged datapoints)...")
            sync_datapoints_incremental(
                TARGET_VPC_URL,
                TARGET_API_KEY,
                target_policy_id,
                dataset_data["dataset"],
            )

            print(
                f"Successfully created new version for policy '{policy_data['name']}' (ID: {target_policy_id})"
            )
        else:
            # Regular import behavior: Create new policy
            print(f"Creating new policy and dataset in {TARGET_VPC_URL}...")
            target_policy_id = create_policy_and_dataset(
                TARGET_VPC_URL,
                TARGET_API_KEY,
                policy_data,
                dataset_data,
                preserve_method_params=args.default_policy,
            )

            print(f"Successfully imported policy '{policy_data['name']}' to {TARGET_VPC_URL}")
            print(f"Target policy ID: {target_policy_id}")

        # Common logic: Update to trained state if model weights are available
        if args.import_model:
            print("Updating policy to trained state with model weights...")
            model_weights, is_encrypted = get_model_weights(args.output_dir)
            update_to_trained(
                TARGET_VPC_URL, TARGET_API_KEY, target_policy_id, model_weights, is_encrypted
            )
            print(f"Successfully updated policy {target_policy_id} to trained state")
        else:
            print(
                f"To update this policy to trained state, run the script again with mode=update-trained and --policy-id={target_policy_id}"
            )

    except requests.exceptions.HTTPError as err:
        print(f"HTTP Error: {err}")
        print(f"Response: {err.response.text}")
        sys.exit(1)
    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)


def update_trained_mode(args):
    """
    Update-trained mode: Updates an existing policy to trained state using provided model weights
    """
    if not args.policy_id:
        print("Error: --policy-id is required in update-trained mode")
        sys.exit(1)

    if not args.model_weights:
        print("Error: --model-weights is required in update-trained mode")
        sys.exit(1)

    try:
        print(
            f"Updating policy {args.policy_id} to trained state with model weights from {args.model_weights}..."
        )
        update_to_trained(
            TARGET_VPC_URL, TARGET_API_KEY, args.policy_id, args.model_weights, args.is_encrypted
        )
        print(f"Successfully updated policy {args.policy_id} to trained state")

    except requests.exceptions.HTTPError as err:
        print(f"HTTP Error: {err}")
        print(f"Response: {err.response.text}")
        sys.exit(1)
    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)


def update_guardrail_config(target_vpc, target_api_key, policy_id, policy_data):
    """
    Updates the guardrail config for an existing policy.

    Args:
        target_vpc (str): The target VPC URL
        target_api_key (str): The target API key
        policy_id (str): The policy ID to update
        policy_data (dict): The policy data containing config fields

    Returns:
        None
    """
    headers = {"Authorization": f"Bearer {target_api_key}"}

    config_payload = {
        "name": policy_data.get("name", ""),
        "description": policy_data.get("description", ""),
        "allowedBehaviors": policy_data.get("allowedBehaviors", []),
        "disallowedBehaviors": policy_data.get("disallowedBehaviors", []),
    }

    response = requests.put(
        f"{target_vpc}/v1/guardrail-dataset/{policy_id}/config",
        json=config_payload,
        headers=headers,
        timeout=REQUEST_TIMEOUT,
    )
    response.raise_for_status()
    print(f"Updated guardrail config for policy ID: {policy_id}")


def normalize_for_comparison(datapoint):
    """
    Normalize datapoint for comparison by:
    1. Removing dynamic attributes (id, createdAt, updatedAt)
    2. Sorting arrays for consistent comparison

    Args:
        datapoint (dict): The datapoint to normalize

    Returns:
        dict: Normalized datapoint (deep copy)
    """
    normalized = copy.deepcopy(datapoint)

    # Remove dynamic attributes
    normalized.pop("id", None)
    normalized.pop("createdAt", None)
    normalized.pop("updatedAt", None)

    # Sort arrays in generated object for consistent comparison
    if "generated" in normalized and isinstance(normalized["generated"], dict):
        if "relevantBehaviors" in normalized["generated"] and isinstance(
            normalized["generated"]["relevantBehaviors"], list
        ):
            normalized["generated"]["relevantBehaviors"] = sorted(
                normalized["generated"]["relevantBehaviors"]
            )

        if "violatedCategoriesResolutionCodes" in normalized["generated"] and isinstance(
            normalized["generated"]["violatedCategoriesResolutionCodes"], list
        ):
            normalized["generated"]["violatedCategoriesResolutionCodes"] = sorted(
                normalized["generated"]["violatedCategoriesResolutionCodes"]
            )

    return normalized


def datapoints_equal(dp1, dp2):
    """
    Compare two datapoints (excluding dynamic attributes).

    Args:
        dp1 (dict): First datapoint
        dp2 (dict): Second datapoint

    Returns:
        bool: True if datapoints are equal (excluding dynamic attributes)
    """
    norm1 = normalize_for_comparison(dp1)
    norm2 = normalize_for_comparison(dp2)
    return norm1 == norm2


def delete_datapoints_by_ids(target_vpc, target_api_key, policy_id, datapoint_ids):
    """
    Deletes datapoints by their IDs.

    Args:
        target_vpc (str): The target VPC URL
        target_api_key (str): The target API key
        policy_id (str): The policy ID
        datapoint_ids (list): List of datapoint IDs to delete

    Returns:
        None
    """
    if not datapoint_ids:
        return

    headers = {"Authorization": f"Bearer {target_api_key}"}
    delete_payload = {"ids": datapoint_ids}
    response = requests.delete(
        f"{target_vpc}/v1/guardrail-dataset/{policy_id}/dataset",
        json=delete_payload,
        headers=headers,
        timeout=REQUEST_TIMEOUT,
    )
    response.raise_for_status()


def sync_datapoints_incremental(target_vpc, target_api_key, policy_id, source_dataset):
    """
    Incrementally syncs datapoints between source and target by:
    1. Skipping exact matches (no add, no delete)
    2. Deleting target datapoints that don't exist in source
    3. Adding source datapoints that don't exist in target

    This minimizes unnecessary diff entries by only creating versions for actual changes.

    Args:
        target_vpc (str): The target VPC URL
        target_api_key (str): The target API key
        policy_id (str): The policy ID
        source_dataset (list): List of source datapoints to sync

    Returns:
        None
    """
    # Get the complete target dataset
    print("Fetching target dataset for comparison...")
    target_dataset_data = get_complete_dataset(target_vpc, target_api_key, policy_id)
    target_dataset = target_dataset_data.get("dataset", [])

    print(f"Source datapoints: {len(source_dataset)}")
    print(f"Target datapoints: {len(target_dataset)}")

    # Build normalized comparison maps with counts to handle duplicates
    # Map: normalized_string -> list of datapoints/IDs
    source_normalized_to_datapoints = {}
    target_normalized_to_ids = {}

    for source_dp in source_dataset:
        normalized = normalize_for_comparison(source_dp)
        # Convert to JSON string for hashing (since dicts aren't hashable)
        normalized_str = json.dumps(normalized, sort_keys=True)
        if normalized_str not in source_normalized_to_datapoints:
            source_normalized_to_datapoints[normalized_str] = []
        source_normalized_to_datapoints[normalized_str].append(source_dp)

    for target_dp in target_dataset:
        normalized = normalize_for_comparison(target_dp)
        normalized_str = json.dumps(normalized, sort_keys=True)
        if normalized_str not in target_normalized_to_ids:
            target_normalized_to_ids[normalized_str] = []
        target_normalized_to_ids[normalized_str].append(target_dp["id"])

    # Calculate differences accounting for duplicates
    all_normalized_strings = set(source_normalized_to_datapoints.keys()) | set(
        target_normalized_to_ids.keys()
    )

    datapoints_to_delete = []
    datapoints_to_add = []
    exact_matches_count = 0

    for normalized_str in all_normalized_strings:
        source_count = len(source_normalized_to_datapoints.get(normalized_str, []))
        target_count = len(target_normalized_to_ids.get(normalized_str, []))

        if source_count == target_count:
            # Exact match (including duplicates) - skip all
            exact_matches_count += source_count
        elif source_count > target_count:
            # Source has more - add the difference
            source_dps = source_normalized_to_datapoints[normalized_str]
            # Add only the difference (source_count - target_count) copies
            datapoints_to_add.extend(source_dps[: source_count - target_count])
        else:
            # Target has more - delete the difference
            target_ids = target_normalized_to_ids[normalized_str]
            # Delete only the difference (target_count - source_count) copies
            datapoints_to_delete.extend(target_ids[: target_count - source_count])

    print("\nComparison results:")
    print(f"  Exact matches (skipped): {exact_matches_count}")
    print(f"  Datapoints to delete: {len(datapoints_to_delete)}")
    print(f"  Datapoints to add: {len(datapoints_to_add)}")

    # Delete datapoints that don't exist in source
    if datapoints_to_delete:
        print(f"\nDeleting {len(datapoints_to_delete)} datapoints...")
        delete_datapoints_by_ids(target_vpc, target_api_key, policy_id, datapoints_to_delete)
        print(f"Successfully deleted {len(datapoints_to_delete)} datapoints")
    else:
        print("\nNo datapoints to delete.")

    # Add new datapoints from source
    if datapoints_to_add:
        print(f"\nAdding {len(datapoints_to_add)} new datapoints...")
        push_dataset_in_batches(target_vpc, target_api_key, policy_id, datapoints_to_add)
    else:
        print("\nNo new datapoints to add.")

    if exact_matches_count > 0:
        print(f"\nSkipped {exact_matches_count} unchanged datapoints (no diff entries created)")


def main():
    args = parse_arguments()

    if args.mode == "export":
        export_mode(args)
    elif args.mode == "import":
        import_mode(args)
    elif args.mode == "update-trained":
        update_trained_mode(args)


if __name__ == "__main__":
    main()
