users #16
69
auth/auth.go
69
auth/auth.go
|
@ -1,32 +1,63 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"code.chrissexton.org/cws/cabinet/db"
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
"code.chrissexton.org/cws/cabinet/config"
|
||||||
|
"code.chrissexton.org/cws/cabinet/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
*db.Database
|
*db.Database
|
||||||
|
|
||||||
ID int64
|
ID int64
|
||||||
Name string
|
Name string
|
||||||
Hash []byte
|
Hash []byte
|
||||||
|
AuthKey string `db:"auth_key"`
|
||||||
|
Invalidate time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func PrepareTable(tx *sqlx.Tx) error {
|
func PrepareTable(tx *sqlx.Tx) error {
|
||||||
q := `create table if not exists users (
|
q := `create table if not exists users (
|
||||||
id integer primary key,
|
id integer primary key,
|
||||||
name text unique,
|
name text unique not null,
|
||||||
hash text
|
hash text not null,
|
||||||
|
auth_key text,
|
||||||
|
invalidate datetime
|
||||||
)`
|
)`
|
||||||
_, err := tx.Exec(q)
|
_, err := tx.Exec(q)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func makeKey() (string, error) {
|
||||||
|
keySize := config.GetInt("key.size", 10)
|
||||||
|
buf := make([]byte, keySize)
|
||||||
|
_, err := rand.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
key := hex.EncodeToString(buf)
|
||||||
|
log.Debug().Msgf("Encoded secret key %s as %s", string(buf), key)
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
func New(db *db.Database, name, password string) (*User, error) {
|
func New(db *db.Database, name, password string) (*User, error) {
|
||||||
q := `insert into users (null, ?, ?)`
|
q := `insert into users values (null, ?, ?, ?, ?)`
|
||||||
res, err := db.Exec(q, name, password)
|
|
||||||
|
key, err := makeKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
invalidate := time.Now().Add(time.Duration(config.GetInt("invalidate.hours", 7*24)) * time.Hour)
|
||||||
|
|
||||||
|
res, err := db.Exec(q, name, password, key, invalidate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -35,8 +66,10 @@ func New(db *db.Database, name, password string) (*User, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
u := &User{
|
u := &User{
|
||||||
ID: id,
|
ID: id,
|
||||||
Name: name,
|
Name: name,
|
||||||
|
AuthKey: key,
|
||||||
|
Invalidate: invalidate,
|
||||||
}
|
}
|
||||||
u.Set(password)
|
u.Set(password)
|
||||||
return u, nil
|
return u, nil
|
||||||
|
@ -67,3 +100,19 @@ func (u *User) Validate(password string) bool {
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetByKey(db *db.Database, key string) (*User, error) {
|
||||||
|
q := `select * from users where auth_key = ?`
|
||||||
|
u := &User{}
|
||||||
|
invalid := errors.New("invalid key")
|
||||||
|
if err := db.Get(u, q, key); err != nil {
|
||||||
|
if err.Error() == "sql: no rows in result set" {
|
||||||
|
return nil, invalid
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if u.Invalidate.Before(time.Now()) {
|
||||||
|
return nil, invalid
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"code.chrissexton.org/cws/cabinet/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMakeKey(t *testing.T) {
|
||||||
|
k, err := makeKey()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotEmpty(t, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetByKey(t *testing.T) {
|
||||||
|
d, err := db.New(":memory:")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
tx, err := d.Beginx()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
err = PrepareTable(tx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
err = tx.Commit()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
u, err := New(d, "test", "abc")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
u2, err := GetByKey(d, u.AuthKey)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, u.ID, u2.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetByKeyFailure(t *testing.T) {
|
||||||
|
d, err := db.New(":memory:")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
tx, err := d.Beginx()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
err = PrepareTable(tx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
err = tx.Commit()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
_, err = New(d, "test", "abc")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
u2, err := GetByKey(d, "foobar")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, u2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGet(t *testing.T) {
|
||||||
|
d, err := db.New(":memory:")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
tx, err := d.Beginx()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
err = PrepareTable(tx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
err = tx.Commit()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
u, err := New(d, "test", "abc")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
u2, err := Get(d, "test")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, u.ID, u2.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_Validate(t *testing.T) {
|
||||||
|
d, err := db.New(":memory:")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
tx, err := d.Beginx()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
err = PrepareTable(tx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
err = tx.Commit()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
u, err := New(d, "test", "abc")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
actual := u.Validate("abc")
|
||||||
|
assert.True(t, actual)
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetInt(key string, fallback int) int {
|
||||||
|
v := Get(key, strconv.Itoa(fallback))
|
||||||
|
if out, err := strconv.Atoi(v); err == nil {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
func Get(key, fallback string) string {
|
||||||
|
key = strings.ToUpper(key)
|
||||||
|
key = strings.ReplaceAll(key, ".", "_")
|
||||||
|
if v, found := os.LookupEnv(key); found {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
2
db/db.go
2
db/db.go
|
@ -2,6 +2,8 @@ package db
|
||||||
|
|
||||||
import "github.com/jmoiron/sqlx"
|
import "github.com/jmoiron/sqlx"
|
||||||
|
|
||||||
|
import _ "github.com/mattn/go-sqlite3"
|
||||||
|
|
||||||
type Database struct {
|
type Database struct {
|
||||||
*sqlx.DB
|
*sqlx.DB
|
||||||
}
|
}
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -11,6 +11,7 @@ require (
|
||||||
github.com/rs/zerolog v1.16.0
|
github.com/rs/zerolog v1.16.0
|
||||||
github.com/speps/go-hashids v2.0.0+incompatible
|
github.com/speps/go-hashids v2.0.0+incompatible
|
||||||
github.com/stretchr/graceful v1.2.15
|
github.com/stretchr/graceful v1.2.15
|
||||||
|
github.com/stretchr/testify v1.4.0
|
||||||
golang.org/x/crypto v0.0.0-20191107222254-f4817d981bb6
|
golang.org/x/crypto v0.0.0-20191107222254-f4817d981bb6
|
||||||
golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c // indirect
|
golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c // indirect
|
||||||
google.golang.org/appengine v1.6.5 // indirect
|
google.golang.org/appengine v1.6.5 // indirect
|
||||||
|
|
Loading…
Reference in New Issue