From c6709ae521a88526ac91cfba6fdafe05ccdb215b Mon Sep 17 00:00:00 2001 From: Chris Sexton Date: Wed, 25 Oct 2017 17:55:55 -0400 Subject: [PATCH] fact: add protection for aliases * No circular aliases may be created * Must have a resolvable destination --- plugins/fact/factoid.go | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/plugins/fact/factoid.go b/plugins/fact/factoid.go index 6107fbf..57d563c 100644 --- a/plugins/fact/factoid.go +++ b/plugins/fact/factoid.go @@ -38,7 +38,7 @@ type alias struct { Next string } -func (a *alias) resolve(db *sqlx.DB) *factoid { +func (a *alias) resolve(db *sqlx.DB) (*factoid, error) { // perform DB query to fill the To field q := `select fact, next from factoid_alias where fact=?` var next alias @@ -47,9 +47,10 @@ func (a *alias) resolve(db *sqlx.DB) *factoid { // we hit the end of the chain, get a factoid named Next fact, err := getSingleFact(db, a.Next) if err != nil { - log.Printf("Error resolvig alias %v: %v", a, err) + err := fmt.Errorf("Error resolvig alias %v: %v", a, err) + return nil, err } - return fact + return fact, nil } return next.resolve(db) } @@ -61,12 +62,21 @@ func findAlias(db *sqlx.DB, fact string) (bool, *factoid) { if err != nil { return false, nil } - return true, a.resolve(db) + f, err := a.resolve(db) + return err == nil, f } func (a *alias) save(db *sqlx.DB) error { + ok, _ := findAlias(db, a.Fact) + if ok { + return fmt.Errorf("DANGER: an opposite alias already exists") + } + _, err := a.resolve(db) + if err != nil { + return fmt.Errorf("there is no fact at that destination") + } q := `insert or replace into factoid_alias (fact, next) values (?, ?)` - _, err := db.Exec(q, a.Fact, a.Next) + _, err = db.Exec(q, a.Fact, a.Next) if err != nil { return err } @@ -77,10 +87,6 @@ func aliasFromStrings(from, to string) *alias { return &alias{from, to} } -func factoidFromAlias(db *sqlx.DB, a alias) *factoid { - return a.resolve(db) -} - func (f *factoid) save(db *sqlx.DB) error { var err error if f.id.Valid { @@ -611,8 +617,11 @@ func (p *Factoid) Message(message msg.Message) bool { return true } a := aliasFromStrings(strings.TrimSpace(parts[1]), strings.TrimSpace(parts[0])) - a.save(p.db) - p.Bot.SendAction(message.Channel, "learns a new synonym") + if err := a.save(p.db); err != nil { + p.Bot.SendMessage(message.Channel, err.Error()) + } else { + p.Bot.SendAction(message.Channel, "learns a new synonym") + } return true } @@ -620,9 +629,8 @@ func (p *Factoid) Message(message msg.Message) bool { if fact := p.randomFact(); fact != nil { p.sayFact(message, *fact) return true - } else { - log.Println("Got a nil fact.") } + log.Println("Got a nil fact.") } if strings.ToLower(message.Body) == "forget that" {