diff --git a/plugins/gpt2/gpt2.go b/plugins/gpt2/gpt2.go index 4e25355..215a709 100644 --- a/plugins/gpt2/gpt2.go +++ b/plugins/gpt2/gpt2.go @@ -36,7 +36,11 @@ func (p *GPT2Plugin) message(c bot.Connector, kind bot.Kind, message msg.Message lowerBody := strings.ToLower(message.Body) if message.Command && strings.HasPrefix(lowerBody, prefix) { input := message.Body[len(prefix)+1:] - p.b.Send(c, bot.Message, ch, p.getGPTText(input)) + txt, err := p.getGPTText(input) + if err != nil { + txt = p.c.Get("gpt.error", "The GPT service is unavailable.") + } + p.b.Send(c, bot.Message, ch, txt) return true } return false @@ -50,24 +54,24 @@ func (p *GPT2Plugin) help(c bot.Connector, kind bot.Kind, message msg.Message, a const separator = "<|endoftext|>" -func (p *GPT2Plugin) getGPTText(input string) string { +func (p *GPT2Plugin) getGPTText(input string) (string, error) { serviceURI := p.c.Get("gpt.service", "") if serviceURI == "" { - return "Cannot contact GPT2 service." + return "", fmt.Errorf("cannot contact GPT2 service") } values := url.Values{} values.Add("text", input) resp, err := http.PostForm(serviceURI, values) if err != nil { - return fmt.Sprintf("Error retrieving GPT2 response: %s", err) + return "", fmt.Errorf("error retrieving GPT2 response: %s", err) } body, err := ioutil.ReadAll(resp.Body) if err != nil { - return fmt.Sprintf("Error reading GPT2 response: %s", err) + return "", fmt.Errorf("error reading GPT2 response: %s", err) } resp.Body.Close() txt := p.cleanup(input + string(body)) - return txt + return txt, nil } func (p *GPT2Plugin) cleanup(txt string) string {