package auth import ( "crypto/rand" "encoding/hex" "errors" "time" "github.com/jmoiron/sqlx" "github.com/rs/zerolog/log" "golang.org/x/crypto/bcrypt" "code.chrissexton.org/cws/cabinet/config" "code.chrissexton.org/cws/cabinet/db" ) type User struct { *db.Database ID int64 Name string Hash []byte AuthKey string `db:"auth_key"` Invalidate time.Time } func PrepareTable(tx *sqlx.Tx) error { q := `create table if not exists users ( id integer primary key, name text unique not null, hash text not null, auth_key text, invalidate datetime )` _, err := tx.Exec(q) return err } func makeKey() (string, error) { keySize := config.GetInt("key.size", 10) buf := make([]byte, keySize) _, err := rand.Read(buf) if err != nil { return "", err } key := hex.EncodeToString(buf) log.Debug().Msgf("Encoded secret key %s as %s", string(buf), key) return key, nil } func New(db *db.Database, name, password string) (*User, error) { q := `insert into users values (null, ?, ?, ?, ?)` key, err := makeKey() if err != nil { return nil, err } invalidate := time.Now().Add(time.Duration(config.GetInt("invalidate.hours", 7*24)) * time.Hour) res, err := db.Exec(q, name, password, key, invalidate) if err != nil { return nil, err } id, err := res.LastInsertId() if err != nil { return nil, err } u := &User{ ID: id, Name: name, AuthKey: key, Invalidate: invalidate, } u.Set(password) return u, nil } func Get(db *db.Database, name string) (*User, error) { q := `select * from users where name = ?` u := &User{} if err := db.Get(u, q, name); err != nil { return nil, err } return u, nil } func (u *User) Set(newPassword string) error { hash, err := bcrypt.GenerateFromPassword([]byte(newPassword), 0) if err != nil { return err } u.Hash = hash return nil } func (u *User) Validate(password string) bool { err := bcrypt.CompareHashAndPassword(u.Hash, []byte(password)) if err != nil { return false } return true } func GetByKey(db *db.Database, key string) (*User, error) { q := `select * from users where auth_key = ?` u := &User{} invalid := errors.New("invalid key") if err := db.Get(u, q, key); err != nil { if err.Error() == "sql: no rows in result set" { return nil, invalid } return nil, err } if u.Invalidate.Before(time.Now()) { return nil, invalid } return u, nil }