From 1b8e5b4d7068f182a5bfd558f2846888c8a0367e Mon Sep 17 00:00:00 2001 From: Chris Sexton Date: Sun, 15 Mar 2020 09:40:47 -0400 Subject: [PATCH] auth: auth flow complete on server-side --- auth/auth.go | 7 ++++++- main.go | 4 ++++ web/auth.go | 47 ++++++++++++++++++++++++++++++++++++++++++++++- web/routes.go | 28 ++++++++++++++++++---------- 4 files changed, 74 insertions(+), 12 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index a964733..ec44a2b 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -56,8 +56,12 @@ func New(db *db.Database, name, password string) (*User, error) { return nil, err } invalidate := time.Now().Add(time.Duration(config.GetInt("invalidate.hours", 7*24)) * time.Hour) + hash, err := bcrypt.GenerateFromPassword([]byte(password), config.GetInt("hash.cost", 16)) + if err != nil { + return nil, err + } - res, err := db.Exec(q, name, password, key, invalidate) + res, err := db.Exec(q, name, hash, key, invalidate) if err != nil { return nil, err } @@ -96,6 +100,7 @@ func (u *User) Set(newPassword string) error { func (u *User) Validate(password string) bool { err := bcrypt.CompareHashAndPassword(u.Hash, []byte(password)) if err != nil { + log.Debug().Err(err).Msg("incorrect credentials") return false } return true diff --git a/main.go b/main.go index c8106b3..b335b6a 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "flag" "os" + "code.chrissexton.org/cws/cabinet/auth" "code.chrissexton.org/cws/cabinet/entry" "code.chrissexton.org/cws/cabinet/db" @@ -45,6 +46,9 @@ func main() { if err := entry.PrepareTable(tx); err != nil { log.Fatal().Err(err).Msg("could not create database") } + if err = auth.PrepareTable(tx); err != nil { + log.Fatal().Err(err).Msg("could not create database") + } tx.Commit() s := web.New(*httpAddr, db, box) diff --git a/web/auth.go b/web/auth.go index 28b80d7..27ca92e 100644 --- a/web/auth.go +++ b/web/auth.go @@ -2,9 +2,12 @@ package web import ( "encoding/json" + "fmt" "net/http" "code.chrissexton.org/cws/cabinet/auth" + "code.chrissexton.org/cws/cabinet/config" + "github.com/rs/zerolog/log" ) @@ -19,7 +22,21 @@ func (web *Web) auth(w http.ResponseWriter, r *http.Request) { } if user.Validate(password) { + resp := struct { + Status bool + User auth.User + }{ + true, + *user, + } + j, err := json.Marshal(resp) + if err != nil { + w.WriteHeader(500) + log.Error().Err(err).Msg("Error encoding json response") + return + } w.WriteHeader(200) + w.Write(j) return } w.WriteHeader(401) @@ -30,7 +47,35 @@ func (web *Web) auth(w http.ResponseWriter, r *http.Request) { j, err := json.Marshal(resp) if err != nil { w.WriteHeader(500) - log.Error().Err(err).Msg("Error") + log.Error().Err(err).Msg("Error encoding json response") + return } w.Write(j) } + +func (web *Web) newUser(w http.ResponseWriter, r *http.Request) { + secret := r.Header.Get("X-secret") + if secret != config.Get("secret", "abc123") { + w.WriteHeader(401) + return + } + dec := json.NewDecoder(r.Body) + req := struct { + Username string + Password string + }{} + err := dec.Decode(&req) + if err != nil { + w.WriteHeader(500) + fmt.Fprint(w, err) + return + } + _, err = auth.New(web.db, req.Username, req.Password) + if err != nil { + w.WriteHeader(500) + fmt.Fprint(w, err) + log.Error().Err(err).Msg("Could not create user") + return + } + w.WriteHeader(200) +} diff --git a/web/routes.go b/web/routes.go index e09019d..eaa990b 100644 --- a/web/routes.go +++ b/web/routes.go @@ -10,6 +10,7 @@ import ( "github.com/gorilla/mux" "github.com/rs/zerolog/log" + "code.chrissexton.org/cws/cabinet/auth" "code.chrissexton.org/cws/cabinet/db" packr "github.com/gobuffalo/packr/v2" @@ -38,15 +39,20 @@ func New(addr string, db *db.Database, box *packr.Box) *Web { return w } -type AuthMiddleware struct{} +type AuthMiddleware struct { + db *db.Database +} func (aw *AuthMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("X-Auth-Key") == "" { + key := r.Header.Get("X-Auth-Key") + u, err := auth.GetByKey(aw.db, key) + if key == "" || err != nil { w.WriteHeader(401) fmt.Fprint(w, "invalid login") + return } - log.Debug().Msgf("This shit is authed!") + log.Debug().Msgf("This shit is authed to user %s!", u.Name) next.ServeHTTP(w, r) }) } @@ -54,7 +60,7 @@ func (aw *AuthMiddleware) Middleware(next http.Handler) http.Handler { func (web *Web) routeSetup() http.Handler { r := mux.NewRouter() api := r.PathPrefix("/v1/").Subrouter() - auth := AuthMiddleware{} + auth := AuthMiddleware{web.db} authedApi := r.PathPrefix("/v1/").Subrouter() authedApi.Use(auth.Middleware) @@ -69,16 +75,18 @@ func (web *Web) routeSetup() http.Handler { // curl 'http://127.0.0.1:8080/v1/test' -X POST -H 'Accept: application/json, text/plain, */*' --compressed -H 'Content-Type: application/json;charset=utf-8' --data '{ "test": 1 }' - api.HandleFunc("/entries", web.allEntries).Methods(http.MethodGet) - api.HandleFunc("/entries", web.newEntry).Methods(http.MethodPost) - api.HandleFunc("/entries", web.newEntry).Methods(http.MethodPost). + authedApi.HandleFunc("/entries", web.newEntry).Methods(http.MethodPost) + authedApi.HandleFunc("/entries", web.newEntry).Methods(http.MethodPost). HeadersRegexp("Content-Type", "application/(text|json).*") - api.HandleFunc("/entries", web.newMarkdownEntry).Methods(http.MethodPost). + authedApi.HandleFunc("/entries", web.newMarkdownEntry).Methods(http.MethodPost). HeadersRegexp("Content-Type", "application/markdown.*") - api.HandleFunc("/entries/{slug}", web.removeEntry).Methods(http.MethodDelete) - api.HandleFunc("/entries/{slug}", web.editEntry).Methods(http.MethodPut) + authedApi.HandleFunc("/entries/{slug}", web.removeEntry).Methods(http.MethodDelete) + authedApi.HandleFunc("/entries/{slug}", web.editEntry).Methods(http.MethodPut) api.HandleFunc("/entries/{slug}", web.getEntry).Methods(http.MethodGet) + api.HandleFunc("/entries", web.allEntries).Methods(http.MethodGet) + + api.HandleFunc("/auth/new", web.newUser).Methods(http.MethodPost) api.HandleFunc("/auth", web.auth).Methods(http.MethodPost) r.PathPrefix("/").HandlerFunc(web.indexHandler("/index.html")) loggedRouter := handlers.LoggingHandler(os.Stdout, r)