upload-service/internal/handler/upload.go

275 lines
6.9 KiB
Go

package handler
import (
"crypto/rand"
"database/sql"
"encoding/hex"
"encoding/json"
"fmt"
"log"
"mime"
"net/http"
"regexp"
"strconv"
"strings"
"time"
gonanoid "github.com/matoous/go-nanoid/v2"
"golang.org/x/crypto/bcrypt"
"github.com/jeffemmett/upload-service/internal/config"
"github.com/jeffemmett/upload-service/internal/r2"
"github.com/jeffemmett/upload-service/internal/store"
)
var slugRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]{0,63}$`)
type Handler struct {
store *store.Store
r2 *r2.Client
config *config.Config
}
func New(s *store.Store, r *r2.Client, c *config.Config) *Handler {
return &Handler{store: s, r2: r, config: c}
}
func (h *Handler) Upload(w http.ResponseWriter, r *http.Request) {
reader, err := r.MultipartReader()
if err != nil {
http.Error(w, "expected multipart/form-data", http.StatusBadRequest)
return
}
var (
filename string
contentType string
expiresIn string
password string
customSlug string
fileSize int64
fileUploaded bool
fileID string
r2Key string
)
deleteToken := make([]byte, 32)
if _, err := rand.Read(deleteToken); err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
deleteTokenHex := hex.EncodeToString(deleteToken)
for {
part, err := reader.NextPart()
if err != nil {
break
}
switch part.FormName() {
case "file":
if fileUploaded {
part.Close()
continue
}
filename = part.FileName()
if filename == "" {
http.Error(w, "no filename", http.StatusBadRequest)
return
}
// Resolve file ID: use custom slug or generate nanoid
if customSlug != "" {
fileID = customSlug
} else {
fileID, err = gonanoid.New(8)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
}
// Detect content type from extension, fall back to part header
ct := mime.TypeByExtension("." + fileExtension(filename))
if ct == "" {
ct = part.Header.Get("Content-Type")
}
if ct == "" {
ct = "application/octet-stream"
}
contentType = ct
r2Key = fmt.Sprintf("uploads/%s/%s", fileID, filename)
// Check Content-Length hint if available
if cl := r.Header.Get("Content-Length"); cl != "" {
if size, err := strconv.ParseInt(cl, 10, 64); err == nil && size > h.config.MaxUploadSize {
http.Error(w, fmt.Sprintf("file too large (max %d bytes)", h.config.MaxUploadSize), http.StatusRequestEntityTooLarge)
return
}
}
// Stream directly to R2 — the part reader is the body pipe
// We use a counting reader to track size
cr := &countingReader{r: part}
if err := h.r2.Upload(r.Context(), r2Key, contentType, -1, cr); err != nil {
log.Printf("r2 upload error: %v", err)
http.Error(w, "upload failed", http.StatusInternalServerError)
return
}
fileSize = cr.n
fileUploaded = true
if fileSize > h.config.MaxUploadSize {
// File exceeded limit after streaming — clean up
h.r2.Delete(r.Context(), r2Key)
http.Error(w, fmt.Sprintf("file too large (max %d bytes)", h.config.MaxUploadSize), http.StatusRequestEntityTooLarge)
return
}
case "expires_in":
buf := make([]byte, 64)
n, _ := part.Read(buf)
expiresIn = strings.TrimSpace(string(buf[:n]))
case "password":
buf := make([]byte, 256)
n, _ := part.Read(buf)
password = strings.TrimSpace(string(buf[:n]))
case "slug":
buf := make([]byte, 128)
n, _ := part.Read(buf)
customSlug = strings.TrimSpace(string(buf[:n]))
}
part.Close()
}
if !fileUploaded {
http.Error(w, "no file provided", http.StatusBadRequest)
return
}
// If slug was provided after file part (handle late slug)
if customSlug != "" && fileID != customSlug {
// Slug came after file — need to validate and re-key
// This shouldn't happen with well-ordered form data, but handle gracefully
// by ignoring the late slug since file is already uploaded with nanoid
}
// Validate custom slug
if customSlug != "" {
if !slugRe.MatchString(customSlug) {
h.r2.Delete(r.Context(), r2Key)
http.Error(w, "invalid slug: use letters, numbers, hyphens, dots, underscores (1-64 chars)", http.StatusBadRequest)
return
}
// Check slug isn't already taken
if _, err := h.store.Get(customSlug); err == nil {
h.r2.Delete(r.Context(), r2Key)
http.Error(w, "slug already taken", http.StatusConflict)
return
} else if err != sql.ErrNoRows {
h.r2.Delete(r.Context(), r2Key)
log.Printf("db check slug error: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
}
rec := &store.FileRecord{
ID: fileID,
Filename: filename,
R2Key: r2Key,
SizeBytes: fileSize,
ContentType: contentType,
DeleteToken: deleteTokenHex,
}
// Handle expiry
if expiresIn != "" {
dur, err := parseDuration(expiresIn)
if err != nil {
h.r2.Delete(r.Context(), r2Key)
http.Error(w, "invalid expires_in value (use: 1h, 1d, 7d, 30d)", http.StatusBadRequest)
return
}
t := time.Now().UTC().Add(dur)
rec.ExpiresAt = &t
}
// Handle password
if password != "" {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
h.r2.Delete(r.Context(), r2Key)
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
hashStr := string(hash)
rec.PasswordHash = &hashStr
}
if err := h.store.Create(rec); err != nil {
h.r2.Delete(r.Context(), r2Key)
log.Printf("db create error: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
resp := map[string]any{
"id": fileID,
"filename": filename,
"size": fileSize,
"url": fmt.Sprintf("%s/f/%s", h.config.BaseURL, fileID),
"delete_url": fmt.Sprintf("%s/f/%s", h.config.BaseURL, fileID),
"delete_token": deleteTokenHex,
}
if rec.ExpiresAt != nil {
resp["expires_at"] = rec.ExpiresAt.Format(time.RFC3339)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(resp)
}
func parseDuration(s string) (time.Duration, error) {
s = strings.TrimSpace(strings.ToLower(s))
if strings.HasSuffix(s, "d") {
days, err := strconv.Atoi(strings.TrimSuffix(s, "d"))
if err != nil || days < 1 || days > 365 {
return 0, fmt.Errorf("invalid days")
}
return time.Duration(days) * 24 * time.Hour, nil
}
if strings.HasSuffix(s, "h") {
hours, err := strconv.Atoi(strings.TrimSuffix(s, "h"))
if err != nil || hours < 1 || hours > 8760 {
return 0, fmt.Errorf("invalid hours")
}
return time.Duration(hours) * time.Hour, nil
}
return 0, fmt.Errorf("unsupported format")
}
func fileExtension(name string) string {
for i := len(name) - 1; i >= 0; i-- {
if name[i] == '.' {
return name[i+1:]
}
}
return ""
}
type countingReader struct {
r interface{ Read([]byte) (int, error) }
n int64
}
func (cr *countingReader) Read(p []byte) (int, error) {
n, err := cr.r.Read(p)
cr.n += int64(n)
return n, err
}