fact: add protection for aliases

* No circular aliases may be created
* Must have a resolvable destination
This commit is contained in:
Chris Sexton 2017-10-25 17:55:55 -04:00
parent d04e6741ea
commit c6709ae521
1 changed files with 21 additions and 13 deletions

View File

@ -38,7 +38,7 @@ type alias struct {
Next string Next string
} }
func (a *alias) resolve(db *sqlx.DB) *factoid { func (a *alias) resolve(db *sqlx.DB) (*factoid, error) {
// perform DB query to fill the To field // perform DB query to fill the To field
q := `select fact, next from factoid_alias where fact=?` q := `select fact, next from factoid_alias where fact=?`
var next alias var next alias
@ -47,9 +47,10 @@ func (a *alias) resolve(db *sqlx.DB) *factoid {
// we hit the end of the chain, get a factoid named Next // we hit the end of the chain, get a factoid named Next
fact, err := getSingleFact(db, a.Next) fact, err := getSingleFact(db, a.Next)
if err != nil { if err != nil {
log.Printf("Error resolvig alias %v: %v", a, err) err := fmt.Errorf("Error resolvig alias %v: %v", a, err)
return nil, err
} }
return fact return fact, nil
} }
return next.resolve(db) return next.resolve(db)
} }
@ -61,12 +62,21 @@ func findAlias(db *sqlx.DB, fact string) (bool, *factoid) {
if err != nil { if err != nil {
return false, nil return false, nil
} }
return true, a.resolve(db) f, err := a.resolve(db)
return err == nil, f
} }
func (a *alias) save(db *sqlx.DB) error { func (a *alias) save(db *sqlx.DB) error {
ok, _ := findAlias(db, a.Fact)
if ok {
return fmt.Errorf("DANGER: an opposite alias already exists")
}
_, err := a.resolve(db)
if err != nil {
return fmt.Errorf("there is no fact at that destination")
}
q := `insert or replace into factoid_alias (fact, next) values (?, ?)` q := `insert or replace into factoid_alias (fact, next) values (?, ?)`
_, err := db.Exec(q, a.Fact, a.Next) _, err = db.Exec(q, a.Fact, a.Next)
if err != nil { if err != nil {
return err return err
} }
@ -77,10 +87,6 @@ func aliasFromStrings(from, to string) *alias {
return &alias{from, to} return &alias{from, to}
} }
func factoidFromAlias(db *sqlx.DB, a alias) *factoid {
return a.resolve(db)
}
func (f *factoid) save(db *sqlx.DB) error { func (f *factoid) save(db *sqlx.DB) error {
var err error var err error
if f.id.Valid { if f.id.Valid {
@ -611,8 +617,11 @@ func (p *Factoid) Message(message msg.Message) bool {
return true return true
} }
a := aliasFromStrings(strings.TrimSpace(parts[1]), strings.TrimSpace(parts[0])) a := aliasFromStrings(strings.TrimSpace(parts[1]), strings.TrimSpace(parts[0]))
a.save(p.db) if err := a.save(p.db); err != nil {
p.Bot.SendMessage(message.Channel, err.Error())
} else {
p.Bot.SendAction(message.Channel, "learns a new synonym") p.Bot.SendAction(message.Channel, "learns a new synonym")
}
return true return true
} }
@ -620,9 +629,8 @@ func (p *Factoid) Message(message msg.Message) bool {
if fact := p.randomFact(); fact != nil { if fact := p.randomFact(); fact != nil {
p.sayFact(message, *fact) p.sayFact(message, *fact)
return true return true
} else {
log.Println("Got a nil fact.")
} }
log.Println("Got a nil fact.")
} }
if strings.ToLower(message.Body) == "forget that" { if strings.ToLower(message.Body) == "forget that" {