diff --git a/funcs.go b/funcs.go
index 6717914..5946d31 100644
--- a/funcs.go
+++ b/funcs.go
@@ -44,7 +44,7 @@ func (m *Ring) parseList() {
lines := strings.Split(string(file), "\n")
for _, line := range lines[:len(lines)-1] {
fields := strings.Fields(line)
- m.ring = append(m.ring, RingMember{handle: fields[0], url: fields[1]})
+ m.ring = append(m.ring, RingMember{id: fields[0], url: fields[1]})
}
fileStat, err := os.Stat(*flagMembers)
if err != nil {
diff --git a/handlers.go b/handlers.go
index 72dffe6..6d6bd07 100644
--- a/handlers.go
+++ b/handlers.go
@@ -5,11 +5,13 @@
package main
import (
+ "fmt"
"html/template"
"log"
"math/rand"
"net/http"
"net/url"
+ "strings"
"time"
)
@@ -26,16 +28,49 @@ func (m Ring) root(writer http.ResponseWriter, request *http.Request) {
var table string
for _, member := range m.ring {
table = table + "
\n"
- table = table + " " + member.handle + " | \n"
+ table = table + " " + member.id + " | \n"
table = table + " " + link(member.url) + " | \n"
table = table + "
\n"
}
m.index.Execute(writer, template.HTML(table))
}
-func (m Ring) match(query url.Values, item RingMember) bool {
- host := query.Get("host")
- return item.url == host
+type MemberSelector struct {
+ id string // exact match
+ urlpart string
+}
+
+func parseQuery(query url.Values) (*MemberSelector, error) {
+ urlpart := query.Get("urlpart")
+ if urlpart == "" {
+ urlpart = query.Get("host")
+ }
+ n_valid := 0
+ if urlpart != "" {
+ n_valid += 1
+ }
+ id := query.Get("id")
+ if id != "" {
+ n_valid += 1
+ }
+ if n_valid != 1 {
+ return nil, fmt.Errorf("Please specify urlpart=xxx or id=xxx")
+ }
+
+ return &MemberSelector{
+ id: id,
+ urlpart: urlpart,
+ }, nil
+}
+
+func (m Ring) match(selector *MemberSelector, item RingMember) bool {
+ if selector.id != "" {
+ return item.id == selector.id
+ }
+ if selector.urlpart != "" {
+ return strings.Contains(item.url, selector.urlpart)
+ }
+ panic("unreachable")
}
// Redirects the visitor to the next member, wrapping around the list if the
@@ -48,9 +83,13 @@ func (m Ring) next(writer http.ResponseWriter, request *http.Request) {
}
success := false
query := request.URL.Query()
+ selector, err := parseQuery(query)
+ if err != nil {
+ http.Error(writer, fmt.Errorf("invalid query: %w", err).Error(), http.StatusBadRequest)
+ }
length := len(m.ring)
for i, item := range m.ring {
- if m.match(query, item) {
+ if m.match(selector, item) {
for j := i + 1; j < length+i; j++ {
dest := m.ring[j%length].url
log.Println("Checking '" + dest + "'")
@@ -78,9 +117,13 @@ func (m Ring) previous(writer http.ResponseWriter, request *http.Request) {
m.parseList()
}
query := request.URL.Query()
+ selector, err := parseQuery(query)
+ if err != nil {
+ http.Error(writer, fmt.Errorf("invalid query: %w", err).Error(), http.StatusBadRequest)
+ }
length := len(m.ring)
for index, item := range m.ring {
- if m.match(query, item) {
+ if m.match(selector, item) {
// from here to start of list
for i := index - 1; i > 0; i-- {
dest := m.ring[i].url
@@ -105,7 +148,7 @@ please email amolith@secluded.site and let him (me) know what's up.`, 500)
return
}
}
- http.Error(writer, "Ring member '"+query.Encode()+"' not found.", 404)
+ http.Error(writer, "Ring member '"+query.Encode()+"' not )found.", 404)
}
// Redirects the visitor to a random member
diff --git a/main.go b/main.go
index 5e9393a..70bee34 100644
--- a/main.go
+++ b/main.go
@@ -16,8 +16,8 @@ import (
)
type RingMember struct {
- handle string
- url string
+ id string
+ url string
}
type Ring struct {
@@ -68,7 +68,7 @@ func main() {
mux.HandleFunc("/next", m.next)
mux.HandleFunc("/previous", m.previous)
mux.HandleFunc("/random", m.random)
-
+
var err error
if strings.Contains(httpServer.Addr, ":") {
err = httpServer.ListenAndServe()