// usePromptManagementWebSocket.ts
import { useCallback, useEffect, useRef } from "react";
import { DEV_FAST_API_WS_PROMPT_MANAGEMENT } from "@/vendor/config.ts";
import { supabase } from "@/vendor/supabaseClient.ts";

export enum PromptManagementMessageType {
  RUN_EVAL = "RUN_EVAL",
}

export interface WebSocketPromptManagementMessage<T = any> {
  messageType: PromptManagementMessageType;
  messageObject: T;
}

interface WebSocketHook {
  send: <T>(messageType: PromptManagementMessageType, messageObject: T) => void;
  onReceive: (messageType: PromptManagementMessageType, handler: (message: any) => void) => void;
  isConnected: boolean;
}

const usePromptManagementWebSocket = (): WebSocketHook => {
  const ws = useRef<WebSocket | null>(null);
  const messageHandlers = useRef<Map<PromptManagementMessageType, ((message: any) => void)[]>>(
    new Map(),
  );
  const reconnectTimeout = useRef<NodeJS.Timeout>();
  const isConnected = useRef<boolean>(false);

  const connect = useCallback(async () => {
    try {
      let URL =
        import.meta.env.VITE_FAST_API_WS_PROMPT_MANAGEMENT || DEV_FAST_API_WS_PROMPT_MANAGEMENT;

      const { data } = await supabase.auth.getSession();
      let jwtToken: string;
      if (!data || !data.session) {
        return;
      } else {
        jwtToken = data.session?.access_token;
      }
      URL += `?token=${encodeURIComponent(jwtToken)}`;
      ws.current = new WebSocket(URL);

      ws.current.onopen = () => {
        console.log("WebSocket connected");
        isConnected.current = true;
        if (reconnectTimeout.current) {
          clearTimeout(reconnectTimeout.current);
        }
      };

      ws.current.onclose = () => {
        console.log("WebSocket disconnected");
        isConnected.current = false;
        // Attempt to reconnect after 3 seconds
        reconnectTimeout.current = setTimeout(() => {
          connect();
        }, 3000);
      };

      ws.current.onmessage = (event) => {
        try {
          const message: WebSocketPromptManagementMessage = JSON.parse(event.data);
          const handlers = messageHandlers.current.get(message.messageType);

          if (handlers) {
            handlers.forEach((handler) => handler(message));
          }
        } catch (error) {
          console.error("Error parsing WebSocket message:", error);
        }
      };

      ws.current.onerror = (error) => {
        console.error("WebSocket error:", error);
      };
    } catch (error) {
      console.error("Error connecting to WebSocket:", error);
    }
  }, [URL]);

  useEffect(() => {
    connect();

    return () => {
      if (ws.current) {
        ws.current.close();
      }
      if (reconnectTimeout.current) {
        clearTimeout(reconnectTimeout.current);
      }
    };
  }, [connect]);

  const send = useCallback(<T>(messageType: PromptManagementMessageType, messageObject: T) => {
    if (ws.current?.readyState === WebSocket.OPEN) {
      const message: WebSocketPromptManagementMessage<T> = {
        messageType,
        messageObject,
      };
      ws.current.send(JSON.stringify(message));
    } else {
      console.error("WebSocket is not connected");
    }
  }, []);

  const onReceive = useCallback(
    (messageType: PromptManagementMessageType, handler: (message: any) => void) => {
      if (!messageHandlers.current.has(messageType)) {
        messageHandlers.current.set(messageType, []);
      }
      messageHandlers.current.get(messageType)!.push(handler);
    },
    [],
  );

  return {
    send,
    onReceive,
    isConnected: isConnected.current,
  };
};

export default usePromptManagementWebSocket;
