package handler import ( "crypto/rand" "database/sql" "encoding/hex" "encoding/json" "fmt" "log" "mime" "net/http" "regexp" "strconv" "strings" "time" gonanoid "github.com/matoous/go-nanoid/v2" "golang.org/x/crypto/bcrypt" "github.com/jeffemmett/upload-service/internal/config" "github.com/jeffemmett/upload-service/internal/r2" "github.com/jeffemmett/upload-service/internal/store" ) var slugRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]{0,63}$`) var reservedSlugs = map[string]bool{ "f": true, "b": true, "health": true, "cli": true, "static": true, "upload": true, "favicon.ico": true, } func isReservedSlug(s string) bool { return reservedSlugs[strings.ToLower(s)] } type Handler struct { store *store.Store r2 *r2.Client config *config.Config } func New(s *store.Store, r *r2.Client, c *config.Config) *Handler { return &Handler{store: s, r2: r, config: c} } func (h *Handler) Upload(w http.ResponseWriter, r *http.Request) { reader, err := r.MultipartReader() if err != nil { http.Error(w, "expected multipart/form-data", http.StatusBadRequest) return } var ( filename string contentType string expiresIn string password string customSlug string batchID string fileSize int64 fileUploaded bool fileID string r2Key string ) deleteToken := make([]byte, 32) if _, err := rand.Read(deleteToken); err != nil { http.Error(w, "internal error", http.StatusInternalServerError) return } deleteTokenHex := hex.EncodeToString(deleteToken) for { part, err := reader.NextPart() if err != nil { break } switch part.FormName() { case "file": if fileUploaded { part.Close() continue } filename = part.FileName() if filename == "" { http.Error(w, "no filename", http.StatusBadRequest) return } // Resolve file ID: use custom slug or generate nanoid if customSlug != "" { fileID = customSlug } else { fileID, err = gonanoid.New(8) if err != nil { http.Error(w, "internal error", http.StatusInternalServerError) return } } // Detect content type from extension, fall back to part header ct := mime.TypeByExtension("." + fileExtension(filename)) if ct == "" { ct = part.Header.Get("Content-Type") } if ct == "" { ct = "application/octet-stream" } contentType = ct r2Key = fmt.Sprintf("uploads/%s/%s", fileID, filename) // Check Content-Length hint if available if cl := r.Header.Get("Content-Length"); cl != "" { if size, err := strconv.ParseInt(cl, 10, 64); err == nil && size > h.config.MaxUploadSize { http.Error(w, fmt.Sprintf("file too large (max %d bytes)", h.config.MaxUploadSize), http.StatusRequestEntityTooLarge) return } } // Stream directly to R2 — the part reader is the body pipe // We use a counting reader to track size cr := &countingReader{r: part} if err := h.r2.Upload(r.Context(), r2Key, contentType, -1, cr); err != nil { log.Printf("r2 upload error: %v", err) http.Error(w, "upload failed", http.StatusInternalServerError) return } fileSize = cr.n fileUploaded = true if fileSize > h.config.MaxUploadSize { // File exceeded limit after streaming — clean up h.r2.Delete(r.Context(), r2Key) http.Error(w, fmt.Sprintf("file too large (max %d bytes)", h.config.MaxUploadSize), http.StatusRequestEntityTooLarge) return } case "expires_in": buf := make([]byte, 64) n, _ := part.Read(buf) expiresIn = strings.TrimSpace(string(buf[:n])) case "password": buf := make([]byte, 256) n, _ := part.Read(buf) password = strings.TrimSpace(string(buf[:n])) case "slug": buf := make([]byte, 128) n, _ := part.Read(buf) customSlug = strings.TrimSpace(string(buf[:n])) case "batch_id": buf := make([]byte, 64) n, _ := part.Read(buf) batchID = strings.TrimSpace(string(buf[:n])) } part.Close() } if !fileUploaded { http.Error(w, "no file provided", http.StatusBadRequest) return } // If slug was provided after file part (handle late slug) if customSlug != "" && fileID != customSlug { // Slug came after file — need to validate and re-key // This shouldn't happen with well-ordered form data, but handle gracefully // by ignoring the late slug since file is already uploaded with nanoid } // Validate custom slug if customSlug != "" { if !slugRe.MatchString(customSlug) { h.r2.Delete(r.Context(), r2Key) http.Error(w, "invalid slug: use letters, numbers, hyphens, dots, underscores (1-64 chars)", http.StatusBadRequest) return } // Check slug isn't already taken if _, err := h.store.Get(customSlug); err == nil { h.r2.Delete(r.Context(), r2Key) http.Error(w, "slug already taken", http.StatusConflict) return } else if err != sql.ErrNoRows { h.r2.Delete(r.Context(), r2Key) log.Printf("db check slug error: %v", err) http.Error(w, "internal error", http.StatusInternalServerError) return } } // Validate batch_id format and reserved slugs if batchID != "" { if !slugRe.MatchString(batchID) { h.r2.Delete(r.Context(), r2Key) http.Error(w, "invalid batch slug: use letters, numbers, hyphens, dots, underscores (1-64 chars)", http.StatusBadRequest) return } if isReservedSlug(batchID) { h.r2.Delete(r.Context(), r2Key) http.Error(w, "that batch slug is reserved", http.StatusConflict) return } } rec := &store.FileRecord{ ID: fileID, Filename: filename, R2Key: r2Key, SizeBytes: fileSize, ContentType: contentType, DeleteToken: deleteTokenHex, } if batchID != "" { rec.BatchID = &batchID } // Handle expiry if expiresIn != "" { dur, err := parseDuration(expiresIn) if err != nil { h.r2.Delete(r.Context(), r2Key) http.Error(w, "invalid expires_in value (use: 1h, 1d, 7d, 30d)", http.StatusBadRequest) return } t := time.Now().UTC().Add(dur) rec.ExpiresAt = &t } // Handle password if password != "" { hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { h.r2.Delete(r.Context(), r2Key) http.Error(w, "internal error", http.StatusInternalServerError) return } hashStr := string(hash) rec.PasswordHash = &hashStr } if err := h.store.Create(rec); err != nil { h.r2.Delete(r.Context(), r2Key) log.Printf("db create error: %v", err) http.Error(w, "internal error", http.StatusInternalServerError) return } resp := map[string]any{ "id": fileID, "filename": filename, "size": fileSize, "url": fmt.Sprintf("%s/f/%s", h.config.BaseURL, fileID), "delete_url": fmt.Sprintf("%s/f/%s", h.config.BaseURL, fileID), "delete_token": deleteTokenHex, } if rec.ExpiresAt != nil { resp["expires_at"] = rec.ExpiresAt.Format(time.RFC3339) } if batchID != "" { resp["batch_id"] = batchID resp["batch_url"] = fmt.Sprintf("%s/%s", h.config.BaseURL, batchID) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) json.NewEncoder(w).Encode(resp) } func parseDuration(s string) (time.Duration, error) { s = strings.TrimSpace(strings.ToLower(s)) if strings.HasSuffix(s, "d") { days, err := strconv.Atoi(strings.TrimSuffix(s, "d")) if err != nil || days < 1 || days > 365 { return 0, fmt.Errorf("invalid days") } return time.Duration(days) * 24 * time.Hour, nil } if strings.HasSuffix(s, "h") { hours, err := strconv.Atoi(strings.TrimSuffix(s, "h")) if err != nil || hours < 1 || hours > 8760 { return 0, fmt.Errorf("invalid hours") } return time.Duration(hours) * time.Hour, nil } return 0, fmt.Errorf("unsupported format") } func fileExtension(name string) string { for i := len(name) - 1; i >= 0; i-- { if name[i] == '.' { return name[i+1:] } } return "" } type countingReader struct { r interface{ Read([]byte) (int, error) } n int64 } func (cr *countingReader) Read(p []byte) (int, error) { n, err := cr.r.Read(p) cr.n += int64(n) return n, err }