catbase/plugins/llm/llama.go

110 lines
2.6 KiB
Go
Raw Permalink Normal View History

2024-05-11 17:56:29 +00:00
package llm
2024-05-11 14:37:57 +00:00
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
)
var InstanceNotFoundError = errors.New("instance not found")
2024-05-11 18:12:04 +00:00
var empty = llamaResponse{}
2024-05-11 14:37:57 +00:00
2024-05-11 17:56:29 +00:00
func (g *LLMPlugin) llama() (chatEntry, error) {
2024-05-11 18:12:04 +00:00
llamaURL := g.c.GetArray("gpt.llamaurls", []string{})
if len(llamaURL) == 0 {
return chatEntry{}, fmt.Errorf("could not find llama url")
2024-05-11 14:37:57 +00:00
}
llamaModel := g.c.Get("gpt.llamamodel", "")
if llamaModel == "" {
2024-05-11 18:12:04 +00:00
return chatEntry{}, fmt.Errorf("could not find llama model")
2024-05-11 14:37:57 +00:00
}
2024-05-11 18:29:43 +00:00
prompt := g.c.Get("gpt.lastprompt", g.c.Get("gpt.prompt", ""))
hist := []chatEntry{{
Role: "system",
Content: prompt,
}}
hist = append(hist, g.chatHistory...)
2024-05-11 14:37:57 +00:00
req := llamaRequest{
Model: llamaModel,
2024-05-11 18:29:43 +00:00
Messages: hist,
2024-05-11 14:37:57 +00:00
Stream: false,
}
2024-05-11 18:12:04 +00:00
for _, u := range llamaURL {
2024-05-11 18:38:43 +00:00
if err := g.healthCheck(u); err != nil {
continue
}
2024-05-11 18:12:04 +00:00
llamaResp, err := mkRequest(u, req)
if err != nil {
continue
}
return llamaResp.Message, nil
}
return chatEntry{}, InstanceNotFoundError
}
2024-05-11 18:38:43 +00:00
func (p *LLMPlugin) healthCheck(llamaURL string) error {
timeout := p.c.GetInt("gpt.timeout", 1000)
req, _ := http.NewRequest("get", llamaURL, nil)
client := http.Client{
Timeout: time.Duration(timeout) * time.Millisecond,
}
_, err := client.Do(req)
if err != nil {
return err
}
return nil
}
2024-05-11 18:12:04 +00:00
func mkRequest(llamaURL string, req llamaRequest) (llamaResponse, error) {
2024-05-11 14:37:57 +00:00
body, err := json.Marshal(req)
if err != nil {
return empty, fmt.Errorf("could not marshal llama request: %w", err)
}
resp, err := http.Post(llamaURL, "application/json", bytes.NewBuffer(body))
if err != nil {
return empty, fmt.Errorf("could not post llama request: %w", err)
}
if resp.StatusCode == 503 {
return empty, InstanceNotFoundError
}
body, _ = io.ReadAll(resp.Body)
llamaResp := llamaResponse{}
err = json.Unmarshal(body, &llamaResp)
if err != nil {
return empty, fmt.Errorf("could not unmarshal llama response: %w, raw: %s", err, string(body))
}
2024-05-11 18:12:04 +00:00
return llamaResp, nil
2024-05-11 14:37:57 +00:00
}
type llamaRequest struct {
Model string `json:"model"`
Stream bool `json:"stream"`
Messages []chatEntry `json:"messages"`
}
type llamaResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Message chatEntry `json:"message"`
DoneReason string `json:"done_reason"`
Done bool `json:"done"`
TotalDuration int64 `json:"total_duration"`
LoadDuration int `json:"load_duration"`
PromptEvalDuration int `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"`
EvalDuration int64 `json:"eval_duration"`
}