upload-service/main.go

175 lines
4.7 KiB
Go

package main
import (
"context"
"embed"
"io/fs"
"log"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/jeffemmett/upload-service/internal/cleanup"
"github.com/jeffemmett/upload-service/internal/config"
"github.com/jeffemmett/upload-service/internal/db"
"github.com/jeffemmett/upload-service/internal/handler"
"github.com/jeffemmett/upload-service/internal/middleware"
"github.com/jeffemmett/upload-service/internal/r2"
"github.com/jeffemmett/upload-service/internal/store"
)
//go:embed web
var webFS embed.FS
func main() {
cfg, err := config.Load()
if err != nil {
log.Fatalf("config: %v", err)
}
database, err := db.Open(cfg.DBPath)
if err != nil {
log.Fatalf("db: %v", err)
}
defer database.Close()
s := store.New(database)
r2Client := r2.NewClient(cfg)
h := handler.New(s, r2Client, cfg)
handler.InitTemplates(webFS)
// CLI script
cliScript, _ := fs.ReadFile(webFS, "web/static/upload.sh")
mux := http.NewServeMux()
// Web UI
mux.HandleFunc("GET /{$}", func(w http.ResponseWriter, r *http.Request) {
data, _ := fs.ReadFile(webFS, "web/index.html")
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Header().Set("Cache-Control", "no-cache, must-revalidate")
w.Write(data)
})
// Static assets
staticFS, _ := fs.Sub(webFS, "web/static")
mux.Handle("GET /static/", http.StripPrefix("/static/", http.FileServer(http.FS(staticFS))))
// API
mux.HandleFunc("POST /upload", h.Upload)
mux.HandleFunc("GET /f/{id}", h.Download)
mux.HandleFunc("GET /f/{id}/dl", h.DirectDownload)
mux.HandleFunc("GET /f/{id}/view", h.ViewFile)
mux.HandleFunc("GET /f/{id}/info", h.Info)
mux.HandleFunc("GET /f/{id}/auth", h.AuthPage)
mux.HandleFunc("POST /f/{id}/auth", h.AuthSubmit)
mux.HandleFunc("DELETE /f/{id}", h.Delete)
// Catch-all for batch routes: /{slug}, /{slug}/dl, /{slug}/auth
// Uses {path...} to avoid conflicts with /static/ subtree pattern
mux.HandleFunc("GET /{path...}", func(w http.ResponseWriter, r *http.Request) {
p := r.PathValue("path")
parts := strings.SplitN(p, "/", 3)
if len(parts) == 0 || parts[0] == "" {
http.NotFound(w, r)
return
}
slug := parts[0]
// Backward compat: /b/{id} -> /{id}
if slug == "b" && len(parts) >= 2 {
target := "/" + strings.Join(parts[1:], "/")
http.Redirect(w, r, target, http.StatusMovedPermanently)
return
}
r.SetPathValue("slug", slug)
if len(parts) == 1 {
h.Batch(w, r)
} else if len(parts) == 2 {
switch parts[1] {
case "dl":
h.BatchDownload(w, r)
case "auth":
h.BatchAuthPage(w, r)
default:
http.NotFound(w, r)
}
} else {
http.NotFound(w, r)
}
})
mux.HandleFunc("POST /{path...}", func(w http.ResponseWriter, r *http.Request) {
p := r.PathValue("path")
parts := strings.SplitN(p, "/", 3)
// Backward compat: /b/{id}/auth -> /{id}/auth
if len(parts) >= 2 && parts[0] == "b" {
target := "/" + strings.Join(parts[1:], "/")
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
return
}
if len(parts) == 2 && parts[1] == "auth" {
r.SetPathValue("slug", parts[0])
h.BatchAuthSubmit(w, r)
return
}
http.NotFound(w, r)
})
// Favicon (prevent 404)
mux.HandleFunc("GET /favicon.ico", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
// Health
mux.HandleFunc("GET /health", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"status":"ok"}`))
})
// CLI download
mux.HandleFunc("GET /cli", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("Content-Disposition", `attachment; filename="upload.sh"`)
w.Write(cliScript)
})
// Middleware chain
rl := middleware.NewRateLimiter(cfg.RateLimit, cfg.RateBurst)
var chain http.Handler = mux
chain = rl.Middleware(chain)
chain = middleware.Security(chain)
srv := &http.Server{
Addr: ":" + cfg.Port,
Handler: chain,
ReadHeaderTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
// No read/write timeout — large uploads can take a long time
}
// Cleanup goroutine
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go cleanup.Start(ctx, s, r2Client)
// Graceful shutdown
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
log.Println("shutting down...")
cancel()
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()
srv.Shutdown(shutdownCtx)
}()
log.Printf("listening on :%s", cfg.Port)
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
log.Fatalf("server: %v", err)
}
}