diff --git a/entry/entry.go b/entry/entry.go index 0b39b9b..5c18704 100644 --- a/entry/entry.go +++ b/entry/entry.go @@ -6,9 +6,10 @@ import ( "strings" "time" - "code.chrissexton.org/cws/cabinet/db" "github.com/jmoiron/sqlx" "github.com/rs/zerolog/log" + + "code.chrissexton.org/cws/cabinet/db" ) type Entry struct { @@ -94,27 +95,49 @@ func GetByID(db *db.Database, id int64) (Entry, error) { return e, e.populateTags() } -func Search(db *db.Database, query string) ([]*Entry, error) { +func SearchByTag(db *db.Database, query string, tags []string) ([]*Entry, error) { entries := []*Entry{} - log.Debug().Str("query", query).Msg("searching") - if query != "" { - q := `select * from entries where content like ? order by updated desc` - err := db.Select(&entries, q, "%"+query+"%") + query = fmt.Sprintf("%%%s%%", query) + log.Debug().Str("tag query", query).Int("len(tags)", len(tags)).Msg("searching") + + if len(tags) > 0 { + q := `select e.* + from entries e + inner join tags t + on e.id=t.entry_id + where + t.name in (?) + AND content like ? + order by updated desc` + + q, args, err := sqlx.In(q, tags, query) + if err != nil { + return nil, err + } + + err = db.Select(&entries, q, args...) if err != nil { return nil, err } } else { - q := `select * from entries order by updated desc` - err := db.Select(&entries, q) + q := `select e.* + from entries e + where + content like ? + order by updated desc` + + err := db.Select(&entries, q, query) if err != nil { return nil, err } } + for _, e := range entries { e.db = db e.Title = e.GenerateTitle() e.populateTags() } + return entries, nil } @@ -277,3 +300,12 @@ func (e *Entry) Create() error { tx.Commit() return nil } + +func (e *Entry) HasTag(tag string) bool { + for _, t := range e.Tags { + if strings.ToLower(tag) == strings.ToLower(t) { + return true + } + } + return false +} diff --git a/frontend/src/App.vue b/frontend/src/App.vue index 692a2e9..dc4db66 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -33,6 +33,15 @@ components: { Error }, + created() { + if (!this.$store.state.key) { + let key = this.$cookies.get('key') + if (key) { + this.$store.commit('setKey', key) + return + } + } + }, methods: { newFile: function() { this.$store.dispatch('newFile') diff --git a/frontend/src/views/Console.vue b/frontend/src/views/Console.vue index 2263ff2..c0532ba 100644 --- a/frontend/src/views/Console.vue +++ b/frontend/src/views/Console.vue @@ -80,11 +80,6 @@ export default { // because it has not been created yet when this guard is called! next(vm => { if (!vm.$store.state.key) { - let key = vm.$cookies.get('key') - if (key) { - vm.$store.commit('setKey', key) - return - } vm.$router.push({name: "login", params: {returnTo: vm.$route.path}}) } }) diff --git a/web/entry.go b/web/entry.go index 70ddf84..85f58f4 100644 --- a/web/entry.go +++ b/web/entry.go @@ -111,12 +111,17 @@ func (web *Web) newEntry(w http.ResponseWriter, r *http.Request) { func (web *Web) allEntries(w http.ResponseWriter, r *http.Request) { query := "" + tags := []string{} + if !web.AuthCheck(r) { + tags = append(tags, "public") + } items, ok := r.URL.Query()["query"] if ok { query = items[0] } - entries, err := entry.Search(web.db, query) + entries, err := entry.SearchByTag(web.db, query, tags) if err != nil { + log.Error().Msgf("Error querying: %w", err) w.WriteHeader(500) fmt.Fprint(w, err) return @@ -143,6 +148,12 @@ func (web *Web) getEntry(w http.ResponseWriter, r *http.Request) { return } + if !web.AuthCheck(r) && !entry.HasTag("public") { + w.WriteHeader(401) + fmt.Fprint(w, "not allowed") + return + } + resp, err := json.Marshal(entry) if err != nil { w.WriteHeader(500) diff --git a/web/routes.go b/web/routes.go index 44635bc..0cf18c6 100644 --- a/web/routes.go +++ b/web/routes.go @@ -38,27 +38,42 @@ func New(addr string, db *db.Database, static http.FileSystem) *Web { } type AuthMiddleware struct { - db *db.Database + web *Web + db *db.Database +} + +func NewAuthMiddleware(web *Web) AuthMiddleware { + return AuthMiddleware{ + web: web, + db: web.db, + } } func (aw *AuthMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - key := r.Header.Get("X-Auth-Key") - u, err := auth.GetByKey(aw.db, key) - if key == "" || err != nil { - w.WriteHeader(401) - fmt.Fprint(w, "invalid login") + if aw.web.AuthCheck(r) { + next.ServeHTTP(w, r) return } - log.Debug().Msgf("This shit is authed to user %s!", u.Name) - next.ServeHTTP(w, r) + w.WriteHeader(401) + fmt.Fprint(w, "invalid login") }) } +func (web *Web) AuthCheck(r *http.Request) bool { + key := r.Header.Get("X-Auth-Key") + u, err := auth.GetByKey(web.db, key) + if key == "" || err != nil { + return false + } + log.Debug().Msgf("This shit is authed to user %s!", u.Name) + return true +} + func (web *Web) routeSetup() http.Handler { r := mux.NewRouter() api := r.PathPrefix("/v1/").Subrouter() - auth := AuthMiddleware{web.db} + auth := NewAuthMiddleware(web) authedApi := r.PathPrefix("/v1/").Subrouter() authedApi.Use(auth.Middleware)