diff --git a/plugins/llm/gpt.go b/plugins/llm/gpt.go index 7a5e84a..b0ad39b 100644 --- a/plugins/llm/gpt.go +++ b/plugins/llm/gpt.go @@ -38,6 +38,12 @@ func New(b bot.Bot) *LLMPlugin { func (p *LLMPlugin) register() { p.h = bot.HandlerTable{ + { + Kind: bot.Message, IsCmd: true, + Regex: regexp.MustCompile(`(?is)^gpt-prompt: (?P.*)`), + HelpText: "set the ChatGPT prompt", + Handler: p.setPromptMessage, + }, { Kind: bot.Message, IsCmd: true, Regex: regexp.MustCompile(`(?is)^llm (?P.*)`), @@ -56,12 +62,6 @@ func (p *LLMPlugin) register() { HelpText: "chat completion", Handler: p.chatMessageForce, }, - { - Kind: bot.Message, IsCmd: true, - Regex: regexp.MustCompile(`(?is)^gpt-prompt: (?P.*)`), - HelpText: "set the ChatGPT prompt", - Handler: p.setPromptMessage, - }, } p.b.RegisterTable(p, p.h) } @@ -89,6 +89,10 @@ func (p *LLMPlugin) chatMessageForce(r bot.Request) bool { Role: "user", Content: r.Values["text"], }) + maxHist := p.c.GetInt("gpt.maxhist", 10) + if len(p.chatHistory) > maxHist { + p.chatHistory = p.chatHistory[len(p.chatHistory)-maxHist:] + } chatResp, err := p.llama() if err == nil { p.chatHistory = append(p.chatHistory, chatResp) diff --git a/plugins/llm/llama.go b/plugins/llm/llama.go index edc9542..9e016b3 100644 --- a/plugins/llm/llama.go +++ b/plugins/llm/llama.go @@ -23,9 +23,16 @@ func (g *LLMPlugin) llama() (chatEntry, error) { return chatEntry{}, fmt.Errorf("could not find llama model") } + prompt := g.c.Get("gpt.lastprompt", g.c.Get("gpt.prompt", "")) + hist := []chatEntry{{ + Role: "system", + Content: prompt, + }} + hist = append(hist, g.chatHistory...) + req := llamaRequest{ Model: llamaModel, - Messages: g.chatHistory, + Messages: hist, Stream: false, }