mirror of https://github.com/velour/catbase.git
Compare commits
3 Commits
2045fc4591
...
0cd1b5e7ed
Author | SHA1 | Date |
---|---|---|
|
0cd1b5e7ed | |
|
086bc9f0f9 | |
|
1e5a82d281 |
1
go.mod
1
go.mod
|
@ -31,6 +31,7 @@ require (
|
||||||
github.com/gonum/internal v0.0.0-20181124074243-f884aa714029 // indirect
|
github.com/gonum/internal v0.0.0-20181124074243-f884aa714029 // indirect
|
||||||
github.com/google/uuid v1.1.1
|
github.com/google/uuid v1.1.1
|
||||||
github.com/inconshreveable/log15 v0.0.0-20200109203555-b30bc20e4fd1 // indirect
|
github.com/inconshreveable/log15 v0.0.0-20200109203555-b30bc20e4fd1 // indirect
|
||||||
|
github.com/itchyny/gojq v0.12.3
|
||||||
github.com/james-bowman/nlp v0.0.0-20191016091239-d9dbfaff30c6
|
github.com/james-bowman/nlp v0.0.0-20191016091239-d9dbfaff30c6
|
||||||
github.com/james-bowman/sparse v0.0.0-20190423065201-80c6877364c7 // indirect
|
github.com/james-bowman/sparse v0.0.0-20190423065201-80c6877364c7 // indirect
|
||||||
github.com/jmoiron/sqlx v1.2.0
|
github.com/jmoiron/sqlx v1.2.0
|
||||||
|
|
12
go.sum
12
go.sum
|
@ -70,6 +70,7 @@ github.com/gonum/floats v0.0.0-20181209220543-c233463c7e82 h1:EvokxLQsaaQjcWVWSV
|
||||||
github.com/gonum/floats v0.0.0-20181209220543-c233463c7e82/go.mod h1:PxC8OnwL11+aosOB5+iEPoV3picfs8tUpkVd0pDo+Kg=
|
github.com/gonum/floats v0.0.0-20181209220543-c233463c7e82/go.mod h1:PxC8OnwL11+aosOB5+iEPoV3picfs8tUpkVd0pDo+Kg=
|
||||||
github.com/gonum/internal v0.0.0-20181124074243-f884aa714029 h1:8jtTdc+Nfj9AR+0soOeia9UZSvYBvETVHZrugUowJ7M=
|
github.com/gonum/internal v0.0.0-20181124074243-f884aa714029 h1:8jtTdc+Nfj9AR+0soOeia9UZSvYBvETVHZrugUowJ7M=
|
||||||
github.com/gonum/internal v0.0.0-20181124074243-f884aa714029/go.mod h1:Pu4dmpkhSyOzRwuXkOgAvijx4o+4YMUJJo9OvPYMkks=
|
github.com/gonum/internal v0.0.0-20181124074243-f884aa714029/go.mod h1:Pu4dmpkhSyOzRwuXkOgAvijx4o+4YMUJJo9OvPYMkks=
|
||||||
|
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
|
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
|
||||||
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH/Q=
|
github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH/Q=
|
||||||
|
@ -78,6 +79,11 @@ github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0U
|
||||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/inconshreveable/log15 v0.0.0-20200109203555-b30bc20e4fd1 h1:KUDFlmBg2buRWNzIcwLlKvfcnujcHQRQ1As1LoaCLAM=
|
github.com/inconshreveable/log15 v0.0.0-20200109203555-b30bc20e4fd1 h1:KUDFlmBg2buRWNzIcwLlKvfcnujcHQRQ1As1LoaCLAM=
|
||||||
github.com/inconshreveable/log15 v0.0.0-20200109203555-b30bc20e4fd1/go.mod h1:cOaXtrgN4ScfRrD9Bre7U1thNq5RtJ8ZoP4iXVGRj6o=
|
github.com/inconshreveable/log15 v0.0.0-20200109203555-b30bc20e4fd1/go.mod h1:cOaXtrgN4ScfRrD9Bre7U1thNq5RtJ8ZoP4iXVGRj6o=
|
||||||
|
github.com/itchyny/go-flags v1.5.0/go.mod h1:lenkYuCobuxLBAd/HGFE4LRoW8D3B6iXRQfWYJ+MNbA=
|
||||||
|
github.com/itchyny/gojq v0.12.3 h1:s7jTCyOk/dy5bnDIScj24YX4Cr1yhEO2iW/bQT4Pm2s=
|
||||||
|
github.com/itchyny/gojq v0.12.3/go.mod h1:mi4PdXSlFllHyByM68JKUrbiArtEdEnNEmjbwxcQKAg=
|
||||||
|
github.com/itchyny/timefmt-go v0.1.2 h1:q0Xa4P5it6K6D7ISsbLAMwx1PnWlixDcJL6/sFs93Hs=
|
||||||
|
github.com/itchyny/timefmt-go v0.1.2/go.mod h1:0osSSCQSASBJMsIZnhAaF1C2fCBTJZXrnj37mG8/c+A=
|
||||||
github.com/james-bowman/nlp v0.0.0-20191016091239-d9dbfaff30c6 h1:k8+n5sfvxlixRNVkbelPGzEYjbGIKaBnRzRlx2NCtYA=
|
github.com/james-bowman/nlp v0.0.0-20191016091239-d9dbfaff30c6 h1:k8+n5sfvxlixRNVkbelPGzEYjbGIKaBnRzRlx2NCtYA=
|
||||||
github.com/james-bowman/nlp v0.0.0-20191016091239-d9dbfaff30c6/go.mod h1:kixuaexEqWB+mHZNysgnb6mqgGIT25WvD1/tFRRt0J0=
|
github.com/james-bowman/nlp v0.0.0-20191016091239-d9dbfaff30c6/go.mod h1:kixuaexEqWB+mHZNysgnb6mqgGIT25WvD1/tFRRt0J0=
|
||||||
github.com/james-bowman/sparse v0.0.0-20190423065201-80c6877364c7 h1:ph/BDQQDL41apnHSN48I5GyNOQXXAlc79HwGqDSXCss=
|
github.com/james-bowman/sparse v0.0.0-20190423065201-80c6877364c7 h1:ph/BDQQDL41apnHSN48I5GyNOQXXAlc79HwGqDSXCss=
|
||||||
|
@ -101,6 +107,7 @@ github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+v
|
||||||
github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
|
github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
|
||||||
github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY=
|
github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY=
|
||||||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
||||||
|
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||||
github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||||
github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q=
|
github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q=
|
||||||
github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||||
|
@ -183,8 +190,11 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
|
||||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae h1:/WDfKMnPU+m5M4xB+6x4kaepxRw6jWvR5iDRdvjHgy8=
|
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae h1:/WDfKMnPU+m5M4xB+6x4kaepxRw6jWvR5iDRdvjHgy8=
|
||||||
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210301091718-77cc2087c03b h1:kHlr0tATeLRMEiZJu5CknOw/E8V6h69sXXQFGoPtjcc=
|
||||||
|
golang.org/x/sys v0.0.0-20210301091718-77cc2087c03b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
|
@ -194,6 +204,7 @@ golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGm
|
||||||
golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||||
golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
|
gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
|
||||||
gonum.org/v1/gonum v0.6.0 h1:DJy6UzXbahnGUf1ujUNkh/NEtK14qMo2nvlBPs4U5yw=
|
gonum.org/v1/gonum v0.6.0 h1:DJy6UzXbahnGUf1ujUNkh/NEtK14qMo2nvlBPs4U5yw=
|
||||||
gonum.org/v1/gonum v0.6.0/go.mod h1:9mxDZsDKxgMAuccQkewq682L+0eCu4dCN2yonUJTCLU=
|
gonum.org/v1/gonum v0.6.0/go.mod h1:9mxDZsDKxgMAuccQkewq682L+0eCu4dCN2yonUJTCLU=
|
||||||
|
@ -212,4 +223,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.2.4 h1:/eiJrUcujPVeJ3xlSWaiNi3uSVmDGBK1pDHUHAnao1I=
|
gopkg.in/yaml.v2 v2.2.4 h1:/eiJrUcujPVeJ3xlSWaiNi3uSVmDGBK1pDHUHAnao1I=
|
||||||
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||||
|
|
4
main.go
4
main.go
|
@ -13,8 +13,8 @@ import (
|
||||||
"github.com/velour/catbase/bot/msg"
|
"github.com/velour/catbase/bot/msg"
|
||||||
"github.com/velour/catbase/connectors/discord"
|
"github.com/velour/catbase/connectors/discord"
|
||||||
"github.com/velour/catbase/plugins/giphy"
|
"github.com/velour/catbase/plugins/giphy"
|
||||||
"github.com/velour/catbase/plugins/gpt2"
|
|
||||||
"github.com/velour/catbase/plugins/last"
|
"github.com/velour/catbase/plugins/last"
|
||||||
|
"github.com/velour/catbase/plugins/rest"
|
||||||
|
|
||||||
"github.com/velour/catbase/plugins/achievements"
|
"github.com/velour/catbase/plugins/achievements"
|
||||||
"github.com/velour/catbase/plugins/aoc"
|
"github.com/velour/catbase/plugins/aoc"
|
||||||
|
@ -122,7 +122,6 @@ func main() {
|
||||||
|
|
||||||
b.AddPlugin(admin.New(b))
|
b.AddPlugin(admin.New(b))
|
||||||
b.AddPlugin(giphy.New(b))
|
b.AddPlugin(giphy.New(b))
|
||||||
b.AddPlugin(gpt2.New(b))
|
|
||||||
b.AddPlugin(emojifyme.New(b))
|
b.AddPlugin(emojifyme.New(b))
|
||||||
b.AddPlugin(last.New(b))
|
b.AddPlugin(last.New(b))
|
||||||
b.AddPlugin(first.New(b))
|
b.AddPlugin(first.New(b))
|
||||||
|
@ -158,6 +157,7 @@ func main() {
|
||||||
b.AddPlugin(achievements.New(b))
|
b.AddPlugin(achievements.New(b))
|
||||||
b.AddPlugin(sms.New(b))
|
b.AddPlugin(sms.New(b))
|
||||||
b.AddPlugin(countdown.New(b))
|
b.AddPlugin(countdown.New(b))
|
||||||
|
b.AddPlugin(rest.New(b))
|
||||||
// catches anything left, will always return true
|
// catches anything left, will always return true
|
||||||
b.AddPlugin(fact.New(b))
|
b.AddPlugin(fact.New(b))
|
||||||
|
|
||||||
|
|
|
@ -1,103 +0,0 @@
|
||||||
package gpt2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/velour/catbase/bot"
|
|
||||||
"github.com/velour/catbase/bot/msg"
|
|
||||||
"github.com/velour/catbase/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
type GPT2Plugin struct {
|
|
||||||
b bot.Bot
|
|
||||||
c *config.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(b bot.Bot) *GPT2Plugin {
|
|
||||||
p := &GPT2Plugin{
|
|
||||||
b: b,
|
|
||||||
c: b.Config(),
|
|
||||||
}
|
|
||||||
|
|
||||||
b.RegisterRegexCmd(p, bot.Message, gpt2Regex, p.gpt2Cmd)
|
|
||||||
b.Register(p, bot.Help, p.help)
|
|
||||||
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
var gpt2Regex = regexp.MustCompile(`(?i)^gpt2 (?P<input>.*)$`)
|
|
||||||
|
|
||||||
func (p *GPT2Plugin) gpt2Cmd(r bot.Request) bool {
|
|
||||||
input := r.Values["input"]
|
|
||||||
txt, err := p.getGPTText(input)
|
|
||||||
if err != nil {
|
|
||||||
txt = p.c.Get("gpt.error", "The GPT service is unavailable.")
|
|
||||||
}
|
|
||||||
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, txt)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *GPT2Plugin) help(c bot.Connector, kind bot.Kind, message msg.Message, args ...interface{}) bool {
|
|
||||||
txt := "Invoke the GPT-2 API with: `!gpt2 <your seed text>"
|
|
||||||
p.b.Send(c, bot.Message, message.Channel, txt)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
const separator = "<|endoftext|>"
|
|
||||||
|
|
||||||
func (p *GPT2Plugin) getGPTText(prefix string) (string, error) {
|
|
||||||
serviceURI := p.c.Get("gpt2.service", "")
|
|
||||||
|
|
||||||
if serviceURI == "" {
|
|
||||||
return "", fmt.Errorf("cannot contact GPT2 service")
|
|
||||||
}
|
|
||||||
|
|
||||||
args := struct {
|
|
||||||
Prefix string `json:"prefix"`
|
|
||||||
Length int `json:"length"`
|
|
||||||
Temperature float64 `json:"temperature"`
|
|
||||||
TopP float64 `json:"top_p"`
|
|
||||||
TopK float64 `json:"top_k"`
|
|
||||||
}{
|
|
||||||
Prefix: prefix,
|
|
||||||
Length: p.c.GetInt("gpt2.length", 50),
|
|
||||||
Temperature: p.c.GetFloat64("gpt2.temperature", 0.7),
|
|
||||||
TopK: p.c.GetFloat64("gpt2.topk", 0),
|
|
||||||
TopP: p.c.GetFloat64("gpt2.topp", 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
values, _ := json.Marshal(args)
|
|
||||||
|
|
||||||
resp, err := http.Post(serviceURI, "application/json", bytes.NewBuffer(values))
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error retrieving GPT2 response: %s", err)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error reading GPT2 response: %s", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
dec := json.NewDecoder(resp.Body)
|
|
||||||
output := struct {
|
|
||||||
Text string `json:"text"`
|
|
||||||
}{}
|
|
||||||
err = dec.Decode(&output)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return p.cleanup(output.Text), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *GPT2Plugin) cleanup(txt string) string {
|
|
||||||
txt = strings.Split(txt, separator)[0]
|
|
||||||
if !strings.HasSuffix(txt, ".") && strings.Count(txt, ".") > 1 {
|
|
||||||
idx := strings.LastIndexAny(txt, ".")
|
|
||||||
txt = txt[:idx+1]
|
|
||||||
}
|
|
||||||
txt = strings.TrimSpace(txt)
|
|
||||||
return txt
|
|
||||||
}
|
|
|
@ -0,0 +1,303 @@
|
||||||
|
package rest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/itchyny/gojq"
|
||||||
|
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
|
"github.com/velour/catbase/bot"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RestPlugin struct {
|
||||||
|
b bot.Bot
|
||||||
|
db *sqlx.DB
|
||||||
|
|
||||||
|
handlers bot.HandlerTable
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(b bot.Bot) *RestPlugin {
|
||||||
|
p := &RestPlugin{
|
||||||
|
b: b,
|
||||||
|
db: b.DB(),
|
||||||
|
handlers: bot.HandlerTable{},
|
||||||
|
}
|
||||||
|
p.setupDB()
|
||||||
|
p.register()
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RestPlugin) setupDB() {
|
||||||
|
tx := p.db.MustBegin()
|
||||||
|
tx.MustExec(`
|
||||||
|
create table if not exists wires (
|
||||||
|
id integer primary key autoincrement,
|
||||||
|
url text not null,
|
||||||
|
parse_regex text not null,
|
||||||
|
return_field text not null
|
||||||
|
)`)
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RestPlugin) register() {
|
||||||
|
p.handlers = bot.HandlerTable{
|
||||||
|
bot.HandlerSpec{Kind: bot.Message, IsCmd: true,
|
||||||
|
Regex: regexp.MustCompile("(?i)^list wires$"),
|
||||||
|
HelpText: "Lists all REST functions",
|
||||||
|
Handler: p.listWires},
|
||||||
|
bot.HandlerSpec{Kind: bot.Message, IsCmd: true,
|
||||||
|
Regex: regexp.MustCompile(`(?i)^rm wire (?P<id>\d+)$`),
|
||||||
|
HelpText: "Removes a wire by ID (use list to view)",
|
||||||
|
Handler: p.rmWire},
|
||||||
|
bot.HandlerSpec{Kind: bot.Message, IsCmd: true,
|
||||||
|
Regex: regexp.MustCompile("(?i)^wire `(?P<parse>[^`]+)` to (?P<url>\\S+) `(?P<returnField>[^`]+)`$"),
|
||||||
|
HelpText: "Registers a new REST function",
|
||||||
|
Handler: p.handleWire},
|
||||||
|
}
|
||||||
|
p.b.RegisterTable(p, p.handlers)
|
||||||
|
wires, err := p.getWires()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
for _, w := range wires {
|
||||||
|
p.b.RegisterRegex(p, bot.Message, w.ParseRegex.Regexp, p.mkHandler(w))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ScanableRegexp struct {
|
||||||
|
*regexp.Regexp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScanableRegexp) Scan(src interface{}) error {
|
||||||
|
var source string
|
||||||
|
switch src.(type) {
|
||||||
|
case string:
|
||||||
|
source = src.(string)
|
||||||
|
default:
|
||||||
|
return errors.New("incompatible type for ScanableRegexp")
|
||||||
|
}
|
||||||
|
r, err := regexp.Compile(source)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.Regexp = r
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ScanableURL struct {
|
||||||
|
*url.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScanableURL) Scan(src interface{}) error {
|
||||||
|
var source string
|
||||||
|
switch src.(type) {
|
||||||
|
case string:
|
||||||
|
source = src.(string)
|
||||||
|
default:
|
||||||
|
return errors.New("incompatible type for ScanableURL")
|
||||||
|
}
|
||||||
|
u, err := url.Parse(source)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.URL = u
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type wire struct {
|
||||||
|
// ID
|
||||||
|
ID sql.NullInt64
|
||||||
|
// The URL to make a request to
|
||||||
|
URL ScanableURL
|
||||||
|
// The regex which will trigger this REST action
|
||||||
|
ParseRegex ScanableRegexp `db:"parse_regex"`
|
||||||
|
// The JSON field that will contain the REST return value
|
||||||
|
ReturnField string `db:"return_field"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w wire) String() string {
|
||||||
|
msg := "Wire:"
|
||||||
|
msg += fmt.Sprintf("\nURL: %s", w.URL)
|
||||||
|
msg += fmt.Sprintf("\nParsing to trigger: `%s`", w.ParseRegex)
|
||||||
|
msg += fmt.Sprintf("\nReturn field: `%s`", w.ReturnField)
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RestPlugin) getWires() ([]*wire, error) {
|
||||||
|
wires := []*wire{}
|
||||||
|
err := p.db.Select(&wires, `select * from wires`)
|
||||||
|
return wires, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RestPlugin) deleteWire(id int64) error {
|
||||||
|
_, err := p.db.Exec(`delete from wires where id=?`, id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wire) Update(db *sqlx.DB) error {
|
||||||
|
if !w.ID.Valid {
|
||||||
|
return w.Save(db)
|
||||||
|
}
|
||||||
|
id, _ := w.ID.Value()
|
||||||
|
_, err := db.Exec(`update wires set url=?, parse_regex=?, return_field=? where id=?`,
|
||||||
|
w.URL.String(), w.ParseRegex.String(), w.ReturnField, id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wire) Save(db *sqlx.DB) error {
|
||||||
|
if w.ID.Valid {
|
||||||
|
return w.Update(db)
|
||||||
|
}
|
||||||
|
res, err := db.Exec(`insert into wires (url, parse_regex, return_field) values (?, ?, ?)`,
|
||||||
|
w.URL.String(), w.ParseRegex.String(), w.ReturnField)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
id, err := res.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_ = w.ID.Scan(id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RestPlugin) listWires(r bot.Request) bool {
|
||||||
|
var msg string
|
||||||
|
wires, err := p.getWires()
|
||||||
|
if err != nil {
|
||||||
|
msg = err.Error()
|
||||||
|
goto SEND
|
||||||
|
}
|
||||||
|
msg = "Current wires:"
|
||||||
|
for _, w := range wires {
|
||||||
|
id, _ := w.ID.Value()
|
||||||
|
msg += fmt.Sprintf("\n\t%d: `%s` => %s", id, w.ParseRegex, w.URL)
|
||||||
|
}
|
||||||
|
SEND:
|
||||||
|
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RestPlugin) rmWire(r bot.Request) bool {
|
||||||
|
id, _ := strconv.ParseInt(r.Values["id"], 10, 64)
|
||||||
|
err := p.deleteWire(id)
|
||||||
|
if err != nil {
|
||||||
|
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Could not delete wire: "+err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("Deleted wire: %d", id))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RestPlugin) handleWire(r bot.Request) bool {
|
||||||
|
var w wire
|
||||||
|
var msg string
|
||||||
|
var err error
|
||||||
|
w.ParseRegex.Regexp, err = regexp.Compile(r.Values["parse"])
|
||||||
|
if err != nil {
|
||||||
|
msg = err.Error()
|
||||||
|
goto SEND
|
||||||
|
}
|
||||||
|
w.URL.URL, err = url.Parse(r.Values["url"])
|
||||||
|
if err != nil {
|
||||||
|
msg = err.Error()
|
||||||
|
goto SEND
|
||||||
|
}
|
||||||
|
w.ReturnField = r.Values["returnField"]
|
||||||
|
err = w.Save(p.db)
|
||||||
|
if err != nil {
|
||||||
|
msg = err.Error()
|
||||||
|
goto SEND
|
||||||
|
}
|
||||||
|
p.b.RegisterRegex(p, bot.Message, w.ParseRegex.Regexp, p.mkHandler(&w))
|
||||||
|
msg = fmt.Sprintf("Saved %s", w)
|
||||||
|
SEND:
|
||||||
|
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RestPlugin) mkHandler(w *wire) bot.ResponseHandler {
|
||||||
|
return func(r bot.Request) bool {
|
||||||
|
if r.Msg.User.Name == p.b.Config().GetString("nick", "") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
values := bot.RegexValues{}
|
||||||
|
for k := range r.Values {
|
||||||
|
values[k] = url.QueryEscape(r.Values[k])
|
||||||
|
}
|
||||||
|
urlStr := w.URL.String()
|
||||||
|
parse, err := template.New(urlStr).Parse(urlStr)
|
||||||
|
if p.handleErr(err, r) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
buf := bytes.Buffer{}
|
||||||
|
err = parse.Execute(&buf, values)
|
||||||
|
if p.handleErr(err, r) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
newURL, err := url.Parse(buf.String())
|
||||||
|
if p.handleErr(err, r) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
resp, err := http.Get(newURL.String())
|
||||||
|
if p.handleErr(err, r) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("Got a status %d: %s from %s",
|
||||||
|
resp.StatusCode, resp.Status, newURL))
|
||||||
|
}
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if p.handleErr(err, r) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
var returnValues interface{}
|
||||||
|
json.Unmarshal(body, &returnValues)
|
||||||
|
|
||||||
|
query, err := gojq.Parse(w.ReturnField)
|
||||||
|
if err != nil {
|
||||||
|
msg := fmt.Sprintf("Wire handler did not find return value (%s): %s => `%s`", newURL.String(), w.URL, w.ReturnField)
|
||||||
|
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg string
|
||||||
|
iter := query.Run(returnValues) // or query.RunWithContext
|
||||||
|
|
||||||
|
for {
|
||||||
|
v, ok := iter.Next()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err, ok := v.(error); ok {
|
||||||
|
return p.handleErr(err, r)
|
||||||
|
}
|
||||||
|
msg += fmt.Sprintf("%s\n", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = strings.TrimSpace(msg)
|
||||||
|
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RestPlugin) handleErr(err error, r bot.Request) bool {
|
||||||
|
if err != nil {
|
||||||
|
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Error: %s", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
Loading…
Reference in New Issue