users #16
59
auth/auth.go
59
auth/auth.go
|
@ -1,9 +1,17 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"code.chrissexton.org/cws/cabinet/db"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"code.chrissexton.org/cws/cabinet/config"
|
||||
"code.chrissexton.org/cws/cabinet/db"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
|
@ -12,21 +20,44 @@ type User struct {
|
|||
ID int64
|
||||
Name string
|
||||
Hash []byte
|
||||
AuthKey string `db:"auth_key"`
|
||||
Invalidate time.Time
|
||||
}
|
||||
|
||||
func PrepareTable(tx *sqlx.Tx) error {
|
||||
q := `create table if not exists users (
|
||||
id integer primary key,
|
||||
name text unique,
|
||||
hash text
|
||||
name text unique not null,
|
||||
hash text not null,
|
||||
auth_key text,
|
||||
invalidate datetime
|
||||
)`
|
||||
_, err := tx.Exec(q)
|
||||
return err
|
||||
}
|
||||
|
||||
func makeKey() (string, error) {
|
||||
keySize := config.GetInt("key.size", 10)
|
||||
buf := make([]byte, keySize)
|
||||
_, err := rand.Read(buf)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
key := hex.EncodeToString(buf)
|
||||
log.Debug().Msgf("Encoded secret key %s as %s", string(buf), key)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func New(db *db.Database, name, password string) (*User, error) {
|
||||
q := `insert into users (null, ?, ?)`
|
||||
res, err := db.Exec(q, name, password)
|
||||
q := `insert into users values (null, ?, ?, ?, ?)`
|
||||
|
||||
key, err := makeKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
invalidate := time.Now().Add(time.Duration(config.GetInt("invalidate.hours", 7*24)) * time.Hour)
|
||||
|
||||
res, err := db.Exec(q, name, password, key, invalidate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -37,6 +68,8 @@ func New(db *db.Database, name, password string) (*User, error) {
|
|||
u := &User{
|
||||
ID: id,
|
||||
Name: name,
|
||||
AuthKey: key,
|
||||
Invalidate: invalidate,
|
||||
}
|
||||
u.Set(password)
|
||||
return u, nil
|
||||
|
@ -67,3 +100,19 @@ func (u *User) Validate(password string) bool {
|
|||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func GetByKey(db *db.Database, key string) (*User, error) {
|
||||
q := `select * from users where auth_key = ?`
|
||||
u := &User{}
|
||||
invalid := errors.New("invalid key")
|
||||
if err := db.Get(u, q, key); err != nil {
|
||||
if err.Error() == "sql: no rows in result set" {
|
||||
return nil, invalid
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if u.Invalidate.Before(time.Now()) {
|
||||
return nil, invalid
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"code.chrissexton.org/cws/cabinet/db"
|
||||
)
|
||||
|
||||
func TestMakeKey(t *testing.T) {
|
||||
k, err := makeKey()
|
||||
assert.Nil(t, err)
|
||||
assert.NotEmpty(t, k)
|
||||
}
|
||||
|
||||
func TestGetByKey(t *testing.T) {
|
||||
d, err := db.New(":memory:")
|
||||
assert.Nil(t, err)
|
||||
tx, err := d.Beginx()
|
||||
assert.Nil(t, err)
|
||||
err = PrepareTable(tx)
|
||||
assert.Nil(t, err)
|
||||
err = tx.Commit()
|
||||
assert.Nil(t, err)
|
||||
u, err := New(d, "test", "abc")
|
||||
assert.Nil(t, err)
|
||||
u2, err := GetByKey(d, u.AuthKey)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, u.ID, u2.ID)
|
||||
}
|
||||
|
||||
func TestGetByKeyFailure(t *testing.T) {
|
||||
d, err := db.New(":memory:")
|
||||
assert.Nil(t, err)
|
||||
tx, err := d.Beginx()
|
||||
assert.Nil(t, err)
|
||||
err = PrepareTable(tx)
|
||||
assert.Nil(t, err)
|
||||
err = tx.Commit()
|
||||
assert.Nil(t, err)
|
||||
_, err = New(d, "test", "abc")
|
||||
assert.Nil(t, err)
|
||||
u2, err := GetByKey(d, "foobar")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, u2)
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
d, err := db.New(":memory:")
|
||||
assert.Nil(t, err)
|
||||
tx, err := d.Beginx()
|
||||
assert.Nil(t, err)
|
||||
err = PrepareTable(tx)
|
||||
assert.Nil(t, err)
|
||||
err = tx.Commit()
|
||||
assert.Nil(t, err)
|
||||
u, err := New(d, "test", "abc")
|
||||
assert.Nil(t, err)
|
||||
u2, err := Get(d, "test")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, u.ID, u2.ID)
|
||||
}
|
||||
|
||||
func TestUser_Validate(t *testing.T) {
|
||||
d, err := db.New(":memory:")
|
||||
assert.Nil(t, err)
|
||||
tx, err := d.Beginx()
|
||||
assert.Nil(t, err)
|
||||
err = PrepareTable(tx)
|
||||
assert.Nil(t, err)
|
||||
err = tx.Commit()
|
||||
assert.Nil(t, err)
|
||||
u, err := New(d, "test", "abc")
|
||||
assert.Nil(t, err)
|
||||
actual := u.Validate("abc")
|
||||
assert.True(t, actual)
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetInt(key string, fallback int) int {
|
||||
v := Get(key, strconv.Itoa(fallback))
|
||||
if out, err := strconv.Atoi(v); err == nil {
|
||||
return out
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func Get(key, fallback string) string {
|
||||
key = strings.ToUpper(key)
|
||||
key = strings.ReplaceAll(key, ".", "_")
|
||||
if v, found := os.LookupEnv(key); found {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
2
db/db.go
2
db/db.go
|
@ -2,6 +2,8 @@ package db
|
|||
|
||||
import "github.com/jmoiron/sqlx"
|
||||
|
||||
import _ "github.com/mattn/go-sqlite3"
|
||||
|
||||
type Database struct {
|
||||
*sqlx.DB
|
||||
}
|
||||
|
|
1
go.mod
1
go.mod
|
@ -11,6 +11,7 @@ require (
|
|||
github.com/rs/zerolog v1.16.0
|
||||
github.com/speps/go-hashids v2.0.0+incompatible
|
||||
github.com/stretchr/graceful v1.2.15
|
||||
github.com/stretchr/testify v1.4.0
|
||||
golang.org/x/crypto v0.0.0-20191107222254-f4817d981bb6
|
||||
golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c // indirect
|
||||
google.golang.org/appengine v1.6.5 // indirect
|
||||
|
|
Loading…
Reference in New Issue