diff --git a/plugins/babbler/babbler.go b/plugins/babbler/babbler.go index 00574c0..4b144b9 100644 --- a/plugins/babbler/babbler.go +++ b/plugins/babbler/babbler.go @@ -3,6 +3,7 @@ package babbler import ( + "database/sql" "errors" "fmt" "log" @@ -15,6 +16,12 @@ import ( "github.com/velour/catbase/config" ) +const ( + NO_BABBLER = "babbler not found" + SAID_NOTHING = "hasn't said anything yet" +) + + type BabblerPlugin struct { Bot bot.Bot db *sqlx.DB @@ -70,16 +77,27 @@ func (p *BabblerPlugin) Message(message msg.Message) bool { _, err := p.getBabbler(who) if err != nil { + if err.Error() == NO_BABBLER { + p.Bot.SendMessage(message.Channel, fmt.Sprintf("%s babbler not found.", who)) + return true + } return false } var saying string if len(tokens) == 2 { - saying, _ = p.babble(who) + saying, err = p.babble(who) } else { - saying, _ = p.babbleSeed(who, tokens[2:]) + saying, err = p.babbleSeed(who, tokens[2:]) } + if err != nil { + if err.Error() == SAID_NOTHING { + p.Bot.SendMessage(message.Channel, fmt.Sprintf("%s hasn't said anything yet.", who)) + return true + } + return false + } if saying != "" { p.Bot.SendMessage(message.Channel, saying) return true @@ -91,7 +109,7 @@ func (p *BabblerPlugin) Message(message msg.Message) bool { p.Bot.SendMessage(message.Channel, "babbler initialization failed.") return true } - p.Bot.SendMessage(message.Channel, "Okay.") + p.Bot.SendMessage(message.Channel, "okay.") return true } else if strings.Index(lowercase, "batch learn for ") == 0 { who := tokens[3] @@ -186,7 +204,8 @@ func (p *BabblerPlugin) makeBabbler(babbler string) (int64, error) { func (p *BabblerPlugin) getBabbler(babbler string) (int64, error) { id := int64(-1) err := p.db.Get(&id, `select id from babblers where babbler = ?`, babbler) - if err != nil { + if err != nil && err == sql.ErrNoRows { + return -1, errors.New(NO_BABBLER) } return id, err } @@ -204,6 +223,7 @@ func (p *BabblerPlugin) getOrCreateBabbler(babbler string) (int64, error) { //we'll just ignore this but the actual creation succeeded previously return id, nil } + for rows.Next() { var tidbit string err := rows.Scan(&tidbit) @@ -336,6 +356,10 @@ func (p *BabblerPlugin) getWeightedRootWord(babblerId int64) (int64, string, err idToWord[id] = word } + if total == 0 { + return -1, "", errors.New(SAID_NOTHING) + } + which := rand.Int63n(total) total = 0 for id, freq := range idToFreq { @@ -370,6 +394,10 @@ func (p *BabblerPlugin) getWeightedNextWord(fromWordId int64) (int64, string, er idToFreq[toWordId] = frequency } + if total == 0 { + return -1, "", errors.New("missing arcs") + } + which := rand.Int63n(total) total = 0 for id, freq := range idToFreq { diff --git a/plugins/babbler/babbler_test.go b/plugins/babbler/babbler_test.go index 7263ce7..da8e502 100644 --- a/plugins/babbler/babbler_test.go +++ b/plugins/babbler/babbler_test.go @@ -31,8 +31,9 @@ func TestBabblerNoBabbler(t *testing.T) { c.config.Babbler.DefaultUsers = []string{"seabass"} assert.NotNil(t, c) c.Message(makeMessage("!seabass2 says")) - res := assert.Len(t, mb.Messages, 0) + res := assert.Len(t, mb.Messages, 1) assert.True(t, res) + assert.Contains(t, mb.Messages[0], "seabass2 babbler not found") } func TestBabblerNothingSaid(t *testing.T) { @@ -40,8 +41,13 @@ func TestBabblerNothingSaid(t *testing.T) { c := New(mb) c.config.Babbler.DefaultUsers = []string{"seabass"} assert.NotNil(t, c) - c.Message(makeMessage("!seabass says")) - assert.Len(t, mb.Messages, 0) + res := c.Message(makeMessage("initialize babbler for seabass")) + assert.True(t, res) + res = c.Message(makeMessage("!seabass says")) + assert.True(t, res) + assert.Len(t, mb.Messages, 2) + assert.Contains(t, mb.Messages[0], "okay.") + assert.Contains(t, mb.Messages[1], "seabass hasn't said anything yet.") } func TestBabbler(t *testing.T) {