canvas-website/src/propagators/tlgraph.ts

115 lines
3.4 KiB
TypeScript

import { isShapeOfType } from "@/propagators/utils";
import { Editor, TLArrowBinding, TLArrowShape, TLShape, TLShapeId } from "tldraw";
export interface Edge {
arrowId: TLShapeId
from: TLShapeId
to: TLShapeId
text?: string
}
export interface Graph {
nodes: TLShapeId[]
edges: Edge[]
}
export function getEdge(shape: TLShape | undefined, editor: Editor): Edge | undefined {
if (!shape || !isShapeOfType<TLArrowShape>(shape, 'arrow')) return undefined
const bindings = editor.getBindingsInvolvingShape<TLArrowBinding>(shape.id)
if (!bindings || bindings.length !== 2) return undefined
if (bindings[0].props.terminal === "end") {
return {
arrowId: shape.id,
from: bindings[1].toId,
to: bindings[0].toId,
text: shape.props.text
}
}
return {
arrowId: shape.id,
from: bindings[0].toId,
to: bindings[1].toId,
text: shape.props.text
}
}
/**
* Returns the graph(s) of edges and nodes from a list of shapes
*/
export function getGraph(shapes: TLShape[], editor: Editor): Graph {
const nodes: Set<TLShapeId> = new Set<TLShapeId>()
const edges: Edge[] = []
for (const shape of shapes) {
const edge = getEdge(shape, editor)
if (edge) {
edges.push({
arrowId: edge.arrowId,
from: edge.from,
to: edge.to,
text: edge.text
})
nodes.add(edge.from)
nodes.add(edge.to)
}
}
return { nodes: Array.from(nodes), edges }
}
/**
* Returns the start and end nodes of a topologically sorted graph
*/
export function sortGraph(graph: Graph): { startNodes: TLShapeId[], endNodes: TLShapeId[] } {
const targetNodes = new Set<TLShapeId>(graph.edges.map(e => e.to));
const sourceNodes = new Set<TLShapeId>(graph.edges.map(e => e.from));
const startNodes = [];
const endNodes = [];
for (const node of graph.nodes) {
if (sourceNodes.has(node) && !targetNodes.has(node)) {
startNodes.push(node);
} else if (targetNodes.has(node) && !sourceNodes.has(node)) {
endNodes.push(node);
}
}
return { startNodes, endNodes };
}
/**
* Returns the arrows starting from the given shape
*/
export function getArrowsFromShape(editor: Editor, shapeId: TLShapeId): TLShapeId[] {
const bindings = editor.getBindingsToShape<TLArrowBinding>(shapeId, 'arrow')
return bindings.filter(edge => edge.props.terminal === 'start').map(edge => edge.fromId)
}
/**
* Returns the arrows ending at the given shape
*/
export function getArrowsToShape(editor: Editor, shapeId: TLShapeId): TLShapeId[] {
const bindings = editor.getBindingsToShape<TLArrowBinding>(shapeId, 'arrow')
return bindings.filter(edge => edge.props.terminal === 'end').map(edge => edge.fromId)
}
/**
* Returns the arrows which share the same start shape as the given arrow
*/
export function getSiblingArrowIds(editor: Editor, arrow: TLShape): TLShapeId[] {
if (arrow.type !== 'arrow') return [];
const bindings = editor.getBindingsInvolvingShape<TLArrowBinding>(arrow.id);
if (!bindings || bindings.length !== 2) return [];
const startShapeId = bindings.find(binding => binding.props.terminal === 'start')?.toId;
if (!startShapeId) return [];
const siblingBindings = editor.getBindingsToShape<TLArrowBinding>(startShapeId, 'arrow');
const siblingArrows = siblingBindings
.filter(binding => binding.props.terminal === 'start' && binding.fromId !== arrow.id)
.map(binding => binding.fromId);
return siblingArrows;
}