diff --git a/plugins/gpt2/gpt2.go b/plugins/gpt2/gpt2.go index 215a709..bd5b4a8 100644 --- a/plugins/gpt2/gpt2.go +++ b/plugins/gpt2/gpt2.go @@ -1,10 +1,10 @@ package gpt2 import ( + "bytes" + "encoding/json" "fmt" - "io/ioutil" "net/http" - "net/url" "strings" "github.com/velour/catbase/bot" @@ -54,24 +54,46 @@ func (p *GPT2Plugin) help(c bot.Connector, kind bot.Kind, message msg.Message, a const separator = "<|endoftext|>" -func (p *GPT2Plugin) getGPTText(input string) (string, error) { - serviceURI := p.c.Get("gpt.service", "") +func (p *GPT2Plugin) getGPTText(prefix string) (string, error) { + serviceURI := p.c.Get("gpt2.service", "") + if serviceURI == "" { return "", fmt.Errorf("cannot contact GPT2 service") } - values := url.Values{} - values.Add("text", input) - resp, err := http.PostForm(serviceURI, values) + + args := struct { + Prefix string `json:"prefix"` + Length int `json:"length"` + Temperature float64 `json:"temperature"` + TopP float64 `json:"top_p"` + TopK float64 `json:"top_k"` + }{ + Prefix: prefix, + Length: p.c.GetInt("gpt2.length", 50), + Temperature: p.c.GetFloat64("gpt2.temperature", 0.7), + TopK: p.c.GetFloat64("gpt2.topk", 0), + TopP: p.c.GetFloat64("gpt2.topp", 0), + } + + values, _ := json.Marshal(args) + + resp, err := http.Post(serviceURI, "application/json", bytes.NewBuffer(values)) if err != nil { return "", fmt.Errorf("error retrieving GPT2 response: %s", err) } - body, err := ioutil.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("error reading GPT2 response: %s", err) } - resp.Body.Close() - txt := p.cleanup(input + string(body)) - return txt, nil + defer resp.Body.Close() + dec := json.NewDecoder(resp.Body) + output := struct { + Text string `json:"text"` + }{} + err = dec.Decode(&output) + if err != nil { + return "", err + } + return p.cleanup(output.Text), nil } func (p *GPT2Plugin) cleanup(txt string) string {