import { useMemo } from "react";
import { Edge, Node } from "reactflow";
import {
  HierarchyNode,
  HierarchyPointNode,
  stratify,
  tree,
} from "d3-hierarchy";
/**
 * Interface definition
 */
export type UseExpandCollapseOptions = {
  layoutNodes?: boolean;
  treeWidth?: number;
  treeHeight?: number;
};
/**
 * isHierarchyPointNode function definition section
 * This function helps us to check the node position types (x and y are position values)
 * @param pointNode
 * @returns Boolean
 */
function isHierarchyPointNode(
  pointNode: HierarchyNode<any> | HierarchyPointNode<any>
): pointNode is HierarchyPointNode<any> {
  return (
    typeof (pointNode as HierarchyPointNode<any>).x === "number" &&
    typeof (pointNode as HierarchyPointNode<any>).y === "number"
  );
}

const useExpandCollapse = (
  nodes: Node[],
  edges: Edge[],
  {
    layoutNodes = true,
    treeWidth = 20,
    treeHeight = 10,
  }: UseExpandCollapseOptions = {}
): { nodes: Node[]; edges: Edge[] } => {
  return useMemo(() => {
    /**
     * stratify function from the d3-hierarchy package helps us to turn a flat data structure (like the nodes and edges in React Flow) into a traversable object.
     * d - node and its parent node data
     */
    const hierarchy = stratify<any>()
      .id((d: any) => d.id)
      .parentId((d: any) => edges.find((e: any) => e.target === d.id)?.source)(
      nodes
    );

    hierarchy.descendants().forEach((d: any) => {
      d.data.data.expandable = !!d.children?.length;
      d.children = d.data.data.expanded ? d.children : undefined;
    });
    /**
     * After constructing and adjusting the hierarchy we are using the tree layout function to layout the current nodes
     */
    const layout = tree<any>()
      .nodeSize([treeWidth, treeHeight])
      .separation(() => 1);

    const root = layoutNodes ? layout(hierarchy) : hierarchy;
    /**
     * It returns currently visible nodes and edges and their positions
     */
    return {
      nodes: root.descendants().map((d: any) => ({
        ...d.data,
        type: "custom",
        position: isHierarchyPointNode(d)
          ? { x: d.x, y: d.y }
          : d.data.position,
      })),
      edges: edges.filter(
        (edge: any) =>
          root.find((h) => h.id === edge.source) &&
          root.find((h) => h.id === edge.target)
      ),
    };
  }, [nodes, edges, layoutNodes, treeWidth, treeHeight]);
};

export default useExpandCollapse;
