Merge pull request #3 from tldraw/lu/socket

Websocket
This commit is contained in:
Lu Wilson 2023-11-23 12:23:57 +00:00 committed by GitHub
commit 15cc361ef1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 102 additions and 17 deletions

View File

@ -13,12 +13,16 @@ import {
import { blobToDataUri } from "@/utils/blob";
import { debounce } from "@/utils/debounce";
import * as fal from "@fal-ai/serverless-client";
import { useCallback, useEffect, useRef, useState } from "react";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import result from "postcss/lib/result";
// See https://www.fal.ai/models/latent-consistency-sd
const LatentConsistency = "110602490-lcm-sd15-i2i";
const DEBOUNCE_TIME = 0.1; // Adjust as needed
const URL = "wss://110602490-lcm-sd15-i2i.gateway.alpha.fal.ai/ws";
type Input = {
prompt: string;
image_url: string;
@ -60,6 +64,75 @@ export class LiveImageShapeUtil extends FrameShapeUtil {
const startedIteration = useRef<number>(0);
const finishedIteration = useRef<number>(0);
//===== SOCKET =====//
const webSocketRef = useRef<WebSocket | null>(null);
const isReconnecting = useRef(false);
const connect = useCallback(() => {
webSocketRef.current = new WebSocket(URL);
webSocketRef.current.onopen = () => {
console.log("WebSocket Open");
};
webSocketRef.current.onclose = () => {
console.log("WebSocket Close");
};
webSocketRef.current.onerror = (error) => {
console.error("WebSocket Error:", error);
};
webSocketRef.current.onmessage = (message) => {
try {
const data = JSON.parse(message.data);
console.log("WebSocket Message:", data);
if (data.images && data.images.length > 0) {
setImage(data.images[0].url);
}
} catch (e) {
console.error("Error parsing the WebSocket response:", e);
}
};
}, []);
const disconnect = useCallback(() => {
webSocketRef.current?.close();
}, []);
const sendMessage = useCallback(
async (message: string) => {
if (
!isReconnecting.current &&
webSocketRef.current?.readyState !== WebSocket.OPEN
) {
isReconnecting.current = true;
connect();
}
if (
isReconnecting.current &&
webSocketRef.current?.readyState !== WebSocket.OPEN
) {
await new Promise<void>((resolve) => {
const checkConnection = setInterval(() => {
if (webSocketRef.current?.readyState === WebSocket.OPEN) {
clearInterval(checkConnection);
resolve();
}
}, 100);
});
isReconnecting.current = false;
}
webSocketRef.current?.send(message);
},
[connect]
);
const sendCurrentData = useMemo(() => {
return debounce(sendMessage, DEBOUNCE_TIME);
}, [sendMessage]);
//===========//
// eslint-disable-next-line react-hooks/exhaustive-deps
const onDrawingChange = useCallback(
debounce(async () => {
@ -100,27 +173,39 @@ export class LiveImageShapeUtil extends FrameShapeUtil {
const prompt =
editor.getShape<TLFrameShape>(shape.id)?.props.name ?? "";
const imageDataUri = await blobToDataUri(image);
const request = {
image_url: imageDataUri,
prompt,
sync_mode: true,
strength: 0.7,
seed: 42, // TODO make this configurable in the UI
enable_safety_checks: false,
};
sendCurrentData(JSON.stringify(request));
if (iteration <= finishedIteration.current) return;
const result = await fal.run<Input, Output>(LatentConsistency, {
input: {
image_url: imageDataUri,
prompt,
sync_mode: true,
strength: 0.6,
seed: 42, // TODO make this configurable in the UI
enable_safety_checks: false,
},
// Disable auto-upload so we can submit the data uri of the image as is
autoUpload: true,
});
// const result = await fal.run<Input, Output>(LatentConsistency, {
// input: {
// image_url: imageDataUri,
// prompt,
// sync_mode: true,
// strength: 0.6,
// seed: 42, // TODO make this configurable in the UI
// enable_safety_checks: false,
// },
// // Disable auto-upload so we can submit the data uri of the image as is
// autoUpload: true,
// });
if (iteration <= finishedIteration.current) return;
finishedIteration.current = iteration;
if (result && result.images.length > 0) {
setImage(result.images[0].url);
}
}, 32),
// if (result && result.images.length > 0) {
// setImage(result.images[0].url);
// }
}, 0),
[]
);