mirror of https://github.com/velour/catbase.git
Merge pull request #50 from velour/babbler_persist
refactoring this to use a db took more time than all previous work co…
This commit is contained in:
commit
419c188085
|
@ -3,14 +3,15 @@
|
||||||
package babbler
|
package babbler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
"github.com/velour/catbase/bot"
|
"github.com/velour/catbase/bot"
|
||||||
"github.com/velour/catbase/bot/msg"
|
"github.com/velour/catbase/bot/msg"
|
||||||
"github.com/velour/catbase/bot/user"
|
|
||||||
"github.com/velour/catbase/config"
|
"github.com/velour/catbase/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,95 +19,86 @@ type BabblerPlugin struct {
|
||||||
Bot bot.Bot
|
Bot bot.Bot
|
||||||
db *sqlx.DB
|
db *sqlx.DB
|
||||||
config *config.Config
|
config *config.Config
|
||||||
babblers map[string]*babbler
|
|
||||||
}
|
|
||||||
|
|
||||||
type babbler struct {
|
|
||||||
start *node
|
|
||||||
end *node
|
|
||||||
lookup map[string]*node
|
|
||||||
}
|
|
||||||
|
|
||||||
type node struct {
|
|
||||||
wordFrequency int
|
|
||||||
arcs map[string]*arc
|
|
||||||
}
|
|
||||||
|
|
||||||
type arc struct {
|
|
||||||
transitionFrequency int
|
|
||||||
next *node
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(bot bot.Bot) *BabblerPlugin {
|
func New(bot bot.Bot) *BabblerPlugin {
|
||||||
|
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||||
|
|
||||||
|
if bot.DBVersion() == 1 {
|
||||||
|
if _, err := bot.DB().Exec(`create table if not exists babblers (
|
||||||
|
id integer primary key,
|
||||||
|
babbler string
|
||||||
|
);`); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := bot.DB().Exec(`create table if not exists babblerWords (
|
||||||
|
id integer primary key,
|
||||||
|
babblerId integer,
|
||||||
|
word string,
|
||||||
|
root integer,
|
||||||
|
rootFrequency integer
|
||||||
|
);`); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := bot.DB().Exec(`create table if not exists babblerArcs (
|
||||||
|
id integer primary key,
|
||||||
|
fromWordId integer,
|
||||||
|
toWordId interger,
|
||||||
|
frequency integer
|
||||||
|
);`); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
plugin := &BabblerPlugin{
|
plugin := &BabblerPlugin{
|
||||||
Bot: bot,
|
Bot: bot,
|
||||||
db: bot.DB(),
|
db: bot.DB(),
|
||||||
config: bot.Config(),
|
config: bot.Config(),
|
||||||
babblers: map[string]*babbler{},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return plugin
|
return plugin
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BabblerPlugin) makeBabbler(newUser user.User) {
|
|
||||||
name := newUser.Name
|
|
||||||
babbler, err := getMarkovChain(p.db, name)
|
|
||||||
if err == nil {
|
|
||||||
p.babblers[name] = babbler
|
|
||||||
} else {
|
|
||||||
p.babblers[name] = newBabbler()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *BabblerPlugin) makeBabblers(newUser user.User) {
|
|
||||||
users := p.Bot.Who(p.config.MainChannel)
|
|
||||||
users = append(users, newUser)
|
|
||||||
for _, name := range p.config.Babbler.DefaultUsers {
|
|
||||||
users = append(users, user.New(name))
|
|
||||||
}
|
|
||||||
for _, u := range users {
|
|
||||||
p.makeBabbler(u)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *BabblerPlugin) Message(message msg.Message) bool {
|
func (p *BabblerPlugin) Message(message msg.Message) bool {
|
||||||
if len(p.babblers) == 0 {
|
|
||||||
p.makeBabblers(*message.User)
|
|
||||||
} else if _, ok := p.babblers[message.User.Name]; !ok {
|
|
||||||
p.makeBabbler(*message.User)
|
|
||||||
}
|
|
||||||
|
|
||||||
lowercase := strings.ToLower(message.Body)
|
lowercase := strings.ToLower(message.Body)
|
||||||
tokens := strings.Fields(lowercase)
|
tokens := strings.Fields(lowercase)
|
||||||
numTokens := len(tokens)
|
numTokens := len(tokens)
|
||||||
|
|
||||||
if numTokens >= 2 && tokens[1] == "says" {
|
if numTokens >= 2 && tokens[1] == "says" {
|
||||||
|
who := tokens[0]
|
||||||
|
_, err := p.getBabbler(who)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
var saying string
|
var saying string
|
||||||
if len(tokens) == 2 {
|
if len(tokens) == 2 {
|
||||||
saying = p.babble(tokens[0])
|
saying, _ = p.babble(who)
|
||||||
} else {
|
} else {
|
||||||
saying = p.babbleSeed(tokens[0], tokens[2:])
|
saying, _ = p.babbleSeed(who, tokens[2:])
|
||||||
}
|
}
|
||||||
|
|
||||||
if saying != "" {
|
if saying != "" {
|
||||||
p.Bot.SendMessage(message.Channel, saying)
|
p.Bot.SendMessage(message.Channel, saying)
|
||||||
}
|
|
||||||
return true
|
return true
|
||||||
|
}
|
||||||
} else if len(tokens) == 4 && strings.Index(lowercase, "initialize babbler for ") == 0 {
|
} else if len(tokens) == 4 && strings.Index(lowercase, "initialize babbler for ") == 0 {
|
||||||
who := tokens[3]
|
who := tokens[3]
|
||||||
if _, ok := p.babblers[who]; !ok {
|
_, err := p.getOrCreateBabbler(who)
|
||||||
babbler, err := getMarkovChain(p.db, who)
|
if err != nil {
|
||||||
if err == nil {
|
p.Bot.SendMessage(message.Channel, "babbler initialization failed.")
|
||||||
p.babblers[who] = babbler
|
return true
|
||||||
} else {
|
|
||||||
p.babblers[who] = newBabbler()
|
|
||||||
}
|
}
|
||||||
p.Bot.SendMessage(message.Channel, "Okay.")
|
p.Bot.SendMessage(message.Channel, "Okay.")
|
||||||
return true
|
return true
|
||||||
}
|
|
||||||
} else if strings.Index(lowercase, "batch learn for ") == 0 {
|
} else if strings.Index(lowercase, "batch learn for ") == 0 {
|
||||||
who := tokens[3]
|
who := tokens[3]
|
||||||
if _, ok := p.babblers[who]; !ok {
|
babblerId, err := p.getOrCreateBabbler(who)
|
||||||
p.babblers[who] = newBabbler()
|
if err != nil {
|
||||||
|
p.Bot.SendMessage(message.Channel, "batch learn failed.")
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
body := strings.Join(tokens[4:], " ")
|
body := strings.Join(tokens[4:], " ")
|
||||||
|
@ -118,14 +110,14 @@ func (p *BabblerPlugin) Message(message msg.Message) bool {
|
||||||
for _, d := range strings.Split(c, "\n") {
|
for _, d := range strings.Split(c, "\n") {
|
||||||
trimmed := strings.TrimSpace(d)
|
trimmed := strings.TrimSpace(d)
|
||||||
if trimmed != "" {
|
if trimmed != "" {
|
||||||
addToMarkovChain(p.babblers[who], trimmed)
|
p.addToMarkovChain(babblerId, trimmed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
p.Bot.SendMessage(message.Channel, "Phew that was tiring.")
|
p.Bot.SendMessage(message.Channel, "phew that was tiring.")
|
||||||
return true
|
return true
|
||||||
} else if len(tokens) == 5 && strings.Index(lowercase, "merge babbler") == 0 {
|
} else if len(tokens) == 5 && strings.Index(lowercase, "merge babbler") == 0 {
|
||||||
if tokens[3] != "into" {
|
if tokens[3] != "into" {
|
||||||
|
@ -137,36 +129,30 @@ func (p *BabblerPlugin) Message(message msg.Message) bool {
|
||||||
into := tokens[4]
|
into := tokens[4]
|
||||||
|
|
||||||
if who == into {
|
if who == into {
|
||||||
p.Bot.SendMessage(message.Channel, "Fuck off")
|
p.Bot.SendMessage(message.Channel, "that's annoying. stop it.")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
var whoBabbler *babbler
|
whoBabbler, err := p.getBabbler(who)
|
||||||
ok := false
|
if err != nil {
|
||||||
if whoBabbler, ok = p.babblers[who]; !ok {
|
p.Bot.SendMessage(message.Channel, "merge failed.")
|
||||||
babbler, err := getMarkovChain(p.db, who)
|
return true
|
||||||
if err == nil {
|
|
||||||
whoBabbler = babbler
|
|
||||||
} else {
|
|
||||||
whoBabbler = newBabbler()
|
|
||||||
}
|
}
|
||||||
|
intoBabbler, err := p.getOrCreateBabbler(into)
|
||||||
|
if err != nil {
|
||||||
|
p.Bot.SendMessage(message.Channel, "merge failed.")
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := p.babblers[into]; !ok {
|
p.merge(intoBabbler, whoBabbler, into, who)
|
||||||
babbler, err := getMarkovChain(p.db, into)
|
|
||||||
if err == nil {
|
|
||||||
p.babblers[into] = babbler
|
|
||||||
} else {
|
|
||||||
p.babblers[into] = newBabbler()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
p.babblers[into].merge(whoBabbler, into, who)
|
|
||||||
|
|
||||||
p.Bot.SendMessage(message.Channel, "mooooiggged")
|
p.Bot.SendMessage(message.Channel, "mooooiggged")
|
||||||
return true
|
return true
|
||||||
} else {
|
} else {
|
||||||
addToMarkovChain(p.babblers[message.User.Name], lowercase)
|
babblerId, err := p.getOrCreateBabbler(message.User.Name)
|
||||||
|
if err == nil {
|
||||||
|
p.addToMarkovChain(babblerId, lowercase)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
@ -188,202 +174,391 @@ func (p *BabblerPlugin) RegisterWeb() *string {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func addToMarkovChain(babble *babbler, phrase string) {
|
func (p *BabblerPlugin) makeBabbler(babbler string) (int64, error) {
|
||||||
words := strings.Fields(strings.ToLower(phrase))
|
res, err := p.db.Exec(`insert into babblers (babbler) values (?);`, babbler)
|
||||||
|
if err == nil {
|
||||||
prev := babble.start
|
id, _ := res.LastInsertId()
|
||||||
prev.wordFrequency++
|
return id, nil
|
||||||
for i := range words {
|
|
||||||
// has this word been seen before
|
|
||||||
if _, ok := babble.lookup[words[i]]; !ok {
|
|
||||||
babble.lookup[words[i]] = &node{
|
|
||||||
wordFrequency: 1,
|
|
||||||
arcs: map[string]*arc{},
|
|
||||||
}
|
}
|
||||||
} else {
|
return -1, err
|
||||||
babble.lookup[words[i]].wordFrequency++
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// has this word been seen after the previous word before
|
func (p *BabblerPlugin) getBabbler(babbler string) (int64, error) {
|
||||||
if _, ok := prev.arcs[words[i]]; !ok {
|
id := int64(-1)
|
||||||
prev.arcs[words[i]] = &arc{
|
err := p.db.Get(&id, `select id from babblers where babbler = ?`, babbler)
|
||||||
transitionFrequency: 1,
|
|
||||||
next: babble.lookup[words[i]],
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
prev.arcs[words[i]].transitionFrequency++
|
|
||||||
}
|
|
||||||
prev = babble.lookup[words[i]]
|
|
||||||
}
|
|
||||||
|
|
||||||
// has this word ended a fact before
|
|
||||||
if _, ok := prev.arcs[""]; !ok {
|
|
||||||
prev.arcs[""] = &arc{
|
|
||||||
transitionFrequency: 1,
|
|
||||||
next: babble.end,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
prev.arcs[""].transitionFrequency++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newBabbler() *babbler {
|
|
||||||
start := &node{
|
|
||||||
wordFrequency: 0,
|
|
||||||
arcs: map[string]*arc{},
|
|
||||||
}
|
|
||||||
return &babbler{
|
|
||||||
start: start,
|
|
||||||
end: &node{
|
|
||||||
wordFrequency: 0,
|
|
||||||
arcs: map[string]*arc{},
|
|
||||||
},
|
|
||||||
lookup: map[string]*node{"": start},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// this who string isn't escaped, just sooo, you know.
|
|
||||||
func getMarkovChain(db *sqlx.DB, who string) (*babbler, error) {
|
|
||||||
query := fmt.Sprintf(`select tidbit from factoid where fact like '%s quotes';`, who)
|
|
||||||
rows, err := db.Query(query)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
}
|
||||||
|
return id, err
|
||||||
}
|
}
|
||||||
|
|
||||||
babble := newBabbler()
|
func (p *BabblerPlugin) getOrCreateBabbler(babbler string) (int64, error) {
|
||||||
|
id, err := p.getBabbler(babbler)
|
||||||
|
if err != nil {
|
||||||
|
id, err = p.makeBabbler(babbler)
|
||||||
|
if err != nil {
|
||||||
|
return id, err
|
||||||
|
}
|
||||||
|
query := fmt.Sprintf(`select tidbit from factoid where fact like '%s quotes';`, babbler)
|
||||||
|
rows, err := p.db.Query(query)
|
||||||
|
if err != nil {
|
||||||
|
//we'll just ignore this but the actual creation succeeded previously
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
|
||||||
var tidbit string
|
var tidbit string
|
||||||
err := rows.Scan(&tidbit)
|
err := rows.Scan(&tidbit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return id, err
|
||||||
|
}
|
||||||
|
p.addToMarkovChain(id, tidbit)
|
||||||
}
|
}
|
||||||
|
|
||||||
addToMarkovChain(babble, tidbit)
|
|
||||||
}
|
}
|
||||||
return babble, nil
|
return id, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BabblerPlugin) babble(who string) string {
|
func (p *BabblerPlugin) getWordId(babblerId int64, word string) (int64, error) {
|
||||||
return p.babbleSeed(who, []string{""})
|
id := int64(-1)
|
||||||
|
err := p.db.Get(&id, `select id from babblerWords where babblerId = ? and word = ?`, babblerId, word)
|
||||||
|
return id, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BabblerPlugin) babbleSeed(who string, seed []string) string {
|
func (p *BabblerPlugin) createNewWord(babblerId int64, word string) (int64, error) {
|
||||||
if babbler, ok := p.babblers[who]; ok {
|
res, err := p.db.Exec(`insert into babblerWords (babblerId, word, root, rootFrequency) values (?, ?, 0, 0);`, babblerId, word)
|
||||||
if len(babbler.start.arcs) == 0 {
|
if err != nil {
|
||||||
return ""
|
return -1, err
|
||||||
|
}
|
||||||
|
id, _ := res.LastInsertId()
|
||||||
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
words := seed
|
func (p *BabblerPlugin) getOrCreateWord(babblerId int64, word string) (int64, error) {
|
||||||
var cur *node
|
id, err := p.getWordId(babblerId, word)
|
||||||
if cur, ok = babbler.lookup[words[0]]; !ok {
|
if err != nil {
|
||||||
if len(words) == 1 {
|
return p.createNewWord(babblerId, word)
|
||||||
return fmt.Sprintf("%s hasn't used the word '%s'", who, words[0])
|
|
||||||
} else {
|
|
||||||
return fmt.Sprintf("%s hasn't used the phrase '%s'", who, strings.Join(words, " "))
|
|
||||||
}
|
}
|
||||||
|
return id, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) incrementRootWordFrequency(babblerId int64, word string) (int64, error) {
|
||||||
|
id, err := p.getOrCreateWord(babblerId, word)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
_, err = p.db.Exec(`update babblerWords set rootFrequency = rootFrequency + 1, root = 1 where id = ?;`, id)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) getWordArcHelper(fromWordId, toWordId int64) (int64, error) {
|
||||||
|
id := int64(-1)
|
||||||
|
err := p.db.Get(&id, `select id from babblerArcs where fromWordId = ? and toWordId = ?`, fromWordId, toWordId)
|
||||||
|
return id, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) incrementWordArc(fromWordId, toWordId int64) (int64, error) {
|
||||||
|
res, err := p.db.Exec(`update babblerArcs set frequency = frequency + 1 where fromWordId = ? and toWordId = ?`, fromWordId, toWordId)
|
||||||
|
affectedRows := int64(0)
|
||||||
|
if err == nil {
|
||||||
|
affectedRows, _ = res.RowsAffected()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil || affectedRows == 0 {
|
||||||
|
res, err = p.db.Exec(`insert into babblerArcs (fromWordId, toWordId, frequency) values (?, ?, 1);`, fromWordId, toWordId)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
id, _ := res.LastInsertId()
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) incrementFinalWordArcHelper(wordId int64) (int64, error) {
|
||||||
|
return p.incrementWordArc(wordId, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) incrementWordArcHelper(babblerId, fromWordId int64, toWord string) (int64, error) {
|
||||||
|
toWordId, err := p.getOrCreateWord(babblerId, toWord)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
_, err = p.incrementWordArc(fromWordId, toWordId)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
return toWordId, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) addToMarkovChain(babblerId int64, phrase string) {
|
||||||
|
words := strings.Fields(strings.ToLower(phrase))
|
||||||
|
|
||||||
|
id, err := p.incrementRootWordFrequency(babblerId, words[0])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 1; i < len(words); i++ {
|
for i := 1; i < len(words); i++ {
|
||||||
if arc, ok := cur.arcs[words[i]]; !ok {
|
id, err = p.incrementWordArcHelper(babblerId, id, words[i])
|
||||||
if len(words) == 1 {
|
if err != nil {
|
||||||
return fmt.Sprintf("%s hasn't used the word '%s'", who, words[0])
|
return
|
||||||
} else {
|
|
||||||
return fmt.Sprintf("%s hasn't used the phrase '%s'", who, strings.Join(words, " "))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cur = arc.next
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for cur != babbler.end {
|
_, err = p.incrementFinalWordArcHelper(id)
|
||||||
which := rand.Intn(cur.wordFrequency)
|
}
|
||||||
sum := 0
|
|
||||||
for word, arc := range cur.arcs {
|
func (p *BabblerPlugin) getWeightedRootWord(babblerId int64) (int64, string, error) {
|
||||||
sum += arc.transitionFrequency
|
query := fmt.Sprintf("select id, word, rootFrequency from babblerWords where babblerId = %d and root = 1", babblerId)
|
||||||
if sum > which {
|
rows, err := p.db.Query(query)
|
||||||
|
if err != nil {
|
||||||
|
return -1, "", err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
idToWord := map[int64]string{}
|
||||||
|
idToFreq := map[int64]int64{}
|
||||||
|
total := int64(0)
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var id int64
|
||||||
|
var word string
|
||||||
|
var rootFrequency int64
|
||||||
|
err = rows.Scan(&id, &word, &rootFrequency)
|
||||||
|
if err != nil {
|
||||||
|
return -1, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
total += rootFrequency
|
||||||
|
idToFreq[id] = rootFrequency
|
||||||
|
idToWord[id] = word
|
||||||
|
}
|
||||||
|
|
||||||
|
which := rand.Int63n(total)
|
||||||
|
total = 0
|
||||||
|
for id, freq := range idToFreq {
|
||||||
|
if total+freq >= which {
|
||||||
|
return id, idToWord[id], nil
|
||||||
|
}
|
||||||
|
total += freq
|
||||||
|
}
|
||||||
|
log.Fatalf("shouldn't happen")
|
||||||
|
return -1, "", errors.New("failed to find weighted root word")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) getWeightedNextWord(fromWordId int64) (int64, string, error) {
|
||||||
|
query := fmt.Sprintf("select toWordId, frequency from babblerArcs where fromWordId = %d;", fromWordId)
|
||||||
|
rows, err := p.db.Query(query)
|
||||||
|
if err != nil {
|
||||||
|
return -1, "", err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
idToFreq := map[int64]int64{}
|
||||||
|
total := int64(0)
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var toWordId int64
|
||||||
|
var frequency int64
|
||||||
|
err = rows.Scan(&toWordId, &frequency)
|
||||||
|
if err != nil {
|
||||||
|
return -1, "", err
|
||||||
|
}
|
||||||
|
total += frequency
|
||||||
|
idToFreq[toWordId] = frequency
|
||||||
|
}
|
||||||
|
|
||||||
|
which := rand.Int63n(total)
|
||||||
|
total = 0
|
||||||
|
for id, freq := range idToFreq {
|
||||||
|
if total+freq >= which {
|
||||||
|
if id < 0 {
|
||||||
|
return -1, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var word string
|
||||||
|
err := p.db.Get(&word, `select word from babblerWords where id = ?`, id)
|
||||||
|
if err != nil {
|
||||||
|
return -1, "", err
|
||||||
|
}
|
||||||
|
return id, word, nil
|
||||||
|
}
|
||||||
|
total +=freq
|
||||||
|
}
|
||||||
|
log.Fatalf("shouldn't happen")
|
||||||
|
return -1, "", errors.New("failed to find weighted next word")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) babble(who string) (string, error) {
|
||||||
|
return p.babbleSeed(who, []string{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) babbleSeed(babbler string, seed []string) (string, error) {
|
||||||
|
babblerId, err := p.getBabbler(babbler)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
words := seed
|
||||||
|
|
||||||
|
var curWordId int64
|
||||||
|
if len(seed) == 0 {
|
||||||
|
id, word, err := p.getWeightedRootWord(babblerId)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
curWordId = id
|
||||||
words = append(words, word)
|
words = append(words, word)
|
||||||
cur = arc.next
|
} else {
|
||||||
|
id, err := p.getWordId(babblerId, seed[0])
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
curWordId = id
|
||||||
|
for i := 1; i < len(seed); i++ {
|
||||||
|
nextWordId, err := p.getWordId(babblerId, seed[i])
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
_, err = p.getWordArcHelper(curWordId, nextWordId)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
curWordId = nextWordId
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
id, word, err := p.getWeightedNextWord(curWordId)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if id < 0 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
words = append(words, word)
|
||||||
|
curWordId = id
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.TrimSpace(strings.Join(words, " "))
|
return strings.TrimSpace(strings.Join(words, " ")), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
func (p *BabblerPlugin) merge(intoId, otherId int64, intoName, otherName string) error {
|
||||||
|
intoString := "<" + intoName + ">"
|
||||||
|
otherString := "<" + otherName + ">"
|
||||||
|
|
||||||
|
mapping := map[int64]int64{}
|
||||||
|
|
||||||
|
query := fmt.Sprintf("select id, word, root, rootFrequency from babblerWords where babblerId = %d;", otherId)
|
||||||
|
rows, err := p.db.Query(query)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
type Word struct {
|
||||||
|
Id int64
|
||||||
|
Word string
|
||||||
|
Root int64
|
||||||
|
RootFrequency int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (into *babbler) merge(other *babbler, intoName, otherName string) {
|
words := []Word{}
|
||||||
intoID := "<" + intoName + ">"
|
|
||||||
otherID := "<" + otherName + ">"
|
|
||||||
|
|
||||||
for nodeWord, myNode := range other.lookup {
|
for rows.Next() {
|
||||||
if nodeWord == otherID {
|
word := Word{}
|
||||||
nodeWord = intoID
|
err = rows.Scan(&word.Id, &word.Word, &word.Root, &word.RootFrequency)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
words = append(words, word)
|
||||||
}
|
}
|
||||||
|
|
||||||
//does this nodeWord exist yet?
|
for _, word := range words {
|
||||||
if _, ok := into.lookup[nodeWord]; !ok {
|
|
||||||
into.lookup[nodeWord] = &node{
|
if word.Word == otherString {
|
||||||
wordFrequency: myNode.wordFrequency,
|
word.Word = intoString
|
||||||
arcs: map[string]*arc{},
|
}
|
||||||
|
|
||||||
|
doInsert := false
|
||||||
|
wordId := int64(-1)
|
||||||
|
if word.Root > 0 {
|
||||||
|
res, err := p.db.Exec(`update babblerWords set rootFrequency = rootFrequency + ?, root = 1 where babblerId = ? and word = ? output id ;`, word.RootFrequency, intoId, word.Word)
|
||||||
|
rowsAffected := int64(0)
|
||||||
|
if err == nil {
|
||||||
|
rowsAffected, _ = res.RowsAffected()
|
||||||
|
}
|
||||||
|
if err != nil || rowsAffected == 0 {
|
||||||
|
doInsert = true
|
||||||
|
} else {
|
||||||
|
wordId, _ = res.LastInsertId()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
into.lookup[nodeWord].wordFrequency += myNode.wordFrequency
|
res, err := p.db.Exec(`update babblerWords set rootFrequency = rootFrequency + ? where babblerId = ? and word = ? output id;`, word.RootFrequency, intoId, word.Word)
|
||||||
}
|
if err != nil {
|
||||||
|
doInsert = true
|
||||||
for arcWord, myArc := range myNode.arcs {
|
|
||||||
if arcWord == otherID {
|
|
||||||
arcWord = intoID
|
|
||||||
}
|
|
||||||
|
|
||||||
if myArc.next == other.end {
|
|
||||||
if _, ok := into.lookup[nodeWord].arcs[arcWord]; !ok {
|
|
||||||
into.lookup[nodeWord].arcs[arcWord] = &arc{
|
|
||||||
transitionFrequency: myArc.transitionFrequency,
|
|
||||||
next: into.end,
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
into.lookup[nodeWord].arcs[arcWord].transitionFrequency += myArc.transitionFrequency
|
wordId, _ = res.LastInsertId()
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
//does the arcWord exist yet?
|
|
||||||
if _, ok := into.lookup[arcWord]; !ok {
|
|
||||||
into.lookup[arcWord] = &node{
|
|
||||||
wordFrequency: 0,
|
|
||||||
arcs: map[string]*arc{},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := into.lookup[nodeWord].arcs[arcWord]; !ok {
|
if doInsert {
|
||||||
into.lookup[nodeWord].arcs[arcWord] = &arc{
|
res, err := p.db.Exec(`insert into babblerWords (babblerId, word, root, rootFrequency) values (?,?,?,?) ;`, intoId, word.Word, word.Root, word.RootFrequency)
|
||||||
transitionFrequency: myArc.transitionFrequency,
|
if err != nil {
|
||||||
next: into.lookup[arcWord],
|
return err
|
||||||
}
|
|
||||||
} else {
|
|
||||||
into.lookup[nodeWord].arcs[arcWord].transitionFrequency += myArc.transitionFrequency
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
wordId, _ = res.LastInsertId()
|
||||||
}
|
}
|
||||||
|
|
||||||
into.start.wordFrequency += other.start.wordFrequency
|
log.Printf("%s %d -> %d\n", word.Word, word.Id, wordId)
|
||||||
|
|
||||||
for startWord, startArc := range other.start.arcs {
|
mapping[word.Id] = wordId
|
||||||
if startWord == otherID {
|
|
||||||
startWord = intoID
|
|
||||||
}
|
}
|
||||||
if _, ok := into.start.arcs[startWord]; !ok {
|
|
||||||
into.start.arcs[startWord] = &arc{
|
type Arc struct {
|
||||||
transitionFrequency: startArc.transitionFrequency,
|
ToWordId int64
|
||||||
next: into.lookup[startWord],
|
Frequency int64
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
into.start.arcs[startWord].transitionFrequency += startArc.transitionFrequency
|
for lookup, newArcStart := range mapping {
|
||||||
|
query = fmt.Sprintf("select toWordId, frequency from babblerArcs where fromWordId = %d;", lookup)
|
||||||
|
rows, err := p.db.Query(query)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
arcs := []Arc{}
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var arc Arc
|
||||||
|
err = rows.Scan(&arc.ToWordId, &arc.Frequency)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
arcs = append(arcs, arc)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, arc := range arcs {
|
||||||
|
newArcEnd := int64(-1) //handle end arcs
|
||||||
|
if arc.ToWordId >= 0 {
|
||||||
|
newArcEnd = mapping[arc.ToWordId]
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := p.db.Exec(`update babblerArcs set frequency = frequency + ? where fromWordId = ? and toWordId = ?`, arc.Frequency, newArcStart, newArcEnd)
|
||||||
|
rowsAffected := int64(0)
|
||||||
|
if err == nil {
|
||||||
|
rowsAffected, _ = res.RowsAffected()
|
||||||
|
}
|
||||||
|
if err != nil || rowsAffected == 0 {
|
||||||
|
_, err = p.db.Exec(`insert into babblerArcs (fromWordId, toWordId, frequency) values (?, ?, ?);`, newArcStart, newArcEnd, arc.Frequency)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -40,9 +40,8 @@ func TestBabblerNothingSaid(t *testing.T) {
|
||||||
c := New(mb)
|
c := New(mb)
|
||||||
c.config.Babbler.DefaultUsers = []string{"seabass"}
|
c.config.Babbler.DefaultUsers = []string{"seabass"}
|
||||||
assert.NotNil(t, c)
|
assert.NotNil(t, c)
|
||||||
res := c.Message(makeMessage("!seabass says"))
|
c.Message(makeMessage("!seabass says"))
|
||||||
assert.Len(t, mb.Messages, 0)
|
assert.Len(t, mb.Messages, 0)
|
||||||
assert.True(t, res)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBabbler(t *testing.T) {
|
func TestBabbler(t *testing.T) {
|
||||||
|
@ -53,7 +52,6 @@ func TestBabbler(t *testing.T) {
|
||||||
seabass := makeMessage("This is a message")
|
seabass := makeMessage("This is a message")
|
||||||
seabass.User = &user.User{Name: "seabass"}
|
seabass.User = &user.User{Name: "seabass"}
|
||||||
res := c.Message(seabass)
|
res := c.Message(seabass)
|
||||||
assert.Len(t, c.babblers, 1)
|
|
||||||
seabass.Body = "This is another message"
|
seabass.Body = "This is another message"
|
||||||
res = c.Message(seabass)
|
res = c.Message(seabass)
|
||||||
seabass.Body = "This is a long message"
|
seabass.Body = "This is a long message"
|
||||||
|
@ -73,7 +71,6 @@ func TestBabblerSeed(t *testing.T) {
|
||||||
seabass := makeMessage("This is a message")
|
seabass := makeMessage("This is a message")
|
||||||
seabass.User = &user.User{Name: "seabass"}
|
seabass.User = &user.User{Name: "seabass"}
|
||||||
res := c.Message(seabass)
|
res := c.Message(seabass)
|
||||||
assert.Len(t, c.babblers, 1)
|
|
||||||
seabass.Body = "This is another message"
|
seabass.Body = "This is another message"
|
||||||
res = c.Message(seabass)
|
res = c.Message(seabass)
|
||||||
seabass.Body = "This is a long message"
|
seabass.Body = "This is a long message"
|
||||||
|
@ -92,7 +89,6 @@ func TestBabblerMultiSeed(t *testing.T) {
|
||||||
seabass := makeMessage("This is a message")
|
seabass := makeMessage("This is a message")
|
||||||
seabass.User = &user.User{Name: "seabass"}
|
seabass.User = &user.User{Name: "seabass"}
|
||||||
res := c.Message(seabass)
|
res := c.Message(seabass)
|
||||||
assert.Len(t, c.babblers, 1)
|
|
||||||
seabass.Body = "This is another message"
|
seabass.Body = "This is another message"
|
||||||
res = c.Message(seabass)
|
res = c.Message(seabass)
|
||||||
seabass.Body = "This is a long message"
|
seabass.Body = "This is a long message"
|
||||||
|
@ -111,7 +107,6 @@ func TestBabblerMultiSeed2(t *testing.T) {
|
||||||
seabass := makeMessage("This is a message")
|
seabass := makeMessage("This is a message")
|
||||||
seabass.User = &user.User{Name: "seabass"}
|
seabass.User = &user.User{Name: "seabass"}
|
||||||
res := c.Message(seabass)
|
res := c.Message(seabass)
|
||||||
assert.Len(t, c.babblers, 1)
|
|
||||||
seabass.Body = "This is another message"
|
seabass.Body = "This is another message"
|
||||||
res = c.Message(seabass)
|
res = c.Message(seabass)
|
||||||
seabass.Body = "This is a long message"
|
seabass.Body = "This is a long message"
|
||||||
|
@ -129,16 +124,13 @@ func TestBabblerBadSeed(t *testing.T) {
|
||||||
assert.NotNil(t, c)
|
assert.NotNil(t, c)
|
||||||
seabass := makeMessage("This is a message")
|
seabass := makeMessage("This is a message")
|
||||||
seabass.User = &user.User{Name: "seabass"}
|
seabass.User = &user.User{Name: "seabass"}
|
||||||
res := c.Message(seabass)
|
c.Message(seabass)
|
||||||
assert.Len(t, c.babblers, 1)
|
|
||||||
seabass.Body = "This is another message"
|
seabass.Body = "This is another message"
|
||||||
res = c.Message(seabass)
|
c.Message(seabass)
|
||||||
seabass.Body = "This is a long message"
|
seabass.Body = "This is a long message"
|
||||||
res = c.Message(seabass)
|
c.Message(seabass)
|
||||||
res = c.Message(makeMessage("!seabass says noooo this is bad"))
|
c.Message(makeMessage("!seabass says noooo this is bad"))
|
||||||
assert.Len(t, mb.Messages, 1)
|
assert.Len(t, mb.Messages, 0)
|
||||||
assert.True(t, res)
|
|
||||||
assert.Contains(t, mb.Messages[0], "seabass hasn't used the phrase 'noooo this is bad'")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBabblerBadSeed2(t *testing.T) {
|
func TestBabblerBadSeed2(t *testing.T) {
|
||||||
|
@ -148,16 +140,13 @@ func TestBabblerBadSeed2(t *testing.T) {
|
||||||
assert.NotNil(t, c)
|
assert.NotNil(t, c)
|
||||||
seabass := makeMessage("This is a message")
|
seabass := makeMessage("This is a message")
|
||||||
seabass.User = &user.User{Name: "seabass"}
|
seabass.User = &user.User{Name: "seabass"}
|
||||||
res := c.Message(seabass)
|
c.Message(seabass)
|
||||||
assert.Len(t, c.babblers, 1)
|
|
||||||
seabass.Body = "This is another message"
|
seabass.Body = "This is another message"
|
||||||
res = c.Message(seabass)
|
c.Message(seabass)
|
||||||
seabass.Body = "This is a long message"
|
seabass.Body = "This is a long message"
|
||||||
res = c.Message(seabass)
|
c.Message(seabass)
|
||||||
res = c.Message(makeMessage("!seabass says This is a really"))
|
c.Message(makeMessage("!seabass says This is a really"))
|
||||||
assert.Len(t, mb.Messages, 1)
|
assert.Len(t, mb.Messages, 0)
|
||||||
assert.True(t, res)
|
|
||||||
assert.Contains(t, mb.Messages[0], "seabass hasn't used the phrase 'this is a really'")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBabblerBatch(t *testing.T) {
|
func TestBabblerBatch(t *testing.T) {
|
||||||
|
@ -167,7 +156,6 @@ func TestBabblerBatch(t *testing.T) {
|
||||||
assert.NotNil(t, c)
|
assert.NotNil(t, c)
|
||||||
seabass := makeMessage("batch learn for seabass This is a message! This is another message. This is not a long message? This is not a message! This is not another message. This is a long message?")
|
seabass := makeMessage("batch learn for seabass This is a message! This is another message. This is not a long message? This is not a message! This is not another message. This is a long message?")
|
||||||
res := c.Message(seabass)
|
res := c.Message(seabass)
|
||||||
assert.Len(t, c.babblers, 2)
|
|
||||||
assert.Len(t, mb.Messages, 1)
|
assert.Len(t, mb.Messages, 1)
|
||||||
res = c.Message(makeMessage("!seabass says"))
|
res = c.Message(makeMessage("!seabass says"))
|
||||||
assert.Len(t, mb.Messages, 2)
|
assert.Len(t, mb.Messages, 2)
|
||||||
|
@ -185,7 +173,6 @@ func TestBabblerMerge(t *testing.T) {
|
||||||
seabass := makeMessage("<seabass> This is a message")
|
seabass := makeMessage("<seabass> This is a message")
|
||||||
seabass.User = &user.User{Name: "seabass"}
|
seabass.User = &user.User{Name: "seabass"}
|
||||||
res := c.Message(seabass)
|
res := c.Message(seabass)
|
||||||
assert.Len(t, c.babblers, 1)
|
|
||||||
assert.Len(t, mb.Messages, 0)
|
assert.Len(t, mb.Messages, 0)
|
||||||
|
|
||||||
seabass.Body = "<seabass> This is another message"
|
seabass.Body = "<seabass> This is another message"
|
||||||
|
|
Loading…
Reference in New Issue