package handler import ( "database/sql" "embed" "html/template" "log" "net/http" "time" "golang.org/x/crypto/bcrypt" ) var passwordTmpl *template.Template func InitTemplates(webFS embed.FS) { passwordTmpl = template.Must(template.ParseFS(webFS, "web/password.html")) } func (h *Handler) AuthPage(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") if id == "" { http.NotFound(w, r) return } rec, err := h.store.Get(id) if err == sql.ErrNoRows { http.NotFound(w, r) return } if err != nil { log.Printf("db get error: %v", err) http.Error(w, "internal error", http.StatusInternalServerError) return } if rec.PasswordHash == nil { http.Redirect(w, r, "/f/"+id, http.StatusSeeOther) return } data := map[string]any{ "ID": id, "Filename": rec.Filename, "Error": r.URL.Query().Get("error"), } w.Header().Set("Content-Type", "text/html; charset=utf-8") passwordTmpl.Execute(w, data) } func (h *Handler) AuthSubmit(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") if id == "" { http.NotFound(w, r) return } if err := r.ParseForm(); err != nil { http.Error(w, "bad request", http.StatusBadRequest) return } password := r.FormValue("password") if password == "" { http.Redirect(w, r, "/f/"+id+"/auth?error=password+required", http.StatusSeeOther) return } rec, err := h.store.Get(id) if err == sql.ErrNoRows { http.NotFound(w, r) return } if err != nil { log.Printf("db get error: %v", err) http.Error(w, "internal error", http.StatusInternalServerError) return } if rec.PasswordHash == nil { http.Redirect(w, r, "/f/"+id, http.StatusSeeOther) return } if err := bcrypt.CompareHashAndPassword([]byte(*rec.PasswordHash), []byte(password)); err != nil { http.Redirect(w, r, "/f/"+id+"/auth?error=wrong+password", http.StatusSeeOther) return } // Set auth cookie (10 min lifetime) http.SetCookie(w, &http.Cookie{ Name: "auth_" + id, Value: "granted", Path: "/f/" + id, MaxAge: 600, HttpOnly: true, SameSite: http.SameSiteLaxMode, Secure: true, Expires: time.Now().Add(10 * time.Minute), }) http.Redirect(w, r, "/f/"+id, http.StatusSeeOther) }