diff --git a/plugins/gpt2/gpt2.go b/plugins/gpt2/gpt2.go index da03f69..4e25355 100644 --- a/plugins/gpt2/gpt2.go +++ b/plugins/gpt2/gpt2.go @@ -29,11 +29,13 @@ func New(b bot.Bot) *GPT2Plugin { return p } +const prefix = "gpt2" + func (p *GPT2Plugin) message(c bot.Connector, kind bot.Kind, message msg.Message, args ...interface{}) bool { ch := message.Channel lowerBody := strings.ToLower(message.Body) - if message.Command && strings.HasPrefix(lowerBody, "gpt2") { - input := strings.TrimPrefix(lowerBody, "gpt2") + if message.Command && strings.HasPrefix(lowerBody, prefix) { + input := message.Body[len(prefix)+1:] p.b.Send(c, bot.Message, ch, p.getGPTText(input)) return true } @@ -46,6 +48,8 @@ func (p *GPT2Plugin) help(c bot.Connector, kind bot.Kind, message msg.Message, a return true } +const separator = "<|endoftext|>" + func (p *GPT2Plugin) getGPTText(input string) string { serviceURI := p.c.Get("gpt.service", "") if serviceURI == "" { @@ -62,7 +66,16 @@ func (p *GPT2Plugin) getGPTText(input string) string { return fmt.Sprintf("Error reading GPT2 response: %s", err) } resp.Body.Close() - txt := string(body) + txt := p.cleanup(input + string(body)) + return txt +} + +func (p *GPT2Plugin) cleanup(txt string) string { + txt = strings.Split(txt, separator)[0] + if !strings.HasSuffix(txt, ".") && strings.Count(txt, ".") > 1 { + idx := strings.LastIndexAny(txt, ".") + txt = txt[:idx+1] + } txt = strings.TrimSpace(txt) return txt }