diff --git a/plugins/gpt3/gpt3.go b/plugins/gpt3/gpt3.go index 7a41c5c..f4a9f30 100644 --- a/plugins/gpt3/gpt3.go +++ b/plugins/gpt3/gpt3.go @@ -14,7 +14,7 @@ import ( "github.com/velour/catbase/config" ) -const gpt3URL = "https://api.openai.com/v1/engines/ada/completions" +const gpt3URL = "https://api.openai.com/v1/engines/%s/completions" type GPT3Plugin struct { b bot.Bot @@ -35,7 +35,7 @@ func (p *GPT3Plugin) register() { p.h = bot.HandlerTable{ { Kind: bot.Message, IsCmd: true, - Regex: regexp.MustCompile(`^gpt3 (?P.*)$`), + Regex: regexp.MustCompile(`(?is)^gpt3 (?P.*)`), HelpText: "request text completion", Handler: p.message, }, @@ -53,12 +53,13 @@ func (p *GPT3Plugin) message(r bot.Request) bool { Temperature: p.c.GetFloat64("gpt3.temperature", 1), TopP: p.c.GetFloat64("gpt3.top_p", 1), N: p.c.GetInt("gpt3.n", 1), - Stop: "\n", + Stop: p.c.GetArray("gpt3.stop", []string{"\n"}), Echo: true, } postBody, _ := json.Marshal(postStruct) client := &http.Client{} - req, err := http.NewRequest("POST", gpt3URL, bytes.NewBuffer(postBody)) + u := fmt.Sprintf(gpt3URL, p.c.Get("gpt3.engine", "ada")) + req, err := http.NewRequest("POST", u, bytes.NewBuffer(postBody)) if err != nil { log.Error().Err(err).Msg("could not make gpt3 request") return false @@ -101,7 +102,7 @@ type gpt3Request struct { N int `json:"n"` Stream bool `json:"stream"` Logprobs interface{} `json:"logprobs"` - Stop string `json:"stop"` + Stop []string `json:"stop"` Echo bool `json:"echo"` }