mirror of https://github.com/velour/catbase.git
probably not a huge memory savings, but only represent each unique word once and then reference them by their key for each babbler
This commit is contained in:
parent
efde5d804d
commit
e97ce576b8
|
@ -39,10 +39,18 @@ func New(bot bot.Bot) *BabblerPlugin {
|
||||||
);`); err != nil {
|
);`); err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := bot.DB().Exec(`create table if not exists babblerWords (
|
if _, err := bot.DB().Exec(`create table if not exists babblerWords (
|
||||||
|
id integer primary key,
|
||||||
|
word string
|
||||||
|
);`); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := bot.DB().Exec(`create table if not exists babblerNodes (
|
||||||
id integer primary key,
|
id integer primary key,
|
||||||
babblerId integer,
|
babblerId integer,
|
||||||
word string,
|
wordId integer,
|
||||||
root integer,
|
root integer,
|
||||||
rootFrequency integer
|
rootFrequency integer
|
||||||
);`); err != nil {
|
);`); err != nil {
|
||||||
|
@ -162,39 +170,97 @@ func (p *BabblerPlugin) getOrCreateBabbler(babbler string) (int64, error) {
|
||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BabblerPlugin) getWordId(babblerId int64, word string) (int64, error) {
|
func (p *BabblerPlugin) getWordId(word string) (int64, error) {
|
||||||
id := int64(-1)
|
id := int64(-1)
|
||||||
err := p.db.Get(&id, `select id from babblerWords where babblerId = ? and word = ?`, babblerId, word)
|
if err := p.db.Get(&id, `select id from babblerWords where word = ?`, word); err == sql.ErrNoRows {
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return -1, NEVER_SAID
|
return -1, NEVER_SAID
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
}
|
}
|
||||||
return id, err
|
return id, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BabblerPlugin) createNewWord(babblerId int64, word string) (int64, error) {
|
func (p *BabblerPlugin) createNewWord(word string) (int64, error) {
|
||||||
res, err := p.db.Exec(`insert into babblerWords (babblerId, word, root, rootFrequency) values (?, ?, 0, 0);`, babblerId, word)
|
if res, err := p.db.Exec(`insert into babblerWords (word) values (?);`, word); err == nil {
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
id, _ := res.LastInsertId()
|
id, _ := res.LastInsertId()
|
||||||
return id, nil
|
return id, nil
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
}
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BabblerPlugin) getOrCreateWord(babblerId int64, word string) (int64, error) {
|
func (p *BabblerPlugin) getOrCreateWord(word string) (int64, error) {
|
||||||
id, err := p.getWordId(babblerId, word)
|
if id, err := p.getWordId(word); err == NEVER_SAID {
|
||||||
|
return p.createNewWord(word)
|
||||||
|
} else {
|
||||||
|
return id, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) getWordNodeId(babblerId int64, word string) (int64, error) {
|
||||||
|
var err error
|
||||||
|
wordId := int64(-1)
|
||||||
|
if wordId, err = p.getWordId(word); err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
id := int64(-1)
|
||||||
|
if err = p.db.Get(&id, `select id from babblerNodes where babblerId = ? and wordId = ?`, babblerId, wordId); err == sql.ErrNoRows {
|
||||||
|
return -1, NEVER_SAID
|
||||||
|
} else {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return p.createNewWord(babblerId, word)
|
log.Print(err)
|
||||||
}
|
}
|
||||||
return id, err
|
return id, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) createWordNodeId(babblerId int64, word string) (int64, error) {
|
||||||
|
var err error
|
||||||
|
wordId := int64(-1)
|
||||||
|
if wordId, err = p.getOrCreateWord(word); err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
id := int64(-1)
|
||||||
|
res, err := p.db.Exec(`insert into babblerNodes (babblerId, wordId, root, rootFrequency) values (?, ?, 0, 0)`, babblerId, wordId)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return -1, NEVER_SAID
|
||||||
|
} else if err == nil {
|
||||||
|
id, _ = res.LastInsertId()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
}
|
||||||
|
return id, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) getOrCreateWordNode(babblerId int64, word string) (int64, error) {
|
||||||
|
if id, err := p.getWordNodeId(babblerId, word); err == NEVER_SAID {
|
||||||
|
return p.createWordNodeId(babblerId, word)
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
}
|
||||||
|
return id, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BabblerPlugin) incrementRootWordFrequency(babblerId int64, word string) (int64, error) {
|
func (p *BabblerPlugin) incrementRootWordFrequency(babblerId int64, word string) (int64, error) {
|
||||||
id, err := p.getOrCreateWord(babblerId, word)
|
var err error
|
||||||
if err != nil {
|
id := int64(-1)
|
||||||
|
if id, err = p.getOrCreateWordNode(babblerId, word); err != nil {
|
||||||
return -1, err
|
return -1, err
|
||||||
}
|
}
|
||||||
_, err = p.db.Exec(`update babblerWords set rootFrequency = rootFrequency + 1, root = 1 where id = ?;`, id)
|
|
||||||
|
if _, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + 1, root = 1 where babblerId = ? and wordId = ?;`, babblerId, id); err != nil {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
}
|
||||||
return -1, err
|
return -1, err
|
||||||
}
|
}
|
||||||
return id, nil
|
return id, nil
|
||||||
|
@ -206,6 +272,9 @@ func (p *BabblerPlugin) getWordArcHelper(fromWordId, toWordId int64) (int64, err
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return -1, NEVER_SAID
|
return -1, NEVER_SAID
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
}
|
||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,7 +288,7 @@ func (p *BabblerPlugin) incrementWordArc(fromWordId, toWordId int64) (int64, err
|
||||||
if err != nil || affectedRows == 0 {
|
if err != nil || affectedRows == 0 {
|
||||||
res, err = p.db.Exec(`insert into babblerArcs (fromWordId, toWordId, frequency) values (?, ?, 1);`, fromWordId, toWordId)
|
res, err = p.db.Exec(`insert into babblerArcs (fromWordId, toWordId, frequency) values (?, ?, 1);`, fromWordId, toWordId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
return -1, err
|
return -1, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -232,7 +301,7 @@ func (p *BabblerPlugin) incrementFinalWordArcHelper(wordId int64) (int64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BabblerPlugin) incrementWordArcHelper(babblerId, fromWordId int64, toWord string) (int64, error) {
|
func (p *BabblerPlugin) incrementWordArcHelper(babblerId, fromWordId int64, toWord string) (int64, error) {
|
||||||
toWordId, err := p.getOrCreateWord(babblerId, toWord)
|
toWordId, err := p.getOrCreateWordNode(babblerId, toWord)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return -1, err
|
return -1, err
|
||||||
}
|
}
|
||||||
|
@ -252,12 +321,14 @@ func (p *BabblerPlugin) addToMarkovChain(babblerId int64, phrase string) {
|
||||||
|
|
||||||
id, err := p.incrementRootWordFrequency(babblerId, words[0])
|
id, err := p.incrementRootWordFrequency(babblerId, words[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 1; i < len(words); i++ {
|
for i := 1; i < len(words); i++ {
|
||||||
id, err = p.incrementWordArcHelper(babblerId, id, words[i])
|
id, err = p.incrementWordArcHelper(babblerId, id, words[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -266,29 +337,31 @@ func (p *BabblerPlugin) addToMarkovChain(babblerId int64, phrase string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BabblerPlugin) getWeightedRootWord(babblerId int64) (int64, string, error) {
|
func (p *BabblerPlugin) getWeightedRootWord(babblerId int64) (int64, string, error) {
|
||||||
query := fmt.Sprintf("select id, word, rootFrequency from babblerWords where babblerId = %d and root = 1", babblerId)
|
query := fmt.Sprintf("select id, wordId, rootFrequency from babblerNodes where babblerId = %d and root = 1", babblerId)
|
||||||
rows, err := p.db.Query(query)
|
rows, err := p.db.Query(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
return -1, "", err
|
return -1, "", err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
idToWord := map[int64]string{}
|
idToWordId := map[int64]int64{}
|
||||||
idToFreq := map[int64]int64{}
|
idToFreq := map[int64]int64{}
|
||||||
total := int64(0)
|
total := int64(0)
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var id int64
|
var id int64
|
||||||
var word string
|
var wordId int64
|
||||||
var rootFrequency int64
|
var rootFrequency int64
|
||||||
err = rows.Scan(&id, &word, &rootFrequency)
|
err = rows.Scan(&id, &wordId, &rootFrequency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
return -1, "", err
|
return -1, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
total += rootFrequency
|
total += rootFrequency
|
||||||
idToFreq[id] = rootFrequency
|
idToFreq[id] = rootFrequency
|
||||||
idToWord[id] = word
|
idToWordId[id] = wordId
|
||||||
}
|
}
|
||||||
|
|
||||||
if total == 0 {
|
if total == 0 {
|
||||||
|
@ -299,7 +372,13 @@ func (p *BabblerPlugin) getWeightedRootWord(babblerId int64) (int64, string, err
|
||||||
total = 0
|
total = 0
|
||||||
for id, freq := range idToFreq {
|
for id, freq := range idToFreq {
|
||||||
if total+freq >= which {
|
if total+freq >= which {
|
||||||
return id, idToWord[id], nil
|
var word string
|
||||||
|
err := p.db.Get(&word, `select word from babblerWords where id = ?`, idToWordId[id])
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return -1, "", err
|
||||||
|
}
|
||||||
|
return id, word, nil
|
||||||
}
|
}
|
||||||
total += freq
|
total += freq
|
||||||
}
|
}
|
||||||
|
@ -311,6 +390,7 @@ func (p *BabblerPlugin) getWeightedNextWord(fromWordId int64) (int64, string, er
|
||||||
query := fmt.Sprintf("select toWordId, frequency from babblerArcs where fromWordId = %d;", fromWordId)
|
query := fmt.Sprintf("select toWordId, frequency from babblerArcs where fromWordId = %d;", fromWordId)
|
||||||
rows, err := p.db.Query(query)
|
rows, err := p.db.Query(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
return -1, "", err
|
return -1, "", err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
@ -323,6 +403,7 @@ func (p *BabblerPlugin) getWeightedNextWord(fromWordId int64) (int64, string, er
|
||||||
var frequency int64
|
var frequency int64
|
||||||
err = rows.Scan(&toWordId, &frequency)
|
err = rows.Scan(&toWordId, &frequency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
return -1, "", err
|
return -1, "", err
|
||||||
}
|
}
|
||||||
total += frequency
|
total += frequency
|
||||||
|
@ -341,11 +422,17 @@ func (p *BabblerPlugin) getWeightedNextWord(fromWordId int64) (int64, string, er
|
||||||
return -1, "", nil
|
return -1, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var word string
|
var wordId int64
|
||||||
err := p.db.Get(&word, `select word from babblerWords where id = ?`, id)
|
if err := p.db.Get(&wordId, `select wordId from babblerNodes where id = ?`, id); err != nil {
|
||||||
if err != nil {
|
log.Print(err)
|
||||||
return -1, "", err
|
return -1, "", err
|
||||||
}
|
}
|
||||||
|
var word string
|
||||||
|
if err := p.db.Get(&word, `select word from babblerWords where id = ?`, wordId); err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return -1, "", err
|
||||||
|
}
|
||||||
|
|
||||||
return id, word, nil
|
return id, word, nil
|
||||||
}
|
}
|
||||||
total +=freq
|
total +=freq
|
||||||
|
@ -375,13 +462,14 @@ func (p *BabblerPlugin) babbleSeed(babbler string, seed []string) (string, error
|
||||||
curWordId = id
|
curWordId = id
|
||||||
words = append(words, word)
|
words = append(words, word)
|
||||||
} else {
|
} else {
|
||||||
id, err := p.getWordId(babblerId, seed[0])
|
id, err := p.getWordNodeId(babblerId, seed[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
curWordId = id
|
curWordId = id
|
||||||
for i := 1; i < len(seed); i++ {
|
for i := 1; i < len(seed); i++ {
|
||||||
nextWordId, err := p.getWordId(babblerId, seed[i])
|
nextWordId, err := p.getWordNodeId(babblerId, seed[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -409,14 +497,24 @@ func (p *BabblerPlugin) babbleSeed(babbler string, seed []string) (string, error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName string) error {
|
func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName string) error {
|
||||||
intoString := "<" + intoName + ">"
|
intoStringId, err := p.getOrCreateWordNode(intoId, "<" + intoName + ">")
|
||||||
otherString := "<" + otherName + ">"
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
otherStringId, err := p.getOrCreateWordNode(otherId, "<" + otherName + ">")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Print(intoStringId)
|
||||||
|
log.Print(otherStringId)
|
||||||
|
|
||||||
mapping := map[int64]int64{}
|
mapping := map[int64]int64{}
|
||||||
|
|
||||||
query := fmt.Sprintf("select id, word, root, rootFrequency from babblerWords where babblerId = %d;", otherId)
|
query := fmt.Sprintf("select wordId, root, rootFrequency from babblerNodes where babblerId = %d;", otherId)
|
||||||
rows, err := p.db.Query(query)
|
rows, err := p.db.Query(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
@ -432,8 +530,9 @@ func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
word := Word{}
|
word := Word{}
|
||||||
err = rows.Scan(&word.Id, &word.Word, &word.Root, &word.RootFrequency)
|
err = rows.Scan(&word.Id, &word.Root, &word.RootFrequency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
words = append(words, word)
|
words = append(words, word)
|
||||||
|
@ -441,40 +540,38 @@ func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName
|
||||||
|
|
||||||
for _, word := range words {
|
for _, word := range words {
|
||||||
|
|
||||||
if word.Word == otherString {
|
if word.Id == otherStringId {
|
||||||
word.Word = intoString
|
log.Print("HERE")
|
||||||
|
word.Id = intoStringId
|
||||||
}
|
}
|
||||||
|
|
||||||
doInsert := false
|
var res sql.Result
|
||||||
wordId := int64(-1)
|
|
||||||
if word.Root > 0 {
|
if word.Root > 0 {
|
||||||
res, err := p.db.Exec(`update babblerWords set rootFrequency = rootFrequency + ?, root = 1 where babblerId = ? and word = ? output id ;`, word.RootFrequency, intoId, word.Word)
|
res, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + ?, root = 1 where babblerId = ? and wordId = ? output id ;`, word.RootFrequency, intoId, word.Id)
|
||||||
rowsAffected := int64(0)
|
} else {
|
||||||
|
res, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + ? where babblerId = ? and wordId = ? output id;`, word.RootFrequency, intoId, word.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
rowsAffected := int64(-1)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
rowsAffected, _ = res.RowsAffected()
|
rowsAffected, _ = res.RowsAffected()
|
||||||
}
|
}
|
||||||
if err != nil || rowsAffected == 0 {
|
|
||||||
doInsert = true
|
|
||||||
} else {
|
|
||||||
wordId, _ = res.LastInsertId()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
res, err := p.db.Exec(`update babblerWords set rootFrequency = rootFrequency + ? where babblerId = ? and word = ? output id;`, word.RootFrequency, intoId, word.Word)
|
|
||||||
if err != nil {
|
|
||||||
doInsert = true
|
|
||||||
} else {
|
|
||||||
wordId, _ = res.LastInsertId()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if doInsert {
|
wordId := int64(-1)
|
||||||
res, err := p.db.Exec(`insert into babblerWords (babblerId, word, root, rootFrequency) values (?,?,?,?) ;`, intoId, word.Word, word.Root, word.RootFrequency)
|
if err != nil || rowsAffected == 0 {
|
||||||
|
res, err = p.db.Exec(`insert into babblerNodes (babblerId, wordId, root, rootFrequency) values (?,?,?,?) ;`, intoId, word.Id, word.Root, word.RootFrequency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
wordId, _ = res.LastInsertId()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wordId, _ = res.LastInsertId()
|
||||||
|
|
||||||
|
if word.Id == intoStringId {
|
||||||
|
word.Id = otherStringId
|
||||||
|
}
|
||||||
mapping[word.Id] = wordId
|
mapping[word.Id] = wordId
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -101,7 +101,10 @@ func (p *BabblerPlugin) merge(tokens []string) (string, bool) {
|
||||||
return "merge failed.", true
|
return "merge failed.", true
|
||||||
}
|
}
|
||||||
|
|
||||||
p.mergeBabblers(intoBabbler, whoBabbler, into, who)
|
err = p.mergeBabblers(intoBabbler, whoBabbler, into, who)
|
||||||
|
if err != nil {
|
||||||
|
return "merge failed.", true
|
||||||
|
}
|
||||||
|
|
||||||
return "mooooiggged", true
|
return "mooooiggged", true
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue