311 lines
7.9 KiB
Go
311 lines
7.9 KiB
Go
package handler
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"mime"
|
|
"net/http"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
gonanoid "github.com/matoous/go-nanoid/v2"
|
|
"golang.org/x/crypto/bcrypt"
|
|
|
|
"github.com/jeffemmett/upload-service/internal/config"
|
|
"github.com/jeffemmett/upload-service/internal/r2"
|
|
"github.com/jeffemmett/upload-service/internal/store"
|
|
)
|
|
|
|
var slugRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]{0,63}$`)
|
|
|
|
var reservedSlugs = map[string]bool{
|
|
"f": true, "b": true, "health": true, "cli": true,
|
|
"static": true, "upload": true, "favicon.ico": true,
|
|
}
|
|
|
|
func isReservedSlug(s string) bool {
|
|
return reservedSlugs[strings.ToLower(s)]
|
|
}
|
|
|
|
type Handler struct {
|
|
store *store.Store
|
|
r2 *r2.Client
|
|
config *config.Config
|
|
}
|
|
|
|
func New(s *store.Store, r *r2.Client, c *config.Config) *Handler {
|
|
return &Handler{store: s, r2: r, config: c}
|
|
}
|
|
|
|
func (h *Handler) Upload(w http.ResponseWriter, r *http.Request) {
|
|
reader, err := r.MultipartReader()
|
|
if err != nil {
|
|
http.Error(w, "expected multipart/form-data", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
var (
|
|
filename string
|
|
contentType string
|
|
expiresIn string
|
|
password string
|
|
customSlug string
|
|
batchID string
|
|
fileSize int64
|
|
fileUploaded bool
|
|
fileID string
|
|
r2Key string
|
|
)
|
|
|
|
deleteToken := make([]byte, 32)
|
|
if _, err := rand.Read(deleteToken); err != nil {
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
deleteTokenHex := hex.EncodeToString(deleteToken)
|
|
|
|
for {
|
|
part, err := reader.NextPart()
|
|
if err != nil {
|
|
break
|
|
}
|
|
|
|
switch part.FormName() {
|
|
case "file":
|
|
if fileUploaded {
|
|
part.Close()
|
|
continue
|
|
}
|
|
filename = part.FileName()
|
|
if filename == "" {
|
|
http.Error(w, "no filename", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Resolve file ID: use custom slug or generate nanoid
|
|
if customSlug != "" {
|
|
fileID = customSlug
|
|
} else {
|
|
fileID, err = gonanoid.New(8)
|
|
if err != nil {
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Detect content type from extension, fall back to part header
|
|
ct := mime.TypeByExtension("." + fileExtension(filename))
|
|
if ct == "" {
|
|
ct = part.Header.Get("Content-Type")
|
|
}
|
|
if ct == "" {
|
|
ct = "application/octet-stream"
|
|
}
|
|
contentType = ct
|
|
|
|
r2Key = fmt.Sprintf("uploads/%s/%s", fileID, filename)
|
|
|
|
// Check Content-Length hint if available
|
|
if cl := r.Header.Get("Content-Length"); cl != "" {
|
|
if size, err := strconv.ParseInt(cl, 10, 64); err == nil && size > h.config.MaxUploadSize {
|
|
http.Error(w, fmt.Sprintf("file too large (max %d bytes)", h.config.MaxUploadSize), http.StatusRequestEntityTooLarge)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Stream directly to R2 — the part reader is the body pipe
|
|
// We use a counting reader to track size
|
|
cr := &countingReader{r: part}
|
|
if err := h.r2.Upload(r.Context(), r2Key, contentType, -1, cr); err != nil {
|
|
log.Printf("r2 upload error: %v", err)
|
|
http.Error(w, "upload failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
fileSize = cr.n
|
|
fileUploaded = true
|
|
|
|
if fileSize > h.config.MaxUploadSize {
|
|
// File exceeded limit after streaming — clean up
|
|
h.r2.Delete(r.Context(), r2Key)
|
|
http.Error(w, fmt.Sprintf("file too large (max %d bytes)", h.config.MaxUploadSize), http.StatusRequestEntityTooLarge)
|
|
return
|
|
}
|
|
|
|
case "expires_in":
|
|
buf := make([]byte, 64)
|
|
n, _ := part.Read(buf)
|
|
expiresIn = strings.TrimSpace(string(buf[:n]))
|
|
|
|
case "password":
|
|
buf := make([]byte, 256)
|
|
n, _ := part.Read(buf)
|
|
password = strings.TrimSpace(string(buf[:n]))
|
|
|
|
case "slug":
|
|
buf := make([]byte, 128)
|
|
n, _ := part.Read(buf)
|
|
customSlug = strings.TrimSpace(string(buf[:n]))
|
|
|
|
case "batch_id":
|
|
buf := make([]byte, 64)
|
|
n, _ := part.Read(buf)
|
|
batchID = strings.TrimSpace(string(buf[:n]))
|
|
}
|
|
part.Close()
|
|
}
|
|
|
|
if !fileUploaded {
|
|
http.Error(w, "no file provided", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// If slug was provided after file part (handle late slug)
|
|
if customSlug != "" && fileID != customSlug {
|
|
// Slug came after file — need to validate and re-key
|
|
// This shouldn't happen with well-ordered form data, but handle gracefully
|
|
// by ignoring the late slug since file is already uploaded with nanoid
|
|
}
|
|
|
|
// Validate custom slug
|
|
if customSlug != "" {
|
|
if !slugRe.MatchString(customSlug) {
|
|
h.r2.Delete(r.Context(), r2Key)
|
|
http.Error(w, "invalid slug: use letters, numbers, hyphens, dots, underscores (1-64 chars)", http.StatusBadRequest)
|
|
return
|
|
}
|
|
// Check slug isn't already taken
|
|
if _, err := h.store.Get(customSlug); err == nil {
|
|
h.r2.Delete(r.Context(), r2Key)
|
|
http.Error(w, "slug already taken", http.StatusConflict)
|
|
return
|
|
} else if err != sql.ErrNoRows {
|
|
h.r2.Delete(r.Context(), r2Key)
|
|
log.Printf("db check slug error: %v", err)
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Validate batch_id format and reserved slugs
|
|
if batchID != "" {
|
|
if !slugRe.MatchString(batchID) {
|
|
h.r2.Delete(r.Context(), r2Key)
|
|
http.Error(w, "invalid batch slug: use letters, numbers, hyphens, dots, underscores (1-64 chars)", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if isReservedSlug(batchID) {
|
|
h.r2.Delete(r.Context(), r2Key)
|
|
http.Error(w, "that batch slug is reserved", http.StatusConflict)
|
|
return
|
|
}
|
|
}
|
|
|
|
rec := &store.FileRecord{
|
|
ID: fileID,
|
|
Filename: filename,
|
|
R2Key: r2Key,
|
|
SizeBytes: fileSize,
|
|
ContentType: contentType,
|
|
DeleteToken: deleteTokenHex,
|
|
}
|
|
if batchID != "" {
|
|
rec.BatchID = &batchID
|
|
}
|
|
|
|
// Handle expiry
|
|
if expiresIn != "" {
|
|
dur, err := parseDuration(expiresIn)
|
|
if err != nil {
|
|
h.r2.Delete(r.Context(), r2Key)
|
|
http.Error(w, "invalid expires_in value (use: 1h, 1d, 7d, 30d)", http.StatusBadRequest)
|
|
return
|
|
}
|
|
t := time.Now().UTC().Add(dur)
|
|
rec.ExpiresAt = &t
|
|
}
|
|
|
|
// Handle password
|
|
if password != "" {
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
h.r2.Delete(r.Context(), r2Key)
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
hashStr := string(hash)
|
|
rec.PasswordHash = &hashStr
|
|
}
|
|
|
|
if err := h.store.Create(rec); err != nil {
|
|
h.r2.Delete(r.Context(), r2Key)
|
|
log.Printf("db create error: %v", err)
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
resp := map[string]any{
|
|
"id": fileID,
|
|
"filename": filename,
|
|
"size": fileSize,
|
|
"url": fmt.Sprintf("%s/f/%s", h.config.BaseURL, fileID),
|
|
"delete_url": fmt.Sprintf("%s/f/%s", h.config.BaseURL, fileID),
|
|
"delete_token": deleteTokenHex,
|
|
}
|
|
if rec.ExpiresAt != nil {
|
|
resp["expires_at"] = rec.ExpiresAt.Format(time.RFC3339)
|
|
}
|
|
if batchID != "" {
|
|
resp["batch_id"] = batchID
|
|
resp["batch_url"] = fmt.Sprintf("%s/%s", h.config.BaseURL, batchID)
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusCreated)
|
|
json.NewEncoder(w).Encode(resp)
|
|
}
|
|
|
|
func parseDuration(s string) (time.Duration, error) {
|
|
s = strings.TrimSpace(strings.ToLower(s))
|
|
if strings.HasSuffix(s, "d") {
|
|
days, err := strconv.Atoi(strings.TrimSuffix(s, "d"))
|
|
if err != nil || days < 1 || days > 365 {
|
|
return 0, fmt.Errorf("invalid days")
|
|
}
|
|
return time.Duration(days) * 24 * time.Hour, nil
|
|
}
|
|
if strings.HasSuffix(s, "h") {
|
|
hours, err := strconv.Atoi(strings.TrimSuffix(s, "h"))
|
|
if err != nil || hours < 1 || hours > 8760 {
|
|
return 0, fmt.Errorf("invalid hours")
|
|
}
|
|
return time.Duration(hours) * time.Hour, nil
|
|
}
|
|
return 0, fmt.Errorf("unsupported format")
|
|
}
|
|
|
|
func fileExtension(name string) string {
|
|
for i := len(name) - 1; i >= 0; i-- {
|
|
if name[i] == '.' {
|
|
return name[i+1:]
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
type countingReader struct {
|
|
r interface{ Read([]byte) (int, error) }
|
|
n int64
|
|
}
|
|
|
|
func (cr *countingReader) Read(p []byte) (int, error) {
|
|
n, err := cr.r.Read(p)
|
|
cr.n += int64(n)
|
|
return n, err
|
|
}
|