From 047a98575b6651462ad21b23becf51c3bb3a1df1 Mon Sep 17 00:00:00 2001 From: Jeff Emmett Date: Wed, 26 Nov 2025 19:11:58 -0800 Subject: [PATCH] Initial commit: AI Orchestrator with Ollama + RunPod smart routing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - FastAPI server with dashboard - Smart routing: Ollama (free) for text, RunPod (GPU) for images/video - Docker + docker-compose with Traefik labels - Endpoints: text, chat, image, video, comfyui, whisper - Cost tracking and savings estimation šŸ¤– Generated with Claude Code --- .env.example | 7 + .gitignore | 33 ++ Dockerfile | 48 +++ docker-compose.yml | 70 +++++ requirements.txt | 4 + server.py | 757 +++++++++++++++++++++++++++++++++++++++++++++ test_api.py | 26 ++ 7 files changed, 945 insertions(+) create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 docker-compose.yml create mode 100644 requirements.txt create mode 100644 server.py create mode 100644 test_api.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..dbf43fc --- /dev/null +++ b/.env.example @@ -0,0 +1,7 @@ +# AI Orchestrator Environment Variables + +# RunPod API Key (required for GPU endpoints) +RUNPOD_API_KEY=your_runpod_api_key_here + +# Ollama host (defaults to http://ollama:11434 in docker-compose) +# OLLAMA_HOST=http://ollama:11434 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b7038ab --- /dev/null +++ b/.gitignore @@ -0,0 +1,33 @@ +# Environment +.env +*.env.local + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +venv/ +env/ +.venv/ +ENV/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Test outputs +*.png +*.mp4 +*.mp3 +*.wav + +# Docker +.docker/ + +# OS +.DS_Store +Thumbs.db diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..44d0815 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,48 @@ +# AI Orchestrator - Optimized Production Dockerfile +# Multi-stage build for minimal image size + +FROM python:3.12-slim as builder + +WORKDIR /build + +# Install build dependencies +RUN pip install --no-cache-dir --upgrade pip wheel + +# Copy and install Python dependencies +COPY requirements.txt . +RUN pip wheel --no-cache-dir --wheel-dir /wheels -r requirements.txt + +# Production stage +FROM python:3.12-slim + +WORKDIR /app + +# Create non-root user for security +RUN useradd --create-home --shell /bin/bash appuser + +# Install wheels from builder stage +COPY --from=builder /wheels /wheels +RUN pip install --no-cache-dir /wheels/* && rm -rf /wheels + +# Copy application code +COPY server.py . + +# Set ownership +RUN chown -R appuser:appuser /app + +# Switch to non-root user +USER appuser + +# Environment variables (can be overridden) +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 + +# Expose port +EXPOSE 8080 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8080/api/health')" || exit 1 + +# Run the application +CMD ["python", "-m", "uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8080"] diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..c02465b --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,70 @@ +version: '3.8' + +services: + ai-orchestrator: + build: + context: . + dockerfile: Dockerfile + image: ai-orchestrator:latest + container_name: ai-orchestrator + restart: unless-stopped + environment: + - RUNPOD_API_KEY=${RUNPOD_API_KEY} + - OLLAMA_HOST=http://ollama:11434 + depends_on: + ollama: + condition: service_healthy + labels: + # Traefik auto-discovery + - "traefik.enable=true" + - "traefik.http.routers.ai-orchestrator.rule=Host(`ai.jeffemmett.com`)" + - "traefik.http.routers.ai-orchestrator.entrypoints=websecure" + - "traefik.http.routers.ai-orchestrator.tls=true" + - "traefik.http.services.ai-orchestrator.loadbalancer.server.port=8080" + # Health check for Traefik + - "traefik.http.services.ai-orchestrator.loadbalancer.healthcheck.path=/api/health" + - "traefik.http.services.ai-orchestrator.loadbalancer.healthcheck.interval=30s" + networks: + - traefik-public + - ai-internal + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8080/api/health')"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 10s + + ollama: + image: ollama/ollama:latest + container_name: ollama + restart: unless-stopped + volumes: + - ollama-data:/root/.ollama + networks: + - ai-internal + # Expose internally only (orchestrator routes to it) + expose: + - "11434" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:11434/api/tags"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 30s + # CPU-only mode (no GPU passthrough needed for RS 8000) + deploy: + resources: + limits: + memory: 16G + reservations: + memory: 4G + +volumes: + ollama-data: + driver: local + +networks: + traefik-public: + external: true + ai-internal: + driver: bridge diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..d445a16 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +fastapi>=0.109.0 +uvicorn>=0.27.0 +httpx>=0.26.0 +pydantic>=2.5.0 diff --git a/server.py b/server.py new file mode 100644 index 0000000..a315cb4 --- /dev/null +++ b/server.py @@ -0,0 +1,757 @@ +""" +AI Orchestrator - Smart routing between local Ollama (free) and RunPod (GPU) + +Routes: +- Text/Code: Ollama (free, local CPU) or RunPod vLLM (paid, fast GPU) +- Images: RunPod Automatic1111/ComfyUI +- Video: RunPod Wan2.2 +- Audio: RunPod WhisperX +""" + +import os +import asyncio +import httpx +from datetime import datetime +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import HTMLResponse, StreamingResponse +from pydantic import BaseModel +from typing import Optional, List, Dict, Any, Literal +from enum import Enum + +# Config +RUNPOD_API_KEY = os.getenv("RUNPOD_API_KEY", "rpa_YYOARL5MEBTTKKWGABRKTW2CVHQYRBTOBZNSGIL3lwwfdz") +RUNPOD_API_BASE = "https://api.runpod.ai/v2" +OLLAMA_HOST = os.getenv("OLLAMA_HOST", "http://localhost:11434") + +# RunPod endpoints (paid GPU) +ENDPOINTS = { + "video": {"id": "4jql4l7l0yw0f3", "name": "Wan2.2 Video", "type": "video"}, + "image": {"id": "tzf1j3sc3zufsy", "name": "Automatic1111 SD", "type": "image"}, + "comfyui": {"id": "5zurj845tbf8he", "name": "ComfyUI", "type": "image"}, + "whisper": {"id": "lrtisuv8ixbtub", "name": "WhisperX", "type": "audio"}, + "llm": {"id": "03g5hz3hlo8gr2", "name": "vLLM", "type": "text"}, +} + +# Ollama models (free local CPU) +OLLAMA_MODELS = { + "llama3.2": {"name": "Llama 3.2 3B", "context": 128000, "size": "3B"}, + "llama3.2:1b": {"name": "Llama 3.2 1B", "context": 128000, "size": "1B"}, + "qwen2.5-coder:7b": {"name": "Qwen 2.5 Coder 7B", "context": 32000, "size": "7B"}, + "mistral": {"name": "Mistral 7B", "context": 32000, "size": "7B"}, + "phi3": {"name": "Phi-3 Mini", "context": 128000, "size": "3.8B"}, +} + +app = FastAPI( + title="AI Orchestrator", + description="Smart routing between local Ollama (free) and RunPod (GPU)", + version="1.0.0", +) + +# CORS middleware for web access +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Configure for production + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Store recent jobs +recent_jobs = [] + +# Track cost savings +cost_tracker = { + "ollama_requests": 0, + "runpod_requests": 0, + "estimated_savings": 0.0, # USD saved by using Ollama +} + + +# ============== Ollama Functions (FREE local inference) ============== + +async def ollama_health() -> dict: + """Check Ollama service health""" + async with httpx.AsyncClient() as client: + try: + resp = await client.get(f"{OLLAMA_HOST}/api/tags", timeout=5) + if resp.status_code == 200: + data = resp.json() + return { + "status": "healthy", + "models": [m["name"] for m in data.get("models", [])], + } + return {"status": "unhealthy", "error": f"Status {resp.status_code}"} + except Exception as e: + return {"status": "unavailable", "error": str(e)} + + +async def ollama_generate( + prompt: str, + model: str = "llama3.2", + system: Optional[str] = None, + stream: bool = False, + options: Optional[dict] = None, +) -> dict: + """Generate text using local Ollama""" + payload = { + "model": model, + "prompt": prompt, + "stream": stream, + } + if system: + payload["system"] = system + if options: + payload["options"] = options + + async with httpx.AsyncClient() as client: + try: + resp = await client.post( + f"{OLLAMA_HOST}/api/generate", + json=payload, + timeout=300, # 5 min timeout for long generations + ) + result = resp.json() + # Track usage + cost_tracker["ollama_requests"] += 1 + cost_tracker["estimated_savings"] += 0.001 # ~$0.001 saved per request vs RunPod + return result + except Exception as e: + return {"error": str(e)} + + +async def ollama_chat( + messages: List[dict], + model: str = "llama3.2", + stream: bool = False, + options: Optional[dict] = None, +) -> dict: + """Chat completion using local Ollama""" + payload = { + "model": model, + "messages": messages, + "stream": stream, + } + if options: + payload["options"] = options + + async with httpx.AsyncClient() as client: + try: + resp = await client.post( + f"{OLLAMA_HOST}/api/chat", + json=payload, + timeout=300, + ) + result = resp.json() + cost_tracker["ollama_requests"] += 1 + cost_tracker["estimated_savings"] += 0.001 + return result + except Exception as e: + return {"error": str(e)} + + +async def ollama_pull_model(model: str) -> dict: + """Pull/download a model to Ollama""" + async with httpx.AsyncClient() as client: + try: + resp = await client.post( + f"{OLLAMA_HOST}/api/pull", + json={"name": model}, + timeout=600, # Models can take a while to download + ) + return resp.json() + except Exception as e: + return {"error": str(e)} + + +def build_comfyui_workflow( + prompt: str, + negative_prompt: str = "", + seed: int = 42, + steps: int = 20, + cfg: float = 1.0, # Flux uses low CFG (1.0) + width: int = 1024, + height: int = 1024, + sampler: str = "euler", + scheduler: str = "simple", + denoise: float = 1.0, + model: str = "flux1-dev-fp8.safetensors", +) -> dict: + """Build a ComfyUI Flux txt2img workflow in API format""" + return { + "4": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": model + } + }, + "5": { + "class_type": "EmptyLatentImage", + "inputs": { + "batch_size": 1, + "height": height, + "width": width + } + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": ["4", 1], + "text": prompt + } + }, + "7": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": ["4", 1], + "text": negative_prompt + } + }, + "3": { + "class_type": "KSampler", + "inputs": { + "cfg": cfg, + "denoise": denoise, + "latent_image": ["5", 0], + "model": ["4", 0], + "negative": ["7", 0], + "positive": ["6", 0], + "sampler_name": sampler, + "scheduler": scheduler, + "seed": seed, + "steps": steps + } + }, + "8": { + "class_type": "VAEDecode", + "inputs": { + "samples": ["3", 0], + "vae": ["4", 2] + } + }, + "9": { + "class_type": "SaveImage", + "inputs": { + "filename_prefix": "ComfyUI", + "images": ["8", 0] + } + } + } + + +async def get_endpoint_health(endpoint_id: str) -> dict: + """Get health status of a RunPod endpoint""" + async with httpx.AsyncClient() as client: + try: + resp = await client.get( + f"{RUNPOD_API_BASE}/{endpoint_id}/health", + headers={"Authorization": f"Bearer {RUNPOD_API_KEY}"}, + timeout=10, + ) + return resp.json() + except Exception as e: + return {"error": str(e)} + + +async def get_job_status(endpoint_id: str, job_id: str) -> dict: + """Get status of a specific job""" + async with httpx.AsyncClient() as client: + try: + resp = await client.get( + f"{RUNPOD_API_BASE}/{endpoint_id}/status/{job_id}", + headers={"Authorization": f"Bearer {RUNPOD_API_KEY}"}, + timeout=10, + ) + return resp.json() + except Exception as e: + return {"error": str(e)} + + +async def submit_job(endpoint_id: str, payload: dict) -> dict: + """Submit a job to a RunPod endpoint""" + async with httpx.AsyncClient() as client: + try: + resp = await client.post( + f"{RUNPOD_API_BASE}/{endpoint_id}/run", + headers={ + "Authorization": f"Bearer {RUNPOD_API_KEY}", + "Content-Type": "application/json", + }, + json={"input": payload}, + timeout=30, + ) + return resp.json() + except Exception as e: + return {"error": str(e)} + + +@app.get("/", response_class=HTMLResponse) +async def dashboard(): + """Main dashboard showing all endpoint statuses""" + + # Fetch all endpoint health in parallel + health_tasks = { + name: get_endpoint_health(ep["id"]) + for name, ep in ENDPOINTS.items() + } + + health_results = {} + for name, task in health_tasks.items(): + health_results[name] = await task + + # Get Ollama status + ollama_status = await ollama_health() + + # Build HTML + html = """ + + + + AI Orchestrator Dashboard + + + + +

šŸ¤– AI Orchestrator Dashboard

+

Last updated: """ + datetime.now().strftime("%Y-%m-%d %H:%M:%S") + """ (auto-refreshes every 10s)

+ +
+
+ Ollama Status + """ + ollama_status.get("status", "unknown").upper() + """ +
+
+ Free Requests (Ollama) + """ + str(cost_tracker["ollama_requests"]) + """ +
+
+ Paid Requests (RunPod) + """ + str(cost_tracker["runpod_requests"]) + """ +
+
+ Est. Savings + $""" + str(round(cost_tracker["estimated_savings"], 2)) + """ +
+
+ +
+ """ + + for name, ep in ENDPOINTS.items(): + health = health_results.get(name, {}) + workers = health.get("workers", {}) + jobs = health.get("jobs", {}) + + # Determine status + if "error" in health: + status_class = "error" + status_text = "Error" + elif workers.get("ready", 0) > 0 or workers.get("running", 0) > 0: + status_class = "ready" + status_text = "Ready" + elif workers.get("throttled", 0) > 0: + status_class = "throttled" + status_text = "Throttled (waiting for GPU)" + elif workers.get("idle", 0) > 0: + status_class = "idle" + status_text = "Idle" + else: + status_class = "idle" + status_text = "Standby" + + html += f""" +
+

{ep['name']}

+ {status_text} +

Type: {ep['type']} | ID: {ep['id'][:8]}...

+ +
Workers Ready{workers.get('ready', 0)}
+
Workers Running{workers.get('running', 0)}
+
Workers Initializing{workers.get('initializing', 0)}
+
Workers Throttled{workers.get('throttled', 0)}
+
Jobs In Queue{jobs.get('inQueue', 0)}
+
Jobs In Progress{jobs.get('inProgress', 0)}
+
Jobs Completed{jobs.get('completed', 0)}
+
Jobs Failed{jobs.get('failed', 0)}
+ + +
+ """ + + html += """ +
+ +

Recent Jobs

+
+ """ + + if recent_jobs: + for job in recent_jobs[-10:][::-1]: + html += f"""
+ {job['endpoint']} - {job['id'][:16]}... + {job['status']} +
""" + else: + html += "

No jobs submitted yet

" + + html += """ +
+ + + """ + + return HTMLResponse(content=html) + + +@app.get("/api/health") +async def api_health(): + """API endpoint to get all endpoint health""" + results = {} + for name, ep in ENDPOINTS.items(): + results[name] = await get_endpoint_health(ep["id"]) + return results + + +@app.get("/api/status/{endpoint}/{job_id}") +async def api_job_status(endpoint: str, job_id: str): + """Get status of a specific job""" + if endpoint not in ENDPOINTS: + raise HTTPException(status_code=404, detail="Endpoint not found") + return await get_job_status(ENDPOINTS[endpoint]["id"], job_id) + + +@app.get("/test/{endpoint}") +async def test_endpoint(endpoint: str): + """Submit a test job to an endpoint""" + if endpoint not in ENDPOINTS: + raise HTTPException(status_code=404, detail="Endpoint not found") + + ep = ENDPOINTS[endpoint] + + # Different test payloads for different endpoint types + if ep["type"] == "video": + payload = { + "prompt": "A cat walking through a garden, cinematic lighting, high quality", + "negative_prompt": "blurry, low quality, distorted", + "seed": 42, + "cfg": 4.0, + "steps": 20, + "width": 832, + "height": 480, + "num_frames": 81, + "length": 81 + } + elif ep["type"] == "image": + if endpoint == "comfyui": + # ComfyUI needs a workflow JSON + payload = { + "workflow": build_comfyui_workflow( + prompt="A beautiful sunset over mountains, photorealistic, 8k", + negative_prompt="blurry, low quality, distorted", + seed=42, + steps=20, + cfg=7.0, + width=512, + height=512 + ) + } + else: + payload = {"prompt": "A beautiful sunset over mountains, photorealistic"} + elif ep["type"] == "audio": + payload = {"audio_url": "https://example.com/test.mp3"} + elif ep["type"] == "text": + payload = {"prompt": "Hello, how are you?", "max_tokens": 50} + else: + payload = {"test": True} + + result = await submit_job(ep["id"], payload) + + # Track job + if "id" in result: + recent_jobs.append({ + "endpoint": endpoint, + "id": result["id"], + "status": result.get("status", "SUBMITTED"), + "timestamp": datetime.now().isoformat(), + }) + + return result + + +class VideoRequest(BaseModel): + prompt: str + negative_prompt: Optional[str] = "blurry, low quality, distorted" + seed: Optional[int] = None # Random if not set + cfg: Optional[float] = 4.0 # Classifier-free guidance + steps: Optional[int] = 20 + width: Optional[int] = 832 + height: Optional[int] = 480 + num_frames: Optional[int] = 81 + length: Optional[int] = 81 # Same as num_frames for Wan2.2 + + +class ImageRequest(BaseModel): + prompt: str + negative_prompt: Optional[str] = "blurry, low quality" + steps: Optional[int] = 20 + width: Optional[int] = 512 + height: Optional[int] = 512 + + +@app.post("/api/generate/video") +async def generate_video(request: VideoRequest): + """Generate video using Wan2.2""" + import random + payload = request.dict() + # Generate random seed if not provided + if payload.get("seed") is None: + payload["seed"] = random.randint(1, 2147483647) + + result = await submit_job(ENDPOINTS["video"]["id"], payload) + if "id" in result: + recent_jobs.append({ + "endpoint": "video", + "id": result["id"], + "status": result.get("status", "SUBMITTED"), + "timestamp": datetime.now().isoformat(), + }) + return result + + +@app.post("/api/generate/image") +async def generate_image(request: ImageRequest): + """Generate image using Automatic1111""" + result = await submit_job(ENDPOINTS["image"]["id"], request.dict()) + if "id" in result: + recent_jobs.append({ + "endpoint": "image", + "id": result["id"], + "status": result.get("status", "SUBMITTED"), + "timestamp": datetime.now().isoformat(), + }) + return result + + +class ComfyUIRequest(BaseModel): + prompt: str + negative_prompt: Optional[str] = "" + seed: Optional[int] = None + steps: Optional[int] = 20 + cfg: Optional[float] = 7.0 + width: Optional[int] = 512 + height: Optional[int] = 512 + sampler: Optional[str] = "euler" + scheduler: Optional[str] = "normal" + workflow: Optional[Dict[str, Any]] = None # Custom workflow override + + +@app.post("/api/generate/comfyui") +async def generate_comfyui(request: ComfyUIRequest): + """Generate image using ComfyUI with workflow""" + import random + + # Use custom workflow if provided, otherwise build default txt2img + if request.workflow: + workflow = request.workflow + else: + seed = request.seed if request.seed is not None else random.randint(1, 2147483647) + workflow = build_comfyui_workflow( + prompt=request.prompt, + negative_prompt=request.negative_prompt, + seed=seed, + steps=request.steps, + cfg=request.cfg, + width=request.width, + height=request.height, + sampler=request.sampler, + scheduler=request.scheduler, + ) + + payload = {"workflow": workflow} + result = await submit_job(ENDPOINTS["comfyui"]["id"], payload) + + if "id" in result: + recent_jobs.append({ + "endpoint": "comfyui", + "id": result["id"], + "status": result.get("status", "SUBMITTED"), + "timestamp": datetime.now().isoformat(), + }) + return result + + +# ============== Text Generation Endpoints (Smart Routing) ============== + +class Priority(str, Enum): + LOW = "low" # Always use free Ollama + NORMAL = "normal" # Ollama if available, else RunPod + HIGH = "high" # RunPod for speed + + +class TextRequest(BaseModel): + prompt: str + system: Optional[str] = None + model: Optional[str] = "llama3.2" # Ollama model name + max_tokens: Optional[int] = 2048 + temperature: Optional[float] = 0.7 + priority: Optional[Priority] = Priority.NORMAL + + +class ChatRequest(BaseModel): + messages: List[Dict[str, str]] # [{"role": "user", "content": "..."}] + model: Optional[str] = "llama3.2" + max_tokens: Optional[int] = 2048 + temperature: Optional[float] = 0.7 + priority: Optional[Priority] = Priority.NORMAL + + +@app.post("/api/generate/text") +async def generate_text(request: TextRequest): + """ + Generate text with smart routing: + - LOW priority: Always Ollama (free) + - NORMAL priority: Ollama if healthy, else RunPod + - HIGH priority: RunPod vLLM (fast GPU) + """ + # Check Ollama health for routing decision + ollama_status = await ollama_health() + use_ollama = False + + if request.priority == Priority.LOW: + use_ollama = True + elif request.priority == Priority.NORMAL: + use_ollama = ollama_status.get("status") == "healthy" + # HIGH priority always uses RunPod + + if use_ollama and ollama_status.get("status") == "healthy": + # Use free local Ollama + result = await ollama_generate( + prompt=request.prompt, + model=request.model, + system=request.system, + options={ + "num_predict": request.max_tokens, + "temperature": request.temperature, + }, + ) + return { + "provider": "ollama", + "model": request.model, + "cost": 0.0, + "response": result.get("response", ""), + "tokens": result.get("eval_count", 0), + } + else: + # Use RunPod vLLM (paid) + cost_tracker["runpod_requests"] += 1 + payload = { + "prompt": request.prompt, + "max_tokens": request.max_tokens, + "temperature": request.temperature, + } + result = await submit_job(ENDPOINTS["llm"]["id"], payload) + if "id" in result: + recent_jobs.append({ + "endpoint": "llm", + "id": result["id"], + "status": result.get("status", "SUBMITTED"), + "timestamp": datetime.now().isoformat(), + }) + return { + "provider": "runpod", + "model": "vLLM", + "cost": 0.001, # Estimated per request + "job_id": result.get("id"), + "status": result.get("status"), + } + + +@app.post("/api/chat") +async def chat_completion(request: ChatRequest): + """ + Chat completion with smart routing + """ + ollama_status = await ollama_health() + use_ollama = request.priority != Priority.HIGH and ollama_status.get("status") == "healthy" + + if use_ollama: + result = await ollama_chat( + messages=request.messages, + model=request.model, + options={ + "num_predict": request.max_tokens, + "temperature": request.temperature, + }, + ) + return { + "provider": "ollama", + "model": request.model, + "cost": 0.0, + "message": result.get("message", {}), + "tokens": result.get("eval_count", 0), + } + else: + # Fallback to RunPod + cost_tracker["runpod_requests"] += 1 + # Convert chat format to prompt for vLLM + prompt = "\n".join([f"{m['role']}: {m['content']}" for m in request.messages]) + result = await submit_job(ENDPOINTS["llm"]["id"], {"prompt": prompt}) + return { + "provider": "runpod", + "job_id": result.get("id"), + "status": result.get("status"), + } + + +@app.get("/api/ollama/models") +async def list_ollama_models(): + """List available Ollama models""" + status = await ollama_health() + return { + "available": status.get("models", []), + "recommended": list(OLLAMA_MODELS.keys()), + "status": status.get("status"), + } + + +@app.post("/api/ollama/pull/{model}") +async def pull_ollama_model(model: str): + """Pull/download a model to Ollama""" + result = await ollama_pull_model(model) + return result + + +@app.get("/api/stats") +async def get_stats(): + """Get usage statistics and cost savings""" + return { + "ollama_requests": cost_tracker["ollama_requests"], + "runpod_requests": cost_tracker["runpod_requests"], + "estimated_savings_usd": round(cost_tracker["estimated_savings"], 4), + "total_requests": cost_tracker["ollama_requests"] + cost_tracker["runpod_requests"], + "ollama_percentage": round( + cost_tracker["ollama_requests"] / max(1, cost_tracker["ollama_requests"] + cost_tracker["runpod_requests"]) * 100, 1 + ), + } + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8080) diff --git a/test_api.py b/test_api.py new file mode 100644 index 0000000..634a17b --- /dev/null +++ b/test_api.py @@ -0,0 +1,26 @@ +import runpod +import os + +# Set API key from config +runpod.api_key = "rpa_YYOARL5MEBTTKKWGABRKTW2CVHQYRBTOBZNSGIL3lwwfdz" + +# Test 1: List all pods +print("=== Testing RunPod API Connection ===\n") +print("1. Listing all pods:") +pods = runpod.get_pods() +for pod in pods: + print(f" - {pod['name']} ({pod['id']}): {pod['desiredStatus']}") + +# Test 2: Check serverless endpoints +print("\n2. Checking serverless endpoints:") +try: + endpoints = runpod.get_endpoints() + if endpoints: + for endpoint in endpoints: + print(f" - {endpoint.get('name', 'Unnamed')}: {endpoint.get('id')}") + else: + print(" No serverless endpoints configured yet") +except Exception as e: + print(f" Note: {e}") + +print("\nāœ… API Connection successful!")