cleanup code

This commit is contained in:
Steve Ruiz 2023-11-25 19:05:49 +00:00
parent 25c620a066
commit eb5d48154d
6 changed files with 727 additions and 332 deletions

View File

@ -1,63 +1,86 @@
"use client";
'use client'
import { LiveImageShapeUtil } from "@/components/live-image";
import * as fal from "@fal-ai/serverless-client";
import { Editor, FrameShapeTool, Tldraw, useEditor } from "@tldraw/tldraw";
import { useCallback } from "react";
import { LiveImageTool, MakeLiveButton } from "../components/LiveImageTool";
import {
LiveImageShape,
LiveImageShapeUtil,
} from '@/components/LiveImageShapeUtil'
import * as fal from '@fal-ai/serverless-client'
import {
AssetRecordType,
Editor,
FrameShapeTool,
Tldraw,
useEditor,
} from '@tldraw/tldraw'
import { useCallback, useEffect } from 'react'
import { LiveImageTool, MakeLiveButton } from '../components/LiveImageTool'
fal.config({
requestMiddleware: fal.withProxy({
targetUrl: "/api/fal/proxy",
}),
});
requestMiddleware: fal.withProxy({
targetUrl: '/api/fal/proxy',
}),
})
const shapeUtils = [LiveImageShapeUtil];
const tools = [LiveImageTool];
const shapeUtils = [LiveImageShapeUtil]
const tools = [LiveImageTool]
export default function Home() {
const onEditorMount = (editor: Editor) => {
// @ts-expect-error: patch
editor.isShapeOfType = function (arg, type) {
const shape = typeof arg === "string" ? this.getShape(arg)! : arg;
if (shape.type === "live-image" && type === "frame") {
return true;
}
return shape.type === type;
};
const onEditorMount = (editor: Editor) => {
// @ts-expect-error: patch
editor.isShapeOfType = function (arg, type) {
const shape = typeof arg === 'string' ? this.getShape(arg)! : arg
if (shape.type === 'live-image' && type === 'frame') {
return true
}
return shape.type === type
}
// If there isn't a live image shape, create one
const liveImage = editor.getCurrentPageShapes().find((shape) => {
return shape.type === "live-image";
});
// If there isn't a live image shape, create one
const liveImage = editor.getCurrentPageShapes().find((shape) => {
return shape.type === 'live-image'
})
if (liveImage) {
return;
}
if (liveImage) {
return
}
editor.createShape({
type: "live-image",
x: 120,
y: 180,
props: {
w: 512,
h: 512,
name: "a city skyline",
},
});
};
editor.createShape<LiveImageShape>({
type: 'live-image',
x: 120,
y: 180,
props: {
w: 512,
h: 512,
name: 'a city skyline',
},
})
}
return (
<main className="flex min-h-screen flex-col items-center justify-between">
<div className="fixed inset-0">
<Tldraw
persistenceKey="tldraw-fal"
onMount={onEditorMount}
shapeUtils={shapeUtils}
tools={tools}
shareZone={<MakeLiveButton />}
/>
</div>
</main>
);
return (
<main className="flex min-h-screen flex-col items-center justify-between">
<div className="fixed inset-0">
<Tldraw
persistenceKey="tldraw-fal"
onMount={onEditorMount}
shapeUtils={shapeUtils}
tools={tools}
shareZone={<MakeLiveButton />}
>
<SneakySideEffects />
</Tldraw>
</div>
</main>
)
}
function SneakySideEffects() {
const editor = useEditor()
useEffect(() => {
editor.sideEffects.registerAfterChangeHandler('shape', () => {
editor.emit('update-drawings' as any)
})
}, [editor])
return null
}

View File

@ -0,0 +1,131 @@
import {
SelectionEdge,
TLShapeId,
canonicalizeRotation,
getPointerInfo,
toDomPrecision,
useEditor,
useIsEditing,
useValue,
} from '@tldraw/editor'
import { useCallback, useEffect, useRef } from 'react'
import { FrameLabelInput } from './FrameLabelInput'
import { preventDefault, stopEventPropagation } from '@tldraw/tldraw'
export function FrameHeading({
id,
name,
width,
height,
}: {
id: TLShapeId
name: string
width: number
height: number
}) {
const editor = useEditor()
const pageRotation = useValue(
'shape rotation',
() => canonicalizeRotation(editor.getShapePageTransform(id)!.rotation()),
[editor, id]
)
const isEditing = useIsEditing(id)
const rInput = useRef<HTMLInputElement>(null)
const handlePointerDown = useCallback(
(e: React.PointerEvent) => {
preventDefault(e)
stopEventPropagation(e)
const event = getPointerInfo(e)
console.log('hello')
// If we're editing the frame label, we shouldn't hijack the pointer event
if (editor.getEditingShapeId() === id) return
editor.dispatch({
...event,
type: 'pointer',
name: 'pointer_down',
target: 'shape',
shape: editor.getShape(id)!,
})
},
[editor, id]
)
useEffect(() => {
const el = rInput.current
if (el && isEditing) {
// On iOS, we must focus here
el.focus()
el.select()
requestAnimationFrame(() => {
// On desktop, the input may have lost focus, so try try try again!
if (document.activeElement !== el) {
el.focus()
el.select()
}
})
}
}, [rInput, isEditing])
// rotate right 45 deg
const offsetRotation = pageRotation + Math.PI / 4
const scaledRotation = (offsetRotation * (2 / Math.PI) + 4) % 4
const labelSide: SelectionEdge = (
['top', 'left', 'bottom', 'right'] as const
)[Math.floor(scaledRotation)]
let labelTranslate: string
switch (labelSide) {
case 'top':
labelTranslate = ``
break
case 'right':
labelTranslate = `translate(${toDomPrecision(
width
)}px, 0px) rotate(90deg)`
break
case 'bottom':
labelTranslate = `translate(${toDomPrecision(width)}px, ${toDomPrecision(
height
)}px) rotate(180deg)`
break
case 'left':
labelTranslate = `translate(0px, ${toDomPrecision(
height
)}px) rotate(270deg)`
break
}
return (
<div
className="tl-frame-heading"
style={{
overflow: isEditing ? 'visible' : 'hidden',
maxWidth: `calc(var(--tl-zoom) * ${
labelSide === 'top' || labelSide === 'bottom'
? Math.ceil(width)
: Math.ceil(height)
}px + var(--space-5))`,
bottom: '100%',
transform: `${labelTranslate} scale(var(--tl-scale)) translateX(calc(-1 * var(--space-3))`,
}}
onPointerDown={handlePointerDown}
>
<div className="tl-frame-heading-hit-area">
<FrameLabelInput
ref={rInput}
id={id}
name={name}
isEditing={isEditing}
/>
</div>
</div>
)
}

View File

@ -0,0 +1,98 @@
import {
TLFrameShape,
TLShapeId,
stopEventPropagation,
useEditor,
} from '@tldraw/editor'
import { forwardRef, useCallback } from 'react'
export const FrameLabelInput = forwardRef<
HTMLInputElement,
{ id: TLShapeId; name: string; isEditing: boolean }
>(function FrameLabelInput({ id, name, isEditing }, ref) {
const editor = useEditor()
const handleKeyDown = useCallback(
(e: React.KeyboardEvent<HTMLInputElement>) => {
if (e.key === 'Enter' && !e.nativeEvent.isComposing) {
// need to prevent the enter keydown making it's way up to the Idle state
// and sending us back into edit mode
stopEventPropagation(e)
e.currentTarget.blur()
editor.setEditingShape(null)
}
},
[editor]
)
const handleBlur = useCallback(
(e: React.FocusEvent<HTMLInputElement>) => {
const shape = editor.getShape<TLFrameShape>(id)
if (!shape) return
const name = shape.props.name
const value = e.currentTarget.value.trim()
if (name === value) return
editor.updateShapes(
[
{
id,
type: 'frame',
props: { name: value },
},
],
{ squashing: true }
)
},
[id, editor]
)
const handleChange = useCallback(
(e: React.ChangeEvent<HTMLInputElement>) => {
const shape = editor.getShape<TLFrameShape>(id)
if (!shape) return
const name = shape.props.name
const value = e.currentTarget.value
if (name === value) return
editor.updateShapes(
[
{
id,
type: 'frame',
props: { name: value },
},
],
{ squashing: true }
)
},
[id, editor]
)
return (
<div
className={`tl-frame-label ${isEditing ? 'tl-frame-label__editing' : ''}`}
>
<input
className="tl-frame-name-input"
ref={ref}
style={{ display: isEditing ? undefined : 'none' }}
value={name}
autoFocus
onKeyDown={handleKeyDown}
onBlur={handleBlur}
onChange={handleChange}
/>
{defaultEmptyAs(name, 'Frame') + String.fromCharCode(8203)}
</div>
)
})
export function defaultEmptyAs(str: string, dflt: string) {
if (str.match(/^\s*$/)) {
return dflt
}
return str
}

View File

@ -0,0 +1,233 @@
/* eslint-disable @next/next/no-img-element */
/* eslint-disable react-hooks/rules-of-hooks */
import {
AssetRecordType,
canonicalizeRotation,
FrameShapeUtil,
Geometry2d,
getDefaultColorTheme,
getHashForObject,
getSvgAsImage,
HTMLContainer,
IdOf,
Rectangle2d,
resizeBox,
SelectionEdge,
ShapeUtil,
SVGContainer,
TLAsset,
TLBaseShape,
TLGroupShape,
TLOnResizeEndHandler,
TLOnResizeHandler,
TLShape,
TLShapeId,
toDomPrecision,
useEditor,
useIsDarkMode,
} from '@tldraw/tldraw'
import { blobToDataUri } from '@/utils/blob'
import { debounce } from '@/utils/debounce'
import * as fal from '@fal-ai/serverless-client'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import result from 'postcss/lib/result'
import { FrameHeading } from './FrameHeading'
import image from 'next/image'
import { connect } from 'http2'
import { useFal } from '@/hooks/useFal'
// See https://www.fal.ai/models/latent-consistency-sd
type Input = {
prompt: string
image_url: string
sync_mode: boolean
seed: number
strength?: number
guidance_scale?: number
num_inference_steps?: number
enable_safety_checks?: boolean
}
type Output = {
images: Array<{
url: string
width: number
height: number
}>
seed: number
num_inference_steps: number
}
export type LiveImageShape = TLBaseShape<
'live-image',
{
w: number
h: number
name: string
}
>
export class LiveImageShapeUtil extends ShapeUtil<LiveImageShape> {
static type = 'live-image' as any
override canBind = () => true
override canEdit = () => true
getDefaultProps() {
return {
w: 512,
h: 512,
name: 'a city skyline',
}
}
override getGeometry(shape: LiveImageShape): Geometry2d {
return new Rectangle2d({
width: shape.props.w,
height: shape.props.h,
isFilled: false,
})
}
canUnmount = () => false
indicator(shape: LiveImageShape) {
const bounds = this.editor.getShapeGeometry(shape).bounds
return (
<rect
width={toDomPrecision(bounds.width)}
height={toDomPrecision(bounds.height)}
className={`tl-frame-indicator`}
/>
)
}
override component(shape: LiveImageShape) {
const editor = useEditor()
useFal(shape.id, {
debounceTime: 0,
url: 'wss://110602490-lcm-sd15-i2i.gateway.alpha.fal.ai/ws',
})
const bounds = this.editor.getShapeGeometry(shape).bounds
const assetId = AssetRecordType.createId(shape.id.split(':')[1])
const asset = editor.getAsset(assetId)
// eslint-disable-next-line react-hooks/rules-of-hooks
const theme = getDefaultColorTheme({ isDarkMode: useIsDarkMode() })
return (
<>
<SVGContainer>
<rect
className={'tl-frame__body'}
width={bounds.width}
height={bounds.height}
fill={theme.solid}
stroke={theme.text}
/>
</SVGContainer>
<FrameHeading
id={shape.id}
name={shape.props.name}
width={bounds.width}
height={bounds.height}
/>
{asset && (
<img
src={asset.props.src!}
alt={shape.props.name}
width={shape.props.w}
height={shape.props.h}
style={{
position: 'relative',
left: shape.props.w,
width: shape.props.w,
height: shape.props.h,
}}
/>
)}
</>
)
}
override canReceiveNewChildrenOfType = (
shape: TLShape,
_type: TLShape['type']
) => {
return !shape.isLocked
}
providesBackgroundForChildren(): boolean {
return true
}
override canDropShapes = (
shape: LiveImageShape,
_shapes: TLShape[]
): boolean => {
return !shape.isLocked
}
override onDragShapesOver = (
frame: LiveImageShape,
shapes: TLShape[]
): { shouldHint: boolean } => {
if (!shapes.every((child) => child.parentId === frame.id)) {
this.editor.reparentShapes(
shapes.map((shape) => shape.id),
frame.id
)
return { shouldHint: true }
}
return { shouldHint: false }
}
override onDragShapesOut = (
_shape: LiveImageShape,
shapes: TLShape[]
): void => {
const parent = this.editor.getShape(_shape.parentId)
const isInGroup =
parent && this.editor.isShapeOfType<TLGroupShape>(parent, 'group')
// If frame is in a group, keep the shape
// moved out in that group
if (isInGroup) {
this.editor.reparentShapes(shapes, parent.id)
} else {
this.editor.reparentShapes(shapes, this.editor.getCurrentPageId())
}
}
override onResizeEnd: TLOnResizeEndHandler<LiveImageShape> = (shape) => {
const bounds = this.editor.getShapePageBounds(shape)!
const children = this.editor.getSortedChildIdsForParent(shape.id)
const shapesToReparent: TLShapeId[] = []
for (const childId of children) {
const childBounds = this.editor.getShapePageBounds(childId)!
if (!bounds.includes(childBounds)) {
shapesToReparent.push(childId)
}
}
if (shapesToReparent.length > 0) {
this.editor.reparentShapes(
shapesToReparent,
this.editor.getCurrentPageId()
)
}
}
override onResize: TLOnResizeHandler<any> = (shape, info) => {
return resizeBox(shape, info)
}
}

View File

@ -1,280 +0,0 @@
/* eslint-disable @next/next/no-img-element */
/* eslint-disable react-hooks/rules-of-hooks */
import {
canonicalizeRotation,
FrameShapeUtil,
getDefaultColorTheme,
getSvgAsImage,
HTMLContainer,
SelectionEdge,
TLEventMapHandler,
TLFrameShape,
TLShape,
useEditor,
} from "@tldraw/tldraw";
import { blobToDataUri } from "@/utils/blob";
import { debounce } from "@/utils/debounce";
import * as fal from "@fal-ai/serverless-client";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import result from "postcss/lib/result";
// See https://www.fal.ai/models/latent-consistency-sd
const DEBOUNCE_TIME = 0.0; // Adjust as needed
const URL = "wss://110602490-lcm-sd15-i2i.gateway.alpha.fal.ai/ws";
type Input = {
prompt: string;
image_url: string;
sync_mode: boolean;
seed: number;
strength?: number;
guidance_scale?: number;
num_inference_steps?: number;
enable_safety_checks?: boolean;
};
type Output = {
images: Array<{
url: string;
width: number;
height: number;
}>;
seed: number;
num_inference_steps: number;
};
export class LiveImageShapeUtil extends FrameShapeUtil {
static override type = "live-image" as any;
override getDefaultProps(): { w: number; h: number; name: string } {
return {
w: 512,
h: 512,
name: "a city skyline",
};
}
override canUnmount = () => false;
override toSvg(shape: TLFrameShape) {
const theme = getDefaultColorTheme({
isDarkMode: this.editor.user.getIsDarkMode(),
});
const g = document.createElementNS("http://www.w3.org/2000/svg", "g");
const rect = document.createElementNS("http://www.w3.org/2000/svg", "rect");
rect.setAttribute("width", shape.props.w.toString());
rect.setAttribute("height", shape.props.h.toString());
rect.setAttribute("fill", theme.solid);
g.appendChild(rect);
return g;
}
override component(shape: TLFrameShape) {
const editor = useEditor();
const component = super.component(shape);
const [image, setImage] = useState<string | null>(null);
const imageDigest = useRef<string | null>(null);
const startedIteration = useRef<number>(0);
const finishedIteration = useRef<number>(0);
//===== SOCKET =====//
const webSocketRef = useRef<WebSocket | null>(null);
const isReconnecting = useRef(false);
const connect = useCallback(() => {
webSocketRef.current = new WebSocket(URL);
webSocketRef.current.onopen = () => {
// console.log("WebSocket Open");
};
webSocketRef.current.onclose = () => {
// console.log("WebSocket Close");
};
webSocketRef.current.onerror = (error) => {
// console.error("WebSocket Error:", error);
};
webSocketRef.current.onmessage = (message) => {
try {
const data = JSON.parse(message.data);
// console.log("WebSocket Message:", data);
if (data.images && data.images.length > 0) {
setImage(data.images[0].url);
}
} catch (e) {
console.error("Error parsing the WebSocket response:", e);
}
};
}, []);
const disconnect = useCallback(() => {
webSocketRef.current?.close();
}, []);
const sendMessage = useCallback(
async (message: string) => {
if (
!isReconnecting.current &&
webSocketRef.current?.readyState !== WebSocket.OPEN
) {
isReconnecting.current = true;
connect();
}
if (
isReconnecting.current &&
webSocketRef.current?.readyState !== WebSocket.OPEN
) {
await new Promise<void>((resolve) => {
const checkConnection = setInterval(() => {
if (webSocketRef.current?.readyState === WebSocket.OPEN) {
clearInterval(checkConnection);
resolve();
}
}, 100);
});
isReconnecting.current = false;
}
webSocketRef.current?.send(message);
},
[connect]
);
const sendCurrentData = useMemo(() => {
return debounce(sendMessage, DEBOUNCE_TIME);
}, [sendMessage]);
//===========//
// eslint-disable-next-line react-hooks/exhaustive-deps
const onDrawingChange = useCallback(
debounce(async () => {
// TODO get actual drawing bounds
// const bounds = new Box2d(120, 180, 512, 512);
const iteration = startedIteration.current++;
const shapes = Array.from(
editor.getShapeAndDescendantIds([shape.id])
).map((id) => editor.getShape(id)) as TLShape[];
// Check if should submit request
const shapesDigest = JSON.stringify(shapes);
if (shapesDigest === imageDigest.current) {
return;
}
imageDigest.current = shapesDigest;
const svg = await editor.getSvg([shape], {
background: true,
padding: 0,
darkMode: editor.user.getIsDarkMode(),
});
if (iteration <= finishedIteration.current) return;
if (!svg) {
return;
}
const image = await getSvgAsImage(svg, editor.environment.isSafari, {
type: "png",
quality: 1,
scale: 1,
});
if (iteration <= finishedIteration.current) return;
if (!image) {
return;
}
const prompt =
editor.getShape<TLFrameShape>(shape.id)?.props.name ?? "";
const imageDataUri = await blobToDataUri(image);
const request = {
image_url: imageDataUri,
prompt,
sync_mode: true,
strength: 0.7,
seed: 42, // TODO make this configurable in the UI
enable_safety_checks: false,
};
sendCurrentData(JSON.stringify(request));
if (iteration <= finishedIteration.current) return;
// const result = await fal.run<Input, Output>(LatentConsistency, {
// input: {
// image_url: imageDataUri,
// prompt,
// sync_mode: true,
// strength: 0.6,
// seed: 42, // TODO make this configurable in the UI
// enable_safety_checks: false,
// },
// // Disable auto-upload so we can submit the data uri of the image as is
// autoUpload: true,
// });
if (iteration <= finishedIteration.current) return;
finishedIteration.current = iteration;
// if (result && result.images.length > 0) {
// setImage(result.images[0].url);
// }
}, 0),
[]
);
useEffect(() => {
const onChange: TLEventMapHandler<"change"> = (event) => {
if (event.source !== "user") {
return;
}
if (
Object.keys(event.changes.added).length ||
Object.keys(event.changes.removed).length ||
Object.keys(event.changes.updated).length
) {
onDrawingChange();
}
};
editor.addListener("change", onChange);
return () => {
editor.removeListener("change", onChange);
};
}, [editor, onDrawingChange]);
return (
<HTMLContainer>
<div
style={{
display: "flex",
}}
>
{component}
{image && (
<img
src={image}
alt=""
width={shape.props.w}
height={shape.props.h}
style={{
position: "relative",
left: shape.props.w,
width: shape.props.w,
height: shape.props.h,
}}
/>
)}
</div>
</HTMLContainer>
);
}
}

190
src/hooks/useFal.ts Normal file
View File

@ -0,0 +1,190 @@
import { LiveImageShape } from '@/components/LiveImageShapeUtil'
import { blobToDataUri } from '@/utils/blob'
import {
AssetRecordType,
TLShape,
TLShapeId,
debounce,
getHashForObject,
getSvgAsImage,
throttle,
useEditor,
} from '@tldraw/tldraw'
import { useRef, useEffect } from 'react'
export function useFal(
shapeId: TLShapeId,
opts: {
debounceTime?: number
throttleTime?: number
url: string
}
) {
const { url, throttleTime = 500, debounceTime = 0 } = opts
const editor = useEditor()
const startedIteration = useRef<number>(0)
const finishedIteration = useRef<number>(0)
const prevHash = useRef<string | null>(null)
useEffect(() => {
let socket: WebSocket | null = null
let isReconnecting = false
function updateImage(url: string | null) {
const shape = editor.getShape<LiveImageShape>(shapeId)!
const id = AssetRecordType.createId(shape.id.split(':')[1])
const asset = editor.getAsset(id)
if (!asset) {
editor.createAssets([
AssetRecordType.create({
id,
type: 'image',
props: {
name: shape.props.name,
w: shape.props.w,
h: shape.props.h,
src: url,
isAnimated: false,
mimeType: 'image/jpeg',
},
}),
])
} else {
editor.updateAssets([
{
...asset,
type: 'image',
props: {
...asset.props,
w: shape.props.w,
h: shape.props.h,
src: url,
},
},
])
}
}
async function connect() {
{
socket = new WebSocket(url)
socket.onopen = () => {
// console.log("WebSocket Open");
}
socket.onclose = () => {
// console.log("WebSocket Close");
}
socket.onerror = (error) => {
// console.error("WebSocket Error:", error);
}
socket.onmessage = (message) => {
try {
const data = JSON.parse(message.data)
// console.log("WebSocket Message:", data);
if (data.images && data.images.length > 0) {
updateImage(data.images[0].url ?? '')
}
} catch (e) {
console.error('Error parsing the WebSocket response:', e)
}
}
}
}
async function sendCurrentData(message: string) {
if (!isReconnecting && socket?.readyState !== WebSocket.OPEN) {
isReconnecting = true
connect()
}
if (isReconnecting && socket?.readyState !== WebSocket.OPEN) {
await new Promise<void>((resolve) => {
const checkConnection = setInterval(() => {
if (socket?.readyState === WebSocket.OPEN) {
clearInterval(checkConnection)
resolve()
}
}, 100)
})
isReconnecting = false
}
socket?.send(message)
}
async function updateDrawing() {
const iteration = startedIteration.current++
const shapes = Array.from(editor.getShapeAndDescendantIds([shapeId])).map(
(id) => editor.getShape(id)
) as TLShape[]
const hash = getHashForObject(shapes)
if (hash === prevHash.current) return
const shape = editor.getShape<LiveImageShape>(shapeId)!
const svg = await editor.getSvg([shape], {
background: true,
padding: 0,
darkMode: editor.user.getIsDarkMode(),
})
if (!svg) {
updateImage('')
return
}
// We might be stale
if (iteration <= finishedIteration.current) return
const image = await getSvgAsImage(svg, editor.environment.isSafari, {
type: 'png',
quality: 1,
scale: 1,
})
if (!image) {
updateImage('')
return
}
// We might be stale
if (iteration <= finishedIteration.current) return
const prompt = shape.props.name
? shape.props.name + ' hd award-winning impressive'
: 'A random image that is safe for work and not surprising—something boring like a city or shoe watercolor'
const imageDataUri = await blobToDataUri(image)
const request = {
image_url: imageDataUri,
prompt,
sync_mode: true,
strength: 0.7,
seed: 42, // TODO make this configurable in the UI
enable_safety_checks: false,
}
// We might be stale
if (iteration <= finishedIteration.current) return
sendCurrentData(JSON.stringify(request))
finishedIteration.current = iteration
}
const onDrawingChange = debounceTime
? debounce(updateDrawing, debounceTime)
: throttleTime
? throttle(updateDrawing, throttleTime)
: debounce(updateDrawing, 16)
editor.on('update-drawings' as any, onDrawingChange)
return () => {
editor.off('update-drawings' as any, onDrawingChange)
}
}, [editor, shapeId, throttleTime, debounceTime, url])
}