mirror of https://github.com/velour/catbase.git
tldr: reverse enter and respect length
This commit is contained in:
parent
5acf14b0ae
commit
852239e89d
|
@ -22,14 +22,13 @@ import (
|
||||||
type TLDRPlugin struct {
|
type TLDRPlugin struct {
|
||||||
b bot.Bot
|
b bot.Bot
|
||||||
c *config.Config
|
c *config.Config
|
||||||
history []history
|
history map[string][]history
|
||||||
index int
|
index int
|
||||||
lastRequest time.Time
|
lastRequest time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type history struct {
|
type history struct {
|
||||||
timestamp time.Time
|
timestamp time.Time
|
||||||
channel string
|
|
||||||
user string
|
user string
|
||||||
body string
|
body string
|
||||||
}
|
}
|
||||||
|
@ -38,7 +37,7 @@ func New(b bot.Bot) *TLDRPlugin {
|
||||||
plugin := &TLDRPlugin{
|
plugin := &TLDRPlugin{
|
||||||
b: b,
|
b: b,
|
||||||
c: b.Config(),
|
c: b.Config(),
|
||||||
history: []history{},
|
history: map[string][]history{},
|
||||||
index: 0,
|
index: 0,
|
||||||
lastRequest: time.Now().Add(-24 * time.Hour),
|
lastRequest: time.Now().Add(-24 * time.Hour),
|
||||||
}
|
}
|
||||||
|
@ -82,10 +81,9 @@ func (p *TLDRPlugin) record(r bot.Request) bool {
|
||||||
hist := history{
|
hist := history{
|
||||||
body: strings.ToLower(r.Msg.Body),
|
body: strings.ToLower(r.Msg.Body),
|
||||||
user: r.Msg.User.Name,
|
user: r.Msg.User.Name,
|
||||||
channel: r.Msg.Channel,
|
|
||||||
timestamp: time.Now(),
|
timestamp: time.Now(),
|
||||||
}
|
}
|
||||||
p.addHistory(hist)
|
p.addHistory(r.Msg.Channel, hist)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -127,7 +125,7 @@ func (p *TLDRPlugin) oldTLDR(r bot.Request) bool {
|
||||||
score := docsOverTopics.At(topic, doc)
|
score := docsOverTopics.At(topic, doc)
|
||||||
if score > minScore {
|
if score > minScore {
|
||||||
bestScores[topic][minIndex] = score
|
bestScores[topic][minIndex] = score
|
||||||
bestDocs[topic][minIndex] = p.history[doc]
|
bestDocs[topic][minIndex] = p.history[r.Msg.Channel][doc]
|
||||||
minScore, minIndex = min(bestScores[topic])
|
minScore, minIndex = min(bestScores[topic])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -165,19 +163,19 @@ func (p *TLDRPlugin) oldTLDR(r bot.Request) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TLDRPlugin) addHistory(hist history) {
|
func (p *TLDRPlugin) addHistory(ch string, hist history) {
|
||||||
p.history = append(p.history, hist)
|
p.history[ch] = append(p.history[ch], hist)
|
||||||
sz := len(p.history)
|
sz := len(p.history[ch])
|
||||||
max := p.b.Config().GetInt("TLDR.HistorySize", 1000)
|
max := p.b.Config().GetInt("TLDR.HistorySize", 1000)
|
||||||
keepHrs := time.Duration(p.b.Config().GetInt("TLDR.KeepHours", 24))
|
keepHrs := time.Duration(p.b.Config().GetInt("TLDR.KeepHours", 24))
|
||||||
// Clamp the size of the history
|
// Clamp the size of the history
|
||||||
if sz > max {
|
if sz > max {
|
||||||
p.history = p.history[len(p.history)-max:]
|
p.history[ch] = p.history[ch][len(p.history)-max:]
|
||||||
}
|
}
|
||||||
// Remove old entries
|
// Remove old entries
|
||||||
yesterday := time.Now().Add(-keepHrs * time.Hour)
|
yesterday := time.Now().Add(-keepHrs * time.Hour)
|
||||||
begin := 0
|
begin := 0
|
||||||
for i, m := range p.history {
|
for i, m := range p.history[ch] {
|
||||||
if !m.timestamp.Before(yesterday) {
|
if !m.timestamp.Before(yesterday) {
|
||||||
begin = i - 1 // should keep this message
|
begin = i - 1 // should keep this message
|
||||||
if begin < 0 {
|
if begin < 0 {
|
||||||
|
@ -186,13 +184,15 @@ func (p *TLDRPlugin) addHistory(hist history) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p.history = p.history[begin:]
|
p.history[ch] = p.history[ch][begin:]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TLDRPlugin) getTopics() []string {
|
func (p *TLDRPlugin) getTopics() []string {
|
||||||
hist := []string{}
|
hist := []string{}
|
||||||
for _, h := range p.history {
|
for _, ch := range p.history {
|
||||||
hist = append(hist, h.body)
|
for _, h := range ch {
|
||||||
|
hist = append(hist, h.body)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return hist
|
return hist
|
||||||
}
|
}
|
||||||
|
@ -216,6 +216,7 @@ func min(slice []float64) (float64, int) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TLDRPlugin) betterTLDR(r bot.Request) bool {
|
func (p *TLDRPlugin) betterTLDR(r bot.Request) bool {
|
||||||
|
ch := r.Msg.Channel
|
||||||
c, err := p.getClient()
|
c, err := p.getClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't fetch an OpenAI client")
|
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't fetch an OpenAI client")
|
||||||
|
@ -227,10 +228,14 @@ func (p *TLDRPlugin) betterTLDR(r bot.Request) bool {
|
||||||
data := p.c.GetMap("tldr.promptdata", map[string]string{})
|
data := p.c.GetMap("tldr.promptdata", map[string]string{})
|
||||||
promptTpl.Execute(&prompt, data)
|
promptTpl.Execute(&prompt, data)
|
||||||
backlog := ""
|
backlog := ""
|
||||||
for _, h := range p.history {
|
maxLen := p.c.GetInt("tldr.maxgpt", 4096)
|
||||||
if h.channel == r.Msg.Channel {
|
for i := len(p.history[ch]) - 1; i >= 0; i-- {
|
||||||
backlog += fmt.Sprintf("%s: %s\n", h.user, h.body)
|
h := p.history[ch][i]
|
||||||
|
str := fmt.Sprintf("%s: %s\n", h.user, h.body)
|
||||||
|
if len(backlog) > maxLen {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
backlog = str + backlog
|
||||||
}
|
}
|
||||||
sess := c.NewChatSession(prompt.String())
|
sess := c.NewChatSession(prompt.String())
|
||||||
completion, err := sess.Complete(context.TODO(), backlog)
|
completion, err := sess.Complete(context.TODO(), backlog)
|
||||||
|
|
|
@ -20,6 +20,8 @@ func init() {
|
||||||
log.Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr})
|
log.Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ch = "test"
|
||||||
|
|
||||||
func makeMessageBy(payload, by string) bot.Request {
|
func makeMessageBy(payload, by string) bot.Request {
|
||||||
isCmd := strings.HasPrefix(payload, "!")
|
isCmd := strings.HasPrefix(payload, "!")
|
||||||
if isCmd {
|
if isCmd {
|
||||||
|
@ -31,7 +33,7 @@ func makeMessageBy(payload, by string) bot.Request {
|
||||||
Kind: bot.Message,
|
Kind: bot.Message,
|
||||||
Msg: msg.Message{
|
Msg: msg.Message{
|
||||||
User: &user.User{Name: by},
|
User: &user.User{Name: by},
|
||||||
Channel: "test",
|
Channel: ch,
|
||||||
Body: payload,
|
Body: payload,
|
||||||
Command: isCmd,
|
Command: isCmd,
|
||||||
},
|
},
|
||||||
|
@ -61,7 +63,7 @@ func TestAddHistoryLimitsDays(t *testing.T) {
|
||||||
user: "tester",
|
user: "tester",
|
||||||
timestamp: t0.Add(time.Duration(i) * time.Hour),
|
timestamp: t0.Add(time.Duration(i) * time.Hour),
|
||||||
}
|
}
|
||||||
c.addHistory(hist)
|
c.addHistory(ch, hist)
|
||||||
}
|
}
|
||||||
assert.Len(t, c.history, expected, "%d != %d", len(c.history), expected)
|
assert.Len(t, c.history[ch], expected, "%d != %d", len(c.history), expected)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue