From e97ce576b8fa0ea2af3e0dac2868794f759bee14 Mon Sep 17 00:00:00 2001 From: skkiesel Date: Wed, 10 May 2017 11:59:35 -0400 Subject: [PATCH] probably not a huge memory savings, but only represent each unique word once and then reference them by their key for each babbler --- plugins/babbler/babbler.go | 211 ++++++++++++++++++++++++++---------- plugins/babbler/commands.go | 5 +- 2 files changed, 158 insertions(+), 58 deletions(-) diff --git a/plugins/babbler/babbler.go b/plugins/babbler/babbler.go index 6073d2d..41b4f56 100644 --- a/plugins/babbler/babbler.go +++ b/plugins/babbler/babbler.go @@ -39,10 +39,18 @@ func New(bot bot.Bot) *BabblerPlugin { );`); err != nil { log.Fatal(err) } + if _, err := bot.DB().Exec(`create table if not exists babblerWords ( + id integer primary key, + word string + );`); err != nil { + log.Fatal(err) + } + + if _, err := bot.DB().Exec(`create table if not exists babblerNodes ( id integer primary key, babblerId integer, - word string, + wordId integer, root integer, rootFrequency integer );`); err != nil { @@ -162,39 +170,97 @@ func (p *BabblerPlugin) getOrCreateBabbler(babbler string) (int64, error) { return id, err } -func (p *BabblerPlugin) getWordId(babblerId int64, word string) (int64, error) { +func (p *BabblerPlugin) getWordId(word string) (int64, error) { id := int64(-1) - err := p.db.Get(&id, `select id from babblerWords where babblerId = ? and word = ?`, babblerId, word) - if err == sql.ErrNoRows { + if err := p.db.Get(&id, `select id from babblerWords where word = ?`, word); err == sql.ErrNoRows { return -1, NEVER_SAID + } else { + if err != nil { + log.Print(err) + } + return id, err } - return id, err } -func (p *BabblerPlugin) createNewWord(babblerId int64, word string) (int64, error) { - res, err := p.db.Exec(`insert into babblerWords (babblerId, word, root, rootFrequency) values (?, ?, 0, 0);`, babblerId, word) - if err != nil { +func (p *BabblerPlugin) createNewWord(word string) (int64, error) { + if res, err := p.db.Exec(`insert into babblerWords (word) values (?);`, word); err == nil { + id, _ := res.LastInsertId() + return id, nil + } else { + if err != nil { + log.Print(err) + } return -1, err } - id, _ := res.LastInsertId() - return id, nil } -func (p *BabblerPlugin) getOrCreateWord(babblerId int64, word string) (int64, error) { - id, err := p.getWordId(babblerId, word) +func (p *BabblerPlugin) getOrCreateWord(word string) (int64, error) { + if id, err := p.getWordId(word); err == NEVER_SAID { + return p.createNewWord(word) + } else { + return id, err + } +} + +func (p *BabblerPlugin) getWordNodeId(babblerId int64, word string) (int64, error) { + var err error + wordId := int64(-1) + if wordId, err = p.getWordId(word); err != nil { + return -1, err + } + id := int64(-1) + if err = p.db.Get(&id, `select id from babblerNodes where babblerId = ? and wordId = ?`, babblerId, wordId); err == sql.ErrNoRows { + return -1, NEVER_SAID + } else { + if err != nil { + log.Print(err) + } + return id, err + } +} + +func (p *BabblerPlugin) createWordNodeId(babblerId int64, word string) (int64, error) { + var err error + wordId := int64(-1) + if wordId, err = p.getOrCreateWord(word); err != nil { + log.Print(err) + return -1, err + } + id := int64(-1) + res, err := p.db.Exec(`insert into babblerNodes (babblerId, wordId, root, rootFrequency) values (?, ?, 0, 0)`, babblerId, wordId) + if err == sql.ErrNoRows { + return -1, NEVER_SAID + } else if err == nil { + id, _ = res.LastInsertId() + } if err != nil { - return p.createNewWord(babblerId, word) + log.Print(err) } return id, err } +func (p *BabblerPlugin) getOrCreateWordNode(babblerId int64, word string) (int64, error) { + if id, err := p.getWordNodeId(babblerId, word); err == NEVER_SAID { + return p.createWordNodeId(babblerId, word) + } else { + if err != nil { + log.Print(err) + } + return id, err + } +} + func (p *BabblerPlugin) incrementRootWordFrequency(babblerId int64, word string) (int64, error) { - id, err := p.getOrCreateWord(babblerId, word) - if err != nil { + var err error + id := int64(-1) + if id, err = p.getOrCreateWordNode(babblerId, word); err != nil { return -1, err } - _, err = p.db.Exec(`update babblerWords set rootFrequency = rootFrequency + 1, root = 1 where id = ?;`, id) - if err != nil { + + if _, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + 1, root = 1 where babblerId = ? and wordId = ?;`, babblerId, id); err != nil { + if err != nil { + log.Print(err) + } return -1, err } return id, nil @@ -206,6 +272,9 @@ func (p *BabblerPlugin) getWordArcHelper(fromWordId, toWordId int64) (int64, err if err == sql.ErrNoRows { return -1, NEVER_SAID } + if err != nil { + log.Print(err) + } return id, err } @@ -219,7 +288,7 @@ func (p *BabblerPlugin) incrementWordArc(fromWordId, toWordId int64) (int64, err if err != nil || affectedRows == 0 { res, err = p.db.Exec(`insert into babblerArcs (fromWordId, toWordId, frequency) values (?, ?, 1);`, fromWordId, toWordId) if err != nil { - + log.Print(err) return -1, err } } @@ -232,7 +301,7 @@ func (p *BabblerPlugin) incrementFinalWordArcHelper(wordId int64) (int64, error) } func (p *BabblerPlugin) incrementWordArcHelper(babblerId, fromWordId int64, toWord string) (int64, error) { - toWordId, err := p.getOrCreateWord(babblerId, toWord) + toWordId, err := p.getOrCreateWordNode(babblerId, toWord) if err != nil { return -1, err } @@ -252,12 +321,14 @@ func (p *BabblerPlugin) addToMarkovChain(babblerId int64, phrase string) { id, err := p.incrementRootWordFrequency(babblerId, words[0]) if err != nil { + log.Print(err) return } for i := 1; i < len(words); i++ { id, err = p.incrementWordArcHelper(babblerId, id, words[i]) if err != nil { + log.Print(err) return } } @@ -266,29 +337,31 @@ func (p *BabblerPlugin) addToMarkovChain(babblerId int64, phrase string) { } func (p *BabblerPlugin) getWeightedRootWord(babblerId int64) (int64, string, error) { - query := fmt.Sprintf("select id, word, rootFrequency from babblerWords where babblerId = %d and root = 1", babblerId) + query := fmt.Sprintf("select id, wordId, rootFrequency from babblerNodes where babblerId = %d and root = 1", babblerId) rows, err := p.db.Query(query) if err != nil { + log.Print(err) return -1, "", err } defer rows.Close() - idToWord := map[int64]string{} + idToWordId := map[int64]int64{} idToFreq := map[int64]int64{} total := int64(0) for rows.Next() { var id int64 - var word string + var wordId int64 var rootFrequency int64 - err = rows.Scan(&id, &word, &rootFrequency) + err = rows.Scan(&id, &wordId, &rootFrequency) if err != nil { + log.Print(err) return -1, "", err } total += rootFrequency idToFreq[id] = rootFrequency - idToWord[id] = word + idToWordId[id] = wordId } if total == 0 { @@ -299,7 +372,13 @@ func (p *BabblerPlugin) getWeightedRootWord(babblerId int64) (int64, string, err total = 0 for id, freq := range idToFreq { if total+freq >= which { - return id, idToWord[id], nil + var word string + err := p.db.Get(&word, `select word from babblerWords where id = ?`, idToWordId[id]) + if err != nil { + log.Print(err) + return -1, "", err + } + return id, word, nil } total += freq } @@ -311,6 +390,7 @@ func (p *BabblerPlugin) getWeightedNextWord(fromWordId int64) (int64, string, er query := fmt.Sprintf("select toWordId, frequency from babblerArcs where fromWordId = %d;", fromWordId) rows, err := p.db.Query(query) if err != nil { + log.Print(err) return -1, "", err } defer rows.Close() @@ -323,6 +403,7 @@ func (p *BabblerPlugin) getWeightedNextWord(fromWordId int64) (int64, string, er var frequency int64 err = rows.Scan(&toWordId, &frequency) if err != nil { + log.Print(err) return -1, "", err } total += frequency @@ -341,11 +422,17 @@ func (p *BabblerPlugin) getWeightedNextWord(fromWordId int64) (int64, string, er return -1, "", nil } - var word string - err := p.db.Get(&word, `select word from babblerWords where id = ?`, id) - if err != nil { + var wordId int64 + if err := p.db.Get(&wordId, `select wordId from babblerNodes where id = ?`, id); err != nil { + log.Print(err) return -1, "", err } + var word string + if err := p.db.Get(&word, `select word from babblerWords where id = ?`, wordId); err != nil { + log.Print(err) + return -1, "", err + } + return id, word, nil } total +=freq @@ -375,13 +462,14 @@ func (p *BabblerPlugin) babbleSeed(babbler string, seed []string) (string, error curWordId = id words = append(words, word) } else { - id, err := p.getWordId(babblerId, seed[0]) + id, err := p.getWordNodeId(babblerId, seed[0]) if err != nil { + return "", err } curWordId = id for i := 1; i < len(seed); i++ { - nextWordId, err := p.getWordId(babblerId, seed[i]) + nextWordId, err := p.getWordNodeId(babblerId, seed[i]) if err != nil { return "", err } @@ -409,14 +497,24 @@ func (p *BabblerPlugin) babbleSeed(babbler string, seed []string) (string, error } func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName string) error { - intoString := "<" + intoName + ">" - otherString := "<" + otherName + ">" + intoStringId, err := p.getOrCreateWordNode(intoId, "<" + intoName + ">") + if err != nil { + return err + } + otherStringId, err := p.getOrCreateWordNode(otherId, "<" + otherName + ">") + if err != nil { + return err + } + + log.Print(intoStringId) + log.Print(otherStringId) mapping := map[int64]int64{} - query := fmt.Sprintf("select id, word, root, rootFrequency from babblerWords where babblerId = %d;", otherId) + query := fmt.Sprintf("select wordId, root, rootFrequency from babblerNodes where babblerId = %d;", otherId) rows, err := p.db.Query(query) if err != nil { + log.Print(err) return err } defer rows.Close() @@ -432,8 +530,9 @@ func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName for rows.Next() { word := Word{} - err = rows.Scan(&word.Id, &word.Word, &word.Root, &word.RootFrequency) + err = rows.Scan(&word.Id, &word.Root, &word.RootFrequency) if err != nil { + log.Print(err) return err } words = append(words, word) @@ -441,40 +540,38 @@ func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName for _, word := range words { - if word.Word == otherString { - word.Word = intoString + if word.Id == otherStringId { + log.Print("HERE") + word.Id = intoStringId } - doInsert := false - wordId := int64(-1) + var res sql.Result + if word.Root > 0 { - res, err := p.db.Exec(`update babblerWords set rootFrequency = rootFrequency + ?, root = 1 where babblerId = ? and word = ? output id ;`, word.RootFrequency, intoId, word.Word) - rowsAffected := int64(0) - if err == nil { - rowsAffected, _ = res.RowsAffected() - } - if err != nil || rowsAffected == 0 { - doInsert = true - } else { - wordId, _ = res.LastInsertId() - } + res, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + ?, root = 1 where babblerId = ? and wordId = ? output id ;`, word.RootFrequency, intoId, word.Id) } else { - res, err := p.db.Exec(`update babblerWords set rootFrequency = rootFrequency + ? where babblerId = ? and word = ? output id;`, word.RootFrequency, intoId, word.Word) - if err != nil { - doInsert = true - } else { - wordId, _ = res.LastInsertId() - } + res, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + ? where babblerId = ? and wordId = ? output id;`, word.RootFrequency, intoId, word.Id) } - if doInsert { - res, err := p.db.Exec(`insert into babblerWords (babblerId, word, root, rootFrequency) values (?,?,?,?) ;`, intoId, word.Word, word.Root, word.RootFrequency) + rowsAffected := int64(-1) + if err == nil { + rowsAffected, _ = res.RowsAffected() + } + + wordId := int64(-1) + if err != nil || rowsAffected == 0 { + res, err = p.db.Exec(`insert into babblerNodes (babblerId, wordId, root, rootFrequency) values (?,?,?,?) ;`, intoId, word.Id, word.Root, word.RootFrequency) if err != nil { + log.Print(err) return err } - wordId, _ = res.LastInsertId() } + wordId, _ = res.LastInsertId() + + if word.Id == intoStringId { + word.Id = otherStringId + } mapping[word.Id] = wordId } diff --git a/plugins/babbler/commands.go b/plugins/babbler/commands.go index b16f768..38e2502 100644 --- a/plugins/babbler/commands.go +++ b/plugins/babbler/commands.go @@ -101,7 +101,10 @@ func (p *BabblerPlugin) merge(tokens []string) (string, bool) { return "merge failed.", true } - p.mergeBabblers(intoBabbler, whoBabbler, into, who) + err = p.mergeBabblers(intoBabbler, whoBabbler, into, who) + if err != nil { + return "merge failed.", true + } return "mooooiggged", true }