From 276f6c188ef998bc5686ea1e8cc43dfb3a8c9cf7 Mon Sep 17 00:00:00 2001 From: Chris Sexton <3216719+chrissexton@users.noreply.github.com> Date: Sat, 11 May 2024 14:12:04 -0400 Subject: [PATCH] llm: failover locally --- plugins/llm/gpt.go | 6 ++++++ plugins/llm/llama.go | 25 +++++++++++++++++++------ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/plugins/llm/gpt.go b/plugins/llm/gpt.go index 21b0d04..7a5e84a 100644 --- a/plugins/llm/gpt.go +++ b/plugins/llm/gpt.go @@ -38,6 +38,12 @@ func New(b bot.Bot) *LLMPlugin { func (p *LLMPlugin) register() { p.h = bot.HandlerTable{ + { + Kind: bot.Message, IsCmd: true, + Regex: regexp.MustCompile(`(?is)^llm (?P.*)`), + HelpText: "chat completion", + Handler: p.chatMessageForce, + }, { Kind: bot.Message, IsCmd: true, Regex: regexp.MustCompile(`(?is)^gpt (?P.*)`), diff --git a/plugins/llm/llama.go b/plugins/llm/llama.go index 975e854..edc9542 100644 --- a/plugins/llm/llama.go +++ b/plugins/llm/llama.go @@ -11,16 +11,16 @@ import ( ) var InstanceNotFoundError = errors.New("instance not found") -var empty = chatEntry{} +var empty = llamaResponse{} func (g *LLMPlugin) llama() (chatEntry, error) { - llamaURL := g.c.Get("gpt.llamaurl", "") - if llamaURL == "" { - return empty, fmt.Errorf("could not find llama url") + llamaURL := g.c.GetArray("gpt.llamaurls", []string{}) + if len(llamaURL) == 0 { + return chatEntry{}, fmt.Errorf("could not find llama url") } llamaModel := g.c.Get("gpt.llamamodel", "") if llamaModel == "" { - return empty, fmt.Errorf("could not find llama model") + return chatEntry{}, fmt.Errorf("could not find llama model") } req := llamaRequest{ @@ -29,6 +29,19 @@ func (g *LLMPlugin) llama() (chatEntry, error) { Stream: false, } + for _, u := range llamaURL { + llamaResp, err := mkRequest(u, req) + if err != nil { + continue + } + + return llamaResp.Message, nil + } + + return chatEntry{}, InstanceNotFoundError +} + +func mkRequest(llamaURL string, req llamaRequest) (llamaResponse, error) { body, err := json.Marshal(req) if err != nil { return empty, fmt.Errorf("could not marshal llama request: %w", err) @@ -50,7 +63,7 @@ func (g *LLMPlugin) llama() (chatEntry, error) { return empty, fmt.Errorf("could not unmarshal llama response: %w, raw: %s", err, string(body)) } - return llamaResp.Message, nil + return llamaResp, nil } type llamaRequest struct {