Compare commits

..

No commits in common. "5acf14b0ae12268cefeb7abfd69160e490ab5692" and "494c9e87d6fe992d441fe456c16f0289ffcc3df5" have entirely different histories.

2 changed files with 134 additions and 176 deletions

View File

@ -1,14 +1,8 @@
package tldr package tldr
import ( import (
"bytes"
"context"
"fmt" "fmt"
"github.com/andrewstuart/openai"
"github.com/velour/catbase/config"
"regexp"
"strings" "strings"
"text/template"
"time" "time"
"github.com/velour/catbase/bot" "github.com/velour/catbase/bot"
@ -20,8 +14,7 @@ import (
) )
type TLDRPlugin struct { type TLDRPlugin struct {
b bot.Bot bot bot.Bot
c *config.Config
history []history history []history
index int index int
lastRequest time.Time lastRequest time.Time
@ -29,60 +22,106 @@ type TLDRPlugin struct {
type history struct { type history struct {
timestamp time.Time timestamp time.Time
channel string
user string user string
body string body string
} }
func New(b bot.Bot) *TLDRPlugin { func New(b bot.Bot) *TLDRPlugin {
plugin := &TLDRPlugin{ plugin := &TLDRPlugin{
b: b, bot: b,
c: b.Config(),
history: []history{}, history: []history{},
index: 0, index: 0,
lastRequest: time.Now().Add(-24 * time.Hour), lastRequest: time.Now().Add(-24 * time.Hour),
} }
plugin.register() b.Register(plugin, bot.Message, plugin.message)
b.Register(plugin, bot.Help, plugin.help)
return plugin return plugin
} }
func (p *TLDRPlugin) register() { func (p *TLDRPlugin) message(c bot.Connector, kind bot.Kind, message msg.Message, args ...any) bool {
p.b.RegisterTable(p, bot.HandlerTable{ timeLimit := time.Duration(p.bot.Config().GetInt("TLDR.HourLimit", 1))
{ lowercaseMessage := strings.ToLower(message.Body)
Kind: bot.Message, IsCmd: true, if lowercaseMessage == "tl;dr" && p.lastRequest.After(time.Now().Add(-timeLimit*time.Hour)) {
Regex: regexp.MustCompile(`old tl;dr`), p.bot.Send(c, bot.Message, message.Channel, "Slow down, cowboy. Read that tiny backlog.")
HelpText: "Get a rather inaccurate summary of the channel", return true
Handler: p.tldrCmd, } else if lowercaseMessage == "tl;dr" {
}, p.lastRequest = time.Now()
{ nTopics := p.bot.Config().GetInt("TLDR.Topics", 5)
Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`tl;dr`), stopWordSlice := p.bot.Config().GetArray("TLDR.StopWords", []string{})
HelpText: "Get a summary of the channel", if len(stopWordSlice) == 0 {
Handler: p.betterTLDR, stopWordSlice = THESE_ARE_NOT_THE_WORDS_YOU_ARE_LOOKING_FOR
}, p.bot.Config().SetArray("TLDR.StopWords", stopWordSlice)
{ }
Kind: bot.Message, IsCmd: false,
Regex: regexp.MustCompile(`.*`), vectoriser := nlp.NewCountVectoriser(stopWordSlice...)
Handler: p.record, lda := nlp.NewLatentDirichletAllocation(nTopics)
}, pipeline := nlp.NewPipeline(vectoriser, lda)
}) docsOverTopics, err := pipeline.FitTransform(p.getTopics()...)
p.b.Register(p, bot.Help, p.help)
} if err != nil {
log.Error().Err(err)
return false
}
bestScores := make([][]float64, nTopics)
bestDocs := make([][]history, nTopics)
supportingDocs := p.bot.Config().GetInt("TLDR.Support", 3)
for i := 0; i < nTopics; i++ {
bestScores[i] = make([]float64, supportingDocs)
bestDocs[i] = make([]history, supportingDocs)
}
dr, dc := docsOverTopics.Dims()
for topic := 0; topic < dr; topic++ {
minScore, minIndex := min(bestScores[topic])
for doc := 0; doc < dc; doc++ {
score := docsOverTopics.At(topic, doc)
if score > minScore {
bestScores[topic][minIndex] = score
bestDocs[topic][minIndex] = p.history[doc]
minScore, minIndex = min(bestScores[topic])
}
}
}
topicsOverWords := lda.Components()
tr, tc := topicsOverWords.Dims()
vocab := make([]string, len(vectoriser.Vocabulary))
for k, v := range vectoriser.Vocabulary {
vocab[v] = k
}
response := "Here you go captain 'too good to read backlog':\n"
for topic := 0; topic < tr; topic++ {
bestScore := -1.
bestTopic := ""
for word := 0; word < tc; word++ {
score := topicsOverWords.At(topic, word)
if score > bestScore {
bestScore = score
bestTopic = vocab[word]
}
}
response += fmt.Sprintf("\n*Topic #%d: %s*\n", topic, bestTopic)
for i := range bestDocs[topic] {
response += fmt.Sprintf("<%s>%s\n", bestDocs[topic][i].user, bestDocs[topic][i].body)
}
}
p.bot.Send(c, bot.Message, message.Channel, response)
func (p *TLDRPlugin) tldrCmd(r bot.Request) bool {
timeLimit := time.Duration(p.b.Config().GetInt("TLDR.HourLimit", 1))
if p.lastRequest.After(time.Now().Add(-timeLimit * time.Hour)) {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Slow down, cowboy. Read that tiny backlog.")
return true return true
} }
return false
}
func (p *TLDRPlugin) record(r bot.Request) bool {
hist := history{ hist := history{
body: strings.ToLower(r.Msg.Body), body: lowercaseMessage,
user: r.Msg.User.Name, user: message.User.Name,
channel: r.Msg.Channel,
timestamp: time.Now(), timestamp: time.Now(),
} }
p.addHistory(hist) p.addHistory(hist)
@ -90,86 +129,11 @@ func (p *TLDRPlugin) record(r bot.Request) bool {
return false return false
} }
func (p *TLDRPlugin) oldTLDR(r bot.Request) bool {
p.lastRequest = time.Now()
nTopics := p.b.Config().GetInt("TLDR.Topics", 5)
stopWordSlice := p.b.Config().GetArray("TLDR.StopWords", []string{})
if len(stopWordSlice) == 0 {
stopWordSlice = THESE_ARE_NOT_THE_WORDS_YOU_ARE_LOOKING_FOR
p.b.Config().SetArray("TLDR.StopWords", stopWordSlice)
}
vectoriser := nlp.NewCountVectoriser(stopWordSlice...)
lda := nlp.NewLatentDirichletAllocation(nTopics)
pipeline := nlp.NewPipeline(vectoriser, lda)
docsOverTopics, err := pipeline.FitTransform(p.getTopics()...)
if err != nil {
log.Error().Err(err)
return false
}
bestScores := make([][]float64, nTopics)
bestDocs := make([][]history, nTopics)
supportingDocs := p.b.Config().GetInt("TLDR.Support", 3)
for i := 0; i < nTopics; i++ {
bestScores[i] = make([]float64, supportingDocs)
bestDocs[i] = make([]history, supportingDocs)
}
dr, dc := docsOverTopics.Dims()
for topic := 0; topic < dr; topic++ {
minScore, minIndex := min(bestScores[topic])
for doc := 0; doc < dc; doc++ {
score := docsOverTopics.At(topic, doc)
if score > minScore {
bestScores[topic][minIndex] = score
bestDocs[topic][minIndex] = p.history[doc]
minScore, minIndex = min(bestScores[topic])
}
}
}
topicsOverWords := lda.Components()
tr, tc := topicsOverWords.Dims()
vocab := make([]string, len(vectoriser.Vocabulary))
for k, v := range vectoriser.Vocabulary {
vocab[v] = k
}
response := "Here you go captain 'too good to read backlog':\n"
for topic := 0; topic < tr; topic++ {
bestScore := -1.
bestTopic := ""
for word := 0; word < tc; word++ {
score := topicsOverWords.At(topic, word)
if score > bestScore {
bestScore = score
bestTopic = vocab[word]
}
}
response += fmt.Sprintf("\n*Topic #%d: %s*\n", topic, bestTopic)
for i := range bestDocs[topic] {
response += fmt.Sprintf("<%s>%s\n", bestDocs[topic][i].user, bestDocs[topic][i].body)
}
}
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, response)
return true
}
func (p *TLDRPlugin) addHistory(hist history) { func (p *TLDRPlugin) addHistory(hist history) {
p.history = append(p.history, hist) p.history = append(p.history, hist)
sz := len(p.history) sz := len(p.history)
max := p.b.Config().GetInt("TLDR.HistorySize", 1000) max := p.bot.Config().GetInt("TLDR.HistorySize", 1000)
keepHrs := time.Duration(p.b.Config().GetInt("TLDR.KeepHours", 24)) keepHrs := time.Duration(p.bot.Config().GetInt("TLDR.KeepHours", 24))
// Clamp the size of the history // Clamp the size of the history
if sz > max { if sz > max {
p.history = p.history[len(p.history)-max:] p.history = p.history[len(p.history)-max:]
@ -199,7 +163,7 @@ func (p *TLDRPlugin) getTopics() []string {
// Help responds to help requests. Every plugin must implement a help function. // Help responds to help requests. Every plugin must implement a help function.
func (p *TLDRPlugin) help(c bot.Connector, kind bot.Kind, message msg.Message, args ...any) bool { func (p *TLDRPlugin) help(c bot.Connector, kind bot.Kind, message msg.Message, args ...any) bool {
p.b.Send(c, bot.Message, message.Channel, "tl;dr") p.bot.Send(c, bot.Message, message.Channel, "tl;dr")
return true return true
} }
@ -214,43 +178,3 @@ func min(slice []float64) (float64, int) {
} }
return minVal, minIndex return minVal, minIndex
} }
func (p *TLDRPlugin) betterTLDR(r bot.Request) bool {
c, err := p.getClient()
if err != nil {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't fetch an OpenAI client")
return true
}
promptConfig := p.c.Get("tldr.prompttemplate", "Summarize the following conversation:\n")
promptTpl := template.Must(template.New("gptprompt").Parse(promptConfig))
prompt := bytes.Buffer{}
data := p.c.GetMap("tldr.promptdata", map[string]string{})
promptTpl.Execute(&prompt, data)
backlog := ""
for _, h := range p.history {
if h.channel == r.Msg.Channel {
backlog += fmt.Sprintf("%s: %s\n", h.user, h.body)
}
}
sess := c.NewChatSession(prompt.String())
completion, err := sess.Complete(context.TODO(), backlog)
if err != nil {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't run the OpenAI request")
return true
}
log.Debug().
Str("prompt", prompt.String()).
Str("backlog", backlog).
Str("completion", completion).
Msgf("tl;dr")
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, completion)
return true
}
func (p *TLDRPlugin) getClient() (*openai.Client, error) {
token := p.c.Get("gpt.token", "")
if token == "" {
return nil, fmt.Errorf("no GPT token given")
}
return openai.NewClient(token)
}

View File

@ -20,25 +20,20 @@ func init() {
log.Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr}) log.Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr})
} }
func makeMessageBy(payload, by string) bot.Request { func makeMessageBy(payload, by string) (bot.Connector, bot.Kind, msg.Message) {
isCmd := strings.HasPrefix(payload, "!") isCmd := strings.HasPrefix(payload, "!")
if isCmd { if isCmd {
payload = payload[1:] payload = payload[1:]
} }
return &cli.CliPlugin{}, bot.Message, msg.Message{
return bot.Request{ User: &user.User{Name: by},
Conn: &cli.CliPlugin{}, Channel: "test",
Kind: bot.Message, Body: payload,
Msg: msg.Message{ Command: isCmd,
User: &user.User{Name: by},
Channel: "test",
Body: payload,
Command: isCmd,
},
} }
} }
func makeMessage(payload string) bot.Request { func makeMessage(payload string) (bot.Connector, bot.Kind, msg.Message) {
return makeMessageBy(payload, "tester") return makeMessageBy(payload, "tester")
} }
@ -48,12 +43,51 @@ func setup(t *testing.T) (*TLDRPlugin, *bot.MockBot) {
return r, mb return r, mb
} }
func Test(t *testing.T) {
c, mb := setup(t)
res := c.message(makeMessage("The quick brown fox jumped over the lazy dog"))
res = c.message(makeMessage("The cow jumped over the moon"))
res = c.message(makeMessage("The little dog laughed to see such fun"))
res = c.message(makeMessage("tl;dr"))
assert.True(t, res)
assert.Len(t, mb.Messages, 1)
}
func TestDoubleUp(t *testing.T) {
c, mb := setup(t)
res := c.message(makeMessage("The quick brown fox jumped over the lazy dog"))
res = c.message(makeMessage("The cow jumped over the moon"))
res = c.message(makeMessage("The little dog laughed to see such fun"))
res = c.message(makeMessage("tl;dr"))
res = c.message(makeMessage("tl;dr"))
assert.True(t, res)
assert.Len(t, mb.Messages, 2)
assert.Contains(t, mb.Messages[1], "Slow down, cowboy.")
}
func TestAddHistoryLimitsMessages(t *testing.T) {
c, _ := setup(t)
max := 1000
c.bot.Config().Set("TLDR.HistorySize", strconv.Itoa(max))
c.bot.Config().Set("TLDR.KeepHours", "24")
t0 := time.Now().Add(-24 * time.Hour)
for i := 0; i < max*2; i++ {
hist := history{
body: "test",
user: "tester",
timestamp: t0.Add(time.Duration(i) * time.Second),
}
c.addHistory(hist)
}
assert.Len(t, c.history, max)
}
func TestAddHistoryLimitsDays(t *testing.T) { func TestAddHistoryLimitsDays(t *testing.T) {
c, _ := setup(t) c, _ := setup(t)
hrs := 24 hrs := 24
expected := 24 expected := 24
c.b.Config().Set("TLDR.HistorySize", "100") c.bot.Config().Set("TLDR.HistorySize", "100")
c.b.Config().Set("TLDR.KeepHours", strconv.Itoa(hrs)) c.bot.Config().Set("TLDR.KeepHours", strconv.Itoa(hrs))
t0 := time.Now().Add(-time.Duration(hrs*2) * time.Hour) t0 := time.Now().Add(-time.Duration(hrs*2) * time.Hour)
for i := 0; i < 48; i++ { for i := 0; i < 48; i++ {
hist := history{ hist := history{