diff --git a/plugins/tldr/tldr.go b/plugins/tldr/tldr.go index 6dbebed..90b20bd 100644 --- a/plugins/tldr/tldr.go +++ b/plugins/tldr/tldr.go @@ -22,14 +22,13 @@ import ( type TLDRPlugin struct { b bot.Bot c *config.Config - history []history + history map[string][]history index int lastRequest time.Time } type history struct { timestamp time.Time - channel string user string body string } @@ -38,7 +37,7 @@ func New(b bot.Bot) *TLDRPlugin { plugin := &TLDRPlugin{ b: b, c: b.Config(), - history: []history{}, + history: map[string][]history{}, index: 0, lastRequest: time.Now().Add(-24 * time.Hour), } @@ -82,10 +81,9 @@ func (p *TLDRPlugin) record(r bot.Request) bool { hist := history{ body: strings.ToLower(r.Msg.Body), user: r.Msg.User.Name, - channel: r.Msg.Channel, timestamp: time.Now(), } - p.addHistory(hist) + p.addHistory(r.Msg.Channel, hist) return false } @@ -127,7 +125,7 @@ func (p *TLDRPlugin) oldTLDR(r bot.Request) bool { score := docsOverTopics.At(topic, doc) if score > minScore { bestScores[topic][minIndex] = score - bestDocs[topic][minIndex] = p.history[doc] + bestDocs[topic][minIndex] = p.history[r.Msg.Channel][doc] minScore, minIndex = min(bestScores[topic]) } } @@ -165,19 +163,19 @@ func (p *TLDRPlugin) oldTLDR(r bot.Request) bool { return true } -func (p *TLDRPlugin) addHistory(hist history) { - p.history = append(p.history, hist) - sz := len(p.history) +func (p *TLDRPlugin) addHistory(ch string, hist history) { + p.history[ch] = append(p.history[ch], hist) + sz := len(p.history[ch]) max := p.b.Config().GetInt("TLDR.HistorySize", 1000) keepHrs := time.Duration(p.b.Config().GetInt("TLDR.KeepHours", 24)) // Clamp the size of the history if sz > max { - p.history = p.history[len(p.history)-max:] + p.history[ch] = p.history[ch][len(p.history)-max:] } // Remove old entries yesterday := time.Now().Add(-keepHrs * time.Hour) begin := 0 - for i, m := range p.history { + for i, m := range p.history[ch] { if !m.timestamp.Before(yesterday) { begin = i - 1 // should keep this message if begin < 0 { @@ -186,13 +184,15 @@ func (p *TLDRPlugin) addHistory(hist history) { break } } - p.history = p.history[begin:] + p.history[ch] = p.history[ch][begin:] } func (p *TLDRPlugin) getTopics() []string { hist := []string{} - for _, h := range p.history { - hist = append(hist, h.body) + for _, ch := range p.history { + for _, h := range ch { + hist = append(hist, h.body) + } } return hist } @@ -216,6 +216,7 @@ func min(slice []float64) (float64, int) { } func (p *TLDRPlugin) betterTLDR(r bot.Request) bool { + ch := r.Msg.Channel c, err := p.getClient() if err != nil { p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't fetch an OpenAI client") @@ -227,10 +228,14 @@ func (p *TLDRPlugin) betterTLDR(r bot.Request) bool { data := p.c.GetMap("tldr.promptdata", map[string]string{}) promptTpl.Execute(&prompt, data) backlog := "" - for _, h := range p.history { - if h.channel == r.Msg.Channel { - backlog += fmt.Sprintf("%s: %s\n", h.user, h.body) + maxLen := p.c.GetInt("tldr.maxgpt", 4096) + for i := len(p.history[ch]) - 1; i >= 0; i-- { + h := p.history[ch][i] + str := fmt.Sprintf("%s: %s\n", h.user, h.body) + if len(backlog) > maxLen { + break } + backlog = str + backlog } sess := c.NewChatSession(prompt.String()) completion, err := sess.Complete(context.TODO(), backlog) diff --git a/plugins/tldr/tldr_test.go b/plugins/tldr/tldr_test.go index 4b7f848..fd9fa41 100644 --- a/plugins/tldr/tldr_test.go +++ b/plugins/tldr/tldr_test.go @@ -20,6 +20,8 @@ func init() { log.Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr}) } +var ch = "test" + func makeMessageBy(payload, by string) bot.Request { isCmd := strings.HasPrefix(payload, "!") if isCmd { @@ -31,7 +33,7 @@ func makeMessageBy(payload, by string) bot.Request { Kind: bot.Message, Msg: msg.Message{ User: &user.User{Name: by}, - Channel: "test", + Channel: ch, Body: payload, Command: isCmd, }, @@ -61,7 +63,7 @@ func TestAddHistoryLimitsDays(t *testing.T) { user: "tester", timestamp: t0.Add(time.Duration(i) * time.Hour), } - c.addHistory(hist) + c.addHistory(ch, hist) } - assert.Len(t, c.history, expected, "%d != %d", len(c.history), expected) + assert.Len(t, c.history[ch], expected, "%d != %d", len(c.history), expected) }