2016-05-09 20:45:02 +00:00
// © 2013 the CatBase Authors under the WTFPL. See AUTHORS for the list of authors.
package babbler
import (
2017-05-10 12:41:41 +00:00
"database/sql"
2017-05-09 23:41:40 +00:00
"errors"
2016-05-09 20:45:02 +00:00
"fmt"
2017-05-09 23:41:40 +00:00
"log"
2016-05-09 20:45:02 +00:00
"math/rand"
"strings"
"github.com/jmoiron/sqlx"
"github.com/velour/catbase/bot"
"github.com/velour/catbase/bot/msg"
2016-05-11 01:15:52 +00:00
"github.com/velour/catbase/config"
2016-05-09 20:45:02 +00:00
)
2017-05-10 13:42:27 +00:00
var (
NO_BABBLER = errors . New ( "babbler not found" )
SAID_NOTHING = errors . New ( "hasn't said anything yet" )
2017-05-10 14:11:49 +00:00
NEVER_SAID = errors . New ( "never said that" )
2017-05-10 12:41:41 +00:00
)
2016-05-09 20:45:02 +00:00
type BabblerPlugin struct {
2017-05-09 23:41:40 +00:00
Bot bot . Bot
db * sqlx . DB
config * config . Config
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
func New ( bot bot . Bot ) * BabblerPlugin {
log . SetFlags ( log . LstdFlags | log . Lshortfile )
if bot . DBVersion ( ) == 1 {
if _ , err := bot . DB ( ) . Exec ( ` create table if not exists babblers (
id integer primary key ,
babbler string
) ; ` ) ; err != nil {
log . Fatal ( err )
}
if _ , err := bot . DB ( ) . Exec ( ` create table if not exists babblerWords (
id integer primary key ,
babblerId integer ,
word string ,
root integer ,
rootFrequency integer
) ; ` ) ; err != nil {
log . Fatal ( err )
}
2016-05-09 20:45:02 +00:00
2017-05-09 23:41:40 +00:00
if _ , err := bot . DB ( ) . Exec ( ` create table if not exists babblerArcs (
id integer primary key ,
fromWordId integer ,
toWordId interger ,
frequency integer
) ; ` ) ; err != nil {
log . Fatal ( err )
}
}
2016-05-09 20:45:02 +00:00
plugin := & BabblerPlugin {
2017-05-09 23:41:40 +00:00
Bot : bot ,
db : bot . DB ( ) ,
config : bot . Config ( ) ,
2016-05-09 20:45:02 +00:00
}
2016-05-11 01:15:52 +00:00
return plugin
}
2016-05-09 20:45:02 +00:00
func ( p * BabblerPlugin ) Message ( message msg . Message ) bool {
lowercase := strings . ToLower ( message . Body )
tokens := strings . Fields ( lowercase )
2016-10-19 23:34:16 +00:00
numTokens := len ( tokens )
2016-05-09 20:45:02 +00:00
2017-05-10 13:42:27 +00:00
saidSomething := false
saidWhat := ""
2017-05-09 23:41:40 +00:00
2017-05-10 13:42:27 +00:00
if numTokens >= 2 && tokens [ 1 ] == "says" {
saidWhat , saidSomething = p . getBabble ( tokens )
2016-05-11 17:07:16 +00:00
} else if len ( tokens ) == 4 && strings . Index ( lowercase , "initialize babbler for " ) == 0 {
2017-05-10 13:42:27 +00:00
saidWhat , saidSomething = p . initializeBabbler ( tokens )
2016-05-11 17:07:16 +00:00
} else if strings . Index ( lowercase , "batch learn for " ) == 0 {
2017-05-10 13:42:27 +00:00
saidWhat , saidSomething = p . batchLearn ( tokens )
2016-05-26 15:06:22 +00:00
} else if len ( tokens ) == 5 && strings . Index ( lowercase , "merge babbler" ) == 0 {
2017-05-10 13:42:27 +00:00
saidWhat , saidSomething = p . merge ( tokens )
2016-05-11 16:11:21 +00:00
} else {
2017-05-10 13:42:27 +00:00
//this should always return "", false
saidWhat , saidSomething = p . addToBabbler ( message . User . Name , lowercase )
2016-05-09 20:45:02 +00:00
}
2017-05-10 13:42:27 +00:00
if saidSomething {
p . Bot . SendMessage ( message . Channel , saidWhat )
}
return saidSomething
2016-05-09 20:45:02 +00:00
}
func ( p * BabblerPlugin ) Help ( channel string , parts [ ] string ) {
2016-05-11 01:15:52 +00:00
p . Bot . SendMessage ( channel , "initialize babbler for seabass\n\nseabass says" )
2016-05-09 20:45:02 +00:00
}
func ( p * BabblerPlugin ) Event ( kind string , message msg . Message ) bool {
return false
}
func ( p * BabblerPlugin ) BotMessage ( message msg . Message ) bool {
return false
}
func ( p * BabblerPlugin ) RegisterWeb ( ) * string {
return nil
}
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) makeBabbler ( babbler string ) ( int64 , error ) {
res , err := p . db . Exec ( ` insert into babblers (babbler) values (?); ` , babbler )
if err == nil {
id , _ := res . LastInsertId ( )
return id , nil
}
return - 1 , err
}
2016-05-09 20:45:02 +00:00
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) getBabbler ( babbler string ) ( int64 , error ) {
id := int64 ( - 1 )
err := p . db . Get ( & id , ` select id from babblers where babbler = ? ` , babbler )
2017-05-10 14:11:49 +00:00
if err == sql . ErrNoRows {
2017-05-10 13:42:27 +00:00
return - 1 , NO_BABBLER
2017-05-09 23:41:40 +00:00
}
return id , err
}
2016-05-09 20:45:02 +00:00
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) getOrCreateBabbler ( babbler string ) ( int64 , error ) {
id , err := p . getBabbler ( babbler )
if err != nil {
id , err = p . makeBabbler ( babbler )
if err != nil {
return id , err
}
query := fmt . Sprintf ( ` select tidbit from factoid where fact like '%s quotes'; ` , babbler )
rows , err := p . db . Query ( query )
if err != nil {
//we'll just ignore this but the actual creation succeeded previously
return id , nil
}
2017-05-10 13:50:37 +00:00
defer rows . Close ( )
2017-05-10 12:41:41 +00:00
2017-05-10 13:50:37 +00:00
tidbits := [ ] string { }
2017-05-09 23:41:40 +00:00
for rows . Next ( ) {
var tidbit string
err := rows . Scan ( & tidbit )
if err != nil {
return id , err
2016-05-09 20:45:02 +00:00
}
2017-05-10 13:50:37 +00:00
tidbits = append ( tidbits , tidbit )
}
for _ , tidbit := range tidbits {
2017-05-09 23:41:40 +00:00
p . addToMarkovChain ( id , tidbit )
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
return id , err
}
2016-05-09 20:45:02 +00:00
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) getWordId ( babblerId int64 , word string ) ( int64 , error ) {
id := int64 ( - 1 )
err := p . db . Get ( & id , ` select id from babblerWords where babblerId = ? and word = ? ` , babblerId , word )
2017-05-10 14:11:49 +00:00
if err == sql . ErrNoRows {
return - 1 , NEVER_SAID
}
2017-05-09 23:41:40 +00:00
return id , err
}
func ( p * BabblerPlugin ) createNewWord ( babblerId int64 , word string ) ( int64 , error ) {
res , err := p . db . Exec ( ` insert into babblerWords (babblerId, word, root, rootFrequency) values (?, ?, 0, 0); ` , babblerId , word )
if err != nil {
return - 1 , err
}
id , _ := res . LastInsertId ( )
return id , nil
}
func ( p * BabblerPlugin ) getOrCreateWord ( babblerId int64 , word string ) ( int64 , error ) {
id , err := p . getWordId ( babblerId , word )
if err != nil {
return p . createNewWord ( babblerId , word )
}
return id , err
}
func ( p * BabblerPlugin ) incrementRootWordFrequency ( babblerId int64 , word string ) ( int64 , error ) {
id , err := p . getOrCreateWord ( babblerId , word )
if err != nil {
return - 1 , err
}
_ , err = p . db . Exec ( ` update babblerWords set rootFrequency = rootFrequency + 1, root = 1 where id = ?; ` , id )
if err != nil {
return - 1 , err
}
return id , nil
}
func ( p * BabblerPlugin ) getWordArcHelper ( fromWordId , toWordId int64 ) ( int64 , error ) {
id := int64 ( - 1 )
err := p . db . Get ( & id , ` select id from babblerArcs where fromWordId = ? and toWordId = ? ` , fromWordId , toWordId )
2017-05-10 14:11:49 +00:00
if err == sql . ErrNoRows {
return - 1 , NEVER_SAID
}
2017-05-09 23:41:40 +00:00
return id , err
}
func ( p * BabblerPlugin ) incrementWordArc ( fromWordId , toWordId int64 ) ( int64 , error ) {
res , err := p . db . Exec ( ` update babblerArcs set frequency = frequency + 1 where fromWordId = ? and toWordId = ? ` , fromWordId , toWordId )
affectedRows := int64 ( 0 )
if err == nil {
affectedRows , _ = res . RowsAffected ( )
}
if err != nil || affectedRows == 0 {
res , err = p . db . Exec ( ` insert into babblerArcs (fromWordId, toWordId, frequency) values (?, ?, 1); ` , fromWordId , toWordId )
if err != nil {
return - 1 , err
2016-05-09 20:45:02 +00:00
}
}
2017-05-09 23:41:40 +00:00
id , _ := res . LastInsertId ( )
return id , nil
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) incrementFinalWordArcHelper ( wordId int64 ) ( int64 , error ) {
return p . incrementWordArc ( wordId , - 1 )
}
func ( p * BabblerPlugin ) incrementWordArcHelper ( babblerId , fromWordId int64 , toWord string ) ( int64 , error ) {
toWordId , err := p . getOrCreateWord ( babblerId , toWord )
if err != nil {
return - 1 , err
2016-10-19 23:34:16 +00:00
}
2017-05-09 23:41:40 +00:00
_ , err = p . incrementWordArc ( fromWordId , toWordId )
if err != nil {
return - 1 , err
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
return toWordId , nil
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) addToMarkovChain ( babblerId int64 , phrase string ) {
words := strings . Fields ( strings . ToLower ( phrase ) )
2017-05-10 13:42:27 +00:00
if len ( words ) <= 0 {
return
}
2017-05-09 23:41:40 +00:00
id , err := p . incrementRootWordFrequency ( babblerId , words [ 0 ] )
2016-05-09 20:45:02 +00:00
if err != nil {
2017-05-09 23:41:40 +00:00
return
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
for i := 1 ; i < len ( words ) ; i ++ {
id , err = p . incrementWordArcHelper ( babblerId , id , words [ i ] )
if err != nil {
return
}
}
2016-05-09 20:45:02 +00:00
2017-05-09 23:41:40 +00:00
_ , err = p . incrementFinalWordArcHelper ( id )
}
2016-05-09 20:45:02 +00:00
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) getWeightedRootWord ( babblerId int64 ) ( int64 , string , error ) {
query := fmt . Sprintf ( "select id, word, rootFrequency from babblerWords where babblerId = %d and root = 1" , babblerId )
rows , err := p . db . Query ( query )
if err != nil {
return - 1 , "" , err
}
defer rows . Close ( )
idToWord := map [ int64 ] string { }
idToFreq := map [ int64 ] int64 { }
total := int64 ( 0 )
for rows . Next ( ) {
var id int64
var word string
var rootFrequency int64
err = rows . Scan ( & id , & word , & rootFrequency )
2016-05-09 20:45:02 +00:00
if err != nil {
2017-05-09 23:41:40 +00:00
return - 1 , "" , err
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
total += rootFrequency
idToFreq [ id ] = rootFrequency
idToWord [ id ] = word
2016-05-09 20:45:02 +00:00
}
2017-05-10 12:41:41 +00:00
if total == 0 {
2017-05-10 13:42:27 +00:00
return - 1 , "" , SAID_NOTHING
2017-05-10 12:41:41 +00:00
}
2017-05-09 23:41:40 +00:00
which := rand . Int63n ( total )
total = 0
for id , freq := range idToFreq {
if total + freq >= which {
return id , idToWord [ id ] , nil
}
total += freq
}
log . Fatalf ( "shouldn't happen" )
return - 1 , "" , errors . New ( "failed to find weighted root word" )
2016-10-19 23:34:16 +00:00
}
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) getWeightedNextWord ( fromWordId int64 ) ( int64 , string , error ) {
query := fmt . Sprintf ( "select toWordId, frequency from babblerArcs where fromWordId = %d;" , fromWordId )
rows , err := p . db . Query ( query )
if err != nil {
return - 1 , "" , err
}
defer rows . Close ( )
idToFreq := map [ int64 ] int64 { }
total := int64 ( 0 )
for rows . Next ( ) {
var toWordId int64
var frequency int64
err = rows . Scan ( & toWordId , & frequency )
if err != nil {
return - 1 , "" , err
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
total += frequency
idToFreq [ toWordId ] = frequency
}
2017-04-07 19:35:18 +00:00
2017-05-10 12:41:41 +00:00
if total == 0 {
return - 1 , "" , errors . New ( "missing arcs" )
}
2017-05-09 23:41:40 +00:00
which := rand . Int63n ( total )
total = 0
for id , freq := range idToFreq {
if total + freq >= which {
if id < 0 {
return - 1 , "" , nil
2017-04-07 19:35:18 +00:00
}
2017-05-09 23:41:40 +00:00
var word string
err := p . db . Get ( & word , ` select word from babblerWords where id = ? ` , id )
if err != nil {
return - 1 , "" , err
2017-04-07 19:35:18 +00:00
}
2017-05-09 23:41:40 +00:00
return id , word , nil
2017-04-07 19:35:18 +00:00
}
2017-05-09 23:41:40 +00:00
total += freq
}
log . Fatalf ( "shouldn't happen" )
return - 1 , "" , errors . New ( "failed to find weighted next word" )
}
2017-04-07 19:35:18 +00:00
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) babble ( who string ) ( string , error ) {
return p . babbleSeed ( who , [ ] string { } )
}
func ( p * BabblerPlugin ) babbleSeed ( babbler string , seed [ ] string ) ( string , error ) {
babblerId , err := p . getBabbler ( babbler )
if err != nil {
return "" , nil
}
words := seed
var curWordId int64
if len ( seed ) == 0 {
id , word , err := p . getWeightedRootWord ( babblerId )
if err != nil {
return "" , err
}
curWordId = id
words = append ( words , word )
} else {
id , err := p . getWordId ( babblerId , seed [ 0 ] )
if err != nil {
return "" , err
}
curWordId = id
for i := 1 ; i < len ( seed ) ; i ++ {
nextWordId , err := p . getWordId ( babblerId , seed [ i ] )
if err != nil {
return "" , err
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
_ , err = p . getWordArcHelper ( curWordId , nextWordId )
if err != nil {
return "" , err
}
curWordId = nextWordId
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
}
2016-05-09 20:45:02 +00:00
2017-05-09 23:41:40 +00:00
for {
id , word , err := p . getWeightedNextWord ( curWordId )
if err != nil {
return "" , err
}
if id < 0 {
break
}
words = append ( words , word )
curWordId = id
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
return strings . TrimSpace ( strings . Join ( words , " " ) ) , nil
2016-05-11 01:15:52 +00:00
}
2016-05-26 15:06:22 +00:00
2017-05-10 13:42:27 +00:00
func ( p * BabblerPlugin ) mergeBabblers ( intoId , otherId int64 , intoName , otherName string ) error {
2017-05-09 23:41:40 +00:00
intoString := "<" + intoName + ">"
otherString := "<" + otherName + ">"
2016-05-26 15:06:22 +00:00
2017-05-09 23:41:40 +00:00
mapping := map [ int64 ] int64 { }
query := fmt . Sprintf ( "select id, word, root, rootFrequency from babblerWords where babblerId = %d;" , otherId )
rows , err := p . db . Query ( query )
if err != nil {
return err
}
defer rows . Close ( )
type Word struct {
Id int64
Word string
Root int64
RootFrequency int64
}
words := [ ] Word { }
for rows . Next ( ) {
word := Word { }
err = rows . Scan ( & word . Id , & word . Word , & word . Root , & word . RootFrequency )
if err != nil {
return err
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
words = append ( words , word )
}
2016-05-26 15:06:22 +00:00
2017-05-09 23:41:40 +00:00
for _ , word := range words {
if word . Word == otherString {
word . Word = intoString
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
doInsert := false
wordId := int64 ( - 1 )
if word . Root > 0 {
res , err := p . db . Exec ( ` update babblerWords set rootFrequency = rootFrequency + ?, root = 1 where babblerId = ? and word = ? output id ; ` , word . RootFrequency , intoId , word . Word )
rowsAffected := int64 ( 0 )
if err == nil {
rowsAffected , _ = res . RowsAffected ( )
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
if err != nil || rowsAffected == 0 {
doInsert = true
} else {
wordId , _ = res . LastInsertId ( )
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
} else {
res , err := p . db . Exec ( ` update babblerWords set rootFrequency = rootFrequency + ? where babblerId = ? and word = ? output id; ` , word . RootFrequency , intoId , word . Word )
if err != nil {
doInsert = true
} else {
wordId , _ = res . LastInsertId ( )
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
}
2016-05-26 15:06:22 +00:00
2017-05-09 23:41:40 +00:00
if doInsert {
res , err := p . db . Exec ( ` insert into babblerWords (babblerId, word, root, rootFrequency) values (?,?,?,?) ; ` , intoId , word . Word , word . Root , word . RootFrequency )
if err != nil {
return err
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
wordId , _ = res . LastInsertId ( )
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
mapping [ word . Id ] = wordId
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
type Arc struct {
ToWordId int64
Frequency int64
}
for lookup , newArcStart := range mapping {
query = fmt . Sprintf ( "select toWordId, frequency from babblerArcs where fromWordId = %d;" , lookup )
rows , err := p . db . Query ( query )
if err != nil {
return err
}
defer rows . Close ( )
arcs := [ ] Arc { }
2016-05-26 15:06:22 +00:00
2017-05-09 23:41:40 +00:00
for rows . Next ( ) {
var arc Arc
err = rows . Scan ( & arc . ToWordId , & arc . Frequency )
if err != nil {
return err
}
arcs = append ( arcs , arc )
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
for _ , arc := range arcs {
newArcEnd := int64 ( - 1 ) //handle end arcs
if arc . ToWordId >= 0 {
newArcEnd = mapping [ arc . ToWordId ]
}
res , err := p . db . Exec ( ` update babblerArcs set frequency = frequency + ? where fromWordId = ? and toWordId = ? ` , arc . Frequency , newArcStart , newArcEnd )
rowsAffected := int64 ( 0 )
if err == nil {
rowsAffected , _ = res . RowsAffected ( )
}
if err != nil || rowsAffected == 0 {
_ , err = p . db . Exec ( ` insert into babblerArcs (fromWordId, toWordId, frequency) values (?, ?, ?); ` , newArcStart , newArcEnd , arc . Frequency )
if err != nil {
return err
}
2016-05-26 15:06:22 +00:00
}
}
}
2017-05-09 23:41:40 +00:00
return nil
2016-05-26 15:06:22 +00:00
}