catbase/plugins/rest/rest.go

394 lines
9.5 KiB
Go

package rest
import (
"bytes"
"crypto/sha512"
"encoding/json"
"errors"
"fmt"
bh "github.com/timshannon/bolthold"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"text/template"
"github.com/itchyny/gojq"
"github.com/rs/zerolog/log"
"github.com/velour/catbase/bot"
)
type RestPlugin struct {
b bot.Bot
store *bh.Store
handlers bot.HandlerTable
}
type postProcessor func(interface{}) string
var postProcessors = map[string]postProcessor{
"gpt2": func(input interface{}) string {
values := input.(map[string]interface{})
text := values["text"].(string)
lastStop := strings.LastIndexAny(text, ".!?")
if lastStop > 0 {
text = text[:lastStop+1]
}
eot := strings.LastIndex(text, "<|endoftext|>")
if eot > 0 {
text = text[:eot]
}
return text
},
}
func New(b bot.Bot) *RestPlugin {
p := &RestPlugin{
b: b,
store: b.Store(),
handlers: bot.HandlerTable{},
}
p.register()
return p
}
func (p *RestPlugin) register() {
p.handlers = bot.HandlerTable{
bot.HandlerSpec{Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile("(?i)^list wires$"),
HelpText: "Lists all REST functions",
Handler: p.listWires},
bot.HandlerSpec{Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`(?i)^rm Wire (?P<id>\d+)$`),
HelpText: "Removes a Wire by ID (use list to view)",
Handler: p.rmWire},
bot.HandlerSpec{Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile("(?i)^testwire `(?P<parse>[^`]+)` to (?P<url>\\S+) `(?P<returnField>[^`]+)` => (?P<text>.*)$"),
HelpText: "Tests a new REST function",
Handler: p.handleTestWire},
bot.HandlerSpec{Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile("(?i)^testwirepost `(?P<parse>[^`]+)` to (?P<url>\\S+) with body `(?P<body>[^`]+)` `(?P<returnField>[^`]+)` => (?P<text>.*)$"),
HelpText: "Tests a new REST function",
Handler: p.handleTestWire},
bot.HandlerSpec{Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile("(?i)^Wire `(?P<parse>[^`]+)` to (?P<url>\\S+) `(?P<returnField>[^`]+)`$"),
HelpText: "Registers a new REST function",
Handler: p.handleWire},
bot.HandlerSpec{Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile("(?i)^wirepost `(?P<parse>[^`]+)` to (?P<url>\\S+) with body `(?P<body>[^`]+)` `(?P<returnField>[^`]+)`$"),
HelpText: "Registers a new REST function",
Handler: p.handleWire},
}
p.b.RegisterTable(p, p.handlers)
wires, err := p.getWires()
if err != nil {
panic(err)
}
for _, w := range wires {
p.b.RegisterRegex(p, bot.Message, w.regex, p.mkHandler(w))
}
p.b.RegisterRegex(p, bot.Help, regexp.MustCompile(`.*`), func(r bot.Request) bool {
out := "Rest commands:\n"
for _, h := range p.handlers {
out += fmt.Sprintf("```%s```\t%s", h.Regex.String(), h.HelpText)
}
out = strings.TrimSpace(out)
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, out)
return true
})
}
type ScanableRegexp struct {
*regexp.Regexp
}
func (s *ScanableRegexp) Scan(src interface{}) error {
var source string
switch src.(type) {
case string:
source = src.(string)
default:
return errors.New("incompatible type for ScanableRegexp")
}
r, err := regexp.Compile(source)
if err != nil {
return err
}
s.Regexp = r
return nil
}
type ScanableURL struct {
*url.URL
}
func (s *ScanableURL) Scan(src interface{}) error {
var source string
switch src.(type) {
case string:
source = src.(string)
default:
return errors.New("incompatible type for ScanableURL")
}
u, err := url.Parse(source)
if err != nil {
return err
}
s.URL = u
return nil
}
type Wire struct {
// ID
ID int64 `boltholdIndex:"ID"`
// The URL to make a request to
URL ScanableURL
// The regex which will trigger this REST action
ParseRegex string `db:"parse_regex"`
regex *regexp.Regexp
// The JSON field that will contain the REST return value
ReturnField string `db:"return_field"`
// Body parse for POST
Body string `db:"body"`
}
func (w Wire) String() string {
msg := "Wire:"
msg += fmt.Sprintf("\nURL: %s", w.URL)
msg += fmt.Sprintf("\nParsing to trigger: `%s`", w.ParseRegex)
msg += fmt.Sprintf("\nReturn field: `%s`", w.ReturnField)
return msg
}
func (p *RestPlugin) getWires() ([]Wire, error) {
wires := []Wire{}
err := p.store.Find(&wires, &bh.Query{})
return wires, err
}
func (p *RestPlugin) deleteWire(id int64) error {
err := p.store.Delete(id, Wire{})
return err
}
func (w *Wire) Update(store *bh.Store) error {
if w.ID == -1 {
return w.Save(store)
}
err := store.Update(w.ID, w)
return err
}
func (w *Wire) Save(store *bh.Store) error {
if w.ID > -1 {
return w.Update(store)
}
return store.Insert(bh.NextSequence(), &w)
}
func (p *RestPlugin) listWires(r bot.Request) bool {
var msg string
wires, err := p.getWires()
if err != nil {
msg = err.Error()
goto SEND
}
msg = "Current wires:"
for _, w := range wires {
msg += fmt.Sprintf("\n\t%d: `%s` => %s", w.ID, w.ParseRegex, w.URL)
}
SEND:
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg)
return true
}
func (p *RestPlugin) rmWire(r bot.Request) bool {
id, _ := strconv.ParseInt(r.Values["id"], 10, 64)
err := p.deleteWire(id)
if err != nil {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Could not delete Wire: "+err.Error())
return true
}
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("Deleted Wire: %d", id))
return true
}
func (p *RestPlugin) mkWire(r bot.Request) (Wire, error) {
var w Wire
var err error
// if it's not a POST, this will be empty (and indicate we will use GET)
w.Body = r.Values["body"]
w.ParseRegex = r.Values["parse"]
w.regex, err = regexp.Compile(r.Values["parse"])
if err != nil {
return w, err
}
w.URL.URL, err = url.Parse(r.Values["url"])
if err != nil {
return w, err
}
w.ReturnField = r.Values["returnField"]
return w, nil
}
func (p *RestPlugin) handleWire(r bot.Request) bool {
var w Wire
var msg string
var err error
w, err = p.mkWire(r)
if err != nil {
msg = err.Error()
goto SEND
}
err = w.Save(p.store)
if err != nil {
msg = err.Error()
goto SEND
}
p.b.RegisterRegex(p, bot.Message, w.regex, p.mkHandler(w))
msg = fmt.Sprintf("Saved %s", w)
SEND:
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg)
return true
}
func (p *RestPlugin) handleTestWire(r bot.Request) bool {
text := r.Values["text"]
w, err := p.mkWire(r)
if err != nil {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, err)
return true
}
h := p.mkHandler(w)
r.Values = bot.ParseValues(w.regex, text)
return h(r)
}
func (p *RestPlugin) mkHandler(w Wire) bot.ResponseHandler {
return func(r bot.Request) bool {
if r.Msg.User.Name == p.b.Config().GetString("nick", "") {
return false
}
values := bot.RegexValues{}
for _, s := range p.b.Config().GetAllSecrets() {
values[s.Key] = s.Value
}
log.Debug().Interface("values", values).Msgf("secrets")
for k := range r.Values {
values[k] = url.QueryEscape(r.Values[k])
}
log.Debug().Interface("values", values).Msgf("r.Values")
urlStr := w.URL.String()
parse, err := template.New(urlStr).Parse(urlStr)
if p.handleErr(err, r) {
return true
}
buf := bytes.Buffer{}
err = parse.Execute(&buf, values)
if p.handleErr(err, r) {
return true
}
newURL, err := url.Parse(buf.String())
if p.handleErr(err, r) {
return true
}
bodyTpl := w.Body
parse, err = template.New(bodyTpl).Parse(bodyTpl)
if p.handleErr(err, r) {
return true
}
bodyBuf := bytes.Buffer{}
err = parse.Execute(&bodyBuf, values)
if p.handleErr(err, r) {
return true
}
// Note, we will expect this is formatted good enough to send over the Wire (?)
log.Debug().
Interface("values", values).
Str("URL", newURL.String()).
Str("body", bodyBuf.String()).
Msg("Querying URL with values")
if p.handleErr(err, r) {
return true
}
client := &http.Client{}
method := "GET"
if bodyBuf.Len() > 0 {
method = "POST"
}
req, err := http.NewRequest(method, newURL.String(), bytes.NewReader(bodyBuf.Bytes()))
if p.handleErr(err, r) {
return true
}
key := fmt.Sprintf("rest.%x.origin", sha512.Sum512([]byte(urlStr)))
if origin := p.b.Config().Get(key, ""); origin != "" {
req.Header.Add("Origin", origin)
}
resp, err := client.Do(req)
if p.handleErr(err, r) {
return true
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("Got a status %d: %s",
resp.StatusCode, resp.Status))
}
body, err := ioutil.ReadAll(resp.Body)
if p.handleErr(err, r) {
return true
}
log.Debug().Str("body", string(body)).Msg("got a body back")
var returnValues interface{}
json.Unmarshal(body, &returnValues)
var msg string
if pp, ok := postProcessors[w.ReturnField]; ok {
msg = pp(returnValues)
} else {
query, err := gojq.Parse(w.ReturnField)
if err != nil {
msg := fmt.Sprintf("Wire handler did not find return value: %s => `%s`", w.URL, w.ReturnField)
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg)
return true
}
iter := query.Run(returnValues) // or query.RunWithContext
for {
v, ok := iter.Next()
if !ok {
break
}
if err, ok := v.(error); ok {
return p.handleErr(err, r)
}
msg += fmt.Sprintf("%s\n", v)
}
}
msg = strings.TrimSpace(msg)
if msg == "" {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "we got an empty body after all we've done")
return true
}
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg)
return true
}
}
func (p *RestPlugin) handleErr(err error, r bot.Request) bool {
if err != nil {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("Error: %s", err))
return true
}
return false
}