clean up and ignore empty strings when adding to the markov chain

This commit is contained in:
skkiesel 2017-05-10 09:42:27 -04:00
parent 8c98734e7e
commit f28dabf0ba
2 changed files with 128 additions and 103 deletions

View File

@ -16,9 +16,9 @@ import (
"github.com/velour/catbase/config" "github.com/velour/catbase/config"
) )
const ( var (
NO_BABBLER = "babbler not found" NO_BABBLER = errors.New("babbler not found")
SAID_NOTHING = "hasn't said anything yet" SAID_NOTHING = errors.New("hasn't said anything yet")
) )
@ -72,108 +72,26 @@ func (p *BabblerPlugin) Message(message msg.Message) bool {
tokens := strings.Fields(lowercase) tokens := strings.Fields(lowercase)
numTokens := len(tokens) numTokens := len(tokens)
saidSomething := false
saidWhat := ""
if numTokens >= 2 && tokens[1] == "says" { if numTokens >= 2 && tokens[1] == "says" {
who := tokens[0] saidWhat, saidSomething = p.getBabble(tokens)
_, err := p.getBabbler(who)
if err != nil {
if err.Error() == NO_BABBLER {
p.Bot.SendMessage(message.Channel, fmt.Sprintf("%s babbler not found.", who))
return true
}
return false
}
var saying string
if len(tokens) == 2 {
saying, err = p.babble(who)
} else {
saying, err = p.babbleSeed(who, tokens[2:])
}
if err != nil {
if err.Error() == SAID_NOTHING {
p.Bot.SendMessage(message.Channel, fmt.Sprintf("%s hasn't said anything yet.", who))
return true
}
return false
}
if saying != "" {
p.Bot.SendMessage(message.Channel, saying)
return true
}
} else if len(tokens) == 4 && strings.Index(lowercase, "initialize babbler for ") == 0 { } else if len(tokens) == 4 && strings.Index(lowercase, "initialize babbler for ") == 0 {
who := tokens[3] saidWhat, saidSomething = p.initializeBabbler(tokens)
_, err := p.getOrCreateBabbler(who)
if err != nil {
p.Bot.SendMessage(message.Channel, "babbler initialization failed.")
return true
}
p.Bot.SendMessage(message.Channel, "okay.")
return true
} else if strings.Index(lowercase, "batch learn for ") == 0 { } else if strings.Index(lowercase, "batch learn for ") == 0 {
who := tokens[3] saidWhat, saidSomething = p.batchLearn(tokens)
babblerId, err := p.getOrCreateBabbler(who)
if err != nil {
p.Bot.SendMessage(message.Channel, "batch learn failed.")
return true
}
body := strings.Join(tokens[4:], " ")
body = strings.ToLower(body)
for _, a := range strings.Split(body, ".") {
for _, b := range strings.Split(a, "!") {
for _, c := range strings.Split(b, "?") {
for _, d := range strings.Split(c, "\n") {
trimmed := strings.TrimSpace(d)
if trimmed != "" {
p.addToMarkovChain(babblerId, trimmed)
}
}
}
}
}
p.Bot.SendMessage(message.Channel, "phew that was tiring.")
return true
} else if len(tokens) == 5 && strings.Index(lowercase, "merge babbler") == 0 { } else if len(tokens) == 5 && strings.Index(lowercase, "merge babbler") == 0 {
if tokens[3] != "into" { saidWhat, saidSomething = p.merge(tokens)
p.Bot.SendMessage(message.Channel, "try using 'merge babbler [x] into [y]'")
return true
}
who := tokens[2]
into := tokens[4]
if who == into {
p.Bot.SendMessage(message.Channel, "that's annoying. stop it.")
return true
}
whoBabbler, err := p.getBabbler(who)
if err != nil {
p.Bot.SendMessage(message.Channel, "merge failed.")
return true
}
intoBabbler, err := p.getOrCreateBabbler(into)
if err != nil {
p.Bot.SendMessage(message.Channel, "merge failed.")
return true
}
p.merge(intoBabbler, whoBabbler, into, who)
p.Bot.SendMessage(message.Channel, "mooooiggged")
return true
} else { } else {
babblerId, err := p.getOrCreateBabbler(message.User.Name) //this should always return "", false
if err == nil { saidWhat, saidSomething = p.addToBabbler(message.User.Name, lowercase)
p.addToMarkovChain(babblerId, lowercase)
}
} }
return false if saidSomething {
p.Bot.SendMessage(message.Channel, saidWhat)
}
return saidSomething
} }
func (p *BabblerPlugin) Help(channel string, parts []string) { func (p *BabblerPlugin) Help(channel string, parts []string) {
@ -205,7 +123,7 @@ func (p *BabblerPlugin) getBabbler(babbler string) (int64, error) {
id := int64(-1) id := int64(-1)
err := p.db.Get(&id, `select id from babblers where babbler = ?`, babbler) err := p.db.Get(&id, `select id from babblers where babbler = ?`, babbler)
if err != nil && err == sql.ErrNoRows { if err != nil && err == sql.ErrNoRows {
return -1, errors.New(NO_BABBLER) return -1, NO_BABBLER
} }
return id, err return id, err
} }
@ -315,6 +233,10 @@ func (p *BabblerPlugin) incrementWordArcHelper(babblerId, fromWordId int64, toWo
func (p *BabblerPlugin) addToMarkovChain(babblerId int64, phrase string) { func (p *BabblerPlugin) addToMarkovChain(babblerId int64, phrase string) {
words := strings.Fields(strings.ToLower(phrase)) words := strings.Fields(strings.ToLower(phrase))
if len(words) <= 0 {
return
}
id, err := p.incrementRootWordFrequency(babblerId, words[0]) id, err := p.incrementRootWordFrequency(babblerId, words[0])
if err != nil { if err != nil {
return return
@ -357,7 +279,7 @@ func (p *BabblerPlugin) getWeightedRootWord(babblerId int64) (int64, string, err
} }
if total == 0 { if total == 0 {
return -1, "", errors.New(SAID_NOTHING) return -1, "", SAID_NOTHING
} }
which := rand.Int63n(total) which := rand.Int63n(total)
@ -473,7 +395,7 @@ func (p *BabblerPlugin) babbleSeed(babbler string, seed []string) (string, error
return strings.TrimSpace(strings.Join(words, " ")), nil return strings.TrimSpace(strings.Join(words, " ")), nil
} }
func (p *BabblerPlugin) merge(intoId, otherId int64, intoName, otherName string) error { func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName string) error {
intoString := "<" + intoName + ">" intoString := "<" + intoName + ">"
otherString := "<" + otherName + ">" otherString := "<" + otherName + ">"
@ -540,8 +462,6 @@ func (p *BabblerPlugin) merge(intoId, otherId int64, intoName, otherName string)
wordId, _ = res.LastInsertId() wordId, _ = res.LastInsertId()
} }
log.Printf("%s %d -> %d\n", word.Word, word.Id, wordId)
mapping[word.Id] = wordId mapping[word.Id] = wordId
} }

105
plugins/babbler/commands.go Normal file
View File

@ -0,0 +1,105 @@
package babbler
import (
"fmt"
"strings"
)
func (p *BabblerPlugin) initializeBabbler(tokens []string) (string, bool) {
who := tokens[3]
_, err := p.getOrCreateBabbler(who)
if err != nil {
return "babbler initialization failed.", true
}
return "okay.", true
}
func (p *BabblerPlugin) addToBabbler(babblerName, whatWasSaid string) (string, bool) {
babblerId, err := p.getOrCreateBabbler(babblerName)
if err == nil {
p.addToMarkovChain(babblerId, whatWasSaid)
}
return "", false
}
func (p *BabblerPlugin) getBabble(tokens []string) (string, bool) {
who := tokens[0]
_, err := p.getBabbler(who)
if err != nil {
if err == NO_BABBLER {
return fmt.Sprintf("%s babbler not found.", who), true
}
} else {
var saying string
if len(tokens) == 2 {
saying, err = p.babble(who)
} else {
saying, err = p.babbleSeed(who, tokens[2:])
}
if err != nil {
if err == SAID_NOTHING {
return fmt.Sprintf("%s hasn't said anything yet.", who), true
}
} else if saying != "" {
return saying, true
}
}
return "", false
}
func (p *BabblerPlugin) batchLearn(tokens []string) (string, bool) {
who := tokens[3]
babblerId, err := p.getOrCreateBabbler(who)
if err != nil {
return "batch learn failed.", true
}
body := strings.Join(tokens[4:], " ")
body = strings.ToLower(body)
for _, a := range strings.Split(body, ".") {
for _, b := range strings.Split(a, "!") {
for _, c := range strings.Split(b, "?") {
for _, d := range strings.Split(c, "\n") {
trimmed := strings.TrimSpace(d)
if trimmed != "" {
p.addToMarkovChain(babblerId, trimmed)
}
}
}
}
}
return "phew that was tiring.", true
}
func (p *BabblerPlugin) merge(tokens []string) (string, bool) {
if tokens[3] != "into" {
return "try using 'merge babbler [x] into [y]'", true
}
who := tokens[2]
into := tokens[4]
if who == into {
return "that's annoying. stop it.", true
}
whoBabbler, err := p.getBabbler(who)
if err != nil {
if err == NO_BABBLER {
return fmt.Sprintf("%s babbler not found.", who), true
}
return "merge failed.", true
}
intoBabbler, err := p.getOrCreateBabbler(into)
if err != nil {
return "merge failed.", true
}
p.mergeBabblers(intoBabbler, whoBabbler, into, who)
return "mooooiggged", true
}