mirror of https://github.com/velour/catbase.git
Compare commits
4 Commits
d6bb334a14
...
6f3ba974e6
Author | SHA1 | Date |
---|---|---|
Chris Sexton | 6f3ba974e6 | |
Chris Sexton | f5fc3b542e | |
Chris Sexton | 276f6c188e | |
Chris Sexton | ed4136aa60 |
|
@ -1,4 +1,4 @@
|
||||||
package gpt
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -10,7 +10,7 @@ import (
|
||||||
var session openai.ChatSession
|
var session openai.ChatSession
|
||||||
var client *openai.Client
|
var client *openai.Client
|
||||||
|
|
||||||
func (p *GPTPlugin) getClient() (*openai.Client, error) {
|
func (p *LLMPlugin) getClient() (*openai.Client, error) {
|
||||||
token := p.c.Get("gpt.token", "")
|
token := p.c.Get("gpt.token", "")
|
||||||
if token == "" {
|
if token == "" {
|
||||||
return nil, fmt.Errorf("no GPT token given")
|
return nil, fmt.Errorf("no GPT token given")
|
||||||
|
@ -18,7 +18,7 @@ func (p *GPTPlugin) getClient() (*openai.Client, error) {
|
||||||
return openai.NewClient(token)
|
return openai.NewClient(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GPTPlugin) chatGPT(request string) (string, error) {
|
func (p *LLMPlugin) chatGPT(request string) (string, error) {
|
||||||
if client == nil {
|
if client == nil {
|
||||||
if err := p.setPrompt(p.getDefaultPrompt()); err != nil {
|
if err := p.setPrompt(p.getDefaultPrompt()); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -32,11 +32,11 @@ func (p *GPTPlugin) chatGPT(request string) (string, error) {
|
||||||
return session.Complete(context.Background(), request)
|
return session.Complete(context.Background(), request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GPTPlugin) getDefaultPrompt() string {
|
func (p *LLMPlugin) getDefaultPrompt() string {
|
||||||
return p.c.Get("gpt.prompt", "")
|
return p.c.Get("gpt.prompt", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GPTPlugin) setPrompt(prompt string) error {
|
func (p *LLMPlugin) setPrompt(prompt string) error {
|
||||||
var err error
|
var err error
|
||||||
client, err = p.getClient()
|
client, err = p.getClient()
|
||||||
if err != nil {
|
if err != nil {
|
|
@ -1,4 +1,4 @@
|
||||||
package gpt
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -13,7 +13,7 @@ import (
|
||||||
const gpt3URL = "https://api.openai.com/v1/engines/%s/completions"
|
const gpt3URL = "https://api.openai.com/v1/engines/%s/completions"
|
||||||
const gpt3ModURL = "https://api.openai.com/v1/moderations"
|
const gpt3ModURL = "https://api.openai.com/v1/moderations"
|
||||||
|
|
||||||
type GPTPlugin struct {
|
type LLMPlugin struct {
|
||||||
b bot.Bot
|
b bot.Bot
|
||||||
c *config.Config
|
c *config.Config
|
||||||
h bot.HandlerTable
|
h bot.HandlerTable
|
||||||
|
@ -27,8 +27,8 @@ type chatEntry struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(b bot.Bot) *GPTPlugin {
|
func New(b bot.Bot) *LLMPlugin {
|
||||||
p := &GPTPlugin{
|
p := &LLMPlugin{
|
||||||
b: b,
|
b: b,
|
||||||
c: b.Config(),
|
c: b.Config(),
|
||||||
}
|
}
|
||||||
|
@ -36,8 +36,20 @@ func New(b bot.Bot) *GPTPlugin {
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GPTPlugin) register() {
|
func (p *LLMPlugin) register() {
|
||||||
p.h = bot.HandlerTable{
|
p.h = bot.HandlerTable{
|
||||||
|
{
|
||||||
|
Kind: bot.Message, IsCmd: true,
|
||||||
|
Regex: regexp.MustCompile(`(?is)^gpt-prompt: (?P<text>.*)`),
|
||||||
|
HelpText: "set the ChatGPT prompt",
|
||||||
|
Handler: p.setPromptMessage,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Kind: bot.Message, IsCmd: true,
|
||||||
|
Regex: regexp.MustCompile(`(?is)^llm (?P<text>.*)`),
|
||||||
|
HelpText: "chat completion",
|
||||||
|
Handler: p.chatMessageForce,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Kind: bot.Message, IsCmd: true,
|
Kind: bot.Message, IsCmd: true,
|
||||||
Regex: regexp.MustCompile(`(?is)^gpt (?P<text>.*)`),
|
Regex: regexp.MustCompile(`(?is)^gpt (?P<text>.*)`),
|
||||||
|
@ -50,17 +62,11 @@ func (p *GPTPlugin) register() {
|
||||||
HelpText: "chat completion",
|
HelpText: "chat completion",
|
||||||
Handler: p.chatMessageForce,
|
Handler: p.chatMessageForce,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Kind: bot.Message, IsCmd: true,
|
|
||||||
Regex: regexp.MustCompile(`(?is)^gpt-prompt: (?P<text>.*)`),
|
|
||||||
HelpText: "set the ChatGPT prompt",
|
|
||||||
Handler: p.setPromptMessage,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
p.b.RegisterTable(p, p.h)
|
p.b.RegisterTable(p, p.h)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GPTPlugin) setPromptMessage(r bot.Request) bool {
|
func (p *LLMPlugin) setPromptMessage(r bot.Request) bool {
|
||||||
prompt := r.Values["text"]
|
prompt := r.Values["text"]
|
||||||
if err := p.setPrompt(prompt); err != nil {
|
if err := p.setPrompt(prompt); err != nil {
|
||||||
resp := fmt.Sprintf("Error: %s", err)
|
resp := fmt.Sprintf("Error: %s", err)
|
||||||
|
@ -70,7 +76,7 @@ func (p *GPTPlugin) setPromptMessage(r bot.Request) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GPTPlugin) chatMessage(r bot.Request) bool {
|
func (p *LLMPlugin) chatMessage(r bot.Request) bool {
|
||||||
if slices.Contains(p.c.GetArray("gpt.silence", []string{}), r.Msg.Channel) {
|
if slices.Contains(p.c.GetArray("gpt.silence", []string{}), r.Msg.Channel) {
|
||||||
log.Debug().Msgf("%s silenced", r.Msg.Channel)
|
log.Debug().Msgf("%s silenced", r.Msg.Channel)
|
||||||
return true
|
return true
|
||||||
|
@ -78,11 +84,15 @@ func (p *GPTPlugin) chatMessage(r bot.Request) bool {
|
||||||
return p.chatMessageForce(r)
|
return p.chatMessageForce(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GPTPlugin) chatMessageForce(r bot.Request) bool {
|
func (p *LLMPlugin) chatMessageForce(r bot.Request) bool {
|
||||||
p.chatHistory = append(p.chatHistory, chatEntry{
|
p.chatHistory = append(p.chatHistory, chatEntry{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: r.Values["text"],
|
Content: r.Values["text"],
|
||||||
})
|
})
|
||||||
|
maxHist := p.c.GetInt("gpt.maxhist", 10)
|
||||||
|
if len(p.chatHistory) > maxHist {
|
||||||
|
p.chatHistory = p.chatHistory[len(p.chatHistory)-maxHist:]
|
||||||
|
}
|
||||||
chatResp, err := p.llama()
|
chatResp, err := p.llama()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
p.chatHistory = append(p.chatHistory, chatResp)
|
p.chatHistory = append(p.chatHistory, chatResp)
|
|
@ -1,4 +1,4 @@
|
||||||
package gpt
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
@ -11,24 +11,60 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var InstanceNotFoundError = errors.New("instance not found")
|
var InstanceNotFoundError = errors.New("instance not found")
|
||||||
var empty = chatEntry{}
|
var empty = llamaResponse{}
|
||||||
|
|
||||||
func (g *GPTPlugin) llama() (chatEntry, error) {
|
func (g *LLMPlugin) llama() (chatEntry, error) {
|
||||||
llamaURL := g.c.Get("gpt.llamaurl", "")
|
llamaURL := g.c.GetArray("gpt.llamaurls", []string{})
|
||||||
if llamaURL == "" {
|
if len(llamaURL) == 0 {
|
||||||
return empty, fmt.Errorf("could not find llama url")
|
return chatEntry{}, fmt.Errorf("could not find llama url")
|
||||||
}
|
}
|
||||||
llamaModel := g.c.Get("gpt.llamamodel", "")
|
llamaModel := g.c.Get("gpt.llamamodel", "")
|
||||||
if llamaModel == "" {
|
if llamaModel == "" {
|
||||||
return empty, fmt.Errorf("could not find llama model")
|
return chatEntry{}, fmt.Errorf("could not find llama model")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prompt := g.c.Get("gpt.lastprompt", g.c.Get("gpt.prompt", ""))
|
||||||
|
hist := []chatEntry{{
|
||||||
|
Role: "system",
|
||||||
|
Content: prompt,
|
||||||
|
}}
|
||||||
|
hist = append(hist, g.chatHistory...)
|
||||||
|
|
||||||
req := llamaRequest{
|
req := llamaRequest{
|
||||||
Model: llamaModel,
|
Model: llamaModel,
|
||||||
Messages: g.chatHistory,
|
Messages: hist,
|
||||||
Stream: false,
|
Stream: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, u := range llamaURL {
|
||||||
|
if err := g.healthCheck(u); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
llamaResp, err := mkRequest(u, req)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return llamaResp.Message, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return chatEntry{}, InstanceNotFoundError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *LLMPlugin) healthCheck(llamaURL string) error {
|
||||||
|
timeout := p.c.GetInt("gpt.timeout", 1000)
|
||||||
|
req, _ := http.NewRequest("get", llamaURL, nil)
|
||||||
|
client := http.Client{
|
||||||
|
Timeout: time.Duration(timeout) * time.Millisecond,
|
||||||
|
}
|
||||||
|
_, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mkRequest(llamaURL string, req llamaRequest) (llamaResponse, error) {
|
||||||
body, err := json.Marshal(req)
|
body, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return empty, fmt.Errorf("could not marshal llama request: %w", err)
|
return empty, fmt.Errorf("could not marshal llama request: %w", err)
|
||||||
|
@ -50,7 +86,7 @@ func (g *GPTPlugin) llama() (chatEntry, error) {
|
||||||
return empty, fmt.Errorf("could not unmarshal llama response: %w, raw: %s", err, string(body))
|
return empty, fmt.Errorf("could not unmarshal llama response: %w, raw: %s", err, string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
return llamaResp.Message, nil
|
return llamaResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type llamaRequest struct {
|
type llamaRequest struct {
|
|
@ -22,10 +22,10 @@ import (
|
||||||
"github.com/velour/catbase/plugins/giphy"
|
"github.com/velour/catbase/plugins/giphy"
|
||||||
"github.com/velour/catbase/plugins/git"
|
"github.com/velour/catbase/plugins/git"
|
||||||
"github.com/velour/catbase/plugins/goals"
|
"github.com/velour/catbase/plugins/goals"
|
||||||
"github.com/velour/catbase/plugins/gpt"
|
|
||||||
"github.com/velour/catbase/plugins/inventory"
|
"github.com/velour/catbase/plugins/inventory"
|
||||||
"github.com/velour/catbase/plugins/last"
|
"github.com/velour/catbase/plugins/last"
|
||||||
"github.com/velour/catbase/plugins/leftpad"
|
"github.com/velour/catbase/plugins/leftpad"
|
||||||
|
"github.com/velour/catbase/plugins/llm"
|
||||||
"github.com/velour/catbase/plugins/mayi"
|
"github.com/velour/catbase/plugins/mayi"
|
||||||
"github.com/velour/catbase/plugins/meme"
|
"github.com/velour/catbase/plugins/meme"
|
||||||
"github.com/velour/catbase/plugins/nerdepedia"
|
"github.com/velour/catbase/plugins/nerdepedia"
|
||||||
|
@ -101,7 +101,7 @@ func Register(b bot.Bot) {
|
||||||
b.AddPlugin(topic.New(b))
|
b.AddPlugin(topic.New(b))
|
||||||
b.AddPlugin(talker.New(b))
|
b.AddPlugin(talker.New(b))
|
||||||
b.AddPlugin(fact.New(b))
|
b.AddPlugin(fact.New(b))
|
||||||
b.AddPlugin(gpt.New(b))
|
b.AddPlugin(llm.New(b))
|
||||||
// catches anything left, will always return true
|
// catches anything left, will always return true
|
||||||
b.AddPlugin(deadend.New(b))
|
b.AddPlugin(deadend.New(b))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue