cabinet/auth/auth.go

119 lines
2.4 KiB
Go

package auth
import (
"crypto/rand"
"encoding/hex"
"errors"
"time"
"github.com/jmoiron/sqlx"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
"code.chrissexton.org/cws/cabinet/config"
"code.chrissexton.org/cws/cabinet/db"
)
type User struct {
*db.Database
ID int64
Name string
Hash []byte
AuthKey string `db:"auth_key"`
Invalidate time.Time
}
func PrepareTable(tx *sqlx.Tx) error {
q := `create table if not exists users (
id integer primary key,
name text unique not null,
hash text not null,
auth_key text,
invalidate datetime
)`
_, err := tx.Exec(q)
return err
}
func makeKey() (string, error) {
keySize := config.GetInt("key.size", 10)
buf := make([]byte, keySize)
_, err := rand.Read(buf)
if err != nil {
return "", err
}
key := hex.EncodeToString(buf)
log.Debug().Msgf("Encoded secret key %s as %s", string(buf), key)
return key, nil
}
func New(db *db.Database, name, password string) (*User, error) {
q := `insert into users values (null, ?, ?, ?, ?)`
key, err := makeKey()
if err != nil {
return nil, err
}
invalidate := time.Now().Add(time.Duration(config.GetInt("invalidate.hours", 7*24)) * time.Hour)
res, err := db.Exec(q, name, password, key, invalidate)
if err != nil {
return nil, err
}
id, err := res.LastInsertId()
if err != nil {
return nil, err
}
u := &User{
ID: id,
Name: name,
AuthKey: key,
Invalidate: invalidate,
}
u.Set(password)
return u, nil
}
func Get(db *db.Database, name string) (*User, error) {
q := `select * from users where name = ?`
u := &User{}
if err := db.Get(u, q, name); err != nil {
return nil, err
}
return u, nil
}
func (u *User) Set(newPassword string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(newPassword), 0)
if err != nil {
return err
}
u.Hash = hash
return nil
}
func (u *User) Validate(password string) bool {
err := bcrypt.CompareHashAndPassword(u.Hash, []byte(password))
if err != nil {
return false
}
return true
}
func GetByKey(db *db.Database, key string) (*User, error) {
q := `select * from users where auth_key = ?`
u := &User{}
invalid := errors.New("invalid key")
if err := db.Get(u, q, key); err != nil {
if err.Error() == "sql: no rows in result set" {
return nil, invalid
}
return nil, err
}
if u.Invalidate.Before(time.Now()) {
return nil, invalid
}
return u, nil
}