diff --git a/main.go b/main.go index fb38297..b174c01 100644 --- a/main.go +++ b/main.go @@ -15,6 +15,7 @@ import ( "github.com/velour/catbase/plugins/giphy" "github.com/velour/catbase/plugins/gpt2" "github.com/velour/catbase/plugins/last" + "github.com/velour/catbase/plugins/rest" "github.com/velour/catbase/plugins/achievements" "github.com/velour/catbase/plugins/aoc" @@ -158,6 +159,7 @@ func main() { b.AddPlugin(achievements.New(b)) b.AddPlugin(sms.New(b)) b.AddPlugin(countdown.New(b)) + b.AddPlugin(rest.New(b)) // catches anything left, will always return true b.AddPlugin(fact.New(b)) diff --git a/plugins/rest/rest.go b/plugins/rest/rest.go new file mode 100644 index 0000000..c3dd04f --- /dev/null +++ b/plugins/rest/rest.go @@ -0,0 +1,284 @@ +package rest + +import ( + "bytes" + "database/sql" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" + "text/template" + + "github.com/jmoiron/sqlx" + "github.com/velour/catbase/bot" +) + +type RestPlugin struct { + b bot.Bot + db *sqlx.DB + + handlers bot.HandlerTable +} + +func New(b bot.Bot) *RestPlugin { + p := &RestPlugin{ + b: b, + db: b.DB(), + handlers: bot.HandlerTable{}, + } + p.setupDB() + p.register() + return p +} + +func (p *RestPlugin) setupDB() { + tx := p.db.MustBegin() + tx.MustExec(` + create table if not exists wires ( + id integer primary key autoincrement, + url text not null, + parse_regex text not null, + return_field text not null + )`) + if err := tx.Commit(); err != nil { + panic(err) + } +} + +func (p *RestPlugin) register() { + p.handlers = bot.HandlerTable{ + bot.HandlerSpec{Kind: bot.Message, IsCmd: true, + Regex: regexp.MustCompile("(?i)^list wires$"), + HelpText: "Lists all REST functions", + Handler: p.listWires}, + bot.HandlerSpec{Kind: bot.Message, IsCmd: true, + Regex: regexp.MustCompile(`(?i)^rm wire (?P\d+)$`), + HelpText: "Removes a wire by ID (use list to view)", + Handler: p.rmWire}, + bot.HandlerSpec{Kind: bot.Message, IsCmd: true, + Regex: regexp.MustCompile("(?i)^wire `(?P[^`]+)` to (?P\\S+) `(?P[^`]+)`$"), + HelpText: "Registers a new REST function", + Handler: p.handleWire}, + } + p.b.RegisterTable(p, p.handlers) + wires, err := p.getWires() + if err != nil { + panic(err) + } + for _, w := range wires { + p.b.RegisterRegex(p, bot.Message, w.ParseRegex.Regexp, p.mkHandler(w)) + } +} + +type ScanableRegexp struct { + *regexp.Regexp +} + +func (s *ScanableRegexp) Scan(src interface{}) error { + var source string + switch src.(type) { + case string: + source = src.(string) + default: + return errors.New("incompatible type for ScanableRegexp") + } + r, err := regexp.Compile(source) + if err != nil { + return err + } + s.Regexp = r + return nil +} + +type ScanableURL struct { + *url.URL +} + +func (s *ScanableURL) Scan(src interface{}) error { + var source string + switch src.(type) { + case string: + source = src.(string) + default: + return errors.New("incompatible type for ScanableURL") + } + u, err := url.Parse(source) + if err != nil { + return err + } + s.URL = u + return nil +} + +type wire struct { + // ID + ID sql.NullInt64 + // The URL to make a request to + URL ScanableURL + // The regex which will trigger this REST action + ParseRegex ScanableRegexp `db:"parse_regex"` + // The JSON field that will contain the REST return value + ReturnField string `db:"return_field"` +} + +func (w wire) String() string { + msg := "Wire:" + msg += fmt.Sprintf("\nURL: %s", w.URL) + msg += fmt.Sprintf("\nParsing to trigger: `%s`", w.ParseRegex) + msg += fmt.Sprintf("\nReturn field: `%s`", w.ReturnField) + return msg +} + +func (p *RestPlugin) getWires() ([]*wire, error) { + wires := []*wire{} + err := p.db.Select(&wires, `select * from wires`) + return wires, err +} + +func (p *RestPlugin) deleteWire(id int64) error { + _, err := p.db.Exec(`delete from wires where id=?`, id) + return err +} + +func (w *wire) Update(db *sqlx.DB) error { + if !w.ID.Valid { + return w.Save(db) + } + id, _ := w.ID.Value() + _, err := db.Exec(`update wires set url=?, parse_regex=?, return_field=? where id=?`, + w.URL.String(), w.ParseRegex.String(), w.ReturnField, id) + return err +} + +func (w *wire) Save(db *sqlx.DB) error { + if w.ID.Valid { + return w.Update(db) + } + res, err := db.Exec(`insert into wires (url, parse_regex, return_field) values (?, ?, ?)`, + w.URL.String(), w.ParseRegex.String(), w.ReturnField) + if err != nil { + return err + } + id, err := res.LastInsertId() + if err != nil { + return err + } + _ = w.ID.Scan(id) + return nil +} + +func (p *RestPlugin) listWires(r bot.Request) bool { + var msg string + wires, err := p.getWires() + if err != nil { + msg = err.Error() + goto SEND + } + msg = "Current wires:" + for _, w := range wires { + id, _ := w.ID.Value() + msg += fmt.Sprintf("\n\t%d: `%s` => %s", id, w.ParseRegex, w.URL) + } +SEND: + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg) + return true +} + +func (p *RestPlugin) rmWire(r bot.Request) bool { + id, _ := strconv.ParseInt(r.Values["id"], 10, 64) + err := p.deleteWire(id) + if err != nil { + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Could not delete wire: "+err.Error()) + return true + } + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("Deleted wire: %d", id)) + return true +} + +func (p *RestPlugin) handleWire(r bot.Request) bool { + var w wire + var msg string + var err error + w.ParseRegex.Regexp, err = regexp.Compile(r.Values["parse"]) + if err != nil { + msg = err.Error() + goto SEND + } + w.URL.URL, err = url.Parse(r.Values["url"]) + if err != nil { + msg = err.Error() + goto SEND + } + w.ReturnField = r.Values["returnField"] + err = w.Save(p.db) + if err != nil { + msg = err.Error() + goto SEND + } + p.b.RegisterRegex(p, bot.Message, w.ParseRegex.Regexp, p.mkHandler(&w)) + msg = fmt.Sprintf("Saved %s", w) +SEND: + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg) + return true +} + +func (p *RestPlugin) mkHandler(w *wire) bot.ResponseHandler { + return func(r bot.Request) bool { + if r.Msg.User.Name == p.b.Config().GetString("nick", "") { + return false + } + values := bot.RegexValues{} + for k := range r.Values { + values[k] = url.QueryEscape(r.Values[k]) + } + urlStr := w.URL.String() + parse, err := template.New(urlStr).Parse(urlStr) + if p.handleErr(err, r) { + return true + } + buf := bytes.Buffer{} + err = parse.Execute(&buf, values) + if p.handleErr(err, r) { + return true + } + newURL, err := url.Parse(buf.String()) + if p.handleErr(err, r) { + return true + } + resp, err := http.Get(newURL.String()) + if p.handleErr(err, r) { + return true + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("Got a status %d: %s from %s", + resp.StatusCode, resp.Status, newURL)) + } + body, err := ioutil.ReadAll(resp.Body) + if p.handleErr(err, r) { + return true + } + returnValues := map[string]string{} + json.Unmarshal(body, &returnValues) + var msg string + if text, ok := returnValues["text"]; ok { + msg = strings.TrimSpace(text) + } else { + msg = fmt.Sprintf("Wire handler did not find return value (%s): %s => `%s`", newURL.String(), w.URL, w.ReturnField) + } + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg) + return true + } +} + +func (p *RestPlugin) handleErr(err error, r bot.Request) bool { + if err != nil { + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Error: %s", err) + return true + } + return false +}