@@ -9,6 +9,7 @@ import type {
99 Dimensions ,
1010 Edge ,
1111 EdgeMarkerType ,
12+ Element ,
1213 ElementData ,
1314 Elements ,
1415 FlowElements ,
@@ -151,24 +152,63 @@ export function parseEdge(edge: Edge, defaults: Partial<GraphEdge> = {}): GraphE
151152 return Object . assign ( { } , defaults , edge , { id : edge . id . toString ( ) } ) as GraphEdge
152153}
153154
154- function getConnectedElements < T extends Elements = FlowElements > (
155+ function getConnectedElements < T extends Node = Node > (
155156 nodeOrId : Node | { id : string } | string ,
156- elements : T ,
157+ nodes : T [ ] ,
158+ edges : Edge [ ] ,
157159 dir : 'source' | 'target' ,
158- ) : T extends FlowElements ? GraphNode [ ] : Node [ ] {
160+ ) : T [ ] {
159161 const id = isString ( nodeOrId ) ? nodeOrId : nodeOrId . id
160162
163+ const connectedIds = new Set ( )
164+
161165 const origin = dir === 'source' ? 'target' : 'source'
162- const ids = elements . filter ( ( e ) => isEdge ( e ) && e [ origin ] === id ) . map ( ( e ) => isEdge ( e ) && e [ dir ] )
163166
164- return elements . filter ( ( e ) => ids . includes ( e . id ) ) as T extends FlowElements ? GraphNode [ ] : Node [ ]
167+ edges . forEach ( ( edge ) => {
168+ if ( edge [ origin ] === id ) {
169+ connectedIds . add ( edge [ dir ] )
170+ }
171+ } )
172+
173+ return nodes . filter ( ( n ) => connectedIds . has ( n . id ) )
165174}
166- export function getOutgoers < T extends Elements = FlowElements > ( nodeOrId : Node | { id : string } | string , elements : T ) {
167- return getConnectedElements ( nodeOrId , elements , 'target' )
175+
176+ export function getOutgoers < N extends Node > ( nodeOrId : Node | { id : string } | string , nodes : N [ ] , edges : Edge [ ] ) : N [ ]
177+ export function getOutgoers < T extends Elements > (
178+ nodeOrId : Node | { id : string } | string ,
179+ elements : T ,
180+ ) : T extends FlowElements ? GraphNode [ ] : Node [ ]
181+ export function getOutgoers ( ...args : any [ ] ) {
182+ if ( args . length === 3 ) {
183+ const [ nodeOrId , nodes , edges ] = args
184+ return getConnectedElements ( nodeOrId , nodes , edges , 'target' )
185+ }
186+
187+ const [ nodeOrId , elements ] = args
188+ const node : Node = isString ( nodeOrId ) ? { id : nodeOrId } : nodeOrId
189+
190+ const outgoers = elements . filter ( ( el : Element ) => isEdge ( el ) && el . source === node . id )
191+
192+ return outgoers . map ( ( edge : Edge ) => elements . find ( ( el : Element ) => isNode ( el ) && el . id === edge . target ) )
168193}
169194
170- export function getIncomers < T extends Elements = FlowElements > ( nodeOrId : Node | { id : string } | string , elements : T ) {
171- return getConnectedElements ( nodeOrId , elements , 'source' )
195+ export function getIncomers < N extends Node > ( nodeOrId : Node | { id : string } | string , nodes : N [ ] , edges : Edge [ ] ) : N [ ]
196+ export function getIncomers < T extends Elements > (
197+ nodeOrId : Node | { id : string } | string ,
198+ elements : T ,
199+ ) : T extends FlowElements ? GraphNode [ ] : Node [ ]
200+ export function getIncomers ( ...args : any [ ] ) {
201+ if ( args . length === 3 ) {
202+ const [ nodeOrId , nodes , edges ] = args
203+ return getConnectedElements ( nodeOrId , nodes , edges , 'source' )
204+ }
205+
206+ const [ nodeOrId , elements ] = args
207+ const node : Node = isString ( nodeOrId ) ? { id : nodeOrId } : nodeOrId
208+
209+ const incomers = elements . filter ( ( el : Element ) => isEdge ( el ) && el . target === node . id )
210+
211+ return incomers . map ( ( edge : Edge ) => elements . find ( ( el : Element ) => isNode ( el ) && el . id === edge . source ) )
172212}
173213
174214export function getEdgeId ( { source, sourceHandle, target, targetHandle } : Connection ) {
@@ -364,26 +404,34 @@ export function getNodesInside(
364404 } )
365405}
366406
367- export function getConnectedEdges < N extends Node | { id : string } | string , E extends Edge > ( nodes : N [ ] , edges : E [ ] ) {
368- const nodeIds = nodes . map ( ( node ) => ( isString ( node ) ? node : node . id ) )
407+ export function getConnectedEdges < E extends Edge > ( nodesOrId : Node [ ] | string , edges : E [ ] ) {
408+ const nodeIds = new Set ( )
369409
370- return edges . filter ( ( edge ) => nodeIds . includes ( edge . source ) || nodeIds . includes ( edge . target ) )
410+ if ( isString ( nodesOrId ) ) {
411+ nodeIds . add ( nodesOrId )
412+ } else if ( nodesOrId . length >= 1 ) {
413+ nodesOrId . forEach ( ( n ) => nodeIds . add ( n . id ) )
414+ }
415+
416+ return edges . filter ( ( edge ) => nodeIds . has ( edge . source ) || nodeIds . has ( edge . target ) )
371417}
372418
373- export function getConnectedNodes < N extends Node | { id : string } | string , E extends Edge > ( nodes : N [ ] , edges : E [ ] ) {
374- const nodeIds = nodes . map ( ( node ) => ( isString ( node ) ? node : node . id ) )
419+ export function getConnectedNodes < N extends Node | { id : string } | string > ( nodes : N [ ] , edges : Edge [ ] ) {
420+ const nodeIds = new Set ( )
421+
422+ nodes . forEach ( ( node ) => nodeIds . add ( isString ( node ) ? node : node . id ) )
375423
376424 const connectedNodeIds = edges . reduce ( ( acc , edge ) => {
377- if ( nodeIds . includes ( edge . source ) ) {
425+ if ( nodeIds . has ( edge . source ) ) {
378426 acc . add ( edge . target )
379427 }
380428
381- if ( nodeIds . includes ( edge . target ) ) {
429+ if ( nodeIds . has ( edge . target ) ) {
382430 acc . add ( edge . source )
383431 }
384432
385433 return acc
386- } , new Set < string > ( ) )
434+ } , new Set ( ) )
387435
388436 return nodes . filter ( ( node ) => connectedNodeIds . has ( isString ( node ) ? node : node . id ) )
389437}
0 commit comments