mirror of https://github.com/velour/catbase.git
Compare commits
No commits in common. "68738f847b3dff439ee442d3a9e8cea057762c84" and "91d21c10766d8578705ee8152ba2a85446320ed6" have entirely different histories.
68738f847b
...
91d21c1076
|
@ -4,12 +4,10 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io/ioutil"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/velour/catbase/bot"
|
"github.com/velour/catbase/bot"
|
||||||
|
@ -17,7 +15,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const gpt3URL = "https://api.openai.com/v1/engines/%s/completions"
|
const gpt3URL = "https://api.openai.com/v1/engines/%s/completions"
|
||||||
const gpt3ModURL = "https://api.openai.com/v1/moderations"
|
|
||||||
|
|
||||||
type GPT3Plugin struct {
|
type GPT3Plugin struct {
|
||||||
b bot.Bot
|
b bot.Bot
|
||||||
|
@ -49,15 +46,7 @@ func (p *GPT3Plugin) register() {
|
||||||
|
|
||||||
func (p *GPT3Plugin) message(r bot.Request) bool {
|
func (p *GPT3Plugin) message(r bot.Request) bool {
|
||||||
stem := r.Values["text"]
|
stem := r.Values["text"]
|
||||||
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, p.mkRequest(stem))
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *GPT3Plugin) mkRequest(stem string) string {
|
|
||||||
log.Debug().Msgf("Got GPT3 request: %s", stem)
|
log.Debug().Msgf("Got GPT3 request: %s", stem)
|
||||||
if err := p.checkStem(stem); err != nil {
|
|
||||||
return "GPT3 moderation " + err.Error()
|
|
||||||
}
|
|
||||||
postStruct := gpt3Request{
|
postStruct := gpt3Request{
|
||||||
Prompt: stem,
|
Prompt: stem,
|
||||||
MaxTokens: p.c.GetInt("gpt3.tokens", 16),
|
MaxTokens: p.c.GetInt("gpt3.tokens", 16),
|
||||||
|
@ -73,21 +62,22 @@ func (p *GPT3Plugin) mkRequest(stem string) string {
|
||||||
req, err := http.NewRequest("POST", u, bytes.NewBuffer(postBody))
|
req, err := http.NewRequest("POST", u, bytes.NewBuffer(postBody))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("could not make gpt3 request")
|
log.Error().Err(err).Msg("could not make gpt3 request")
|
||||||
return err.Error()
|
return false
|
||||||
}
|
}
|
||||||
gpt3Key := p.c.Get("gpt3.bearer", "")
|
gpt3Key := p.c.Get("gpt3.bearer", "")
|
||||||
if gpt3Key == "" {
|
if gpt3Key == "" {
|
||||||
log.Error().Msgf("no GPT3 key given")
|
log.Error().Msgf("no GPT3 key given")
|
||||||
return "No GPT3 API key"
|
return false
|
||||||
}
|
}
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", gpt3Key))
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", gpt3Key))
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err.Error()
|
log.Error().Err(err).Msg("could not make gpt3 request")
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
resBody, _ := io.ReadAll(res.Body)
|
resBody, _ := ioutil.ReadAll(res.Body)
|
||||||
gpt3Resp := gpt3Response{}
|
gpt3Resp := gpt3Response{}
|
||||||
err = json.Unmarshal(resBody, &gpt3Resp)
|
err = json.Unmarshal(resBody, &gpt3Resp)
|
||||||
|
|
||||||
|
@ -100,53 +90,8 @@ func (p *GPT3Plugin) mkRequest(stem string) string {
|
||||||
if len(gpt3Resp.Choices) > 0 {
|
if len(gpt3Resp.Choices) > 0 {
|
||||||
msg = gpt3Resp.Choices[rand.Intn(len(gpt3Resp.Choices))].Text
|
msg = gpt3Resp.Choices[rand.Intn(len(gpt3Resp.Choices))].Text
|
||||||
}
|
}
|
||||||
return msg
|
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg)
|
||||||
}
|
return true
|
||||||
|
|
||||||
func (p *GPT3Plugin) checkStem(stem string) error {
|
|
||||||
if !p.c.GetBool("gpt3.moderation", true) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
postBody, _ := json.Marshal(gpt3ModRequest{Input: stem})
|
|
||||||
client := &http.Client{}
|
|
||||||
req, err := http.NewRequest("POST", gpt3ModURL, bytes.NewBuffer(postBody))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
gpt3Key := p.c.Get("gpt3.bearer", "")
|
|
||||||
if gpt3Key == "" {
|
|
||||||
return fmt.Errorf("no GPT3 API key")
|
|
||||||
}
|
|
||||||
req.Header.Add("Content-Type", "application/json")
|
|
||||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", gpt3Key))
|
|
||||||
res, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
resBody, _ := io.ReadAll(res.Body)
|
|
||||||
log.Debug().Str("resBody", string(resBody)).Msg("res")
|
|
||||||
gpt3Resp := gpt3Moderation{}
|
|
||||||
err = json.Unmarshal(resBody, &gpt3Resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Debug().Interface("GPT3 Moderation", gpt3Resp).Msg("Moderation result")
|
|
||||||
for _, res := range gpt3Resp.Results {
|
|
||||||
if res.Flagged {
|
|
||||||
list := ""
|
|
||||||
categories := reflect.ValueOf(res.Categories)
|
|
||||||
fields := reflect.VisibleFields(reflect.TypeOf(res.Categories))
|
|
||||||
for i := 0; i < categories.NumField(); i++ {
|
|
||||||
if categories.Field(i).Bool() {
|
|
||||||
list += fields[i].Name + ", "
|
|
||||||
}
|
|
||||||
}
|
|
||||||
list = strings.TrimSuffix(list, ", ")
|
|
||||||
return fmt.Errorf("flagged: %s", list)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type gpt3Request struct {
|
type gpt3Request struct {
|
||||||
|
@ -161,10 +106,6 @@ type gpt3Request struct {
|
||||||
Echo bool `json:"echo"`
|
Echo bool `json:"echo"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type gpt3ModRequest struct {
|
|
||||||
Input string `json:"input"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type gpt3Response struct {
|
type gpt3Response struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
|
@ -177,29 +118,3 @@ type gpt3Response struct {
|
||||||
FinishReason string `json:"finish_reason"`
|
FinishReason string `json:"finish_reason"`
|
||||||
} `json:"choices"`
|
} `json:"choices"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type gpt3Moderation struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Results []struct {
|
|
||||||
Categories struct {
|
|
||||||
Hate bool `json:"hate"`
|
|
||||||
HateThreatening bool `json:"hate/threatening"`
|
|
||||||
SelfHarm bool `json:"self-harm"`
|
|
||||||
Sexual bool `json:"sexual"`
|
|
||||||
SexualMinors bool `json:"sexual/minors"`
|
|
||||||
Violence bool `json:"violence"`
|
|
||||||
ViolenceGraphic bool `json:"violence/graphic"`
|
|
||||||
} `json:"categories"`
|
|
||||||
CategoryScores struct {
|
|
||||||
Hate float64 `json:"hate"`
|
|
||||||
HateThreatening float64 `json:"hate/threatening"`
|
|
||||||
SelfHarm float64 `json:"self-harm"`
|
|
||||||
Sexual float64 `json:"sexual"`
|
|
||||||
SexualMinors float64 `json:"sexual/minors"`
|
|
||||||
Violence float64 `json:"violence"`
|
|
||||||
ViolenceGraphic float64 `json:"violence/graphic"`
|
|
||||||
} `json:"category_scores"`
|
|
||||||
Flagged bool `json:"flagged"`
|
|
||||||
} `json:"results"`
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue