gpt2: clean up text

This commit is contained in:
Chris Sexton 2020-08-21 10:19:36 -04:00 committed by Chris Sexton
parent c6cb8053c0
commit cdeae1b019
1 changed files with 16 additions and 3 deletions

View File

@ -29,11 +29,13 @@ func New(b bot.Bot) *GPT2Plugin {
return p return p
} }
const prefix = "gpt2"
func (p *GPT2Plugin) message(c bot.Connector, kind bot.Kind, message msg.Message, args ...interface{}) bool { func (p *GPT2Plugin) message(c bot.Connector, kind bot.Kind, message msg.Message, args ...interface{}) bool {
ch := message.Channel ch := message.Channel
lowerBody := strings.ToLower(message.Body) lowerBody := strings.ToLower(message.Body)
if message.Command && strings.HasPrefix(lowerBody, "gpt2") { if message.Command && strings.HasPrefix(lowerBody, prefix) {
input := strings.TrimPrefix(lowerBody, "gpt2") input := message.Body[len(prefix)+1:]
p.b.Send(c, bot.Message, ch, p.getGPTText(input)) p.b.Send(c, bot.Message, ch, p.getGPTText(input))
return true return true
} }
@ -46,6 +48,8 @@ func (p *GPT2Plugin) help(c bot.Connector, kind bot.Kind, message msg.Message, a
return true return true
} }
const separator = "<|endoftext|>"
func (p *GPT2Plugin) getGPTText(input string) string { func (p *GPT2Plugin) getGPTText(input string) string {
serviceURI := p.c.Get("gpt.service", "") serviceURI := p.c.Get("gpt.service", "")
if serviceURI == "" { if serviceURI == "" {
@ -62,7 +66,16 @@ func (p *GPT2Plugin) getGPTText(input string) string {
return fmt.Sprintf("Error reading GPT2 response: %s", err) return fmt.Sprintf("Error reading GPT2 response: %s", err)
} }
resp.Body.Close() resp.Body.Close()
txt := string(body) txt := p.cleanup(input + string(body))
return txt
}
func (p *GPT2Plugin) cleanup(txt string) string {
txt = strings.Split(txt, separator)[0]
if !strings.HasSuffix(txt, ".") && strings.Count(txt, ".") > 1 {
idx := strings.LastIndexAny(txt, ".")
txt = txt[:idx+1]
}
txt = strings.TrimSpace(txt) txt = strings.TrimSpace(txt)
return txt return txt
} }