diff --git a/funcs.go b/funcs.go index 6717914..5946d31 100644 --- a/funcs.go +++ b/funcs.go @@ -44,7 +44,7 @@ func (m *Ring) parseList() { lines := strings.Split(string(file), "\n") for _, line := range lines[:len(lines)-1] { fields := strings.Fields(line) - m.ring = append(m.ring, RingMember{handle: fields[0], url: fields[1]}) + m.ring = append(m.ring, RingMember{id: fields[0], url: fields[1]}) } fileStat, err := os.Stat(*flagMembers) if err != nil { diff --git a/handlers.go b/handlers.go index 72dffe6..6d6bd07 100644 --- a/handlers.go +++ b/handlers.go @@ -5,11 +5,13 @@ package main import ( + "fmt" "html/template" "log" "math/rand" "net/http" "net/url" + "strings" "time" ) @@ -26,16 +28,49 @@ func (m Ring) root(writer http.ResponseWriter, request *http.Request) { var table string for _, member := range m.ring { table = table + " \n" - table = table + " " + member.handle + "\n" + table = table + " " + member.id + "\n" table = table + " " + link(member.url) + "\n" table = table + " \n" } m.index.Execute(writer, template.HTML(table)) } -func (m Ring) match(query url.Values, item RingMember) bool { - host := query.Get("host") - return item.url == host +type MemberSelector struct { + id string // exact match + urlpart string +} + +func parseQuery(query url.Values) (*MemberSelector, error) { + urlpart := query.Get("urlpart") + if urlpart == "" { + urlpart = query.Get("host") + } + n_valid := 0 + if urlpart != "" { + n_valid += 1 + } + id := query.Get("id") + if id != "" { + n_valid += 1 + } + if n_valid != 1 { + return nil, fmt.Errorf("Please specify urlpart=xxx or id=xxx") + } + + return &MemberSelector{ + id: id, + urlpart: urlpart, + }, nil +} + +func (m Ring) match(selector *MemberSelector, item RingMember) bool { + if selector.id != "" { + return item.id == selector.id + } + if selector.urlpart != "" { + return strings.Contains(item.url, selector.urlpart) + } + panic("unreachable") } // Redirects the visitor to the next member, wrapping around the list if the @@ -48,9 +83,13 @@ func (m Ring) next(writer http.ResponseWriter, request *http.Request) { } success := false query := request.URL.Query() + selector, err := parseQuery(query) + if err != nil { + http.Error(writer, fmt.Errorf("invalid query: %w", err).Error(), http.StatusBadRequest) + } length := len(m.ring) for i, item := range m.ring { - if m.match(query, item) { + if m.match(selector, item) { for j := i + 1; j < length+i; j++ { dest := m.ring[j%length].url log.Println("Checking '" + dest + "'") @@ -78,9 +117,13 @@ func (m Ring) previous(writer http.ResponseWriter, request *http.Request) { m.parseList() } query := request.URL.Query() + selector, err := parseQuery(query) + if err != nil { + http.Error(writer, fmt.Errorf("invalid query: %w", err).Error(), http.StatusBadRequest) + } length := len(m.ring) for index, item := range m.ring { - if m.match(query, item) { + if m.match(selector, item) { // from here to start of list for i := index - 1; i > 0; i-- { dest := m.ring[i].url @@ -105,7 +148,7 @@ please email amolith@secluded.site and let him (me) know what's up.`, 500) return } } - http.Error(writer, "Ring member '"+query.Encode()+"' not found.", 404) + http.Error(writer, "Ring member '"+query.Encode()+"' not )found.", 404) } // Redirects the visitor to a random member diff --git a/main.go b/main.go index 5e9393a..70bee34 100644 --- a/main.go +++ b/main.go @@ -16,8 +16,8 @@ import ( ) type RingMember struct { - handle string - url string + id string + url string } type Ring struct { @@ -68,7 +68,7 @@ func main() { mux.HandleFunc("/next", m.next) mux.HandleFunc("/previous", m.previous) mux.HandleFunc("/random", m.random) - + var err error if strings.Contains(httpServer.Addr, ":") { err = httpServer.ListenAndServe()