tldr: reverse enter and respect length

This commit is contained in:
Chris Sexton 2024-01-09 15:13:43 -05:00
parent 5acf14b0ae
commit 852239e89d
2 changed files with 27 additions and 20 deletions

View File

@ -22,14 +22,13 @@ import (
type TLDRPlugin struct { type TLDRPlugin struct {
b bot.Bot b bot.Bot
c *config.Config c *config.Config
history []history history map[string][]history
index int index int
lastRequest time.Time lastRequest time.Time
} }
type history struct { type history struct {
timestamp time.Time timestamp time.Time
channel string
user string user string
body string body string
} }
@ -38,7 +37,7 @@ func New(b bot.Bot) *TLDRPlugin {
plugin := &TLDRPlugin{ plugin := &TLDRPlugin{
b: b, b: b,
c: b.Config(), c: b.Config(),
history: []history{}, history: map[string][]history{},
index: 0, index: 0,
lastRequest: time.Now().Add(-24 * time.Hour), lastRequest: time.Now().Add(-24 * time.Hour),
} }
@ -82,10 +81,9 @@ func (p *TLDRPlugin) record(r bot.Request) bool {
hist := history{ hist := history{
body: strings.ToLower(r.Msg.Body), body: strings.ToLower(r.Msg.Body),
user: r.Msg.User.Name, user: r.Msg.User.Name,
channel: r.Msg.Channel,
timestamp: time.Now(), timestamp: time.Now(),
} }
p.addHistory(hist) p.addHistory(r.Msg.Channel, hist)
return false return false
} }
@ -127,7 +125,7 @@ func (p *TLDRPlugin) oldTLDR(r bot.Request) bool {
score := docsOverTopics.At(topic, doc) score := docsOverTopics.At(topic, doc)
if score > minScore { if score > minScore {
bestScores[topic][minIndex] = score bestScores[topic][minIndex] = score
bestDocs[topic][minIndex] = p.history[doc] bestDocs[topic][minIndex] = p.history[r.Msg.Channel][doc]
minScore, minIndex = min(bestScores[topic]) minScore, minIndex = min(bestScores[topic])
} }
} }
@ -165,19 +163,19 @@ func (p *TLDRPlugin) oldTLDR(r bot.Request) bool {
return true return true
} }
func (p *TLDRPlugin) addHistory(hist history) { func (p *TLDRPlugin) addHistory(ch string, hist history) {
p.history = append(p.history, hist) p.history[ch] = append(p.history[ch], hist)
sz := len(p.history) sz := len(p.history[ch])
max := p.b.Config().GetInt("TLDR.HistorySize", 1000) max := p.b.Config().GetInt("TLDR.HistorySize", 1000)
keepHrs := time.Duration(p.b.Config().GetInt("TLDR.KeepHours", 24)) keepHrs := time.Duration(p.b.Config().GetInt("TLDR.KeepHours", 24))
// Clamp the size of the history // Clamp the size of the history
if sz > max { if sz > max {
p.history = p.history[len(p.history)-max:] p.history[ch] = p.history[ch][len(p.history)-max:]
} }
// Remove old entries // Remove old entries
yesterday := time.Now().Add(-keepHrs * time.Hour) yesterday := time.Now().Add(-keepHrs * time.Hour)
begin := 0 begin := 0
for i, m := range p.history { for i, m := range p.history[ch] {
if !m.timestamp.Before(yesterday) { if !m.timestamp.Before(yesterday) {
begin = i - 1 // should keep this message begin = i - 1 // should keep this message
if begin < 0 { if begin < 0 {
@ -186,14 +184,16 @@ func (p *TLDRPlugin) addHistory(hist history) {
break break
} }
} }
p.history = p.history[begin:] p.history[ch] = p.history[ch][begin:]
} }
func (p *TLDRPlugin) getTopics() []string { func (p *TLDRPlugin) getTopics() []string {
hist := []string{} hist := []string{}
for _, h := range p.history { for _, ch := range p.history {
for _, h := range ch {
hist = append(hist, h.body) hist = append(hist, h.body)
} }
}
return hist return hist
} }
@ -216,6 +216,7 @@ func min(slice []float64) (float64, int) {
} }
func (p *TLDRPlugin) betterTLDR(r bot.Request) bool { func (p *TLDRPlugin) betterTLDR(r bot.Request) bool {
ch := r.Msg.Channel
c, err := p.getClient() c, err := p.getClient()
if err != nil { if err != nil {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't fetch an OpenAI client") p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't fetch an OpenAI client")
@ -227,10 +228,14 @@ func (p *TLDRPlugin) betterTLDR(r bot.Request) bool {
data := p.c.GetMap("tldr.promptdata", map[string]string{}) data := p.c.GetMap("tldr.promptdata", map[string]string{})
promptTpl.Execute(&prompt, data) promptTpl.Execute(&prompt, data)
backlog := "" backlog := ""
for _, h := range p.history { maxLen := p.c.GetInt("tldr.maxgpt", 4096)
if h.channel == r.Msg.Channel { for i := len(p.history[ch]) - 1; i >= 0; i-- {
backlog += fmt.Sprintf("%s: %s\n", h.user, h.body) h := p.history[ch][i]
str := fmt.Sprintf("%s: %s\n", h.user, h.body)
if len(backlog) > maxLen {
break
} }
backlog = str + backlog
} }
sess := c.NewChatSession(prompt.String()) sess := c.NewChatSession(prompt.String())
completion, err := sess.Complete(context.TODO(), backlog) completion, err := sess.Complete(context.TODO(), backlog)

View File

@ -20,6 +20,8 @@ func init() {
log.Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr}) log.Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr})
} }
var ch = "test"
func makeMessageBy(payload, by string) bot.Request { func makeMessageBy(payload, by string) bot.Request {
isCmd := strings.HasPrefix(payload, "!") isCmd := strings.HasPrefix(payload, "!")
if isCmd { if isCmd {
@ -31,7 +33,7 @@ func makeMessageBy(payload, by string) bot.Request {
Kind: bot.Message, Kind: bot.Message,
Msg: msg.Message{ Msg: msg.Message{
User: &user.User{Name: by}, User: &user.User{Name: by},
Channel: "test", Channel: ch,
Body: payload, Body: payload,
Command: isCmd, Command: isCmd,
}, },
@ -61,7 +63,7 @@ func TestAddHistoryLimitsDays(t *testing.T) {
user: "tester", user: "tester",
timestamp: t0.Add(time.Duration(i) * time.Hour), timestamp: t0.Add(time.Duration(i) * time.Hour),
} }
c.addHistory(hist) c.addHistory(ch, hist)
} }
assert.Len(t, c.history, expected, "%d != %d", len(c.history), expected) assert.Len(t, c.history[ch], expected, "%d != %d", len(c.history), expected)
} }