diff --git a/src/hooks/store/useGraph.tsx b/src/hooks/store/useGraph.tsx index 04fee85..b230885 100644 --- a/src/hooks/store/useGraph.tsx +++ b/src/hooks/store/useGraph.tsx @@ -1,8 +1,8 @@ import create from "zustand"; import { EdgeData, NodeData } from "reaflow/dist/types"; import { Graph } from "src/components/Graph"; -import { findEdgeChildren } from "src/utils/findEdgeChildren"; -import { findNodeChildren } from "src/utils/findNodeChildren"; +import { getChildrenEdges } from "src/utils/getChildrenEdges"; +import { getOutgoers } from "src/utils/getOutgoers"; export interface Graph { nodes: NodeData[]; @@ -34,12 +34,8 @@ const useGraph = create((set) => ({ }), expandNodes: (nodeId) => set((state) => { - const childrenNodes = findNodeChildren(nodeId, state.nodes, state.edges); - const childrenEdges = findEdgeChildren( - nodeId, - childrenNodes, - state.edges - ); + const childrenNodes = getOutgoers(nodeId, state.nodes, state.edges); + const childrenEdges = getChildrenEdges(childrenNodes, state.edges); const nodeIds = childrenNodes.map((node) => node.id); const edgeIds = childrenEdges.map((edge) => edge.id); @@ -56,12 +52,8 @@ const useGraph = create((set) => ({ }), collapseNodes: (nodeId) => set((state) => { - const childrenNodes = findNodeChildren(nodeId, state.nodes, state.edges); - const childrenEdges = findEdgeChildren( - nodeId, - childrenNodes, - state.edges - ); + const childrenNodes = getOutgoers(nodeId, state.nodes, state.edges); + const childrenEdges = getChildrenEdges(childrenNodes, state.edges); const nodeIds = childrenNodes.map((node) => node.id); const edgeIds = childrenEdges.map((edge) => edge.id); diff --git a/src/utils/findEdgeChildren.ts b/src/utils/findEdgeChildren.ts deleted file mode 100644 index 86bb8df..0000000 --- a/src/utils/findEdgeChildren.ts +++ /dev/null @@ -1,17 +0,0 @@ -import { NodeData, EdgeData } from "reaflow/dist/types"; - -export const findEdgeChildren = ( - selectedNode: string, - connections: NodeData[], - edges: EdgeData[] -) => { - const nodeIds = connections.map((n) => n.id); - - nodeIds.push(selectedNode); - const newEdges = edges.filter( - (e) => - nodeIds.includes(e.from as string) && nodeIds.includes(e.to as string) - ); - - return newEdges; -}; diff --git a/src/utils/findNodeChildren.ts b/src/utils/findNodeChildren.ts deleted file mode 100644 index e124904..0000000 --- a/src/utils/findNodeChildren.ts +++ /dev/null @@ -1,35 +0,0 @@ -import { NodeData, EdgeData } from "reaflow/dist/types"; - -export const findNodeChildren = ( - selectedNode: string, - nodes: NodeData[], - edges: EdgeData[] -) => { - const toByFrom = {}; - for (const edge of edges) { - if (edge.from) { - toByFrom[edge.from] ??= []; - toByFrom[edge.from].push(edge.to); - } - } - - const getNodes = (parent, allNodesIds: string[] = []) => { - const tos = toByFrom[parent]; - if (tos) { - allNodesIds.push(...tos); - for (const to of tos) { - getNodes(to, allNodesIds); - } - } - return allNodesIds; - }; - - const myNodes = getNodes(selectedNode); - - const findNodes = myNodes.map((id) => { - const node = nodes.find((n) => n.id === id); - return node as NodeData; - }); - - return findNodes; -}; diff --git a/src/utils/getChildrenEdges.ts b/src/utils/getChildrenEdges.ts new file mode 100644 index 0000000..25a1f60 --- /dev/null +++ b/src/utils/getChildrenEdges.ts @@ -0,0 +1,14 @@ +import { NodeData, EdgeData } from "reaflow/dist/types"; + +export const getChildrenEdges = ( + nodes: NodeData[], + edges: EdgeData[] +): EdgeData[] => { + const nodeIds = nodes.map((node) => node.id); + + return edges.filter( + (edge) => + nodeIds.includes(edge.from as string) || + nodeIds.includes(edge.to as string) + ); +}; diff --git a/src/utils/getOutgoers.ts b/src/utils/getOutgoers.ts new file mode 100644 index 0000000..17c0c09 --- /dev/null +++ b/src/utils/getOutgoers.ts @@ -0,0 +1,20 @@ +import { NodeData, EdgeData } from "reaflow/dist/types"; + +export const getOutgoers = ( + nodeId: string, + nodes: NodeData[], + edges: EdgeData[] +): NodeData[] => { + const allOutgoers: NodeData[] = []; + + const runner = (nodeId: string) => { + const outgoerIds = edges.filter((e) => e.from === nodeId).map((e) => e.to); + const nodeList = nodes.filter((n) => outgoerIds.includes(n.id)); + allOutgoers.push(...nodeList); + nodeList.forEach((node) => runner(node.id)); + }; + + runner(nodeId); + + return allOutgoers; +};