Compare commits

..

4 Commits

Author SHA1 Message Date
Chris Sexton 6f3ba974e6 llm: add health check before request 2024-05-11 14:41:17 -04:00
Chris Sexton f5fc3b542e llm: cull history and use a prompt 2024-05-11 14:41:17 -04:00
Chris Sexton 276f6c188e llm: failover locally 2024-05-11 14:41:17 -04:00
Chris Sexton ed4136aa60 gpt: Rename package to llm 2024-05-11 14:41:17 -04:00
4 changed files with 76 additions and 30 deletions

View File

@ -1,4 +1,4 @@
package gpt package llm
import ( import (
"context" "context"
@ -10,7 +10,7 @@ import (
var session openai.ChatSession var session openai.ChatSession
var client *openai.Client var client *openai.Client
func (p *GPTPlugin) getClient() (*openai.Client, error) { func (p *LLMPlugin) getClient() (*openai.Client, error) {
token := p.c.Get("gpt.token", "") token := p.c.Get("gpt.token", "")
if token == "" { if token == "" {
return nil, fmt.Errorf("no GPT token given") return nil, fmt.Errorf("no GPT token given")
@ -18,7 +18,7 @@ func (p *GPTPlugin) getClient() (*openai.Client, error) {
return openai.NewClient(token) return openai.NewClient(token)
} }
func (p *GPTPlugin) chatGPT(request string) (string, error) { func (p *LLMPlugin) chatGPT(request string) (string, error) {
if client == nil { if client == nil {
if err := p.setPrompt(p.getDefaultPrompt()); err != nil { if err := p.setPrompt(p.getDefaultPrompt()); err != nil {
return "", err return "", err
@ -32,11 +32,11 @@ func (p *GPTPlugin) chatGPT(request string) (string, error) {
return session.Complete(context.Background(), request) return session.Complete(context.Background(), request)
} }
func (p *GPTPlugin) getDefaultPrompt() string { func (p *LLMPlugin) getDefaultPrompt() string {
return p.c.Get("gpt.prompt", "") return p.c.Get("gpt.prompt", "")
} }
func (p *GPTPlugin) setPrompt(prompt string) error { func (p *LLMPlugin) setPrompt(prompt string) error {
var err error var err error
client, err = p.getClient() client, err = p.getClient()
if err != nil { if err != nil {

View File

@ -1,4 +1,4 @@
package gpt package llm
import ( import (
"errors" "errors"
@ -13,7 +13,7 @@ import (
const gpt3URL = "https://api.openai.com/v1/engines/%s/completions" const gpt3URL = "https://api.openai.com/v1/engines/%s/completions"
const gpt3ModURL = "https://api.openai.com/v1/moderations" const gpt3ModURL = "https://api.openai.com/v1/moderations"
type GPTPlugin struct { type LLMPlugin struct {
b bot.Bot b bot.Bot
c *config.Config c *config.Config
h bot.HandlerTable h bot.HandlerTable
@ -27,8 +27,8 @@ type chatEntry struct {
Content string `json:"content"` Content string `json:"content"`
} }
func New(b bot.Bot) *GPTPlugin { func New(b bot.Bot) *LLMPlugin {
p := &GPTPlugin{ p := &LLMPlugin{
b: b, b: b,
c: b.Config(), c: b.Config(),
} }
@ -36,8 +36,20 @@ func New(b bot.Bot) *GPTPlugin {
return p return p
} }
func (p *GPTPlugin) register() { func (p *LLMPlugin) register() {
p.h = bot.HandlerTable{ p.h = bot.HandlerTable{
{
Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`(?is)^gpt-prompt: (?P<text>.*)`),
HelpText: "set the ChatGPT prompt",
Handler: p.setPromptMessage,
},
{
Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`(?is)^llm (?P<text>.*)`),
HelpText: "chat completion",
Handler: p.chatMessageForce,
},
{ {
Kind: bot.Message, IsCmd: true, Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`(?is)^gpt (?P<text>.*)`), Regex: regexp.MustCompile(`(?is)^gpt (?P<text>.*)`),
@ -50,17 +62,11 @@ func (p *GPTPlugin) register() {
HelpText: "chat completion", HelpText: "chat completion",
Handler: p.chatMessageForce, Handler: p.chatMessageForce,
}, },
{
Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`(?is)^gpt-prompt: (?P<text>.*)`),
HelpText: "set the ChatGPT prompt",
Handler: p.setPromptMessage,
},
} }
p.b.RegisterTable(p, p.h) p.b.RegisterTable(p, p.h)
} }
func (p *GPTPlugin) setPromptMessage(r bot.Request) bool { func (p *LLMPlugin) setPromptMessage(r bot.Request) bool {
prompt := r.Values["text"] prompt := r.Values["text"]
if err := p.setPrompt(prompt); err != nil { if err := p.setPrompt(prompt); err != nil {
resp := fmt.Sprintf("Error: %s", err) resp := fmt.Sprintf("Error: %s", err)
@ -70,7 +76,7 @@ func (p *GPTPlugin) setPromptMessage(r bot.Request) bool {
return true return true
} }
func (p *GPTPlugin) chatMessage(r bot.Request) bool { func (p *LLMPlugin) chatMessage(r bot.Request) bool {
if slices.Contains(p.c.GetArray("gpt.silence", []string{}), r.Msg.Channel) { if slices.Contains(p.c.GetArray("gpt.silence", []string{}), r.Msg.Channel) {
log.Debug().Msgf("%s silenced", r.Msg.Channel) log.Debug().Msgf("%s silenced", r.Msg.Channel)
return true return true
@ -78,11 +84,15 @@ func (p *GPTPlugin) chatMessage(r bot.Request) bool {
return p.chatMessageForce(r) return p.chatMessageForce(r)
} }
func (p *GPTPlugin) chatMessageForce(r bot.Request) bool { func (p *LLMPlugin) chatMessageForce(r bot.Request) bool {
p.chatHistory = append(p.chatHistory, chatEntry{ p.chatHistory = append(p.chatHistory, chatEntry{
Role: "user", Role: "user",
Content: r.Values["text"], Content: r.Values["text"],
}) })
maxHist := p.c.GetInt("gpt.maxhist", 10)
if len(p.chatHistory) > maxHist {
p.chatHistory = p.chatHistory[len(p.chatHistory)-maxHist:]
}
chatResp, err := p.llama() chatResp, err := p.llama()
if err == nil { if err == nil {
p.chatHistory = append(p.chatHistory, chatResp) p.chatHistory = append(p.chatHistory, chatResp)

View File

@ -1,4 +1,4 @@
package gpt package llm
import ( import (
"bytes" "bytes"
@ -11,24 +11,60 @@ import (
) )
var InstanceNotFoundError = errors.New("instance not found") var InstanceNotFoundError = errors.New("instance not found")
var empty = chatEntry{} var empty = llamaResponse{}
func (g *GPTPlugin) llama() (chatEntry, error) { func (g *LLMPlugin) llama() (chatEntry, error) {
llamaURL := g.c.Get("gpt.llamaurl", "") llamaURL := g.c.GetArray("gpt.llamaurls", []string{})
if llamaURL == "" { if len(llamaURL) == 0 {
return empty, fmt.Errorf("could not find llama url") return chatEntry{}, fmt.Errorf("could not find llama url")
} }
llamaModel := g.c.Get("gpt.llamamodel", "") llamaModel := g.c.Get("gpt.llamamodel", "")
if llamaModel == "" { if llamaModel == "" {
return empty, fmt.Errorf("could not find llama model") return chatEntry{}, fmt.Errorf("could not find llama model")
} }
prompt := g.c.Get("gpt.lastprompt", g.c.Get("gpt.prompt", ""))
hist := []chatEntry{{
Role: "system",
Content: prompt,
}}
hist = append(hist, g.chatHistory...)
req := llamaRequest{ req := llamaRequest{
Model: llamaModel, Model: llamaModel,
Messages: g.chatHistory, Messages: hist,
Stream: false, Stream: false,
} }
for _, u := range llamaURL {
if err := g.healthCheck(u); err != nil {
continue
}
llamaResp, err := mkRequest(u, req)
if err != nil {
continue
}
return llamaResp.Message, nil
}
return chatEntry{}, InstanceNotFoundError
}
func (p *LLMPlugin) healthCheck(llamaURL string) error {
timeout := p.c.GetInt("gpt.timeout", 1000)
req, _ := http.NewRequest("get", llamaURL, nil)
client := http.Client{
Timeout: time.Duration(timeout) * time.Millisecond,
}
_, err := client.Do(req)
if err != nil {
return err
}
return nil
}
func mkRequest(llamaURL string, req llamaRequest) (llamaResponse, error) {
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return empty, fmt.Errorf("could not marshal llama request: %w", err) return empty, fmt.Errorf("could not marshal llama request: %w", err)
@ -50,7 +86,7 @@ func (g *GPTPlugin) llama() (chatEntry, error) {
return empty, fmt.Errorf("could not unmarshal llama response: %w, raw: %s", err, string(body)) return empty, fmt.Errorf("could not unmarshal llama response: %w, raw: %s", err, string(body))
} }
return llamaResp.Message, nil return llamaResp, nil
} }
type llamaRequest struct { type llamaRequest struct {

View File

@ -22,10 +22,10 @@ import (
"github.com/velour/catbase/plugins/giphy" "github.com/velour/catbase/plugins/giphy"
"github.com/velour/catbase/plugins/git" "github.com/velour/catbase/plugins/git"
"github.com/velour/catbase/plugins/goals" "github.com/velour/catbase/plugins/goals"
"github.com/velour/catbase/plugins/gpt"
"github.com/velour/catbase/plugins/inventory" "github.com/velour/catbase/plugins/inventory"
"github.com/velour/catbase/plugins/last" "github.com/velour/catbase/plugins/last"
"github.com/velour/catbase/plugins/leftpad" "github.com/velour/catbase/plugins/leftpad"
"github.com/velour/catbase/plugins/llm"
"github.com/velour/catbase/plugins/mayi" "github.com/velour/catbase/plugins/mayi"
"github.com/velour/catbase/plugins/meme" "github.com/velour/catbase/plugins/meme"
"github.com/velour/catbase/plugins/nerdepedia" "github.com/velour/catbase/plugins/nerdepedia"
@ -101,7 +101,7 @@ func Register(b bot.Bot) {
b.AddPlugin(topic.New(b)) b.AddPlugin(topic.New(b))
b.AddPlugin(talker.New(b)) b.AddPlugin(talker.New(b))
b.AddPlugin(fact.New(b)) b.AddPlugin(fact.New(b))
b.AddPlugin(gpt.New(b)) b.AddPlugin(llm.New(b))
// catches anything left, will always return true // catches anything left, will always return true
b.AddPlugin(deadend.New(b)) b.AddPlugin(deadend.New(b))
} }