diff --git a/auth/auth.go b/auth/auth.go index 41ce89e..a964733 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -1,32 +1,63 @@ package auth import ( - "code.chrissexton.org/cws/cabinet/db" + "crypto/rand" + "encoding/hex" + "errors" + "time" + "github.com/jmoiron/sqlx" + "github.com/rs/zerolog/log" "golang.org/x/crypto/bcrypt" + + "code.chrissexton.org/cws/cabinet/config" + "code.chrissexton.org/cws/cabinet/db" ) type User struct { *db.Database - ID int64 - Name string - Hash []byte + ID int64 + Name string + Hash []byte + AuthKey string `db:"auth_key"` + Invalidate time.Time } func PrepareTable(tx *sqlx.Tx) error { q := `create table if not exists users ( id integer primary key, - name text unique, - hash text + name text unique not null, + hash text not null, + auth_key text, + invalidate datetime )` _, err := tx.Exec(q) return err } +func makeKey() (string, error) { + keySize := config.GetInt("key.size", 10) + buf := make([]byte, keySize) + _, err := rand.Read(buf) + if err != nil { + return "", err + } + key := hex.EncodeToString(buf) + log.Debug().Msgf("Encoded secret key %s as %s", string(buf), key) + return key, nil +} + func New(db *db.Database, name, password string) (*User, error) { - q := `insert into users (null, ?, ?)` - res, err := db.Exec(q, name, password) + q := `insert into users values (null, ?, ?, ?, ?)` + + key, err := makeKey() + if err != nil { + return nil, err + } + invalidate := time.Now().Add(time.Duration(config.GetInt("invalidate.hours", 7*24)) * time.Hour) + + res, err := db.Exec(q, name, password, key, invalidate) if err != nil { return nil, err } @@ -35,8 +66,10 @@ func New(db *db.Database, name, password string) (*User, error) { return nil, err } u := &User{ - ID: id, - Name: name, + ID: id, + Name: name, + AuthKey: key, + Invalidate: invalidate, } u.Set(password) return u, nil @@ -67,3 +100,19 @@ func (u *User) Validate(password string) bool { } return true } + +func GetByKey(db *db.Database, key string) (*User, error) { + q := `select * from users where auth_key = ?` + u := &User{} + invalid := errors.New("invalid key") + if err := db.Get(u, q, key); err != nil { + if err.Error() == "sql: no rows in result set" { + return nil, invalid + } + return nil, err + } + if u.Invalidate.Before(time.Now()) { + return nil, invalid + } + return u, nil +} diff --git a/auth/auth_test.go b/auth/auth_test.go new file mode 100644 index 0000000..9f881b8 --- /dev/null +++ b/auth/auth_test.go @@ -0,0 +1,78 @@ +package auth + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "code.chrissexton.org/cws/cabinet/db" +) + +func TestMakeKey(t *testing.T) { + k, err := makeKey() + assert.Nil(t, err) + assert.NotEmpty(t, k) +} + +func TestGetByKey(t *testing.T) { + d, err := db.New(":memory:") + assert.Nil(t, err) + tx, err := d.Beginx() + assert.Nil(t, err) + err = PrepareTable(tx) + assert.Nil(t, err) + err = tx.Commit() + assert.Nil(t, err) + u, err := New(d, "test", "abc") + assert.Nil(t, err) + u2, err := GetByKey(d, u.AuthKey) + assert.Nil(t, err) + assert.Equal(t, u.ID, u2.ID) +} + +func TestGetByKeyFailure(t *testing.T) { + d, err := db.New(":memory:") + assert.Nil(t, err) + tx, err := d.Beginx() + assert.Nil(t, err) + err = PrepareTable(tx) + assert.Nil(t, err) + err = tx.Commit() + assert.Nil(t, err) + _, err = New(d, "test", "abc") + assert.Nil(t, err) + u2, err := GetByKey(d, "foobar") + assert.Error(t, err) + assert.Nil(t, u2) +} + +func TestGet(t *testing.T) { + d, err := db.New(":memory:") + assert.Nil(t, err) + tx, err := d.Beginx() + assert.Nil(t, err) + err = PrepareTable(tx) + assert.Nil(t, err) + err = tx.Commit() + assert.Nil(t, err) + u, err := New(d, "test", "abc") + assert.Nil(t, err) + u2, err := Get(d, "test") + assert.Nil(t, err) + assert.Equal(t, u.ID, u2.ID) +} + +func TestUser_Validate(t *testing.T) { + d, err := db.New(":memory:") + assert.Nil(t, err) + tx, err := d.Beginx() + assert.Nil(t, err) + err = PrepareTable(tx) + assert.Nil(t, err) + err = tx.Commit() + assert.Nil(t, err) + u, err := New(d, "test", "abc") + assert.Nil(t, err) + actual := u.Validate("abc") + assert.True(t, actual) +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..ad898aa --- /dev/null +++ b/config/config.go @@ -0,0 +1,24 @@ +package config + +import ( + "os" + "strconv" + "strings" +) + +func GetInt(key string, fallback int) int { + v := Get(key, strconv.Itoa(fallback)) + if out, err := strconv.Atoi(v); err == nil { + return out + } + return fallback +} + +func Get(key, fallback string) string { + key = strings.ToUpper(key) + key = strings.ReplaceAll(key, ".", "_") + if v, found := os.LookupEnv(key); found { + return v + } + return fallback +} diff --git a/db/db.go b/db/db.go index 0d43c40..a702794 100644 --- a/db/db.go +++ b/db/db.go @@ -2,6 +2,8 @@ package db import "github.com/jmoiron/sqlx" +import _ "github.com/mattn/go-sqlite3" + type Database struct { *sqlx.DB } diff --git a/go.mod b/go.mod index 2af1902..de7bd02 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/rs/zerolog v1.16.0 github.com/speps/go-hashids v2.0.0+incompatible github.com/stretchr/graceful v1.2.15 + github.com/stretchr/testify v1.4.0 golang.org/x/crypto v0.0.0-20191107222254-f4817d981bb6 golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c // indirect google.golang.org/appengine v1.6.5 // indirect