import { useCallback, useState } from "react";
import { Edge, Node, Position } from "reactflow";
import dagre from "dagre";

export const useLayoutedElements = () => {
  const [redrawKey, setRedrawKey] = useState(Math.random());

  const getLayoutedElements = useCallback(
    (nodes: Node[], edges: Edge[]): { nodes: Node[]; edges: Edge[] } => {
      if (nodes.length === 0 || edges.length === 0 || edges.some((edge) => !edge || !edge.target)) {
        return { nodes, edges };
      }

      const dagreGraph = new dagre.graphlib.Graph();
      dagreGraph.setDefaultEdgeLabel(() => ({}));
      dagreGraph.setGraph({ rankdir: "TB", ranksep: 100 });
      const nodeWidth = 504;
      const nodeHeight = 104;

      nodes.forEach((node) => {
        dagreGraph.setNode(node.id, { width: nodeWidth, height: nodeHeight });
      });

      edges.forEach((edge) => {
        if (edge.source && edge.target) {
          dagreGraph.setEdge(edge.source, edge.target);
        }
      });

      dagre.layout(dagreGraph);

      const layoutedNodes: Node[] = nodes.map((node) => {
        const nodeWithPosition = dagreGraph.node(node.id);
        return {
          ...node,
          targetPosition: Position.Top,
          sourcePosition: Position.Bottom,
          position: {
            x: nodeWithPosition.x - nodeWidth / 2,
            y: nodeWithPosition.y - nodeHeight / 2,
          },
        };
      });

      setRedrawKey(Math.random());
      return { nodes: layoutedNodes, edges };
    },
    [],
  );

  return { getLayoutedElements, redrawKey };
};
