From eb5d48154d6b532f5d4e5311ac596d3e75983ca8 Mon Sep 17 00:00:00 2001 From: Steve Ruiz Date: Sat, 25 Nov 2023 19:05:49 +0000 Subject: [PATCH] cleanup code --- src/app/page.tsx | 127 +++++++----- src/components/FrameHeading.tsx | 131 ++++++++++++ src/components/FrameLabelInput.tsx | 98 +++++++++ src/components/LiveImageShapeUtil.tsx | 233 +++++++++++++++++++++ src/components/live-image.tsx | 280 -------------------------- src/hooks/useFal.ts | 190 +++++++++++++++++ 6 files changed, 727 insertions(+), 332 deletions(-) create mode 100644 src/components/FrameHeading.tsx create mode 100644 src/components/FrameLabelInput.tsx create mode 100644 src/components/LiveImageShapeUtil.tsx delete mode 100644 src/components/live-image.tsx create mode 100644 src/hooks/useFal.ts diff --git a/src/app/page.tsx b/src/app/page.tsx index 0bb725c..3d01545 100644 --- a/src/app/page.tsx +++ b/src/app/page.tsx @@ -1,63 +1,86 @@ -"use client"; +'use client' -import { LiveImageShapeUtil } from "@/components/live-image"; -import * as fal from "@fal-ai/serverless-client"; -import { Editor, FrameShapeTool, Tldraw, useEditor } from "@tldraw/tldraw"; -import { useCallback } from "react"; -import { LiveImageTool, MakeLiveButton } from "../components/LiveImageTool"; +import { + LiveImageShape, + LiveImageShapeUtil, +} from '@/components/LiveImageShapeUtil' +import * as fal from '@fal-ai/serverless-client' +import { + AssetRecordType, + Editor, + FrameShapeTool, + Tldraw, + useEditor, +} from '@tldraw/tldraw' +import { useCallback, useEffect } from 'react' +import { LiveImageTool, MakeLiveButton } from '../components/LiveImageTool' fal.config({ - requestMiddleware: fal.withProxy({ - targetUrl: "/api/fal/proxy", - }), -}); + requestMiddleware: fal.withProxy({ + targetUrl: '/api/fal/proxy', + }), +}) -const shapeUtils = [LiveImageShapeUtil]; -const tools = [LiveImageTool]; +const shapeUtils = [LiveImageShapeUtil] +const tools = [LiveImageTool] export default function Home() { - const onEditorMount = (editor: Editor) => { - // @ts-expect-error: patch - editor.isShapeOfType = function (arg, type) { - const shape = typeof arg === "string" ? this.getShape(arg)! : arg; - if (shape.type === "live-image" && type === "frame") { - return true; - } - return shape.type === type; - }; + const onEditorMount = (editor: Editor) => { + // @ts-expect-error: patch + editor.isShapeOfType = function (arg, type) { + const shape = typeof arg === 'string' ? this.getShape(arg)! : arg + if (shape.type === 'live-image' && type === 'frame') { + return true + } + return shape.type === type + } - // If there isn't a live image shape, create one - const liveImage = editor.getCurrentPageShapes().find((shape) => { - return shape.type === "live-image"; - }); + // If there isn't a live image shape, create one + const liveImage = editor.getCurrentPageShapes().find((shape) => { + return shape.type === 'live-image' + }) - if (liveImage) { - return; - } + if (liveImage) { + return + } - editor.createShape({ - type: "live-image", - x: 120, - y: 180, - props: { - w: 512, - h: 512, - name: "a city skyline", - }, - }); - }; + editor.createShape({ + type: 'live-image', + x: 120, + y: 180, + props: { + w: 512, + h: 512, + name: 'a city skyline', + }, + }) + } - return ( -
-
- } - /> -
-
- ); + return ( +
+
+ } + > + + +
+
+ ) +} + +function SneakySideEffects() { + const editor = useEditor() + + useEffect(() => { + editor.sideEffects.registerAfterChangeHandler('shape', () => { + editor.emit('update-drawings' as any) + }) + }, [editor]) + + return null } diff --git a/src/components/FrameHeading.tsx b/src/components/FrameHeading.tsx new file mode 100644 index 0000000..760a8d9 --- /dev/null +++ b/src/components/FrameHeading.tsx @@ -0,0 +1,131 @@ +import { + SelectionEdge, + TLShapeId, + canonicalizeRotation, + getPointerInfo, + toDomPrecision, + useEditor, + useIsEditing, + useValue, +} from '@tldraw/editor' +import { useCallback, useEffect, useRef } from 'react' +import { FrameLabelInput } from './FrameLabelInput' +import { preventDefault, stopEventPropagation } from '@tldraw/tldraw' + +export function FrameHeading({ + id, + name, + width, + height, +}: { + id: TLShapeId + name: string + width: number + height: number +}) { + const editor = useEditor() + const pageRotation = useValue( + 'shape rotation', + () => canonicalizeRotation(editor.getShapePageTransform(id)!.rotation()), + [editor, id] + ) + + const isEditing = useIsEditing(id) + + const rInput = useRef(null) + + const handlePointerDown = useCallback( + (e: React.PointerEvent) => { + preventDefault(e) + stopEventPropagation(e) + + const event = getPointerInfo(e) + + console.log('hello') + + // If we're editing the frame label, we shouldn't hijack the pointer event + if (editor.getEditingShapeId() === id) return + + editor.dispatch({ + ...event, + type: 'pointer', + name: 'pointer_down', + target: 'shape', + shape: editor.getShape(id)!, + }) + }, + [editor, id] + ) + + useEffect(() => { + const el = rInput.current + if (el && isEditing) { + // On iOS, we must focus here + el.focus() + el.select() + + requestAnimationFrame(() => { + // On desktop, the input may have lost focus, so try try try again! + if (document.activeElement !== el) { + el.focus() + el.select() + } + }) + } + }, [rInput, isEditing]) + + // rotate right 45 deg + const offsetRotation = pageRotation + Math.PI / 4 + const scaledRotation = (offsetRotation * (2 / Math.PI) + 4) % 4 + const labelSide: SelectionEdge = ( + ['top', 'left', 'bottom', 'right'] as const + )[Math.floor(scaledRotation)] + + let labelTranslate: string + switch (labelSide) { + case 'top': + labelTranslate = `` + break + case 'right': + labelTranslate = `translate(${toDomPrecision( + width + )}px, 0px) rotate(90deg)` + break + case 'bottom': + labelTranslate = `translate(${toDomPrecision(width)}px, ${toDomPrecision( + height + )}px) rotate(180deg)` + break + case 'left': + labelTranslate = `translate(0px, ${toDomPrecision( + height + )}px) rotate(270deg)` + break + } + + return ( +
+
+ +
+
+ ) +} diff --git a/src/components/FrameLabelInput.tsx b/src/components/FrameLabelInput.tsx new file mode 100644 index 0000000..7d7b744 --- /dev/null +++ b/src/components/FrameLabelInput.tsx @@ -0,0 +1,98 @@ +import { + TLFrameShape, + TLShapeId, + stopEventPropagation, + useEditor, +} from '@tldraw/editor' +import { forwardRef, useCallback } from 'react' + +export const FrameLabelInput = forwardRef< + HTMLInputElement, + { id: TLShapeId; name: string; isEditing: boolean } +>(function FrameLabelInput({ id, name, isEditing }, ref) { + const editor = useEditor() + + const handleKeyDown = useCallback( + (e: React.KeyboardEvent) => { + if (e.key === 'Enter' && !e.nativeEvent.isComposing) { + // need to prevent the enter keydown making it's way up to the Idle state + // and sending us back into edit mode + stopEventPropagation(e) + e.currentTarget.blur() + editor.setEditingShape(null) + } + }, + [editor] + ) + + const handleBlur = useCallback( + (e: React.FocusEvent) => { + const shape = editor.getShape(id) + if (!shape) return + + const name = shape.props.name + const value = e.currentTarget.value.trim() + if (name === value) return + + editor.updateShapes( + [ + { + id, + type: 'frame', + props: { name: value }, + }, + ], + { squashing: true } + ) + }, + [id, editor] + ) + + const handleChange = useCallback( + (e: React.ChangeEvent) => { + const shape = editor.getShape(id) + if (!shape) return + + const name = shape.props.name + const value = e.currentTarget.value + if (name === value) return + + editor.updateShapes( + [ + { + id, + type: 'frame', + props: { name: value }, + }, + ], + { squashing: true } + ) + }, + [id, editor] + ) + + return ( +
+ + {defaultEmptyAs(name, 'Frame') + String.fromCharCode(8203)} +
+ ) +}) + +export function defaultEmptyAs(str: string, dflt: string) { + if (str.match(/^\s*$/)) { + return dflt + } + return str +} diff --git a/src/components/LiveImageShapeUtil.tsx b/src/components/LiveImageShapeUtil.tsx new file mode 100644 index 0000000..79a65a8 --- /dev/null +++ b/src/components/LiveImageShapeUtil.tsx @@ -0,0 +1,233 @@ +/* eslint-disable @next/next/no-img-element */ +/* eslint-disable react-hooks/rules-of-hooks */ +import { + AssetRecordType, + canonicalizeRotation, + FrameShapeUtil, + Geometry2d, + getDefaultColorTheme, + getHashForObject, + getSvgAsImage, + HTMLContainer, + IdOf, + Rectangle2d, + resizeBox, + SelectionEdge, + ShapeUtil, + SVGContainer, + TLAsset, + TLBaseShape, + TLGroupShape, + TLOnResizeEndHandler, + TLOnResizeHandler, + TLShape, + TLShapeId, + toDomPrecision, + useEditor, + useIsDarkMode, +} from '@tldraw/tldraw' + +import { blobToDataUri } from '@/utils/blob' +import { debounce } from '@/utils/debounce' +import * as fal from '@fal-ai/serverless-client' +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import result from 'postcss/lib/result' +import { FrameHeading } from './FrameHeading' +import image from 'next/image' +import { connect } from 'http2' +import { useFal } from '@/hooks/useFal' + +// See https://www.fal.ai/models/latent-consistency-sd + +type Input = { + prompt: string + image_url: string + sync_mode: boolean + seed: number + strength?: number + guidance_scale?: number + num_inference_steps?: number + enable_safety_checks?: boolean +} + +type Output = { + images: Array<{ + url: string + width: number + height: number + }> + seed: number + num_inference_steps: number +} + +export type LiveImageShape = TLBaseShape< + 'live-image', + { + w: number + h: number + name: string + } +> + +export class LiveImageShapeUtil extends ShapeUtil { + static type = 'live-image' as any + + override canBind = () => true + + override canEdit = () => true + + getDefaultProps() { + return { + w: 512, + h: 512, + name: 'a city skyline', + } + } + + override getGeometry(shape: LiveImageShape): Geometry2d { + return new Rectangle2d({ + width: shape.props.w, + height: shape.props.h, + isFilled: false, + }) + } + + canUnmount = () => false + + indicator(shape: LiveImageShape) { + const bounds = this.editor.getShapeGeometry(shape).bounds + + return ( + + ) + } + + override component(shape: LiveImageShape) { + const editor = useEditor() + + useFal(shape.id, { + debounceTime: 0, + url: 'wss://110602490-lcm-sd15-i2i.gateway.alpha.fal.ai/ws', + }) + + const bounds = this.editor.getShapeGeometry(shape).bounds + const assetId = AssetRecordType.createId(shape.id.split(':')[1]) + const asset = editor.getAsset(assetId) + + // eslint-disable-next-line react-hooks/rules-of-hooks + const theme = getDefaultColorTheme({ isDarkMode: useIsDarkMode() }) + + return ( + <> + + + + + {asset && ( + {shape.props.name} + )} + + ) + } + + override canReceiveNewChildrenOfType = ( + shape: TLShape, + _type: TLShape['type'] + ) => { + return !shape.isLocked + } + + providesBackgroundForChildren(): boolean { + return true + } + + override canDropShapes = ( + shape: LiveImageShape, + _shapes: TLShape[] + ): boolean => { + return !shape.isLocked + } + + override onDragShapesOver = ( + frame: LiveImageShape, + shapes: TLShape[] + ): { shouldHint: boolean } => { + if (!shapes.every((child) => child.parentId === frame.id)) { + this.editor.reparentShapes( + shapes.map((shape) => shape.id), + frame.id + ) + return { shouldHint: true } + } + return { shouldHint: false } + } + + override onDragShapesOut = ( + _shape: LiveImageShape, + shapes: TLShape[] + ): void => { + const parent = this.editor.getShape(_shape.parentId) + const isInGroup = + parent && this.editor.isShapeOfType(parent, 'group') + + // If frame is in a group, keep the shape + // moved out in that group + + if (isInGroup) { + this.editor.reparentShapes(shapes, parent.id) + } else { + this.editor.reparentShapes(shapes, this.editor.getCurrentPageId()) + } + } + + override onResizeEnd: TLOnResizeEndHandler = (shape) => { + const bounds = this.editor.getShapePageBounds(shape)! + const children = this.editor.getSortedChildIdsForParent(shape.id) + + const shapesToReparent: TLShapeId[] = [] + + for (const childId of children) { + const childBounds = this.editor.getShapePageBounds(childId)! + if (!bounds.includes(childBounds)) { + shapesToReparent.push(childId) + } + } + + if (shapesToReparent.length > 0) { + this.editor.reparentShapes( + shapesToReparent, + this.editor.getCurrentPageId() + ) + } + } + + override onResize: TLOnResizeHandler = (shape, info) => { + return resizeBox(shape, info) + } +} diff --git a/src/components/live-image.tsx b/src/components/live-image.tsx deleted file mode 100644 index 56b554b..0000000 --- a/src/components/live-image.tsx +++ /dev/null @@ -1,280 +0,0 @@ -/* eslint-disable @next/next/no-img-element */ -/* eslint-disable react-hooks/rules-of-hooks */ -import { - canonicalizeRotation, - FrameShapeUtil, - getDefaultColorTheme, - getSvgAsImage, - HTMLContainer, - SelectionEdge, - TLEventMapHandler, - TLFrameShape, - TLShape, - useEditor, -} from "@tldraw/tldraw"; - -import { blobToDataUri } from "@/utils/blob"; -import { debounce } from "@/utils/debounce"; -import * as fal from "@fal-ai/serverless-client"; -import { useCallback, useEffect, useMemo, useRef, useState } from "react"; -import result from "postcss/lib/result"; - -// See https://www.fal.ai/models/latent-consistency-sd - -const DEBOUNCE_TIME = 0.0; // Adjust as needed -const URL = "wss://110602490-lcm-sd15-i2i.gateway.alpha.fal.ai/ws"; - -type Input = { - prompt: string; - image_url: string; - sync_mode: boolean; - seed: number; - strength?: number; - guidance_scale?: number; - num_inference_steps?: number; - enable_safety_checks?: boolean; -}; - -type Output = { - images: Array<{ - url: string; - width: number; - height: number; - }>; - seed: number; - num_inference_steps: number; -}; - -export class LiveImageShapeUtil extends FrameShapeUtil { - static override type = "live-image" as any; - - override getDefaultProps(): { w: number; h: number; name: string } { - return { - w: 512, - h: 512, - name: "a city skyline", - }; - } - - override canUnmount = () => false; - - override toSvg(shape: TLFrameShape) { - const theme = getDefaultColorTheme({ - isDarkMode: this.editor.user.getIsDarkMode(), - }); - const g = document.createElementNS("http://www.w3.org/2000/svg", "g"); - - const rect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); - rect.setAttribute("width", shape.props.w.toString()); - rect.setAttribute("height", shape.props.h.toString()); - rect.setAttribute("fill", theme.solid); - g.appendChild(rect); - - return g; - } - - override component(shape: TLFrameShape) { - const editor = useEditor(); - const component = super.component(shape); - const [image, setImage] = useState(null); - - const imageDigest = useRef(null); - const startedIteration = useRef(0); - const finishedIteration = useRef(0); - - //===== SOCKET =====// - const webSocketRef = useRef(null); - const isReconnecting = useRef(false); - - const connect = useCallback(() => { - webSocketRef.current = new WebSocket(URL); - webSocketRef.current.onopen = () => { - // console.log("WebSocket Open"); - }; - - webSocketRef.current.onclose = () => { - // console.log("WebSocket Close"); - }; - - webSocketRef.current.onerror = (error) => { - // console.error("WebSocket Error:", error); - }; - - webSocketRef.current.onmessage = (message) => { - try { - const data = JSON.parse(message.data); - // console.log("WebSocket Message:", data); - if (data.images && data.images.length > 0) { - setImage(data.images[0].url); - } - } catch (e) { - console.error("Error parsing the WebSocket response:", e); - } - }; - }, []); - - const disconnect = useCallback(() => { - webSocketRef.current?.close(); - }, []); - - const sendMessage = useCallback( - async (message: string) => { - if ( - !isReconnecting.current && - webSocketRef.current?.readyState !== WebSocket.OPEN - ) { - isReconnecting.current = true; - connect(); - } - - if ( - isReconnecting.current && - webSocketRef.current?.readyState !== WebSocket.OPEN - ) { - await new Promise((resolve) => { - const checkConnection = setInterval(() => { - if (webSocketRef.current?.readyState === WebSocket.OPEN) { - clearInterval(checkConnection); - resolve(); - } - }, 100); - }); - isReconnecting.current = false; - } - webSocketRef.current?.send(message); - }, - [connect] - ); - - const sendCurrentData = useMemo(() => { - return debounce(sendMessage, DEBOUNCE_TIME); - }, [sendMessage]); - //===========// - - // eslint-disable-next-line react-hooks/exhaustive-deps - const onDrawingChange = useCallback( - debounce(async () => { - // TODO get actual drawing bounds - // const bounds = new Box2d(120, 180, 512, 512); - - const iteration = startedIteration.current++; - - const shapes = Array.from( - editor.getShapeAndDescendantIds([shape.id]) - ).map((id) => editor.getShape(id)) as TLShape[]; - - // Check if should submit request - const shapesDigest = JSON.stringify(shapes); - if (shapesDigest === imageDigest.current) { - return; - } - imageDigest.current = shapesDigest; - - const svg = await editor.getSvg([shape], { - background: true, - padding: 0, - darkMode: editor.user.getIsDarkMode(), - }); - if (iteration <= finishedIteration.current) return; - - if (!svg) { - return; - } - const image = await getSvgAsImage(svg, editor.environment.isSafari, { - type: "png", - quality: 1, - scale: 1, - }); - - if (iteration <= finishedIteration.current) return; - - if (!image) { - return; - } - - const prompt = - editor.getShape(shape.id)?.props.name ?? ""; - const imageDataUri = await blobToDataUri(image); - - const request = { - image_url: imageDataUri, - prompt, - sync_mode: true, - strength: 0.7, - seed: 42, // TODO make this configurable in the UI - enable_safety_checks: false, - }; - - sendCurrentData(JSON.stringify(request)); - - if (iteration <= finishedIteration.current) return; - - // const result = await fal.run(LatentConsistency, { - // input: { - // image_url: imageDataUri, - // prompt, - // sync_mode: true, - // strength: 0.6, - // seed: 42, // TODO make this configurable in the UI - // enable_safety_checks: false, - // }, - // // Disable auto-upload so we can submit the data uri of the image as is - // autoUpload: true, - // }); - if (iteration <= finishedIteration.current) return; - - finishedIteration.current = iteration; - // if (result && result.images.length > 0) { - // setImage(result.images[0].url); - // } - }, 0), - [] - ); - - useEffect(() => { - const onChange: TLEventMapHandler<"change"> = (event) => { - if (event.source !== "user") { - return; - } - if ( - Object.keys(event.changes.added).length || - Object.keys(event.changes.removed).length || - Object.keys(event.changes.updated).length - ) { - onDrawingChange(); - } - }; - editor.addListener("change", onChange); - return () => { - editor.removeListener("change", onChange); - }; - }, [editor, onDrawingChange]); - - return ( - -
- {component} - - {image && ( - - )} -
-
- ); - } -} diff --git a/src/hooks/useFal.ts b/src/hooks/useFal.ts new file mode 100644 index 0000000..739dabb --- /dev/null +++ b/src/hooks/useFal.ts @@ -0,0 +1,190 @@ +import { LiveImageShape } from '@/components/LiveImageShapeUtil' +import { blobToDataUri } from '@/utils/blob' +import { + AssetRecordType, + TLShape, + TLShapeId, + debounce, + getHashForObject, + getSvgAsImage, + throttle, + useEditor, +} from '@tldraw/tldraw' +import { useRef, useEffect } from 'react' + +export function useFal( + shapeId: TLShapeId, + opts: { + debounceTime?: number + throttleTime?: number + url: string + } +) { + const { url, throttleTime = 500, debounceTime = 0 } = opts + const editor = useEditor() + const startedIteration = useRef(0) + const finishedIteration = useRef(0) + + const prevHash = useRef(null) + + useEffect(() => { + let socket: WebSocket | null = null + + let isReconnecting = false + + function updateImage(url: string | null) { + const shape = editor.getShape(shapeId)! + const id = AssetRecordType.createId(shape.id.split(':')[1]) + + const asset = editor.getAsset(id) + + if (!asset) { + editor.createAssets([ + AssetRecordType.create({ + id, + type: 'image', + props: { + name: shape.props.name, + w: shape.props.w, + h: shape.props.h, + src: url, + isAnimated: false, + mimeType: 'image/jpeg', + }, + }), + ]) + } else { + editor.updateAssets([ + { + ...asset, + type: 'image', + props: { + ...asset.props, + w: shape.props.w, + h: shape.props.h, + src: url, + }, + }, + ]) + } + } + + async function connect() { + { + socket = new WebSocket(url) + socket.onopen = () => { + // console.log("WebSocket Open"); + } + + socket.onclose = () => { + // console.log("WebSocket Close"); + } + + socket.onerror = (error) => { + // console.error("WebSocket Error:", error); + } + + socket.onmessage = (message) => { + try { + const data = JSON.parse(message.data) + // console.log("WebSocket Message:", data); + if (data.images && data.images.length > 0) { + updateImage(data.images[0].url ?? '') + } + } catch (e) { + console.error('Error parsing the WebSocket response:', e) + } + } + } + } + + async function sendCurrentData(message: string) { + if (!isReconnecting && socket?.readyState !== WebSocket.OPEN) { + isReconnecting = true + connect() + } + + if (isReconnecting && socket?.readyState !== WebSocket.OPEN) { + await new Promise((resolve) => { + const checkConnection = setInterval(() => { + if (socket?.readyState === WebSocket.OPEN) { + clearInterval(checkConnection) + resolve() + } + }, 100) + }) + isReconnecting = false + } + socket?.send(message) + } + + async function updateDrawing() { + const iteration = startedIteration.current++ + + const shapes = Array.from(editor.getShapeAndDescendantIds([shapeId])).map( + (id) => editor.getShape(id) + ) as TLShape[] + + const hash = getHashForObject(shapes) + if (hash === prevHash.current) return + + const shape = editor.getShape(shapeId)! + + const svg = await editor.getSvg([shape], { + background: true, + padding: 0, + darkMode: editor.user.getIsDarkMode(), + }) + if (!svg) { + updateImage('') + return + } + + // We might be stale + if (iteration <= finishedIteration.current) return + + const image = await getSvgAsImage(svg, editor.environment.isSafari, { + type: 'png', + quality: 1, + scale: 1, + }) + if (!image) { + updateImage('') + return + } + + // We might be stale + if (iteration <= finishedIteration.current) return + + const prompt = shape.props.name + ? shape.props.name + ' hd award-winning impressive' + : 'A random image that is safe for work and not surprising—something boring like a city or shoe watercolor' + const imageDataUri = await blobToDataUri(image) + const request = { + image_url: imageDataUri, + prompt, + sync_mode: true, + strength: 0.7, + seed: 42, // TODO make this configurable in the UI + enable_safety_checks: false, + } + + // We might be stale + if (iteration <= finishedIteration.current) return + sendCurrentData(JSON.stringify(request)) + finishedIteration.current = iteration + } + + const onDrawingChange = debounceTime + ? debounce(updateDrawing, debounceTime) + : throttleTime + ? throttle(updateDrawing, throttleTime) + : debounce(updateDrawing, 16) + + editor.on('update-drawings' as any, onDrawingChange) + + return () => { + editor.off('update-drawings' as any, onDrawingChange) + } + }, [editor, shapeId, throttleTime, debounceTime, url]) +}