mirror of https://github.com/velour/catbase.git
Merge pull request #24 from velour/MarkovBabblerPlugin
Markov babbler plugin
This commit is contained in:
commit
e4f4cae6a7
|
@ -16,11 +16,13 @@ type MockBot struct {
|
|||
mock.Mock
|
||||
db *sqlx.DB
|
||||
|
||||
Cfg config.Config
|
||||
|
||||
Messages []string
|
||||
Actions []string
|
||||
}
|
||||
|
||||
func (mb *MockBot) Config() *config.Config { return &config.Config{} }
|
||||
func (mb *MockBot) Config() *config.Config { return &mb.Cfg }
|
||||
func (mb *MockBot) DBVersion() int64 { return 1 }
|
||||
func (mb *MockBot) DB() *sqlx.DB { return mb.db }
|
||||
func (mb *MockBot) Who(string) []user.User { return []user.User{} }
|
||||
|
|
|
@ -64,6 +64,9 @@ type Config struct {
|
|||
QuoteTime int
|
||||
StartupFact string
|
||||
}
|
||||
Babbler struct {
|
||||
DefaultUsers []string
|
||||
}
|
||||
}
|
||||
|
||||
// Readconfig loads the config data out of a JSON file located in cfile
|
||||
|
|
|
@ -58,5 +58,10 @@
|
|||
"QuoteTime": 1,
|
||||
"StartupFact": "speed test",
|
||||
"MinLen": 5
|
||||
},
|
||||
"Babbler": {
|
||||
"DefaultUsers": [
|
||||
"seabass"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
2
main.go
2
main.go
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/velour/catbase/plugins/fact"
|
||||
"github.com/velour/catbase/plugins/leftpad"
|
||||
"github.com/velour/catbase/plugins/reminder"
|
||||
"github.com/velour/catbase/plugins/babbler"
|
||||
"github.com/velour/catbase/plugins/talker"
|
||||
"github.com/velour/catbase/plugins/your"
|
||||
"github.com/velour/catbase/slack"
|
||||
|
@ -53,6 +54,7 @@ func main() {
|
|||
b.AddHandler("your", your.New(b))
|
||||
b.AddHandler("counter", counter.New(b))
|
||||
b.AddHandler("reminder", reminder.New(b))
|
||||
b.AddHandler("babbler", babbler.New(b))
|
||||
// catches anything left, will always return true
|
||||
b.AddHandler("factoid", fact.New(b))
|
||||
|
||||
|
|
|
@ -0,0 +1,226 @@
|
|||
// © 2013 the CatBase Authors under the WTFPL. See AUTHORS for the list of authors.
|
||||
|
||||
package babbler
|
||||
|
||||
import (
|
||||
// "database/sql"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/velour/catbase/bot"
|
||||
"github.com/velour/catbase/bot/msg"
|
||||
"github.com/velour/catbase/bot/user"
|
||||
"github.com/velour/catbase/config"
|
||||
)
|
||||
|
||||
type BabblerPlugin struct {
|
||||
Bot bot.Bot
|
||||
db *sqlx.DB
|
||||
config *config.Config
|
||||
babblers map[string]*babbler
|
||||
}
|
||||
|
||||
type babbler struct {
|
||||
start *node
|
||||
end *node
|
||||
lookup map[string]*node
|
||||
}
|
||||
|
||||
type node struct {
|
||||
wordFrequency int
|
||||
arcs map[string]*arc
|
||||
}
|
||||
|
||||
type arc struct {
|
||||
transitionFrequency int
|
||||
next *node
|
||||
}
|
||||
|
||||
func New(bot bot.Bot) *BabblerPlugin {
|
||||
plugin := &BabblerPlugin{
|
||||
Bot: bot,
|
||||
db: bot.DB(),
|
||||
config: bot.Config(),
|
||||
babblers: map[string]*babbler{},
|
||||
}
|
||||
|
||||
return plugin
|
||||
}
|
||||
|
||||
func (p *BabblerPlugin) makeBabbler(newUser user.User) {
|
||||
name := newUser.Name
|
||||
babbler, err := getMarkovChain(p.db, name)
|
||||
if err == nil {
|
||||
p.babblers[name] = babbler
|
||||
} else {
|
||||
p.babblers[name] = newBabbler()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BabblerPlugin) makeBabblers(newUser user.User) {
|
||||
users := p.Bot.Who(p.config.MainChannel)
|
||||
users = append(users, newUser)
|
||||
for _, name := range p.config.Babbler.DefaultUsers {
|
||||
users = append(users, user.New(name))
|
||||
}
|
||||
for _, u := range users {
|
||||
p.makeBabbler(u)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BabblerPlugin) Message(message msg.Message) bool {
|
||||
if len(p.babblers) == 0 {
|
||||
p.makeBabblers(*message.User)
|
||||
} else if _, ok := p.babblers[message.User.Name]; !ok {
|
||||
p.makeBabbler(*message.User)
|
||||
}
|
||||
|
||||
lowercase := strings.ToLower(message.Body)
|
||||
tokens := strings.Fields(lowercase)
|
||||
|
||||
addToMarkovChain(p.babblers[message.User.Name], lowercase)
|
||||
|
||||
if len(tokens) == 4 && strings.Contains(lowercase, "initialize babbler for ") {
|
||||
who := tokens[len(tokens)-1]
|
||||
if _, ok := p.babblers[who]; !ok {
|
||||
babbler, err := getMarkovChain(p.db, who)
|
||||
if err == nil {
|
||||
p.babblers[who] = babbler
|
||||
} else {
|
||||
p.babblers[who] = newBabbler()
|
||||
}
|
||||
p.Bot.SendMessage(message.Channel, "Okay.")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if len(tokens) == 2 && tokens[1] == "says" {
|
||||
saying := p.babble(tokens[0])
|
||||
if saying == "" {
|
||||
p.Bot.SendMessage(message.Channel, "Ze ain't said nothin'")
|
||||
}
|
||||
p.Bot.SendMessage(message.Channel, saying)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *BabblerPlugin) Help(channel string, parts []string) {
|
||||
p.Bot.SendMessage(channel, "initialize babbler for seabass\n\nseabass says")
|
||||
}
|
||||
|
||||
func (p *BabblerPlugin) Event(kind string, message msg.Message) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *BabblerPlugin) BotMessage(message msg.Message) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *BabblerPlugin) RegisterWeb() *string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func addToMarkovChain(babble *babbler, phrase string) {
|
||||
words := strings.Fields(strings.ToLower(phrase))
|
||||
|
||||
prev := babble.start
|
||||
prev.wordFrequency++
|
||||
for i := range words {
|
||||
// has this word been seen before
|
||||
if _, ok := babble.lookup[words[i]]; !ok {
|
||||
babble.lookup[words[i]] = &node{
|
||||
wordFrequency: 1,
|
||||
arcs: map[string]*arc{},
|
||||
}
|
||||
} else {
|
||||
babble.lookup[words[i]].wordFrequency++
|
||||
}
|
||||
|
||||
// has this word been seen after the previous word before
|
||||
if _, ok := prev.arcs[words[i]]; !ok {
|
||||
prev.arcs[words[i]] = &arc{
|
||||
transitionFrequency: 1,
|
||||
next: babble.lookup[words[i]],
|
||||
}
|
||||
} else {
|
||||
prev.arcs[words[i]].transitionFrequency++
|
||||
}
|
||||
prev = babble.lookup[words[i]]
|
||||
}
|
||||
|
||||
// has this word ended a fact before
|
||||
if _, ok := prev.arcs[""]; !ok {
|
||||
prev.arcs[""] = &arc{
|
||||
transitionFrequency: 1,
|
||||
next: babble.end,
|
||||
}
|
||||
} else {
|
||||
prev.arcs[""].transitionFrequency++
|
||||
}
|
||||
}
|
||||
|
||||
func newBabbler() *babbler {
|
||||
return &babbler{
|
||||
start: &node{
|
||||
wordFrequency: 0,
|
||||
arcs: map[string]*arc{},
|
||||
},
|
||||
end: &node{
|
||||
wordFrequency: 0,
|
||||
arcs: map[string]*arc{},
|
||||
},
|
||||
lookup: map[string]*node{},
|
||||
}
|
||||
}
|
||||
|
||||
// this who string isn't escaped, just sooo, you know.
|
||||
func getMarkovChain(db *sqlx.DB, who string) (*babbler, error) {
|
||||
query := fmt.Sprintf(`select tidbit from factoid where fact like '%s quotes';`, who)
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
babble := newBabbler()
|
||||
|
||||
for rows.Next() {
|
||||
|
||||
var tidbit string
|
||||
err := rows.Scan(&tidbit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addToMarkovChain(babble, tidbit)
|
||||
}
|
||||
return babble, nil
|
||||
}
|
||||
|
||||
func (p *BabblerPlugin) babble(who string) string {
|
||||
if babbler, ok := p.babblers[who]; ok {
|
||||
if len(babbler.start.arcs) == 0 {
|
||||
return ""
|
||||
}
|
||||
words := []string{}
|
||||
cur := babbler.start
|
||||
for cur != babbler.end {
|
||||
which := rand.Intn(cur.wordFrequency)
|
||||
sum := 0
|
||||
for word, arc := range cur.arcs {
|
||||
sum += arc.transitionFrequency
|
||||
if sum > which {
|
||||
words = append(words, word)
|
||||
cur = arc.next
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(words, " ")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("could not find babbler: %s", who)
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
// © 2013 the CatBase Authors under the WTFPL. See AUTHORS for the list of authors.
|
||||
|
||||
package babbler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/velour/catbase/bot"
|
||||
"github.com/velour/catbase/bot/msg"
|
||||
"github.com/velour/catbase/bot/user"
|
||||
)
|
||||
|
||||
func makeMessage(payload string) msg.Message {
|
||||
isCmd := strings.HasPrefix(payload, "!")
|
||||
if isCmd {
|
||||
payload = payload[1:]
|
||||
}
|
||||
return msg.Message{
|
||||
User: &user.User{Name: "tester"},
|
||||
Channel: "test",
|
||||
Body: payload,
|
||||
Command: isCmd,
|
||||
}
|
||||
}
|
||||
|
||||
func TestBabbler(t *testing.T) {
|
||||
mb := bot.NewMockBot()
|
||||
c := New(mb)
|
||||
c.config.Babbler.DefaultUsers = []string{"seabass"}
|
||||
assert.NotNil(t, c)
|
||||
seabass := makeMessage("This is a message")
|
||||
seabass.User = &user.User{Name: "seabass"}
|
||||
res := c.Message(seabass)
|
||||
assert.Len(t, c.babblers, 1)
|
||||
seabass.Body = "This is another message"
|
||||
res = c.Message(seabass)
|
||||
seabass.Body = "This is a long message"
|
||||
res = c.Message(seabass)
|
||||
res = c.Message(makeMessage("!seabass says"))
|
||||
assert.Len(t, mb.Messages, 1)
|
||||
assert.True(t, res)
|
||||
assert.Contains(t, mb.Messages[0], "this is")
|
||||
assert.Contains(t, mb.Messages[0], "message")
|
||||
}
|
||||
|
||||
func TestHelp(t *testing.T) {
|
||||
mb := bot.NewMockBot()
|
||||
c := New(mb)
|
||||
assert.NotNil(t, c)
|
||||
c.Help("channel", []string{})
|
||||
assert.Len(t, mb.Messages, 1)
|
||||
}
|
||||
|
||||
func TestBotMessage(t *testing.T) {
|
||||
mb := bot.NewMockBot()
|
||||
c := New(mb)
|
||||
assert.NotNil(t, c)
|
||||
assert.False(t, c.BotMessage(makeMessage("test")))
|
||||
}
|
||||
|
||||
func TestEvent(t *testing.T) {
|
||||
mb := bot.NewMockBot()
|
||||
c := New(mb)
|
||||
assert.NotNil(t, c)
|
||||
assert.False(t, c.Event("dummy", makeMessage("test")))
|
||||
}
|
||||
|
||||
func TestRegisterWeb(t *testing.T) {
|
||||
mb := bot.NewMockBot()
|
||||
c := New(mb)
|
||||
assert.NotNil(t, c)
|
||||
assert.Nil(t, c.RegisterWeb())
|
||||
}
|
|
@ -258,9 +258,14 @@ func (s *Slack) Who(id string) []string {
|
|||
u := s.url + "channels.info"
|
||||
resp, err := http.PostForm(u,
|
||||
url.Values{"token": {s.config.Slack.Token}, "channel": {id}})
|
||||
if err != nil || resp.StatusCode != 200 {
|
||||
log.Printf("Error posting user info request: %d %s",
|
||||
resp.StatusCode, err)
|
||||
if err != nil {
|
||||
log.Printf("Error posting user info request: %s",
|
||||
err)
|
||||
return []string{}
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
log.Printf("Error posting user info request: %d",
|
||||
resp.StatusCode)
|
||||
return []string{}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
|
Loading…
Reference in New Issue