canvas-website/src/hooks/useLiveImage.tsx

367 lines
10 KiB
TypeScript

/**
* useLiveImage Hook
* Captures drawings within a frame shape and sends them to Fal.ai for AI enhancement
* Based on draw-fast implementation, adapted for canvas-website with Automerge sync
*/
import React, { createContext, useContext, useEffect, useRef, useCallback, useState } from 'react'
import { Editor, TLShapeId, Box, exportToBlob } from 'tldraw'
import { fal } from '@fal-ai/client'
import { getFalConfig } from '@/lib/clientConfig'
// Fal.ai model endpoints
const FAL_MODEL_LCM = 'fal-ai/lcm-sd15-i2i' // Fast, real-time (~150ms)
const FAL_MODEL_FLUX_CANNY = 'fal-ai/flux-control-lora-canny/image-to-image' // Higher quality
interface LiveImageContextValue {
isConnected: boolean
apiKey: string | null
setApiKey: (key: string) => void
}
const LiveImageContext = createContext<LiveImageContextValue | null>(null)
interface LiveImageProviderProps {
children: React.ReactNode
apiKey?: string
}
/**
* Provider component that manages Fal.ai connection
*/
export function LiveImageProvider({ children, apiKey: initialApiKey }: LiveImageProviderProps) {
// Get default FAL key from clientConfig (includes the hardcoded default)
const falConfig = getFalConfig()
const defaultApiKey = falConfig?.apiKey || null
const [apiKey, setApiKeyState] = useState<string | null>(
initialApiKey || import.meta.env.VITE_FAL_API_KEY || defaultApiKey
)
const [isConnected, setIsConnected] = useState(false)
// Configure Fal.ai client when API key is available
useEffect(() => {
if (apiKey) {
fal.config({ credentials: apiKey })
setIsConnected(true)
} else {
setIsConnected(false)
}
}, [apiKey])
const setApiKey = useCallback((key: string) => {
setApiKeyState(key)
// Also save to localStorage for persistence
localStorage.setItem('fal_api_key', key)
}, [])
// Try to load API key from localStorage on mount (but only if no default key)
useEffect(() => {
if (!apiKey) {
const storedKey = localStorage.getItem('fal_api_key')
if (storedKey) {
setApiKeyState(storedKey)
} else if (defaultApiKey) {
// Use default key from config
setApiKeyState(defaultApiKey)
}
}
}, [defaultApiKey])
return (
<LiveImageContext.Provider value={{ isConnected, apiKey, setApiKey }}>
{children}
</LiveImageContext.Provider>
)
}
export function useLiveImageContext() {
const context = useContext(LiveImageContext)
if (!context) {
throw new Error('useLiveImageContext must be used within a LiveImageProvider')
}
return context
}
interface UseLiveImageOptions {
editor: Editor
shapeId: TLShapeId
prompt: string
enabled?: boolean
throttleMs?: number
model?: 'lcm' | 'flux-canny'
strength?: number
onResult?: (imageUrl: string) => void
onError?: (error: Error) => void
}
interface LiveImageState {
isGenerating: boolean
lastGeneratedUrl: string | null
error: string | null
}
/**
* Hook that watches for drawing changes within a frame and generates AI images
*/
export function useLiveImage({
editor,
shapeId,
prompt,
enabled = true,
throttleMs = 500,
model = 'lcm',
strength = 0.65,
onResult,
onError,
}: UseLiveImageOptions): LiveImageState {
const [state, setState] = useState<LiveImageState>({
isGenerating: false,
lastGeneratedUrl: null,
error: null,
})
const requestVersionRef = useRef(0)
const lastRequestTimeRef = useRef(0)
const pendingRequestRef = useRef<NodeJS.Timeout | null>(null)
const context = useContext(LiveImageContext)
// Get shapes that intersect with this frame
const getChildShapes = useCallback(() => {
const shape = editor.getShape(shapeId)
if (!shape) return []
const bounds = editor.getShapePageBounds(shapeId)
if (!bounds) return []
// Find all shapes that intersect with this frame
const allShapes = editor.getCurrentPageShapes()
return allShapes.filter(s => {
if (s.id === shapeId) return false // Exclude the frame itself
const shapeBounds = editor.getShapePageBounds(s.id)
if (!shapeBounds) return false
return bounds.contains(shapeBounds) || bounds.collides(shapeBounds)
})
}, [editor, shapeId])
// Capture the drawing as a base64 image
const captureDrawing = useCallback(async (): Promise<string | null> => {
try {
const childShapes = getChildShapes()
if (childShapes.length === 0) return null
const shapeIds = childShapes.map(s => s.id)
// Export shapes to blob
const blob = await exportToBlob({
editor,
ids: shapeIds,
format: 'jpeg',
opts: {
background: true,
padding: 0,
scale: 1,
},
})
// Convert blob to data URL
return new Promise((resolve, reject) => {
const reader = new FileReader()
reader.onloadend = () => resolve(reader.result as string)
reader.onerror = reject
reader.readAsDataURL(blob)
})
} catch (error) {
console.error('LiveImage: Failed to capture drawing:', error)
return null
}
}, [editor, getChildShapes])
// Generate AI image from the sketch
const generateImage = useCallback(async () => {
if (!context?.isConnected || !enabled) {
return
}
const currentVersion = ++requestVersionRef.current
setState(prev => ({ ...prev, isGenerating: true, error: null }))
try {
const imageDataUrl = await captureDrawing()
if (!imageDataUrl) {
setState(prev => ({ ...prev, isGenerating: false }))
return
}
// Check if this request is still valid (not superseded by newer request)
if (currentVersion !== requestVersionRef.current) {
return
}
const modelEndpoint = model === 'flux-canny' ? FAL_MODEL_FLUX_CANNY : FAL_MODEL_LCM
// Build the full prompt
const fullPrompt = prompt
? `${prompt}, hd, award-winning, impressive, detailed`
: 'hd, award-winning, impressive, detailed illustration'
const result = await fal.subscribe(modelEndpoint, {
input: {
prompt: fullPrompt,
image_url: imageDataUrl,
strength: strength,
sync_mode: true,
seed: 42,
num_inference_steps: model === 'lcm' ? 4 : 20,
guidance_scale: model === 'lcm' ? 1 : 7.5,
enable_safety_checks: false,
},
pollInterval: 1000,
logs: true,
})
// Check if this result is still relevant
if (currentVersion !== requestVersionRef.current) {
return
}
// Extract image URL from result
let imageUrl: string | null = null
if (result.data) {
const data = result.data as any
if (data.images && Array.isArray(data.images) && data.images.length > 0) {
imageUrl = data.images[0].url || data.images[0]
} else if (data.image) {
imageUrl = data.image.url || data.image
} else if (data.output) {
imageUrl = typeof data.output === 'string' ? data.output : data.output.url
}
}
if (imageUrl) {
setState(prev => ({
...prev,
isGenerating: false,
lastGeneratedUrl: imageUrl,
error: null,
}))
onResult?.(imageUrl)
} else {
throw new Error('No image URL in response')
}
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error'
console.error('LiveImage: Generation failed:', errorMessage)
if (currentVersion === requestVersionRef.current) {
setState(prev => ({
...prev,
isGenerating: false,
error: errorMessage,
}))
onError?.(error instanceof Error ? error : new Error(errorMessage))
}
}
}, [context?.isConnected, enabled, captureDrawing, model, prompt, strength, onResult, onError])
// Throttled generation trigger
const triggerGeneration = useCallback(() => {
if (!enabled) return
const now = Date.now()
const timeSinceLastRequest = now - lastRequestTimeRef.current
// Clear any pending request
if (pendingRequestRef.current) {
clearTimeout(pendingRequestRef.current)
}
if (timeSinceLastRequest >= throttleMs) {
// Enough time has passed, generate immediately
lastRequestTimeRef.current = now
generateImage()
} else {
// Schedule generation after throttle period
const delay = throttleMs - timeSinceLastRequest
pendingRequestRef.current = setTimeout(() => {
lastRequestTimeRef.current = Date.now()
generateImage()
}, delay)
}
}, [enabled, throttleMs, generateImage])
// Watch for changes to shapes within the frame
useEffect(() => {
if (!enabled) return
const handleChange = () => {
triggerGeneration()
}
// Subscribe to store changes
const unsubscribe = editor.store.listen(handleChange, {
source: 'user',
scope: 'document',
})
return () => {
unsubscribe()
if (pendingRequestRef.current) {
clearTimeout(pendingRequestRef.current)
}
}
}, [editor, enabled, triggerGeneration])
return state
}
/**
* Convert SVG string to JPEG data URL (fast method)
*/
async function svgToJpegDataUrl(
svgString: string,
width: number,
height: number,
quality: number = 0.3
): Promise<string> {
return new Promise((resolve, reject) => {
const img = new Image()
const svgBlob = new Blob([svgString], { type: 'image/svg+xml' })
const url = URL.createObjectURL(svgBlob)
img.onload = () => {
const canvas = document.createElement('canvas')
canvas.width = width
canvas.height = height
const ctx = canvas.getContext('2d')
if (!ctx) {
reject(new Error('Failed to get canvas context'))
return
}
// Fill with white background
ctx.fillStyle = 'white'
ctx.fillRect(0, 0, width, height)
// Draw the SVG
ctx.drawImage(img, 0, 0, width, height)
// Convert to JPEG
const dataUrl = canvas.toDataURL('image/jpeg', quality)
URL.revokeObjectURL(url)
resolve(dataUrl)
}
img.onerror = () => {
URL.revokeObjectURL(url)
reject(new Error('Failed to load SVG'))
}
img.src = url
})
}