catbase/plugins/tldr/tldr.go

262 lines
6.4 KiB
Go
Raw Normal View History

package tldr
import (
2024-01-05 16:44:05 +00:00
"bytes"
"context"
"fmt"
2024-01-05 16:44:05 +00:00
"github.com/andrewstuart/openai"
"github.com/velour/catbase/config"
"regexp"
"strings"
2024-01-05 16:44:05 +00:00
"text/template"
"time"
"github.com/velour/catbase/bot"
"github.com/velour/catbase/bot/msg"
"github.com/rs/zerolog/log"
"github.com/james-bowman/nlp"
)
type TLDRPlugin struct {
2024-01-05 16:44:05 +00:00
b bot.Bot
c *config.Config
2024-01-09 20:13:43 +00:00
history map[string][]history
index int
lastRequest time.Time
}
type history struct {
timestamp time.Time
user string
body string
}
func New(b bot.Bot) *TLDRPlugin {
plugin := &TLDRPlugin{
2024-01-05 16:44:05 +00:00
b: b,
c: b.Config(),
2024-01-09 20:13:43 +00:00
history: map[string][]history{},
index: 0,
lastRequest: time.Now().Add(-24 * time.Hour),
}
2024-01-05 16:44:05 +00:00
plugin.register()
return plugin
}
2024-01-05 16:44:05 +00:00
func (p *TLDRPlugin) register() {
p.b.RegisterTable(p, bot.HandlerTable{
{
Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`old tl;dr`),
HelpText: "Get a rather inaccurate summary of the channel",
Handler: p.tldrCmd,
},
{
Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`tl;dr`),
HelpText: "Get a summary of the channel",
Handler: p.betterTLDR,
},
{
Kind: bot.Message, IsCmd: false,
Regex: regexp.MustCompile(`.*`),
Handler: p.record,
},
})
p.b.Register(p, bot.Help, p.help)
}
func (p *TLDRPlugin) tldrCmd(r bot.Request) bool {
timeLimit := time.Duration(p.b.Config().GetInt("TLDR.HourLimit", 1))
if p.lastRequest.After(time.Now().Add(-timeLimit * time.Hour)) {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Slow down, cowboy. Read that tiny backlog.")
return true
2024-01-05 16:44:05 +00:00
}
return false
}
2019-03-28 21:38:04 +00:00
2024-01-05 16:44:05 +00:00
func (p *TLDRPlugin) record(r bot.Request) bool {
hist := history{
body: strings.ToLower(r.Msg.Body),
user: r.Msg.User.Name,
timestamp: time.Now(),
}
2024-01-09 20:13:43 +00:00
p.addHistory(r.Msg.Channel, hist)
2024-01-05 16:44:05 +00:00
return false
}
2024-01-05 16:44:05 +00:00
func (p *TLDRPlugin) oldTLDR(r bot.Request) bool {
p.lastRequest = time.Now()
nTopics := p.b.Config().GetInt("TLDR.Topics", 5)
2019-03-22 01:46:28 +00:00
2024-01-05 16:44:05 +00:00
stopWordSlice := p.b.Config().GetArray("TLDR.StopWords", []string{})
if len(stopWordSlice) == 0 {
stopWordSlice = THESE_ARE_NOT_THE_WORDS_YOU_ARE_LOOKING_FOR
p.b.Config().SetArray("TLDR.StopWords", stopWordSlice)
}
2024-01-05 16:44:05 +00:00
vectoriser := nlp.NewCountVectoriser(stopWordSlice...)
lda := nlp.NewLatentDirichletAllocation(nTopics)
pipeline := nlp.NewPipeline(vectoriser, lda)
docsOverTopics, err := pipeline.FitTransform(p.getTopics()...)
2024-01-05 16:44:05 +00:00
if err != nil {
log.Error().Err(err)
return false
}
2024-01-05 16:44:05 +00:00
bestScores := make([][]float64, nTopics)
bestDocs := make([][]history, nTopics)
2024-01-05 16:44:05 +00:00
supportingDocs := p.b.Config().GetInt("TLDR.Support", 3)
for i := 0; i < nTopics; i++ {
bestScores[i] = make([]float64, supportingDocs)
bestDocs[i] = make([]history, supportingDocs)
}
2024-01-05 16:44:05 +00:00
dr, dc := docsOverTopics.Dims()
for topic := 0; topic < dr; topic++ {
minScore, minIndex := min(bestScores[topic])
for doc := 0; doc < dc; doc++ {
score := docsOverTopics.At(topic, doc)
if score > minScore {
bestScores[topic][minIndex] = score
2024-01-09 20:13:43 +00:00
bestDocs[topic][minIndex] = p.history[r.Msg.Channel][doc]
2024-01-05 16:44:05 +00:00
minScore, minIndex = min(bestScores[topic])
}
}
2024-01-05 16:44:05 +00:00
}
2024-01-05 16:44:05 +00:00
topicsOverWords := lda.Components()
tr, tc := topicsOverWords.Dims()
2024-01-05 16:44:05 +00:00
vocab := make([]string, len(vectoriser.Vocabulary))
for k, v := range vectoriser.Vocabulary {
vocab[v] = k
}
2024-01-05 16:44:05 +00:00
response := "Here you go captain 'too good to read backlog':\n"
for topic := 0; topic < tr; topic++ {
bestScore := -1.
bestTopic := ""
for word := 0; word < tc; word++ {
score := topicsOverWords.At(topic, word)
if score > bestScore {
bestScore = score
bestTopic = vocab[word]
}
}
response += fmt.Sprintf("\n*Topic #%d: %s*\n", topic, bestTopic)
for i := range bestDocs[topic] {
response += fmt.Sprintf("<%s>%s\n", bestDocs[topic][i].user, bestDocs[topic][i].body)
}
}
2024-01-05 16:44:05 +00:00
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, response)
return true
}
2024-01-09 20:13:43 +00:00
func (p *TLDRPlugin) addHistory(ch string, hist history) {
p.history[ch] = append(p.history[ch], hist)
sz := len(p.history[ch])
2024-01-05 16:44:05 +00:00
max := p.b.Config().GetInt("TLDR.HistorySize", 1000)
keepHrs := time.Duration(p.b.Config().GetInt("TLDR.KeepHours", 24))
// Clamp the size of the history
if sz > max {
2024-01-09 20:13:43 +00:00
p.history[ch] = p.history[ch][len(p.history)-max:]
}
// Remove old entries
yesterday := time.Now().Add(-keepHrs * time.Hour)
begin := 0
2024-01-09 20:13:43 +00:00
for i, m := range p.history[ch] {
if !m.timestamp.Before(yesterday) {
begin = i - 1 // should keep this message
if begin < 0 {
begin = 0
}
break
}
}
2024-01-09 20:13:43 +00:00
p.history[ch] = p.history[ch][begin:]
}
func (p *TLDRPlugin) getTopics() []string {
hist := []string{}
2024-01-09 20:13:43 +00:00
for _, ch := range p.history {
for _, h := range ch {
hist = append(hist, h.body)
}
}
return hist
}
// Help responds to help requests. Every plugin must implement a help function.
func (p *TLDRPlugin) help(c bot.Connector, kind bot.Kind, message msg.Message, args ...any) bool {
2024-01-05 16:44:05 +00:00
p.b.Send(c, bot.Message, message.Channel, "tl;dr")
return true
}
2019-03-22 01:46:28 +00:00
func min(slice []float64) (float64, int) {
minVal := 1.
minIndex := -1
for index, val := range slice {
if val < minVal {
minVal = val
minIndex = index
}
}
return minVal, minIndex
}
2024-01-05 16:44:05 +00:00
func (p *TLDRPlugin) betterTLDR(r bot.Request) bool {
2024-01-09 20:13:43 +00:00
ch := r.Msg.Channel
2024-01-05 16:44:05 +00:00
c, err := p.getClient()
if err != nil {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't fetch an OpenAI client")
return true
}
promptConfig := p.c.Get("tldr.prompttemplate", "Summarize the following conversation:\n")
promptTpl := template.Must(template.New("gptprompt").Parse(promptConfig))
prompt := bytes.Buffer{}
data := p.c.GetMap("tldr.promptdata", map[string]string{})
promptTpl.Execute(&prompt, data)
backlog := ""
2024-01-09 20:13:43 +00:00
maxLen := p.c.GetInt("tldr.maxgpt", 4096)
for i := len(p.history[ch]) - 1; i >= 0; i-- {
h := p.history[ch][i]
str := fmt.Sprintf("%s: %s\n", h.user, h.body)
if len(backlog) > maxLen {
break
2024-01-06 00:00:41 +00:00
}
2024-01-09 20:13:43 +00:00
backlog = str + backlog
2024-01-05 16:44:05 +00:00
}
sess := c.NewChatSession(prompt.String())
completion, err := sess.Complete(context.TODO(), backlog)
if err != nil {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't run the OpenAI request")
return true
}
log.Debug().
Str("prompt", prompt.String()).
Str("backlog", backlog).
Str("completion", completion).
Msgf("tl;dr")
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, completion)
return true
}
func (p *TLDRPlugin) getClient() (*openai.Client, error) {
token := p.c.Get("gpt.token", "")
if token == "" {
return nil, fmt.Errorf("no GPT token given")
}
return openai.NewClient(token)
}