catbase/plugins/llm/gpt.go

167 lines
4.1 KiB
Go
Raw Normal View History

2024-05-11 17:56:29 +00:00
package llm
2024-05-11 14:37:57 +00:00
import (
"errors"
"fmt"
2024-09-27 13:51:25 +00:00
"github.com/google/generative-ai-go/genai"
2024-05-11 14:37:57 +00:00
"github.com/rs/zerolog/log"
"github.com/velour/catbase/bot"
"github.com/velour/catbase/config"
"regexp"
2024-09-27 21:12:47 +00:00
"time"
2024-05-11 14:37:57 +00:00
)
const gpt3URL = "https://api.openai.com/v1/engines/%s/completions"
const gpt3ModURL = "https://api.openai.com/v1/moderations"
2024-05-11 17:56:29 +00:00
type LLMPlugin struct {
2024-05-11 14:37:57 +00:00
b bot.Bot
c *config.Config
h bot.HandlerTable
chatCount int
chatHistory []chatEntry
2024-09-27 13:51:25 +00:00
geminiClient *genai.Client
2024-05-11 14:37:57 +00:00
}
type chatEntry struct {
2024-09-27 21:12:47 +00:00
Role string `json:"role"`
Content string `json:"content"`
TS time.Time `json:"ts"`
2024-05-11 14:37:57 +00:00
}
2024-05-11 17:56:29 +00:00
func New(b bot.Bot) *LLMPlugin {
p := &LLMPlugin{
2024-05-11 14:37:57 +00:00
b: b,
c: b.Config(),
}
p.register()
return p
}
2024-05-11 17:56:29 +00:00
func (p *LLMPlugin) register() {
2024-05-11 14:37:57 +00:00
p.h = bot.HandlerTable{
2024-05-11 18:29:43 +00:00
{
Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`(?is)^gpt-prompt: (?P<text>.*)`),
HelpText: "set the ChatGPT prompt",
Handler: p.setPromptMessage,
},
2024-05-11 18:12:04 +00:00
{
Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`(?is)^llm (?P<text>.*)`),
2024-05-24 14:32:52 +00:00
HelpText: "chat completion using first-available AI",
2024-09-27 13:51:25 +00:00
Handler: p.geminiChatMessage,
2024-05-11 18:12:04 +00:00
},
2024-05-11 14:37:57 +00:00
{
Kind: bot.Message, IsCmd: true,
2024-05-24 14:32:52 +00:00
Regex: regexp.MustCompile(`(?is)^gpt4 (?P<text>.*)`),
HelpText: "chat completion using OpenAI",
Handler: p.gptMessage,
2024-05-11 14:37:57 +00:00
},
{
Kind: bot.Message, IsCmd: true,
2024-05-24 14:35:49 +00:00
Regex: regexp.MustCompile(`(?is)^llm-puke$`),
HelpText: "clear chat history",
Handler: p.puke,
2024-05-11 14:37:57 +00:00
},
}
p.b.RegisterTable(p, p.h)
}
2024-05-11 17:56:29 +00:00
func (p *LLMPlugin) setPromptMessage(r bot.Request) bool {
2024-09-27 15:40:48 +00:00
p.c.Set("gemini.systemprompt", r.Values["text"])
2024-09-27 15:46:34 +00:00
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf(`Okay. I set the prompt to: "%s"`, r.Values["text"]))
2024-05-11 14:37:57 +00:00
return true
}
2024-09-27 21:12:47 +00:00
const defaultDuration = 15 * time.Minute
func (p *LLMPlugin) getChatHistory() []chatEntry {
horizonTxt := p.c.Get("gemini.horizon", defaultDuration.String())
dur, err := time.ParseDuration(horizonTxt)
if err != nil {
dur = defaultDuration
}
output := []chatEntry{}
for _, e := range p.chatHistory {
if e.TS.After(time.Now().Add(-dur)) {
output = append(output, e)
}
}
return output
}
func (p *LLMPlugin) addChatHistoryUser(content string) {
p.addChatHistory(chatEntry{
Role: "user",
Content: content,
})
}
func (p *LLMPlugin) addChatHistory(content chatEntry) {
content.TS = time.Now()
p.chatHistory = append(p.chatHistory, content)
}
2024-09-27 13:51:25 +00:00
func (p *LLMPlugin) geminiChatMessage(r bot.Request) bool {
if p.geminiClient == nil && p.geminiConnect() != nil {
log.Error().Msgf("Could not connect to Gemini")
return p.gptMessage(r)
}
chatResp, err := p.gemini(r.Values["text"])
if err != nil {
log.Error().Err(err).Send()
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("Problem with Gemini: %s", err))
return true
}
2024-09-27 21:12:47 +00:00
p.addChatHistoryUser(r.Values["text"])
p.addChatHistory(chatResp)
2024-09-27 13:51:25 +00:00
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, chatResp.Content)
log.Info().Msgf("Successfully used Gemini")
return true
}
func (p *LLMPlugin) ollamaChatMessage(r bot.Request) bool {
2024-05-11 14:37:57 +00:00
p.chatHistory = append(p.chatHistory, chatEntry{
Role: "user",
Content: r.Values["text"],
})
2024-05-11 18:29:43 +00:00
maxHist := p.c.GetInt("gpt.maxhist", 10)
if len(p.chatHistory) > maxHist {
p.chatHistory = p.chatHistory[len(p.chatHistory)-maxHist:]
}
2024-05-11 14:37:57 +00:00
chatResp, err := p.llama()
if err == nil {
p.chatHistory = append(p.chatHistory, chatResp)
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, chatResp.Content)
return true
} else if !errors.Is(err, InstanceNotFoundError) {
log.Error().Err(err).Msgf("error contacting llama")
} else {
2024-09-27 13:51:25 +00:00
log.Error().Msgf("llama is currently down")
2024-05-11 14:37:57 +00:00
}
2024-05-24 14:32:52 +00:00
return p.gptMessage(r)
}
func (p *LLMPlugin) gptMessage(r bot.Request) bool {
2024-05-11 14:37:57 +00:00
resp, err := p.chatGPT(r.Values["text"])
if err != nil {
resp = fmt.Sprintf("Error: %s", err)
}
p.chatHistory = append(p.chatHistory, chatEntry{
Role: "assistant",
Content: resp,
})
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, resp)
return true
}
2024-05-24 14:35:49 +00:00
func (p *LLMPlugin) puke(r bot.Request) bool {
resp := fmt.Sprintf("I just forgot %d lines of chat history.", len(p.chatHistory))
p.chatHistory = []chatEntry{}
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, resp)
return true
}