from decimal import Decimal, InvalidOperation 
from rest_framework import serializers
from .models import Salary, Staff

class StaffSerializer(serializers.ModelSerializer):
    class Meta:
        model = Staff
        fields = [
            "id", "name", "father_name", "nic", "photo",
            "address", "position", "salary", "status",
        ]

    def to_representation(self, instance):
        data = super().to_representation(instance)
        try:
            data["salary"] = float(instance.salary)
        except (ValueError, TypeError):
            data["salary"] = 0.0 
        return data

class SalarySerializer(serializers.ModelSerializer):

    class Meta:
        model = Salary
        fields = ("id", "month", "year", "total", "customers_list")

    def update(self, instance, validated_data):
        instance.month = validated_data.get("month", instance.month)
        instance.year = validated_data.get("year", instance.year) # Also allow year update
        customers_data = validated_data.get("customers_list", None)

        if not isinstance(instance.customers_list, dict):
            instance.customers_list = {}
        customers_list_changed = False # Flag to track if recalculation is needed
        if customers_data and isinstance(customers_data, dict):
            current_customers_list = instance.customers_list

            for staff_id, incoming_staff_data in customers_data.items():
                if not isinstance(incoming_staff_data, dict):
                    print(f"Warning: Invalid data format for staff_id {staff_id}. Skipping.")
                    continue

                staff_id_str = str(staff_id)
                existing_staff_data = current_customers_list.get(staff_id_str, {})

                try:
                    salary_str = str(incoming_staff_data.get("salary", existing_staff_data.get("salary", "0.0")))
                    taken_str = str(incoming_staff_data.get("taken", existing_staff_data.get("taken", "0.0")))
                    description = incoming_staff_data.get("description", existing_staff_data.get("description", "")) # Keep description
                    salary = Decimal(salary_str)
                    taken = Decimal(taken_str)
                    remainder = salary - taken

                    updated_staff_entry = {
                        "salary": float(salary),    # Store as float as per original pattern
                        "taken": float(taken),      # Store as float
                        "remainder": float(remainder),# STORE THE CALCULATED REMAINDER
                        "description": description,
                    }
                    current_customers_list[staff_id_str] = updated_staff_entry
                    customers_list_changed = True

                except (InvalidOperation, ValueError, TypeError) as e:
                    print(f"Warning: Error processing data for staff_id {staff_id_str}: {e}. Skipping.")
                    continue # Skip this staff member on error

            if customers_list_changed:
                total_salary = Decimal("0.0")
                for staff_data in current_customers_list.values():
                    if isinstance(staff_data, dict):
                        try:
                            total_salary += Decimal(str(staff_data.get("salary", "0.0")))
                        except (InvalidOperation, ValueError, TypeError):
                            print(f"Warning: Invalid salary '{staff_data.get('salary')}' during total calculation.")
                instance.total = total_salary

        instance.save()
        return instance

    def to_representation(self, instance):
        data = super().to_representation(instance)

        total_taken = Decimal("0.0")
        total_remainder = Decimal("0.0")

        customers_list_data = data.get("customers_list", {})
        if not isinstance(customers_list_data, dict):
            customers_list_data = {}

        for customer_data in customers_list_data.values():
            if isinstance(customer_data, dict):
                try:
                    total_taken += Decimal(str(customer_data.get("taken", "0.0")))
                    total_remainder += Decimal(str(customer_data.get("remainder", "0.0"))) # Sum the stored remainder
                except (InvalidOperation, ValueError, TypeError) as e:
                    print(f"Warning: Non-numeric value during representation sum: {customer_data}, Error: {e}")

        data["total_taken"] = float(total_taken)
        data["total_remainder"] = float(total_remainder)

        try:
            data["total"] = float(instance.total)
        except (ValueError, TypeError):
            data["total"] = 0.0

        return data