users #16

Merged
cws merged 6 commits from users into master 2020-03-16 20:29:55 +00:00
5 changed files with 164 additions and 10 deletions
Showing only changes of commit ce02dca041 - Show all commits

View File

@ -1,32 +1,63 @@
package auth
import (
"code.chrissexton.org/cws/cabinet/db"
"crypto/rand"
"encoding/hex"
"errors"
"time"
"github.com/jmoiron/sqlx"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
"code.chrissexton.org/cws/cabinet/config"
"code.chrissexton.org/cws/cabinet/db"
)
type User struct {
*db.Database
ID int64
Name string
Hash []byte
ID int64
Name string
Hash []byte
AuthKey string `db:"auth_key"`
Invalidate time.Time
}
func PrepareTable(tx *sqlx.Tx) error {
q := `create table if not exists users (
id integer primary key,
name text unique,
hash text
name text unique not null,
hash text not null,
auth_key text,
invalidate datetime
)`
_, err := tx.Exec(q)
return err
}
func makeKey() (string, error) {
keySize := config.GetInt("key.size", 10)
buf := make([]byte, keySize)
_, err := rand.Read(buf)
if err != nil {
return "", err
}
key := hex.EncodeToString(buf)
log.Debug().Msgf("Encoded secret key %s as %s", string(buf), key)
return key, nil
}
func New(db *db.Database, name, password string) (*User, error) {
q := `insert into users (null, ?, ?)`
res, err := db.Exec(q, name, password)
q := `insert into users values (null, ?, ?, ?, ?)`
key, err := makeKey()
if err != nil {
return nil, err
}
invalidate := time.Now().Add(time.Duration(config.GetInt("invalidate.hours", 7*24)) * time.Hour)
res, err := db.Exec(q, name, password, key, invalidate)
if err != nil {
return nil, err
}
@ -35,8 +66,10 @@ func New(db *db.Database, name, password string) (*User, error) {
return nil, err
}
u := &User{
ID: id,
Name: name,
ID: id,
Name: name,
AuthKey: key,
Invalidate: invalidate,
}
u.Set(password)
return u, nil
@ -67,3 +100,19 @@ func (u *User) Validate(password string) bool {
}
return true
}
func GetByKey(db *db.Database, key string) (*User, error) {
q := `select * from users where auth_key = ?`
u := &User{}
invalid := errors.New("invalid key")
if err := db.Get(u, q, key); err != nil {
if err.Error() == "sql: no rows in result set" {
return nil, invalid
}
return nil, err
}
if u.Invalidate.Before(time.Now()) {
return nil, invalid
}
return u, nil
}

78
auth/auth_test.go Normal file
View File

@ -0,0 +1,78 @@
package auth
import (
"testing"
"github.com/stretchr/testify/assert"
"code.chrissexton.org/cws/cabinet/db"
)
func TestMakeKey(t *testing.T) {
k, err := makeKey()
assert.Nil(t, err)
assert.NotEmpty(t, k)
}
func TestGetByKey(t *testing.T) {
d, err := db.New(":memory:")
assert.Nil(t, err)
tx, err := d.Beginx()
assert.Nil(t, err)
err = PrepareTable(tx)
assert.Nil(t, err)
err = tx.Commit()
assert.Nil(t, err)
u, err := New(d, "test", "abc")
assert.Nil(t, err)
u2, err := GetByKey(d, u.AuthKey)
assert.Nil(t, err)
assert.Equal(t, u.ID, u2.ID)
}
func TestGetByKeyFailure(t *testing.T) {
d, err := db.New(":memory:")
assert.Nil(t, err)
tx, err := d.Beginx()
assert.Nil(t, err)
err = PrepareTable(tx)
assert.Nil(t, err)
err = tx.Commit()
assert.Nil(t, err)
_, err = New(d, "test", "abc")
assert.Nil(t, err)
u2, err := GetByKey(d, "foobar")
assert.Error(t, err)
assert.Nil(t, u2)
}
func TestGet(t *testing.T) {
d, err := db.New(":memory:")
assert.Nil(t, err)
tx, err := d.Beginx()
assert.Nil(t, err)
err = PrepareTable(tx)
assert.Nil(t, err)
err = tx.Commit()
assert.Nil(t, err)
u, err := New(d, "test", "abc")
assert.Nil(t, err)
u2, err := Get(d, "test")
assert.Nil(t, err)
assert.Equal(t, u.ID, u2.ID)
}
func TestUser_Validate(t *testing.T) {
d, err := db.New(":memory:")
assert.Nil(t, err)
tx, err := d.Beginx()
assert.Nil(t, err)
err = PrepareTable(tx)
assert.Nil(t, err)
err = tx.Commit()
assert.Nil(t, err)
u, err := New(d, "test", "abc")
assert.Nil(t, err)
actual := u.Validate("abc")
assert.True(t, actual)
}

24
config/config.go Normal file
View File

@ -0,0 +1,24 @@
package config
import (
"os"
"strconv"
"strings"
)
func GetInt(key string, fallback int) int {
v := Get(key, strconv.Itoa(fallback))
if out, err := strconv.Atoi(v); err == nil {
return out
}
return fallback
}
func Get(key, fallback string) string {
key = strings.ToUpper(key)
key = strings.ReplaceAll(key, ".", "_")
if v, found := os.LookupEnv(key); found {
return v
}
return fallback
}

View File

@ -2,6 +2,8 @@ package db
import "github.com/jmoiron/sqlx"
import _ "github.com/mattn/go-sqlite3"
type Database struct {
*sqlx.DB
}

1
go.mod
View File

@ -11,6 +11,7 @@ require (
github.com/rs/zerolog v1.16.0
github.com/speps/go-hashids v2.0.0+incompatible
github.com/stretchr/graceful v1.2.15
github.com/stretchr/testify v1.4.0
golang.org/x/crypto v0.0.0-20191107222254-f4817d981bb6
golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c // indirect
google.golang.org/appengine v1.6.5 // indirect