package com.saas.voip.handler;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.saas.voip.config.VoiceAiConfig;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.handshake.ServerHandshake;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;

import java.net.URI;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@Component
@RequiredArgsConstructor
@Slf4j
public class ElevenLabsSessionHandler implements AiSessionHandler {
    
    private final VoiceAiConfig voiceAiConfig;
    private final ObjectMapper objectMapper = new ObjectMapper();
    
    private final Map<String, WebSocketClient> elevenLabsClients = new ConcurrentHashMap<>();
    private final Map<String, StringBuilder> sessionTranscripts = new ConcurrentHashMap<>();
    private final Map<String, String> sessionCallSids = new ConcurrentHashMap<>();
    private final Map<String, String> sessionStreamIds = new ConcurrentHashMap<>(); // sessionId -> streamSid (for Twilio)
    private final Map<String, String> conversationIdToCallSid = new ConcurrentHashMap<>(); // conversation_id -> callSid
    private final Map<String, String> conversationIdToSessionId = new ConcurrentHashMap<>(); // conversation_id -> sessionId
    
    @Override
    public void onClientConnect(WebSocketSession session, String streamSid, String callSid, String fromNumber, String toNumber) throws Exception {
        log.info("ElevenLabs handler - Client connected: streamSid={}, callSid={}", streamSid, callSid);
        sessionStreamIds.put(session.getId(), streamSid);
        sessionCallSids.put(session.getId(), callSid);
        sessionTranscripts.put(session.getId(), new StringBuilder());
        connectToElevenLabs(session);
    }
    
    private void connectToElevenLabs(WebSocketSession twilioSession) {
        try {
            String wsUrl = voiceAiConfig.getElevenlabs().getWsUrl();
            String agentId = voiceAiConfig.getElevenlabs().getAgentId();
            
            // Agent ID MUST be passed as URL parameter, not in init message
            String wsUrlWithAgent = wsUrl + "?agent_id=" + agentId;
            log.debug("Connecting to ElevenLabs with URL: {}", wsUrlWithAgent);
            
            WebSocketClient elevenLabsClient = new WebSocketClient(new URI(wsUrlWithAgent)) {
                @Override
                public void onOpen(ServerHandshake handshake) {
                    log.info("Connected to ElevenLabs Conversational AI for session: {}", twilioSession.getId());
                    initializeConversation(this, agentId);
                }

                @Override
                public void onMessage(String message) {
                    handleElevenLabsMessage(twilioSession, message);
                }

                @Override
                public void onClose(int code, String reason, boolean remote) {
                    log.info("Disconnected from ElevenLabs: {} - {}", code, reason);
                }

                @Override
                public void onError(Exception ex) {
                    log.error("ElevenLabs WebSocket error", ex);
                }
            };

            String apiKey = voiceAiConfig.getElevenlabs().getApiKey();
            elevenLabsClient.addHeader("xi-api-key", apiKey);
            elevenLabsClient.connect();

            elevenLabsClients.put(twilioSession.getId(), elevenLabsClient);

        } catch (Exception e) {
            log.error("Failed to connect to ElevenLabs", e);
        }
    }
    
    private void initializeConversation(WebSocketClient client, String agentId) {
        try {
            // Agent ID is already in the WebSocket URL
            // Configure audio formats: mulaw 8kHz for Twilio compatibility
            ObjectNode initMessage = objectMapper.createObjectNode();
            initMessage.put("type", "conversation_initiation_client_data");
            
            // Request mulaw 8kHz output to match Twilio's expected format
            ObjectNode conversationConfig = initMessage.putObject("conversation_config_override");
            ObjectNode tts = conversationConfig.putObject("tts");
            tts.put("output_format", "ulaw_8000"); // Request mulaw 8kHz for Twilio
            
            String message = objectMapper.writeValueAsString(initMessage);
            log.debug("Initializing ElevenLabs conversation with mulaw 8kHz output: {}", message);
            client.send(message);

        } catch (Exception e) {
            log.error("Error initializing ElevenLabs conversation", e);
        }
    }
    
    @Override
    public void onMediaFrame(WebSocketSession session, String payload) throws Exception {
        JsonNode mediaNode = objectMapper.readTree(payload);
        if (mediaNode.has("media") && mediaNode.get("media").has("payload")) {
            String audioPayload = mediaNode.get("media").get("payload").asText();
            sendAudioToElevenLabs(session, audioPayload);
        }
    }
    
    private void sendAudioToElevenLabs(WebSocketSession twilioSession, String audioBase64) {
        try {
            WebSocketClient client = elevenLabsClients.get(twilioSession.getId());
            if (client != null && client.isOpen()) {
                // Convert Twilio's mulaw 8kHz to PCM 16-bit for ElevenLabs
                String pcmBase64 = convertMulawToPcm(audioBase64);
                
                // ElevenLabs expects: {"user_audio_chunk": "base64_pcm_string"}
                ObjectNode audioMessage = objectMapper.createObjectNode();
                audioMessage.put("user_audio_chunk", pcmBase64);
                
                String message = objectMapper.writeValueAsString(audioMessage);
                log.debug("📣 Sending user audio to ElevenLabs: {} bytes PCM 16kHz (upsampled from {} bytes mulaw 8kHz)", 
                         pcmBase64.length(), audioBase64.length());
                client.send(message);
                log.debug("✅ User audio sent to ElevenLabs successfully");
            } else {
                log.warn("⚠️ Cannot send audio to ElevenLabs: client is null or closed");
            }
        } catch (Exception e) {
            log.error("❌ Error sending audio to ElevenLabs", e);
        }
    }
    
    private String convertMulawToPcm(String mulawBase64) {
        try {
            // Decode mulaw bytes from base64
            byte[] mulawBytes = Base64.getDecoder().decode(mulawBase64);
            
            // Step 1: Decode mulaw to 16-bit PCM samples (8kHz)
            short[] pcm8kSamples = new short[mulawBytes.length];
            for (int i = 0; i < mulawBytes.length; i++) {
                pcm8kSamples[i] = mulawToLinear(mulawBytes[i]);
            }
            
            // Step 2: Upsample from 8kHz to 16kHz (duplicate each sample)
            // ElevenLabs expects 16kHz, so we need to double the sample rate
            short[] pcm16kSamples = new short[pcm8kSamples.length * 2];
            for (int i = 0; i < pcm8kSamples.length; i++) {
                pcm16kSamples[i * 2] = pcm8kSamples[i];
                pcm16kSamples[i * 2 + 1] = pcm8kSamples[i];
            }
            
            // Step 3: Convert to byte array (little-endian 16-bit)
            byte[] pcmBytes = new byte[pcm16kSamples.length * 2];
            for (int i = 0; i < pcm16kSamples.length; i++) {
                pcmBytes[i * 2] = (byte) (pcm16kSamples[i] & 0xFF);        // Low byte
                pcmBytes[i * 2 + 1] = (byte) ((pcm16kSamples[i] >> 8) & 0xFF); // High byte
            }
            
            // Encode to base64
            return Base64.getEncoder().encodeToString(pcmBytes);
            
        } catch (Exception e) {
            log.error("Error converting mulaw to PCM", e);
            return mulawBase64; // Fallback to original
        }
    }
    
    private short mulawToLinear(byte mulawByte) {
        // G.711 mu-law to linear PCM decoding
        int mulaw = ~mulawByte & 0xFF;
        
        // Extract sign, exponent, and mantissa
        int sign = (mulaw & 0x80) >> 7;
        int exponent = (mulaw & 0x70) >> 4;
        int mantissa = mulaw & 0x0F;
        
        // Calculate linear value
        int linear = ((mantissa << 3) + 0x84) << exponent;
        linear = linear - 0x84;
        
        // Apply sign
        if (sign == 1) {
            linear = -linear;
        }
        
        return (short) linear;
    }
    
    private void handleElevenLabsMessage(WebSocketSession twilioSession, String message) {
        try {
            JsonNode response = objectMapper.readTree(message);
            String type = response.has("type") ? response.get("type").asText() : "";

            log.debug("Received ElevenLabs event: {}", type);
            
            // Log full message for audio events to debug structure
            if ("audio".equals(type) || "conversation_initiation_metadata".equals(type)) {
                log.debug("🔍 Full {} message: {}", type, message);
            }
            
            // Capture conversation_id from initialization metadata
            if ("conversation_initiation_metadata".equals(type) && response.has("conversation_initiation_metadata_event")) {
                JsonNode metadataEvent = response.get("conversation_initiation_metadata_event");
                if (metadataEvent.has("conversation_id")) {
                    String conversationId = metadataEvent.get("conversation_id").asText();
                    String callSid = sessionCallSids.get(twilioSession.getId());
                    if (callSid != null) {
                        conversationIdToCallSid.put(conversationId, callSid);
                        conversationIdToSessionId.put(conversationId, twilioSession.getId());
                        log.info("📝 Stored conversation mapping: {} → callSid: {}", conversationId, callSid);
                    }
                }
            }

            // Handle audio chunks from ElevenLabs (PCM 16kHz format)
            if ("audio".equals(type) && response.has("audio_event")) {
                JsonNode audioEvent = response.get("audio_event");
                if (audioEvent.has("audio_base_64")) {
                    String audioBase64 = audioEvent.get("audio_base_64").asText();
                    log.debug("Received audio from ElevenLabs, length: {} bytes", audioBase64.length());
                    sendAudioToTwilio(twilioSession, audioBase64);
                }
            }
            
            // Handle user transcript
            if ("user_transcript".equals(type) && response.has("user_transcript")) {
                String transcript = response.get("user_transcript").asText();
                log.info("User said: {}", transcript);
                appendTranscript(twilioSession.getId(), "user", transcript);
            }
            
            // Handle agent transcript
            if ("agent_response".equals(type) && response.has("agent_response")) {
                String transcript = response.get("agent_response").asText();
                log.info("Agent said: {}", transcript);
                appendTranscript(twilioSession.getId(), "assistant", transcript);
            }
            
            // Handle interruption events
            if ("interruption".equals(type)) {
                log.info("User interrupted the agent");
            }

        } catch (Exception e) {
            log.error("Error processing ElevenLabs message", e);
        }
    }
    
    private void sendAudioToTwilio(WebSocketSession twilioSession, String audioBase64) {
        try {
            if (!twilioSession.isOpen()) {
                log.warn("⚠️ Twilio session is closed, cannot send audio");
                return;
            }
            
            // Get the streamSid - CRITICAL for Twilio to accept audio!
            String streamSid = sessionStreamIds.get(twilioSession.getId());
            if (streamSid == null) {
                log.error("❌ No streamSid found for session {}", twilioSession.getId());
                return;
            }
            
            // Convert PCM 16kHz to mulaw 8kHz for Twilio
            String mulawBase64 = convertPcm16kToMulaw8k(audioBase64);
            
            ObjectNode mediaMessage = objectMapper.createObjectNode();
            mediaMessage.put("event", "media");
            mediaMessage.put("streamSid", streamSid); // REQUIRED by Twilio!
            
            ObjectNode media = objectMapper.createObjectNode();
            media.put("payload", mulawBase64);
            mediaMessage.set("media", media);
            
            String messageJson = objectMapper.writeValueAsString(mediaMessage);
            log.debug("📤 Sending {} bytes of mulaw audio to Twilio (converted from {} bytes PCM, streamSid: {})", 
                     mulawBase64.length(), audioBase64.length(), streamSid);
            
            twilioSession.sendMessage(new TextMessage(messageJson));
            log.debug("✅ Audio sent to Twilio successfully");
            
        } catch (Exception e) {
            log.error("❌ Error sending audio to Twilio", e);
        }
    }
    
    private String convertPcm16kToMulaw8k(String pcmBase64) {
        try {
            // Decode base64 to PCM bytes (16kHz, 16-bit signed, little-endian)
            byte[] pcm16k = java.util.Base64.getDecoder().decode(pcmBase64);
            
            // Downsample from 16kHz to 8kHz (take every other sample)
            int numSamples = pcm16k.length / 2; // 16-bit = 2 bytes per sample
            int outputSamples = numSamples / 2; // Downsample to half
            byte[] mulaw8k = new byte[outputSamples];
            
            for (int i = 0; i < outputSamples; i++) {
                // Read every other 16-bit PCM sample (little-endian)
                int sampleIndex = i * 4; // Skip every other sample (2 bytes * 2)
                if (sampleIndex + 1 >= pcm16k.length) break;
                
                // Read 16-bit sample as little-endian
                int lowByte = pcm16k[sampleIndex] & 0xFF;
                int highByte = pcm16k[sampleIndex + 1];
                int sample = (highByte << 8) | lowByte;
                
                // Convert to mulaw
                mulaw8k[i] = linearToMulaw(sample);
            }
            
            // Encode to base64
            return java.util.Base64.getEncoder().encodeToString(mulaw8k);
            
        } catch (Exception e) {
            log.error("Error converting PCM to mulaw", e);
            return pcmBase64; // Fallback to original
        }
    }
    
    private byte linearToMulaw(int pcmSample) {
        // G.711 mu-law encoding (standard algorithm)
        final int MULAW_BIAS = 0x84;
        final int MULAW_CLIP = 32635;
        
        // Get sign and absolute value
        int sign = (pcmSample >> 8) & 0x80;
        if (sign != 0) {
            pcmSample = -pcmSample;
        }
        
        // Clip to maximum value
        if (pcmSample > MULAW_CLIP) {
            pcmSample = MULAW_CLIP;
        }
        
        // Add bias
        pcmSample = pcmSample + MULAW_BIAS;
        
        // Find exponent (segment) 
        int exponent = 7;
        for (int mask = 0x4000; (pcmSample & mask) == 0 && exponent > 0; exponent--, mask >>= 1);
        
        // Find mantissa
        int mantissa = (pcmSample >> (exponent + 3)) & 0x0F;
        
        // Combine sign, exponent, and mantissa, then invert
        int mulawByte = ~(sign | (exponent << 4) | mantissa);
        
        return (byte) (mulawByte & 0xFF);
    }
    
    private void appendTranscript(String sessionId, String role, String text) {
        StringBuilder transcript = sessionTranscripts.get(sessionId);
        if (transcript != null) {
            if (transcript.length() > 0) {
                transcript.append(", ");
            }
            transcript.append(String.format("{\"role\": \"%s\", \"content\": \"%s\"}", 
                role, text.replace("\"", "\\\"")));
        }
    }
    
    @Override
    public void onMark(WebSocketSession session, Map<String, Object> markData) throws Exception {
        // No special handling needed for ElevenLabs
        log.debug("Mark event received (no action required for ElevenLabs)");
    }
    
    @Override
    public void onClose(WebSocketSession session) throws Exception {
        log.info("ElevenLabs handler - Closing session: {}", session.getId());
        
        WebSocketClient client = elevenLabsClients.remove(session.getId());
        if (client != null && client.isOpen()) {
            client.close();
        }
        
        sessionTranscripts.remove(session.getId());
        sessionCallSids.remove(session.getId());
        
        // Clean up conversation mappings (Note: we keep them briefly for callback processing)
        // They will be cleaned up after a delay to allow callback to arrive
    }
    
    @Override
    public boolean supportsStructuredExtraction() {
        return false; // ElevenLabs uses tool callbacks, not direct structured output
    }
    
    @Override
    public Map<String, Object> buildStructuredPayload(String transcript) {
        // ElevenLabs will call back via webhook with structured data
        return new HashMap<>();
    }
    
    public String getTranscript(String sessionId) {
        StringBuilder transcript = sessionTranscripts.get(sessionId);
        if (transcript != null && transcript.length() > 0) {
            return "[" + transcript.toString() + "]";
        }
        return null;
    }
    
    public String getCallSid(String sessionId) {
        return sessionCallSids.get(sessionId);
    }
    
    public String getCallSidByConversationId(String conversationId) {
        return conversationIdToCallSid.get(conversationId);
    }
    
    public String getTranscriptByConversationId(String conversationId) {
        String sessionId = conversationIdToSessionId.get(conversationId);
        if (sessionId != null) {
            return getTranscript(sessionId);
        }
        return null;
    }
}
