From 8c98734e7e74430942c41f5c42a5d8df5d978775 Mon Sep 17 00:00:00 2001 From: skkiesel Date: Wed, 10 May 2017 08:41:41 -0400 Subject: [PATCH 1/6] Add useful error messages when a babbler isn't found or when a user hasn't said something yet --- plugins/babbler/babbler.go | 36 +++++++++++++++++++++++++++++---- plugins/babbler/babbler_test.go | 12 ++++++++--- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/plugins/babbler/babbler.go b/plugins/babbler/babbler.go index 00574c0..4b144b9 100644 --- a/plugins/babbler/babbler.go +++ b/plugins/babbler/babbler.go @@ -3,6 +3,7 @@ package babbler import ( + "database/sql" "errors" "fmt" "log" @@ -15,6 +16,12 @@ import ( "github.com/velour/catbase/config" ) +const ( + NO_BABBLER = "babbler not found" + SAID_NOTHING = "hasn't said anything yet" +) + + type BabblerPlugin struct { Bot bot.Bot db *sqlx.DB @@ -70,16 +77,27 @@ func (p *BabblerPlugin) Message(message msg.Message) bool { _, err := p.getBabbler(who) if err != nil { + if err.Error() == NO_BABBLER { + p.Bot.SendMessage(message.Channel, fmt.Sprintf("%s babbler not found.", who)) + return true + } return false } var saying string if len(tokens) == 2 { - saying, _ = p.babble(who) + saying, err = p.babble(who) } else { - saying, _ = p.babbleSeed(who, tokens[2:]) + saying, err = p.babbleSeed(who, tokens[2:]) } + if err != nil { + if err.Error() == SAID_NOTHING { + p.Bot.SendMessage(message.Channel, fmt.Sprintf("%s hasn't said anything yet.", who)) + return true + } + return false + } if saying != "" { p.Bot.SendMessage(message.Channel, saying) return true @@ -91,7 +109,7 @@ func (p *BabblerPlugin) Message(message msg.Message) bool { p.Bot.SendMessage(message.Channel, "babbler initialization failed.") return true } - p.Bot.SendMessage(message.Channel, "Okay.") + p.Bot.SendMessage(message.Channel, "okay.") return true } else if strings.Index(lowercase, "batch learn for ") == 0 { who := tokens[3] @@ -186,7 +204,8 @@ func (p *BabblerPlugin) makeBabbler(babbler string) (int64, error) { func (p *BabblerPlugin) getBabbler(babbler string) (int64, error) { id := int64(-1) err := p.db.Get(&id, `select id from babblers where babbler = ?`, babbler) - if err != nil { + if err != nil && err == sql.ErrNoRows { + return -1, errors.New(NO_BABBLER) } return id, err } @@ -204,6 +223,7 @@ func (p *BabblerPlugin) getOrCreateBabbler(babbler string) (int64, error) { //we'll just ignore this but the actual creation succeeded previously return id, nil } + for rows.Next() { var tidbit string err := rows.Scan(&tidbit) @@ -336,6 +356,10 @@ func (p *BabblerPlugin) getWeightedRootWord(babblerId int64) (int64, string, err idToWord[id] = word } + if total == 0 { + return -1, "", errors.New(SAID_NOTHING) + } + which := rand.Int63n(total) total = 0 for id, freq := range idToFreq { @@ -370,6 +394,10 @@ func (p *BabblerPlugin) getWeightedNextWord(fromWordId int64) (int64, string, er idToFreq[toWordId] = frequency } + if total == 0 { + return -1, "", errors.New("missing arcs") + } + which := rand.Int63n(total) total = 0 for id, freq := range idToFreq { diff --git a/plugins/babbler/babbler_test.go b/plugins/babbler/babbler_test.go index 7263ce7..da8e502 100644 --- a/plugins/babbler/babbler_test.go +++ b/plugins/babbler/babbler_test.go @@ -31,8 +31,9 @@ func TestBabblerNoBabbler(t *testing.T) { c.config.Babbler.DefaultUsers = []string{"seabass"} assert.NotNil(t, c) c.Message(makeMessage("!seabass2 says")) - res := assert.Len(t, mb.Messages, 0) + res := assert.Len(t, mb.Messages, 1) assert.True(t, res) + assert.Contains(t, mb.Messages[0], "seabass2 babbler not found") } func TestBabblerNothingSaid(t *testing.T) { @@ -40,8 +41,13 @@ func TestBabblerNothingSaid(t *testing.T) { c := New(mb) c.config.Babbler.DefaultUsers = []string{"seabass"} assert.NotNil(t, c) - c.Message(makeMessage("!seabass says")) - assert.Len(t, mb.Messages, 0) + res := c.Message(makeMessage("initialize babbler for seabass")) + assert.True(t, res) + res = c.Message(makeMessage("!seabass says")) + assert.True(t, res) + assert.Len(t, mb.Messages, 2) + assert.Contains(t, mb.Messages[0], "okay.") + assert.Contains(t, mb.Messages[1], "seabass hasn't said anything yet.") } func TestBabbler(t *testing.T) { From f28dabf0bad0905de6dd49eb71c2859553b031e3 Mon Sep 17 00:00:00 2001 From: skkiesel Date: Wed, 10 May 2017 09:42:27 -0400 Subject: [PATCH 2/6] clean up and ignore empty strings when adding to the markov chain --- plugins/babbler/babbler.go | 126 +++++++----------------------------- plugins/babbler/commands.go | 105 ++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 103 deletions(-) create mode 100644 plugins/babbler/commands.go diff --git a/plugins/babbler/babbler.go b/plugins/babbler/babbler.go index 4b144b9..49c2aed 100644 --- a/plugins/babbler/babbler.go +++ b/plugins/babbler/babbler.go @@ -16,9 +16,9 @@ import ( "github.com/velour/catbase/config" ) -const ( - NO_BABBLER = "babbler not found" - SAID_NOTHING = "hasn't said anything yet" +var ( + NO_BABBLER = errors.New("babbler not found") + SAID_NOTHING = errors.New("hasn't said anything yet") ) @@ -72,108 +72,26 @@ func (p *BabblerPlugin) Message(message msg.Message) bool { tokens := strings.Fields(lowercase) numTokens := len(tokens) + saidSomething := false + saidWhat := "" + if numTokens >= 2 && tokens[1] == "says" { - who := tokens[0] - _, err := p.getBabbler(who) - - if err != nil { - if err.Error() == NO_BABBLER { - p.Bot.SendMessage(message.Channel, fmt.Sprintf("%s babbler not found.", who)) - return true - } - return false - } - - var saying string - if len(tokens) == 2 { - saying, err = p.babble(who) - } else { - saying, err = p.babbleSeed(who, tokens[2:]) - } - - if err != nil { - if err.Error() == SAID_NOTHING { - p.Bot.SendMessage(message.Channel, fmt.Sprintf("%s hasn't said anything yet.", who)) - return true - } - return false - } - if saying != "" { - p.Bot.SendMessage(message.Channel, saying) - return true - } + saidWhat, saidSomething = p.getBabble(tokens) } else if len(tokens) == 4 && strings.Index(lowercase, "initialize babbler for ") == 0 { - who := tokens[3] - _, err := p.getOrCreateBabbler(who) - if err != nil { - p.Bot.SendMessage(message.Channel, "babbler initialization failed.") - return true - } - p.Bot.SendMessage(message.Channel, "okay.") - return true + saidWhat, saidSomething = p.initializeBabbler(tokens) } else if strings.Index(lowercase, "batch learn for ") == 0 { - who := tokens[3] - babblerId, err := p.getOrCreateBabbler(who) - if err != nil { - p.Bot.SendMessage(message.Channel, "batch learn failed.") - return true - } - - body := strings.Join(tokens[4:], " ") - body = strings.ToLower(body) - - for _, a := range strings.Split(body, ".") { - for _, b := range strings.Split(a, "!") { - for _, c := range strings.Split(b, "?") { - for _, d := range strings.Split(c, "\n") { - trimmed := strings.TrimSpace(d) - if trimmed != "" { - p.addToMarkovChain(babblerId, trimmed) - } - } - } - } - } - - p.Bot.SendMessage(message.Channel, "phew that was tiring.") - return true + saidWhat, saidSomething = p.batchLearn(tokens) } else if len(tokens) == 5 && strings.Index(lowercase, "merge babbler") == 0 { - if tokens[3] != "into" { - p.Bot.SendMessage(message.Channel, "try using 'merge babbler [x] into [y]'") - return true - } - - who := tokens[2] - into := tokens[4] - - if who == into { - p.Bot.SendMessage(message.Channel, "that's annoying. stop it.") - return true - } - - whoBabbler, err := p.getBabbler(who) - if err != nil { - p.Bot.SendMessage(message.Channel, "merge failed.") - return true - } - intoBabbler, err := p.getOrCreateBabbler(into) - if err != nil { - p.Bot.SendMessage(message.Channel, "merge failed.") - return true - } - - p.merge(intoBabbler, whoBabbler, into, who) - - p.Bot.SendMessage(message.Channel, "mooooiggged") - return true + saidWhat, saidSomething = p.merge(tokens) } else { - babblerId, err := p.getOrCreateBabbler(message.User.Name) - if err == nil { - p.addToMarkovChain(babblerId, lowercase) - } + //this should always return "", false + saidWhat, saidSomething = p.addToBabbler(message.User.Name, lowercase) } - return false + if saidSomething { + p.Bot.SendMessage(message.Channel, saidWhat) + } + return saidSomething } func (p *BabblerPlugin) Help(channel string, parts []string) { @@ -205,7 +123,7 @@ func (p *BabblerPlugin) getBabbler(babbler string) (int64, error) { id := int64(-1) err := p.db.Get(&id, `select id from babblers where babbler = ?`, babbler) if err != nil && err == sql.ErrNoRows { - return -1, errors.New(NO_BABBLER) + return -1, NO_BABBLER } return id, err } @@ -315,6 +233,10 @@ func (p *BabblerPlugin) incrementWordArcHelper(babblerId, fromWordId int64, toWo func (p *BabblerPlugin) addToMarkovChain(babblerId int64, phrase string) { words := strings.Fields(strings.ToLower(phrase)) + if len(words) <= 0 { + return + } + id, err := p.incrementRootWordFrequency(babblerId, words[0]) if err != nil { return @@ -357,7 +279,7 @@ func (p *BabblerPlugin) getWeightedRootWord(babblerId int64) (int64, string, err } if total == 0 { - return -1, "", errors.New(SAID_NOTHING) + return -1, "", SAID_NOTHING } which := rand.Int63n(total) @@ -473,7 +395,7 @@ func (p *BabblerPlugin) babbleSeed(babbler string, seed []string) (string, error return strings.TrimSpace(strings.Join(words, " ")), nil } -func (p *BabblerPlugin) merge(intoId, otherId int64, intoName, otherName string) error { +func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName string) error { intoString := "<" + intoName + ">" otherString := "<" + otherName + ">" @@ -540,8 +462,6 @@ func (p *BabblerPlugin) merge(intoId, otherId int64, intoName, otherName string) wordId, _ = res.LastInsertId() } - log.Printf("%s %d -> %d\n", word.Word, word.Id, wordId) - mapping[word.Id] = wordId } diff --git a/plugins/babbler/commands.go b/plugins/babbler/commands.go new file mode 100644 index 0000000..c52210a --- /dev/null +++ b/plugins/babbler/commands.go @@ -0,0 +1,105 @@ +package babbler + +import ( + "fmt" + "strings" +) + +func (p *BabblerPlugin) initializeBabbler(tokens []string) (string, bool) { + who := tokens[3] + _, err := p.getOrCreateBabbler(who) + if err != nil { + return "babbler initialization failed.", true + } + return "okay.", true +} + +func (p *BabblerPlugin) addToBabbler(babblerName, whatWasSaid string) (string, bool) { + babblerId, err := p.getOrCreateBabbler(babblerName) + if err == nil { + p.addToMarkovChain(babblerId, whatWasSaid) + } + return "", false +} + +func (p *BabblerPlugin) getBabble(tokens []string) (string, bool) { + who := tokens[0] + _, err := p.getBabbler(who) + + if err != nil { + if err == NO_BABBLER { + return fmt.Sprintf("%s babbler not found.", who), true + } + } else { + + var saying string + if len(tokens) == 2 { + saying, err = p.babble(who) + } else { + saying, err = p.babbleSeed(who, tokens[2:]) + } + + if err != nil { + if err == SAID_NOTHING { + return fmt.Sprintf("%s hasn't said anything yet.", who), true + } + } else if saying != "" { + return saying, true + } + } + return "", false +} + +func (p *BabblerPlugin) batchLearn(tokens []string) (string, bool) { + who := tokens[3] + babblerId, err := p.getOrCreateBabbler(who) + if err != nil { + return "batch learn failed.", true + } + + body := strings.Join(tokens[4:], " ") + body = strings.ToLower(body) + + for _, a := range strings.Split(body, ".") { + for _, b := range strings.Split(a, "!") { + for _, c := range strings.Split(b, "?") { + for _, d := range strings.Split(c, "\n") { + trimmed := strings.TrimSpace(d) + if trimmed != "" { + p.addToMarkovChain(babblerId, trimmed) + } + } + } + } + } + return "phew that was tiring.", true +} + +func (p *BabblerPlugin) merge(tokens []string) (string, bool) { + if tokens[3] != "into" { + return "try using 'merge babbler [x] into [y]'", true + } + + who := tokens[2] + into := tokens[4] + + if who == into { + return "that's annoying. stop it.", true + } + + whoBabbler, err := p.getBabbler(who) + if err != nil { + if err == NO_BABBLER { + return fmt.Sprintf("%s babbler not found.", who), true + } + return "merge failed.", true + } + intoBabbler, err := p.getOrCreateBabbler(into) + if err != nil { + return "merge failed.", true + } + + p.mergeBabblers(intoBabbler, whoBabbler, into, who) + + return "mooooiggged", true +} From 24f373d20f8252508748c11eb0384ef1af1e3913 Mon Sep 17 00:00:00 2001 From: skkiesel Date: Wed, 10 May 2017 09:50:37 -0400 Subject: [PATCH 3/6] when read through the whole set of quote rows before beginning adding back into the babbler --- plugins/babbler/babbler.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/plugins/babbler/babbler.go b/plugins/babbler/babbler.go index 49c2aed..b6e1598 100644 --- a/plugins/babbler/babbler.go +++ b/plugins/babbler/babbler.go @@ -141,13 +141,19 @@ func (p *BabblerPlugin) getOrCreateBabbler(babbler string) (int64, error) { //we'll just ignore this but the actual creation succeeded previously return id, nil } + defer rows.Close() + tidbits := []string{} for rows.Next() { var tidbit string err := rows.Scan(&tidbit) if err != nil { return id, err } + tidbits = append(tidbits, tidbit) + } + + for _, tidbit := range tidbits { p.addToMarkovChain(id, tidbit) } From efde5d804d90738586aadb6c876ee1143527e02c Mon Sep 17 00:00:00 2001 From: skkiesel Date: Wed, 10 May 2017 10:11:49 -0400 Subject: [PATCH 4/6] add feedback for when a seed isn't found --- plugins/babbler/babbler.go | 9 ++++++++- plugins/babbler/babbler_test.go | 6 ++++-- plugins/babbler/commands.go | 2 ++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/plugins/babbler/babbler.go b/plugins/babbler/babbler.go index b6e1598..6073d2d 100644 --- a/plugins/babbler/babbler.go +++ b/plugins/babbler/babbler.go @@ -19,6 +19,7 @@ import ( var ( NO_BABBLER = errors.New("babbler not found") SAID_NOTHING = errors.New("hasn't said anything yet") + NEVER_SAID = errors.New("never said that") ) @@ -122,7 +123,7 @@ func (p *BabblerPlugin) makeBabbler(babbler string) (int64, error) { func (p *BabblerPlugin) getBabbler(babbler string) (int64, error) { id := int64(-1) err := p.db.Get(&id, `select id from babblers where babbler = ?`, babbler) - if err != nil && err == sql.ErrNoRows { + if err == sql.ErrNoRows { return -1, NO_BABBLER } return id, err @@ -164,6 +165,9 @@ func (p *BabblerPlugin) getOrCreateBabbler(babbler string) (int64, error) { func (p *BabblerPlugin) getWordId(babblerId int64, word string) (int64, error) { id := int64(-1) err := p.db.Get(&id, `select id from babblerWords where babblerId = ? and word = ?`, babblerId, word) + if err == sql.ErrNoRows { + return -1, NEVER_SAID + } return id, err } @@ -199,6 +203,9 @@ func (p *BabblerPlugin) incrementRootWordFrequency(babblerId int64, word string) func (p *BabblerPlugin) getWordArcHelper(fromWordId, toWordId int64) (int64, error) { id := int64(-1) err := p.db.Get(&id, `select id from babblerArcs where fromWordId = ? and toWordId = ?`, fromWordId, toWordId) + if err == sql.ErrNoRows { + return -1, NEVER_SAID + } return id, err } diff --git a/plugins/babbler/babbler_test.go b/plugins/babbler/babbler_test.go index da8e502..a568b8b 100644 --- a/plugins/babbler/babbler_test.go +++ b/plugins/babbler/babbler_test.go @@ -136,7 +136,8 @@ func TestBabblerBadSeed(t *testing.T) { seabass.Body = "This is a long message" c.Message(seabass) c.Message(makeMessage("!seabass says noooo this is bad")) - assert.Len(t, mb.Messages, 0) + assert.Len(t, mb.Messages, 1) + assert.Contains(t, mb.Messages[0], "seabass never said 'noooo this is bad'") } func TestBabblerBadSeed2(t *testing.T) { @@ -152,7 +153,8 @@ func TestBabblerBadSeed2(t *testing.T) { seabass.Body = "This is a long message" c.Message(seabass) c.Message(makeMessage("!seabass says This is a really")) - assert.Len(t, mb.Messages, 0) + assert.Len(t, mb.Messages, 1) + assert.Contains(t, mb.Messages[0], "seabass never said 'this is a really'") } func TestBabblerBatch(t *testing.T) { diff --git a/plugins/babbler/commands.go b/plugins/babbler/commands.go index c52210a..b16f768 100644 --- a/plugins/babbler/commands.go +++ b/plugins/babbler/commands.go @@ -42,6 +42,8 @@ func (p *BabblerPlugin) getBabble(tokens []string) (string, bool) { if err != nil { if err == SAID_NOTHING { return fmt.Sprintf("%s hasn't said anything yet.", who), true + } else if err == NEVER_SAID { + return fmt.Sprintf("%s never said '%s'", who, strings.Join(tokens[2:], " ")), true } } else if saying != "" { return saying, true From e97ce576b8fa0ea2af3e0dac2868794f759bee14 Mon Sep 17 00:00:00 2001 From: skkiesel Date: Wed, 10 May 2017 11:59:35 -0400 Subject: [PATCH 5/6] probably not a huge memory savings, but only represent each unique word once and then reference them by their key for each babbler --- plugins/babbler/babbler.go | 211 ++++++++++++++++++++++++++---------- plugins/babbler/commands.go | 5 +- 2 files changed, 158 insertions(+), 58 deletions(-) diff --git a/plugins/babbler/babbler.go b/plugins/babbler/babbler.go index 6073d2d..41b4f56 100644 --- a/plugins/babbler/babbler.go +++ b/plugins/babbler/babbler.go @@ -39,10 +39,18 @@ func New(bot bot.Bot) *BabblerPlugin { );`); err != nil { log.Fatal(err) } + if _, err := bot.DB().Exec(`create table if not exists babblerWords ( + id integer primary key, + word string + );`); err != nil { + log.Fatal(err) + } + + if _, err := bot.DB().Exec(`create table if not exists babblerNodes ( id integer primary key, babblerId integer, - word string, + wordId integer, root integer, rootFrequency integer );`); err != nil { @@ -162,39 +170,97 @@ func (p *BabblerPlugin) getOrCreateBabbler(babbler string) (int64, error) { return id, err } -func (p *BabblerPlugin) getWordId(babblerId int64, word string) (int64, error) { +func (p *BabblerPlugin) getWordId(word string) (int64, error) { id := int64(-1) - err := p.db.Get(&id, `select id from babblerWords where babblerId = ? and word = ?`, babblerId, word) - if err == sql.ErrNoRows { + if err := p.db.Get(&id, `select id from babblerWords where word = ?`, word); err == sql.ErrNoRows { return -1, NEVER_SAID + } else { + if err != nil { + log.Print(err) + } + return id, err } - return id, err } -func (p *BabblerPlugin) createNewWord(babblerId int64, word string) (int64, error) { - res, err := p.db.Exec(`insert into babblerWords (babblerId, word, root, rootFrequency) values (?, ?, 0, 0);`, babblerId, word) - if err != nil { +func (p *BabblerPlugin) createNewWord(word string) (int64, error) { + if res, err := p.db.Exec(`insert into babblerWords (word) values (?);`, word); err == nil { + id, _ := res.LastInsertId() + return id, nil + } else { + if err != nil { + log.Print(err) + } return -1, err } - id, _ := res.LastInsertId() - return id, nil } -func (p *BabblerPlugin) getOrCreateWord(babblerId int64, word string) (int64, error) { - id, err := p.getWordId(babblerId, word) +func (p *BabblerPlugin) getOrCreateWord(word string) (int64, error) { + if id, err := p.getWordId(word); err == NEVER_SAID { + return p.createNewWord(word) + } else { + return id, err + } +} + +func (p *BabblerPlugin) getWordNodeId(babblerId int64, word string) (int64, error) { + var err error + wordId := int64(-1) + if wordId, err = p.getWordId(word); err != nil { + return -1, err + } + id := int64(-1) + if err = p.db.Get(&id, `select id from babblerNodes where babblerId = ? and wordId = ?`, babblerId, wordId); err == sql.ErrNoRows { + return -1, NEVER_SAID + } else { + if err != nil { + log.Print(err) + } + return id, err + } +} + +func (p *BabblerPlugin) createWordNodeId(babblerId int64, word string) (int64, error) { + var err error + wordId := int64(-1) + if wordId, err = p.getOrCreateWord(word); err != nil { + log.Print(err) + return -1, err + } + id := int64(-1) + res, err := p.db.Exec(`insert into babblerNodes (babblerId, wordId, root, rootFrequency) values (?, ?, 0, 0)`, babblerId, wordId) + if err == sql.ErrNoRows { + return -1, NEVER_SAID + } else if err == nil { + id, _ = res.LastInsertId() + } if err != nil { - return p.createNewWord(babblerId, word) + log.Print(err) } return id, err } +func (p *BabblerPlugin) getOrCreateWordNode(babblerId int64, word string) (int64, error) { + if id, err := p.getWordNodeId(babblerId, word); err == NEVER_SAID { + return p.createWordNodeId(babblerId, word) + } else { + if err != nil { + log.Print(err) + } + return id, err + } +} + func (p *BabblerPlugin) incrementRootWordFrequency(babblerId int64, word string) (int64, error) { - id, err := p.getOrCreateWord(babblerId, word) - if err != nil { + var err error + id := int64(-1) + if id, err = p.getOrCreateWordNode(babblerId, word); err != nil { return -1, err } - _, err = p.db.Exec(`update babblerWords set rootFrequency = rootFrequency + 1, root = 1 where id = ?;`, id) - if err != nil { + + if _, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + 1, root = 1 where babblerId = ? and wordId = ?;`, babblerId, id); err != nil { + if err != nil { + log.Print(err) + } return -1, err } return id, nil @@ -206,6 +272,9 @@ func (p *BabblerPlugin) getWordArcHelper(fromWordId, toWordId int64) (int64, err if err == sql.ErrNoRows { return -1, NEVER_SAID } + if err != nil { + log.Print(err) + } return id, err } @@ -219,7 +288,7 @@ func (p *BabblerPlugin) incrementWordArc(fromWordId, toWordId int64) (int64, err if err != nil || affectedRows == 0 { res, err = p.db.Exec(`insert into babblerArcs (fromWordId, toWordId, frequency) values (?, ?, 1);`, fromWordId, toWordId) if err != nil { - + log.Print(err) return -1, err } } @@ -232,7 +301,7 @@ func (p *BabblerPlugin) incrementFinalWordArcHelper(wordId int64) (int64, error) } func (p *BabblerPlugin) incrementWordArcHelper(babblerId, fromWordId int64, toWord string) (int64, error) { - toWordId, err := p.getOrCreateWord(babblerId, toWord) + toWordId, err := p.getOrCreateWordNode(babblerId, toWord) if err != nil { return -1, err } @@ -252,12 +321,14 @@ func (p *BabblerPlugin) addToMarkovChain(babblerId int64, phrase string) { id, err := p.incrementRootWordFrequency(babblerId, words[0]) if err != nil { + log.Print(err) return } for i := 1; i < len(words); i++ { id, err = p.incrementWordArcHelper(babblerId, id, words[i]) if err != nil { + log.Print(err) return } } @@ -266,29 +337,31 @@ func (p *BabblerPlugin) addToMarkovChain(babblerId int64, phrase string) { } func (p *BabblerPlugin) getWeightedRootWord(babblerId int64) (int64, string, error) { - query := fmt.Sprintf("select id, word, rootFrequency from babblerWords where babblerId = %d and root = 1", babblerId) + query := fmt.Sprintf("select id, wordId, rootFrequency from babblerNodes where babblerId = %d and root = 1", babblerId) rows, err := p.db.Query(query) if err != nil { + log.Print(err) return -1, "", err } defer rows.Close() - idToWord := map[int64]string{} + idToWordId := map[int64]int64{} idToFreq := map[int64]int64{} total := int64(0) for rows.Next() { var id int64 - var word string + var wordId int64 var rootFrequency int64 - err = rows.Scan(&id, &word, &rootFrequency) + err = rows.Scan(&id, &wordId, &rootFrequency) if err != nil { + log.Print(err) return -1, "", err } total += rootFrequency idToFreq[id] = rootFrequency - idToWord[id] = word + idToWordId[id] = wordId } if total == 0 { @@ -299,7 +372,13 @@ func (p *BabblerPlugin) getWeightedRootWord(babblerId int64) (int64, string, err total = 0 for id, freq := range idToFreq { if total+freq >= which { - return id, idToWord[id], nil + var word string + err := p.db.Get(&word, `select word from babblerWords where id = ?`, idToWordId[id]) + if err != nil { + log.Print(err) + return -1, "", err + } + return id, word, nil } total += freq } @@ -311,6 +390,7 @@ func (p *BabblerPlugin) getWeightedNextWord(fromWordId int64) (int64, string, er query := fmt.Sprintf("select toWordId, frequency from babblerArcs where fromWordId = %d;", fromWordId) rows, err := p.db.Query(query) if err != nil { + log.Print(err) return -1, "", err } defer rows.Close() @@ -323,6 +403,7 @@ func (p *BabblerPlugin) getWeightedNextWord(fromWordId int64) (int64, string, er var frequency int64 err = rows.Scan(&toWordId, &frequency) if err != nil { + log.Print(err) return -1, "", err } total += frequency @@ -341,11 +422,17 @@ func (p *BabblerPlugin) getWeightedNextWord(fromWordId int64) (int64, string, er return -1, "", nil } - var word string - err := p.db.Get(&word, `select word from babblerWords where id = ?`, id) - if err != nil { + var wordId int64 + if err := p.db.Get(&wordId, `select wordId from babblerNodes where id = ?`, id); err != nil { + log.Print(err) return -1, "", err } + var word string + if err := p.db.Get(&word, `select word from babblerWords where id = ?`, wordId); err != nil { + log.Print(err) + return -1, "", err + } + return id, word, nil } total +=freq @@ -375,13 +462,14 @@ func (p *BabblerPlugin) babbleSeed(babbler string, seed []string) (string, error curWordId = id words = append(words, word) } else { - id, err := p.getWordId(babblerId, seed[0]) + id, err := p.getWordNodeId(babblerId, seed[0]) if err != nil { + return "", err } curWordId = id for i := 1; i < len(seed); i++ { - nextWordId, err := p.getWordId(babblerId, seed[i]) + nextWordId, err := p.getWordNodeId(babblerId, seed[i]) if err != nil { return "", err } @@ -409,14 +497,24 @@ func (p *BabblerPlugin) babbleSeed(babbler string, seed []string) (string, error } func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName string) error { - intoString := "<" + intoName + ">" - otherString := "<" + otherName + ">" + intoStringId, err := p.getOrCreateWordNode(intoId, "<" + intoName + ">") + if err != nil { + return err + } + otherStringId, err := p.getOrCreateWordNode(otherId, "<" + otherName + ">") + if err != nil { + return err + } + + log.Print(intoStringId) + log.Print(otherStringId) mapping := map[int64]int64{} - query := fmt.Sprintf("select id, word, root, rootFrequency from babblerWords where babblerId = %d;", otherId) + query := fmt.Sprintf("select wordId, root, rootFrequency from babblerNodes where babblerId = %d;", otherId) rows, err := p.db.Query(query) if err != nil { + log.Print(err) return err } defer rows.Close() @@ -432,8 +530,9 @@ func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName for rows.Next() { word := Word{} - err = rows.Scan(&word.Id, &word.Word, &word.Root, &word.RootFrequency) + err = rows.Scan(&word.Id, &word.Root, &word.RootFrequency) if err != nil { + log.Print(err) return err } words = append(words, word) @@ -441,40 +540,38 @@ func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName for _, word := range words { - if word.Word == otherString { - word.Word = intoString + if word.Id == otherStringId { + log.Print("HERE") + word.Id = intoStringId } - doInsert := false - wordId := int64(-1) + var res sql.Result + if word.Root > 0 { - res, err := p.db.Exec(`update babblerWords set rootFrequency = rootFrequency + ?, root = 1 where babblerId = ? and word = ? output id ;`, word.RootFrequency, intoId, word.Word) - rowsAffected := int64(0) - if err == nil { - rowsAffected, _ = res.RowsAffected() - } - if err != nil || rowsAffected == 0 { - doInsert = true - } else { - wordId, _ = res.LastInsertId() - } + res, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + ?, root = 1 where babblerId = ? and wordId = ? output id ;`, word.RootFrequency, intoId, word.Id) } else { - res, err := p.db.Exec(`update babblerWords set rootFrequency = rootFrequency + ? where babblerId = ? and word = ? output id;`, word.RootFrequency, intoId, word.Word) - if err != nil { - doInsert = true - } else { - wordId, _ = res.LastInsertId() - } + res, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + ? where babblerId = ? and wordId = ? output id;`, word.RootFrequency, intoId, word.Id) } - if doInsert { - res, err := p.db.Exec(`insert into babblerWords (babblerId, word, root, rootFrequency) values (?,?,?,?) ;`, intoId, word.Word, word.Root, word.RootFrequency) + rowsAffected := int64(-1) + if err == nil { + rowsAffected, _ = res.RowsAffected() + } + + wordId := int64(-1) + if err != nil || rowsAffected == 0 { + res, err = p.db.Exec(`insert into babblerNodes (babblerId, wordId, root, rootFrequency) values (?,?,?,?) ;`, intoId, word.Id, word.Root, word.RootFrequency) if err != nil { + log.Print(err) return err } - wordId, _ = res.LastInsertId() } + wordId, _ = res.LastInsertId() + + if word.Id == intoStringId { + word.Id = otherStringId + } mapping[word.Id] = wordId } diff --git a/plugins/babbler/commands.go b/plugins/babbler/commands.go index b16f768..38e2502 100644 --- a/plugins/babbler/commands.go +++ b/plugins/babbler/commands.go @@ -101,7 +101,10 @@ func (p *BabblerPlugin) merge(tokens []string) (string, bool) { return "merge failed.", true } - p.mergeBabblers(intoBabbler, whoBabbler, into, who) + err = p.mergeBabblers(intoBabbler, whoBabbler, into, who) + if err != nil { + return "merge failed.", true + } return "mooooiggged", true } From 9e56b54680e3fb62fcb1a26789747e8cccf7201f Mon Sep 17 00:00:00 2001 From: skkiesel Date: Wed, 17 May 2017 10:06:10 -0400 Subject: [PATCH 6/6] it was a disaster before, now it is less. --- plugins/babbler/babbler.go | 593 ++++++++++++++++++++----------------- 1 file changed, 314 insertions(+), 279 deletions(-) diff --git a/plugins/babbler/babbler.go b/plugins/babbler/babbler.go index 41b4f56..0c54b45 100644 --- a/plugins/babbler/babbler.go +++ b/plugins/babbler/babbler.go @@ -17,18 +17,42 @@ import ( ) var ( - NO_BABBLER = errors.New("babbler not found") + NO_BABBLER = errors.New("babbler not found") SAID_NOTHING = errors.New("hasn't said anything yet") - NEVER_SAID = errors.New("never said that") + NEVER_SAID = errors.New("never said that") ) - type BabblerPlugin struct { Bot bot.Bot db *sqlx.DB config *config.Config } +type Babbler struct { + BabblerId int64 `db:"id"` + Name string `db:"babbler"` +} + +type BabblerWord struct { + WordId int64 `db:"id"` + Word string `db:"word"` +} + +type BabblerNode struct { + NodeId int64 `db:"id"` + BabblerId int64 `db:"babblerId"` + WordId int64 `db:"wordId"` + Root int64 `db:"root"` + RootFrequency int64 `db:"rootFrequency"` +} + +type BabblerArc struct { + ArcId int64 `db:"id"` + FromNodeId int64 `db:"fromNodeId"` + ToNodeId int64 `db:"toNodeId"` + Frequency int64 `db:"frequency"` +} + func New(bot bot.Bot) *BabblerPlugin { log.SetFlags(log.LstdFlags | log.Lshortfile) @@ -59,8 +83,8 @@ func New(bot bot.Bot) *BabblerPlugin { if _, err := bot.DB().Exec(`create table if not exists babblerArcs ( id integer primary key, - fromWordId integer, - toWordId interger, + fromNodeId integer, + toNodeId interger, frequency integer );`); err != nil { log.Fatal(err) @@ -73,6 +97,8 @@ func New(bot bot.Bot) *BabblerPlugin { config: bot.Config(), } + plugin.createNewWord("") + return plugin } @@ -119,36 +145,50 @@ func (p *BabblerPlugin) RegisterWeb() *string { return nil } -func (p *BabblerPlugin) makeBabbler(babbler string) (int64, error) { - res, err := p.db.Exec(`insert into babblers (babbler) values (?);`, babbler) +func (p *BabblerPlugin) makeBabbler(name string) (*Babbler, error) { + res, err := p.db.Exec(`insert into babblers (babbler) values (?);`, name) if err == nil { - id, _ := res.LastInsertId() - return id, nil - } - return -1, err -} - -func (p *BabblerPlugin) getBabbler(babbler string) (int64, error) { - id := int64(-1) - err := p.db.Get(&id, `select id from babblers where babbler = ?`, babbler) - if err == sql.ErrNoRows { - return -1, NO_BABBLER - } - return id, err -} - -func (p *BabblerPlugin) getOrCreateBabbler(babbler string) (int64, error) { - id, err := p.getBabbler(babbler) - if err != nil { - id, err = p.makeBabbler(babbler) + id, err := res.LastInsertId() if err != nil { - return id, err + log.Print(err) + return nil, err } - query := fmt.Sprintf(`select tidbit from factoid where fact like '%s quotes';`, babbler) - rows, err := p.db.Query(query) + return &Babbler{ + BabblerId: id, + Name: name, + }, nil + } + return nil, err +} + +func (p *BabblerPlugin) getBabbler(name string) (*Babbler, error) { + var bblr Babbler + err := p.db.QueryRowx(`select * from babblers where babbler = ? LIMIT 1;`, name).StructScan(&bblr) + if err != nil { + if err == sql.ErrNoRows { + log.Printf("failed to find babbler") + return nil, NO_BABBLER + } + log.Printf("encountered problem in babbler lookup") + log.Print(err) + return nil, err + } + return &bblr, nil +} + +func (p *BabblerPlugin) getOrCreateBabbler(name string) (*Babbler, error) { + babbler, err := p.getBabbler(name) + if err == NO_BABBLER { + babbler, err = p.makeBabbler(name) if err != nil { - //we'll just ignore this but the actual creation succeeded previously - return id, nil + log.Print(err) + return nil, err + } + + rows, err := p.db.Queryx(fmt.Sprintf("select tidbit from factoid where fact like '%s quotes';", babbler.Name)) + if err != nil { + log.Print(err) + return babbler, nil } defer rows.Close() @@ -156,401 +196,411 @@ func (p *BabblerPlugin) getOrCreateBabbler(babbler string) (int64, error) { for rows.Next() { var tidbit string err := rows.Scan(&tidbit) + + log.Print(tidbit) + if err != nil { - return id, err + log.Print(err) + return babbler, err } tidbits = append(tidbits, tidbit) } for _, tidbit := range tidbits { - p.addToMarkovChain(id, tidbit) + if err = p.addToMarkovChain(babbler, tidbit); err != nil { + log.Print(err) + } } - } - return id, err + return babbler, err } -func (p *BabblerPlugin) getWordId(word string) (int64, error) { - id := int64(-1) - if err := p.db.Get(&id, `select id from babblerWords where word = ?`, word); err == sql.ErrNoRows { - return -1, NEVER_SAID - } else { - if err != nil { - log.Print(err) +func (p *BabblerPlugin) getWord(word string) (*BabblerWord, error) { + var w BabblerWord + err := p.db.QueryRowx(`select * from babblerWords where word = ? LIMIT 1;`, word).StructScan(&w) + if err != nil { + if err == sql.ErrNoRows { + return nil, NEVER_SAID } - return id, err + return nil, err } + return &w, nil } -func (p *BabblerPlugin) createNewWord(word string) (int64, error) { - if res, err := p.db.Exec(`insert into babblerWords (word) values (?);`, word); err == nil { - id, _ := res.LastInsertId() - return id, nil - } else { - if err != nil { - log.Print(err) - } - return -1, err +func (p *BabblerPlugin) createNewWord(word string) (*BabblerWord, error) { + res, err := p.db.Exec(`insert into babblerWords (word) values (?);`, word) + if err != nil { + log.Print(err) + return nil, err } + id, err := res.LastInsertId() + if err != nil { + log.Print(err) + return nil, err + } + return &BabblerWord{ + WordId: id, + Word: word, + }, nil } -func (p *BabblerPlugin) getOrCreateWord(word string) (int64, error) { - if id, err := p.getWordId(word); err == NEVER_SAID { +func (p *BabblerPlugin) getOrCreateWord(word string) (*BabblerWord, error) { + if w, err := p.getWord(word); err == NEVER_SAID { return p.createNewWord(word) - } else { - return id, err - } -} - -func (p *BabblerPlugin) getWordNodeId(babblerId int64, word string) (int64, error) { - var err error - wordId := int64(-1) - if wordId, err = p.getWordId(word); err != nil { - return -1, err - } - id := int64(-1) - if err = p.db.Get(&id, `select id from babblerNodes where babblerId = ? and wordId = ?`, babblerId, wordId); err == sql.ErrNoRows { - return -1, NEVER_SAID } else { if err != nil { log.Print(err) } - return id, err + return w, err } } -func (p *BabblerPlugin) createWordNodeId(babblerId int64, word string) (int64, error) { - var err error - wordId := int64(-1) - if wordId, err = p.getOrCreateWord(word); err != nil { - log.Print(err) - return -1, err +func (p *BabblerPlugin) getBabblerNode(babbler *Babbler, word string) (*BabblerNode, error) { + w, err := p.getWord(word) + if err != nil { + return nil, err } - id := int64(-1) - res, err := p.db.Exec(`insert into babblerNodes (babblerId, wordId, root, rootFrequency) values (?, ?, 0, 0)`, babblerId, wordId) - if err == sql.ErrNoRows { - return -1, NEVER_SAID - } else if err == nil { - id, _ = res.LastInsertId() + + var node BabblerNode + err = p.db.QueryRowx(`select * from babblerNodes where babblerId = ? and wordId = ? LIMIT 1;`, babbler.BabblerId, w.WordId).StructScan(&node) + if err != nil { + if err == sql.ErrNoRows { + return nil, NEVER_SAID + } + return nil, err } + return &node, nil +} + +func (p *BabblerPlugin) createBabblerNode(babbler *Babbler, word string) (*BabblerNode, error) { + w, err := p.getOrCreateWord(word) if err != nil { log.Print(err) - } - return id, err -} - -func (p *BabblerPlugin) getOrCreateWordNode(babblerId int64, word string) (int64, error) { - if id, err := p.getWordNodeId(babblerId, word); err == NEVER_SAID { - return p.createWordNodeId(babblerId, word) - } else { - if err != nil { - log.Print(err) - } - return id, err - } -} - -func (p *BabblerPlugin) incrementRootWordFrequency(babblerId int64, word string) (int64, error) { - var err error - id := int64(-1) - if id, err = p.getOrCreateWordNode(babblerId, word); err != nil { - return -1, err + return nil, err } - if _, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + 1, root = 1 where babblerId = ? and wordId = ?;`, babblerId, id); err != nil { - if err != nil { - log.Print(err) - } - return -1, err - } - return id, nil -} - -func (p *BabblerPlugin) getWordArcHelper(fromWordId, toWordId int64) (int64, error) { - id := int64(-1) - err := p.db.Get(&id, `select id from babblerArcs where fromWordId = ? and toWordId = ?`, fromWordId, toWordId) - if err == sql.ErrNoRows { - return -1, NEVER_SAID - } + res, err := p.db.Exec(`insert into babblerNodes (babblerId, wordId, root, rootFrequency) values (?, ?, 0, 0)`, babbler.BabblerId, w.WordId) if err != nil { log.Print(err) + return nil, err } - return id, err + + id, err := res.LastInsertId() + if err != nil { + log.Print(err) + return nil, err + } + + return &BabblerNode{ + NodeId: id, + WordId: w.WordId, + Root: 0, + RootFrequency: 0, + }, nil } -func (p *BabblerPlugin) incrementWordArc(fromWordId, toWordId int64) (int64, error) { - res, err := p.db.Exec(`update babblerArcs set frequency = frequency + 1 where fromWordId = ? and toWordId = ?`, fromWordId, toWordId) +func (p *BabblerPlugin) getOrCreateBabblerNode(babbler *Babbler, word string) (*BabblerNode, error) { + node, err := p.getBabblerNode(babbler, word) + if err != nil { + return p.createBabblerNode(babbler, word) + } + return node, nil +} + +func (p *BabblerPlugin) incrementRootWordFrequency(babbler *Babbler, word string) (*BabblerNode, error) { + node, err := p.getOrCreateBabblerNode(babbler, word) + if err != nil { + log.Print(err) + return nil, err + } + _, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + 1, root = 1 where id = ?;`, node.NodeId) + if err != nil { + log.Print(err) + return nil, err + } + node.RootFrequency += 1 + return node, nil +} + +func (p *BabblerPlugin) getBabblerArc(fromNode, toNode *BabblerNode) (*BabblerArc, error) { + var arc BabblerArc + err := p.db.QueryRowx(`select * from babblerArcs where fromNodeId = ? and toNodeId = ?;`, fromNode.NodeId, toNode.NodeId).StructScan(&arc) + if err != nil { + if err == sql.ErrNoRows { + return nil, NEVER_SAID + } + return nil, err + } + return &arc, nil +} + +func (p *BabblerPlugin) incrementWordArc(fromNode, toNode *BabblerNode) (*BabblerArc, error) { + res, err := p.db.Exec(`update babblerArcs set frequency = frequency + 1 where fromNodeId = ? and toNodeId = ?;`, fromNode.NodeId, toNode.NodeId) + if err != nil { + log.Print(err) + return nil, err + } + affectedRows := int64(0) if err == nil { affectedRows, _ = res.RowsAffected() } - if err != nil || affectedRows == 0 { - res, err = p.db.Exec(`insert into babblerArcs (fromWordId, toWordId, frequency) values (?, ?, 1);`, fromWordId, toWordId) + if affectedRows == 0 { + res, err = p.db.Exec(`insert into babblerArcs (fromNodeId, toNodeId, frequency) values (?, ?, 1);`, fromNode.NodeId, toNode.NodeId) if err != nil { log.Print(err) - return -1, err + return nil, err } } - id, _ := res.LastInsertId() - return id, nil + return p.getBabblerArc(fromNode, toNode) } -func (p *BabblerPlugin) incrementFinalWordArcHelper(wordId int64) (int64, error) { - return p.incrementWordArc(wordId, -1) -} - -func (p *BabblerPlugin) incrementWordArcHelper(babblerId, fromWordId int64, toWord string) (int64, error) { - toWordId, err := p.getOrCreateWordNode(babblerId, toWord) +func (p *BabblerPlugin) incrementFinalWordArcHelper(babbler *Babbler, node *BabblerNode) (*BabblerArc, error) { + nextNode, err := p.getOrCreateBabblerNode(babbler, " ") if err != nil { - return -1, err + return nil, err } - _, err = p.incrementWordArc(fromWordId, toWordId) - if err != nil { - return -1, err - } - return toWordId, nil + return p.incrementWordArc(node, nextNode) } -func (p *BabblerPlugin) addToMarkovChain(babblerId int64, phrase string) { +func (p *BabblerPlugin) addToMarkovChain(babbler *Babbler, phrase string) error { words := strings.Fields(strings.ToLower(phrase)) if len(words) <= 0 { - return + return nil } - id, err := p.incrementRootWordFrequency(babblerId, words[0]) + curNode, err := p.incrementRootWordFrequency(babbler, words[0]) if err != nil { log.Print(err) - return + return err } for i := 1; i < len(words); i++ { - id, err = p.incrementWordArcHelper(babblerId, id, words[i]) + nextNode, err := p.getOrCreateBabblerNode(babbler, words[i]) if err != nil { log.Print(err) - return + return err } + _, err = p.incrementWordArc(curNode, nextNode) + if err != nil { + log.Print(err) + return err + } + curNode = nextNode } - _, err = p.incrementFinalWordArcHelper(id) + _, err = p.incrementFinalWordArcHelper(babbler, curNode) + return err } -func (p *BabblerPlugin) getWeightedRootWord(babblerId int64) (int64, string, error) { - query := fmt.Sprintf("select id, wordId, rootFrequency from babblerNodes where babblerId = %d and root = 1", babblerId) - rows, err := p.db.Query(query) +func (p *BabblerPlugin) getWeightedRootNode(babbler *Babbler) (*BabblerNode, *BabblerWord, error) { + rows, err := p.db.Queryx(`select * from babblerNodes where babblerId = ? and root = 1;`, babbler.BabblerId) if err != nil { log.Print(err) - return -1, "", err + return nil, nil, err } defer rows.Close() - idToWordId := map[int64]int64{} - idToFreq := map[int64]int64{} + rootNodes := []*BabblerNode{} total := int64(0) for rows.Next() { - var id int64 - var wordId int64 - var rootFrequency int64 - err = rows.Scan(&id, &wordId, &rootFrequency) + var node BabblerNode + err = rows.StructScan(&node) if err != nil { log.Print(err) - return -1, "", err + return nil, nil, err } - - total += rootFrequency - idToFreq[id] = rootFrequency - idToWordId[id] = wordId + rootNodes = append(rootNodes, &node) + total += node.RootFrequency } - if total == 0 { - return -1, "", SAID_NOTHING + if len(rootNodes) == 0 { + return nil, nil, SAID_NOTHING } which := rand.Int63n(total) total = 0 - for id, freq := range idToFreq { - if total+freq >= which { - var word string - err := p.db.Get(&word, `select word from babblerWords where id = ?`, idToWordId[id]) + for _, node := range rootNodes { + total += node.RootFrequency + if total >= which { + var w BabblerWord + err := p.db.QueryRowx(`select * from babblerWords where id = ? LIMIT 1;`, node.WordId).StructScan(&w) if err != nil { log.Print(err) - return -1, "", err + return nil, nil, err } - return id, word, nil + return node, &w, nil } - total += freq + } log.Fatalf("shouldn't happen") - return -1, "", errors.New("failed to find weighted root word") + return nil, nil, errors.New("failed to find weighted root word") } -func (p *BabblerPlugin) getWeightedNextWord(fromWordId int64) (int64, string, error) { - query := fmt.Sprintf("select toWordId, frequency from babblerArcs where fromWordId = %d;", fromWordId) - rows, err := p.db.Query(query) +func (p *BabblerPlugin) getWeightedNextWord(fromNode *BabblerNode) (*BabblerNode, *BabblerWord, error) { + rows, err := p.db.Queryx(`select * from babblerArcs where fromNodeId = ?;`, fromNode.NodeId) if err != nil { log.Print(err) - return -1, "", err + return nil, nil, err } defer rows.Close() - idToFreq := map[int64]int64{} + arcs := []*BabblerArc{} total := int64(0) - for rows.Next() { - var toWordId int64 - var frequency int64 - err = rows.Scan(&toWordId, &frequency) + var arc BabblerArc + err = rows.StructScan(&arc) if err != nil { log.Print(err) - return -1, "", err + return nil, nil, err } - total += frequency - idToFreq[toWordId] = frequency + arcs = append(arcs, &arc) + total += arc.Frequency } - if total == 0 { - return -1, "", errors.New("missing arcs") + if len(arcs) == 0 { + return nil, nil, errors.New("missing arcs") } which := rand.Int63n(total) total = 0 - for id, freq := range idToFreq { - if total+freq >= which { - if id < 0 { - return -1, "", nil + for _, arc := range arcs { + + total += arc.Frequency + + if total >= which { + var node BabblerNode + err := p.db.QueryRowx(`select * from babblerNodes where id = ? LIMIT 1;`, arc.ToNodeId).StructScan(&node) + if err != nil { + log.Print(err) + return nil, nil, err } - var wordId int64 - if err := p.db.Get(&wordId, `select wordId from babblerNodes where id = ?`, id); err != nil { + var w BabblerWord + err = p.db.QueryRowx(`select * from babblerWords where id = ? LIMIT 1;`, node.WordId).StructScan(&w) + if err != nil { log.Print(err) - return -1, "", err + return nil, nil, err } - var word string - if err := p.db.Get(&word, `select word from babblerWords where id = ?`, wordId); err != nil { - log.Print(err) - return -1, "", err - } - - return id, word, nil + return &node, &w, nil } - total +=freq + } log.Fatalf("shouldn't happen") - return -1, "", errors.New("failed to find weighted next word") + return nil, nil, errors.New("failed to find weighted next word") } func (p *BabblerPlugin) babble(who string) (string, error) { return p.babbleSeed(who, []string{}) } -func (p *BabblerPlugin) babbleSeed(babbler string, seed []string) (string, error) { - babblerId, err := p.getBabbler(babbler) +func (p *BabblerPlugin) babbleSeed(babblerName string, seed []string) (string, error) { + babbler, err := p.getBabbler(babblerName) if err != nil { + log.Print(err) return "", nil } words := seed - var curWordId int64 + var curNode *BabblerNode + var curWord *BabblerWord if len(seed) == 0 { - id, word, err := p.getWeightedRootWord(babblerId) + curNode, curWord, err = p.getWeightedRootNode(babbler) if err != nil { + log.Print(err) return "", err } - curWordId = id - words = append(words, word) + words = append(words, curWord.Word) } else { - id, err := p.getWordNodeId(babblerId, seed[0]) + curNode, err = p.getBabblerNode(babbler, seed[0]) if err != nil { - + log.Print(err) return "", err } - curWordId = id for i := 1; i < len(seed); i++ { - nextWordId, err := p.getWordNodeId(babblerId, seed[i]) + nextNode, err := p.getBabblerNode(babbler, seed[i]) if err != nil { + log.Print(err) return "", err } - _, err = p.getWordArcHelper(curWordId, nextWordId) + _, err = p.getBabblerArc(curNode, nextNode) if err != nil { + log.Print(err) return "", err } - curWordId = nextWordId + curNode = nextNode } } for { - id, word, err := p.getWeightedNextWord(curWordId) + curNode, curWord, err = p.getWeightedNextWord(curNode) if err != nil { + log.Print(err) return "", err } - if id < 0 { + if curWord.Word == " " { break } - words = append(words, word) - curWordId = id + words = append(words, curWord.Word) } return strings.TrimSpace(strings.Join(words, " ")), nil } -func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName string) error { - intoStringId, err := p.getOrCreateWordNode(intoId, "<" + intoName + ">") +func (p *BabblerPlugin) mergeBabblers(intoBabbler, otherBabbler *Babbler, intoName, otherName string) error { + intoNode, err := p.getOrCreateBabblerNode(intoBabbler, "<"+intoName+">") if err != nil { + log.Print(err) return err } - otherStringId, err := p.getOrCreateWordNode(otherId, "<" + otherName + ">") + otherNode, err := p.getOrCreateBabblerNode(otherBabbler, "<"+otherName+">") if err != nil { + log.Print(err) return err } - log.Print(intoStringId) - log.Print(otherStringId) + mapping := map[int64]*BabblerNode{} - mapping := map[int64]int64{} - - query := fmt.Sprintf("select wordId, root, rootFrequency from babblerNodes where babblerId = %d;", otherId) - rows, err := p.db.Query(query) + rows, err := p.db.Queryx("select * from babblerNodes where babblerId = ?;", otherBabbler.BabblerId) if err != nil { log.Print(err) return err } defer rows.Close() - type Word struct { - Id int64 - Word string - Root int64 - RootFrequency int64 - } - - words := []Word{} + nodes := []*BabblerNode{} for rows.Next() { - word := Word{} - err = rows.Scan(&word.Id, &word.Root, &word.RootFrequency) + var node BabblerNode + err = rows.StructScan(&node) if err != nil { log.Print(err) return err } - words = append(words, word) + nodes = append(nodes, &node) } - for _, word := range words { - - if word.Id == otherStringId { - log.Print("HERE") - word.Id = intoStringId - } - + for _, node := range nodes { var res sql.Result - if word.Root > 0 { - res, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + ?, root = 1 where babblerId = ? and wordId = ? output id ;`, word.RootFrequency, intoId, word.Id) + if node.NodeId == otherNode.NodeId { + node.WordId = intoNode.WordId + } + + if node.Root > 0 { + res, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + ?, root = 1 where babblerId = ? and wordId = ?;`, node.RootFrequency, intoBabbler.BabblerId, node.WordId) + if err != nil { + log.Print(err) + } } else { - res, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + ? where babblerId = ? and wordId = ? output id;`, word.RootFrequency, intoId, word.Id) + res, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + ? where babblerId = ? and wordId = ?;`, node.RootFrequency, intoBabbler.BabblerId, node.WordId) + if err != nil { + log.Print(err) + } } rowsAffected := int64(-1) @@ -558,65 +608,50 @@ func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName rowsAffected, _ = res.RowsAffected() } - wordId := int64(-1) if err != nil || rowsAffected == 0 { - res, err = p.db.Exec(`insert into babblerNodes (babblerId, wordId, root, rootFrequency) values (?,?,?,?) ;`, intoId, word.Id, word.Root, word.RootFrequency) + res, err = p.db.Exec(`insert into babblerNodes (babblerId, wordId, root, rootFrequency) values (?,?,?,?) ;`, intoBabbler.BabblerId, node.WordId, node.Root, node.RootFrequency) if err != nil { log.Print(err) return err } } - wordId, _ = res.LastInsertId() - - if word.Id == intoStringId { - word.Id = otherStringId + var updatedNode BabblerNode + err = p.db.QueryRowx(`select * from babblerNodes where babblerId = ? and wordId = ? LIMIT 1;`, intoBabbler.BabblerId, node.WordId).StructScan(&updatedNode) + if err != nil { + log.Print(err) + return err } - mapping[word.Id] = wordId + + mapping[node.NodeId] = &updatedNode } - type Arc struct { - ToWordId int64 - Frequency int64 - } - - for lookup, newArcStart := range mapping { - query = fmt.Sprintf("select toWordId, frequency from babblerArcs where fromWordId = %d;", lookup) - rows, err := p.db.Query(query) + for oldNodeId, newNode := range mapping { + rows, err := p.db.Queryx("select * from babblerArcs where fromNodeId = ?;", oldNodeId) if err != nil { return err } defer rows.Close() - arcs := []Arc{} + arcs := []*BabblerArc{} for rows.Next() { - var arc Arc - err = rows.Scan(&arc.ToWordId, &arc.Frequency) + var arc BabblerArc + err = rows.StructScan(&arc) if err != nil { + log.Print(err) return err } - arcs = append(arcs, arc) + arcs = append(arcs, &arc) } for _, arc := range arcs { - newArcEnd := int64(-1) //handle end arcs - if arc.ToWordId >= 0 { - newArcEnd = mapping[arc.ToWordId] - } - - res, err := p.db.Exec(`update babblerArcs set frequency = frequency + ? where fromWordId = ? and toWordId = ?`, arc.Frequency, newArcStart, newArcEnd) - rowsAffected := int64(0) - if err == nil { - rowsAffected, _ = res.RowsAffected() - } - if err != nil || rowsAffected == 0 { - _, err = p.db.Exec(`insert into babblerArcs (fromWordId, toWordId, frequency) values (?, ?, ?);`, newArcStart, newArcEnd, arc.Frequency) - if err != nil { - return err - } + _, err := p.incrementWordArc(newNode, mapping[arc.ToNodeId]) + if err != nil { + return err } } } - return nil + + return err }