more tests

This commit is contained in:
Lu[ke] Wilson 2023-11-21 17:09:03 +00:00
parent 6be0379c49
commit 7e6fcbe1dc
3 changed files with 166 additions and 146 deletions

View File

@ -2,8 +2,9 @@
import { LiveImageShapeUtil } from "@/components/live-image"; import { LiveImageShapeUtil } from "@/components/live-image";
import * as fal from "@fal-ai/serverless-client"; import * as fal from "@fal-ai/serverless-client";
import { Editor, Tldraw } from "@tldraw/tldraw"; import { Editor, FrameShapeTool, Tldraw, useEditor } from "@tldraw/tldraw";
import { useCallback } from "react"; import { useCallback } from "react";
import { LiveImageTool, MakeLiveButton } from "../components/LiveImageTool";
fal.config({ fal.config({
requestMiddleware: fal.withProxy({ requestMiddleware: fal.withProxy({
@ -12,6 +13,7 @@ fal.config({
}); });
const shapeUtils = [LiveImageShapeUtil]; const shapeUtils = [LiveImageShapeUtil];
const tools = [LiveImageTool];
export default function Home() { export default function Home() {
const onEditorMount = (editor: Editor) => { const onEditorMount = (editor: Editor) => {
@ -28,7 +30,11 @@ export default function Home() {
type: "live-image", type: "live-image",
x: 120, x: 120,
y: 180, y: 180,
isLocked: true, props: {
w: 512,
h: 512,
name: "a city skyline",
},
}); });
}; };
@ -39,6 +45,8 @@ export default function Home() {
persistenceKey="tldraw-fal" persistenceKey="tldraw-fal"
onMount={onEditorMount} onMount={onEditorMount}
shapeUtils={shapeUtils} shapeUtils={shapeUtils}
tools={tools}
shareZone={<MakeLiveButton />}
/> />
</div> </div>
</main> </main>

View File

@ -0,0 +1,27 @@
import { FrameShapeTool, useEditor } from "@tldraw/tldraw";
import { useCallback } from "react";
export class LiveImageTool extends FrameShapeTool {
static override id = "live-image";
static override initial = "idle";
override shapeType = "live-image";
}
export function MakeLiveButton() {
const editor = useEditor();
const makeLive = useCallback(() => {
editor.setCurrentTool("live-image");
}, [editor]);
return (
<button
onClick={makeLive}
className="p-2"
style={{ cursor: "pointer", zIndex: 100000, pointerEvents: "all" }}
>
<div className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded">
Make Live
</div>
</button>
);
}

View File

@ -1,10 +1,11 @@
/* eslint-disable @next/next/no-img-element */
/* eslint-disable react-hooks/rules-of-hooks */
import { import {
Box2d, FrameShapeUtil,
getSvgAsImage, getSvgAsImage,
Rectangle2d, HTMLContainer,
ShapeUtil,
TLBaseShape,
TLEventMapHandler, TLEventMapHandler,
TLFrameShape,
TLShape, TLShape,
useEditor, useEditor,
} from "@tldraw/tldraw"; } from "@tldraw/tldraw";
@ -13,7 +14,6 @@ import { blobToDataUri } from "@/utils/blob";
import { debounce } from "@/utils/debounce"; import { debounce } from "@/utils/debounce";
import * as fal from "@fal-ai/serverless-client"; import * as fal from "@fal-ai/serverless-client";
import { useCallback, useEffect, useRef, useState } from "react"; import { useCallback, useEffect, useRef, useState } from "react";
import { FalLogo } from "./fal-logo";
// See https://www.fal.ai/models/latent-consistency-sd // See https://www.fal.ai/models/latent-consistency-sd
@ -24,6 +24,10 @@ type Input = {
image_url: string; image_url: string;
sync_mode: boolean; sync_mode: boolean;
seed: number; seed: number;
strength?: number;
guidance_scale?: number;
num_inference_steps?: number;
enable_safety_checks?: boolean;
}; };
type Output = { type Output = {
@ -36,16 +40,25 @@ type Output = {
num_inference_steps: number; num_inference_steps: number;
}; };
// TODO make this an input on the canvas export class LiveImageShapeUtil extends FrameShapeUtil {
const PROMPT = "a city skyline"; static override type = "live-image" as any;
export function LiveImage() { override getDefaultProps(): { w: number; h: number; name: string } {
return {
w: 512,
h: 512,
name: "a city skyline",
};
}
override component(shape: TLFrameShape) {
const editor = useEditor(); const editor = useEditor();
const component = super.component(shape);
const [image, setImage] = useState<string | null>(null); const [image, setImage] = useState<string | null>(null);
// Used to prevent multiple requests from being sent at once for the same image
// There's probably a better way to do this using TLDraw's state
const imageDigest = useRef<string | null>(null); const imageDigest = useRef<string | null>(null);
const startedIteration = useRef<number>(0);
const finishedIteration = useRef<number>(0);
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
const onDrawingChange = useCallback( const onDrawingChange = useCallback(
@ -53,17 +66,11 @@ export function LiveImage() {
// TODO get actual drawing bounds // TODO get actual drawing bounds
// const bounds = new Box2d(120, 180, 512, 512); // const bounds = new Box2d(120, 180, 512, 512);
const shapes = editor.getCurrentPageShapes().filter((shape) => { const iteration = startedIteration.current++;
if (shape.type === "live-image") {
return false; const shapes = Array.from(editor.getShapeAndDescendantIds([shape.id]))
} .filter((id) => id !== shape.id)
return true; .map((id) => editor.getShape(id)) as TLShape[];
// const pageBounds = editor.getShapeMaskedPageBounds(shape);
// if (!pageBounds) {
// return false;
// }
// return bounds.includes(pageBounds);
});
// Check if should submit request // Check if should submit request
const shapesDigest = JSON.stringify(shapes); const shapesDigest = JSON.stringify(shapes);
@ -73,33 +80,47 @@ export function LiveImage() {
imageDigest.current = shapesDigest; imageDigest.current = shapesDigest;
const svg = await editor.getSvg(shapes, { background: true }); const svg = await editor.getSvg(shapes, { background: true });
if (iteration <= finishedIteration.current) return;
if (!svg) { if (!svg) {
return; return;
} }
const image = await getSvgAsImage(svg, editor.environment.isSafari, { const image = await getSvgAsImage(svg, editor.environment.isSafari, {
type: "png", type: "png",
quality: 0.5, quality: 1,
scale: 1, scale: 1,
}); });
if (iteration <= finishedIteration.current) return;
if (!image) { if (!image) {
return; return;
} }
const prompt =
editor.getShape<TLFrameShape>(shape.id)?.props.name ?? "";
const imageDataUri = await blobToDataUri(image); const imageDataUri = await blobToDataUri(image);
if (iteration <= finishedIteration.current) return;
const result = await fal.run<Input, Output>(LatentConsistency, { const result = await fal.run<Input, Output>(LatentConsistency, {
input: { input: {
image_url: imageDataUri, image_url: imageDataUri,
prompt: PROMPT, prompt,
sync_mode: true, sync_mode: true,
strength: 0.6,
seed: 42, // TODO make this configurable in the UI seed: 42, // TODO make this configurable in the UI
enable_safety_checks: false,
}, },
// Disable auto-upload so we can submit the data uri of the image as is // Disable auto-upload so we can submit the data uri of the image as is
autoUpload: false, autoUpload: true,
}); });
if (iteration <= finishedIteration.current) return;
finishedIteration.current = iteration;
if (result && result.images.length > 0) { if (result && result.images.length > 0) {
setImage(result.images[0].url); setImage(result.images[0].url);
} }
}, 16), }, 32),
[] []
); );
@ -120,69 +141,33 @@ export function LiveImage() {
return () => { return () => {
editor.removeListener("change", onChange); editor.removeListener("change", onChange);
}; };
}, []); }, [editor, onDrawingChange]);
return ( return (
<div className="flex flex-row w-[1060px] h-[560px] absolute bg-indigo-200 border border-indigo-500 rounded space-x-4 p-4 pb-8"> <HTMLContainer>
<div className="flex-1 h-[512px] bg-white border border-indigo-500"> <div
<div className="flex flex-row items-center px-4 py-2 space-x-2"> style={{
<span className="font-mono text-indigo-900/50">/imagine</span> display: "flex",
<input }}
className="border-0 bg-transparent flex-1 text-base text-indigo-900"
placeholder="something cool..."
value={PROMPT}
/>
</div>
</div>
<div className="flex-1 h-[512px] bg-white border border-indigo-500">
{/* eslint-disable-next-line @next/next/no-img-element */}
{image && <img src={image} alt="" width={512} height={512} />}
</div>
<span className="absolute bottom-1.5 right-4">
<a
href="https://fal.ai/models/latent-consistency"
target="_blank"
className="flex flex-row space-x-1"
> >
<span className="text-xs text-indigo-900/50">powered by</span> {component}
<span className="w-[36px] opacity-50">
<FalLogo />
</span>
</a>
</span>
</div>
);
}
type LiveImageShape = TLBaseShape<"live-image", { w: number; h: number }>; {image && (
<img
export class LiveImageShapeUtil extends ShapeUtil<LiveImageShape> { src={image}
static override type = "live-image" as const; alt=""
width={shape.props.w}
override canResize = () => false; height={shape.props.h}
style={{
getDefaultProps(): LiveImageShape["props"] { position: "relative",
return { left: shape.props.w,
w: 1060,
h: 560,
};
}
getGeometry(shape: LiveImageShape) {
return new Rectangle2d({
width: shape.props.w, width: shape.props.w,
height: shape.props.h, height: shape.props.h,
isFilled: true, }}
}); />
} )}
</div>
component(shape: LiveImageShape) { </HTMLContainer>
return <LiveImage />;
}
indicator(shape: LiveImageShape) {
return (
<rect width={shape.props.w} height={shape.props.h} radius={4}></rect>
); );
} }
} }