diff --git a/client/main.go b/client/main.go index d490d56..d37fd2d 100644 --- a/client/main.go +++ b/client/main.go @@ -19,27 +19,26 @@ var privKey ed25519.PrivateKey var id string var servers []string -// Post new server list to DHT -func dhtPost(s string) { +// Sign and post to a server +func signPost(s string, b []byte) { buf := new(bytes.Buffer) err := binary.Write(buf, binary.LittleEndian, time.Now().Unix()) if err != nil { panic(err) } - _, err = buf.WriteString(strings.Join(servers, "\n")) + _, err = buf.Write(b) if err != nil { panic(err) } - var message []byte - _, err = buf.Read(message) + _, err = buf.Write(ed25519.Sign(privKey, buf.Bytes())) if err != nil { panic(err) } - _, err = buf.Write(ed25519.Sign(privKey, message)) + resp, err := http.Post(s, "application/octet-stream", buf) if err != nil { panic(err) } - fmt.Println(http.Post(s+"/dht/"+id, "application/octet-stream", buf)) + fmt.Println(resp) } func main() { @@ -92,7 +91,7 @@ func main() { if flag.Arg(0) == "add" { // Add server servers = append(servers, flag.Arg(1)) - dhtPost(flag.Arg(1)) + signPost(flag.Arg(1)+"/dht/"+id, []byte(strings.Join(servers, "\n"))) http.Get(flag.Arg(1) + "/user/" + id) if servers[0] != flag.Arg(1) { http.Get(servers[0] + "/user/" + id) @@ -110,10 +109,10 @@ func main() { } } if len(servers) > 0 { - dhtPost(servers[0]) + signPost(servers[0]+"/dht/"+id, []byte(strings.Join(servers, "\n"))) http.Get(servers[0] + "/user/" + id) } else { - dhtPost(flag.Arg(1)) + signPost(flag.Arg(1)+"/dht/"+id, []byte(strings.Join(servers, "\n"))) } http.Get(flag.Arg(1) + "/user/" + id) err := os.WriteFile("servers", []byte(strings.Join(servers, "\n")), 0644) @@ -126,7 +125,7 @@ func main() { if i > 0 && servers[i] == flag.Arg(1) { servers[i] = servers[0] servers[0] = flag.Arg(i) - dhtPost(flag.Arg(i)) + signPost(flag.Arg(1)+"/dht/"+id, []byte(strings.Join(servers, "\n"))) http.Get(servers[0] + "/user/" + id) http.Get(servers[i] + "/user/" + id) break @@ -139,16 +138,24 @@ func main() { } else if flag.Arg(0) == "get" { user := flag.Arg(1) filename := flag.Arg(2) - response, _ := http.Get(servers[0] + "/storage/" + user + "/" + filename) - responseBodyBytes, _ := io.ReadAll(response.Body) - os.WriteFile(user + "/" + filename, responseBodyBytes, 0644) + response, err := http.Get(servers[0] + "/storage/" + user + "/" + filename) + if err != nil { + panic(err) + } + b, err := io.ReadAll(response.Body) + if err != nil { + panic(err) + } + err = os.WriteFile(filename, b, 0644) + if err != nil { + panic(err) + } } else if flag.Arg(0) == "post" { filename := flag.Arg(1) - file, _ := os.Open(filename) - http.Post( - servers[0] + "/storage/" + id + "/" + filename, - "application/octet-stream", - file, - ) + file, err := os.ReadFile(filename) + if err != nil { + panic(err) + } + signPost(servers[0] + "/storage/" + id + "/" + filename, file) } } diff --git a/server/dht.go b/server/dht.go index 7327676..40bf4e6 100644 --- a/server/dht.go +++ b/server/dht.go @@ -234,7 +234,14 @@ func dhtHandler(w http.ResponseWriter, r *http.Request) { w.Write(val) } else if r.Method == "POST" { val, err := io.ReadAll(r.Body) - if err != nil || dhtPost(key, r.Form.Get("phase"), r.Form.Has("direct"), val) != nil { + if err != nil { + log.Print(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + err = dhtPost(key, r.Form.Get("phase"), r.Form.Has("direct"), val) + if err != nil { + log.Print(err) w.WriteHeader(http.StatusInternalServerError) return } diff --git a/server/main.go b/server/main.go index 4cf49ed..03a62a4 100644 --- a/server/main.go +++ b/server/main.go @@ -40,7 +40,7 @@ func main() { peerHashes = append(peerHashes, sha256sum(me)) hashToDomain = map[string]string{peerHashes[0]: me} kvstore = make(map[string][]byte) - users = make(map[string]user) + users = make(map[string]*user) // Load user data from disk os.Mkdir(dataDir, 0755) @@ -55,9 +55,9 @@ func main() { log.Fatal(err) continue } - var user user + var user *user dec := gob.NewDecoder(reader) - dec.Decode(&user) + dec.Decode(user) users[id] = user } @@ -69,7 +69,6 @@ func main() { go cleanKVStore() go redistributeKeys() - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "Hello! This is a Kela server.") }) diff --git a/server/storage.go b/server/storage.go index 21950fa..e845dd5 100644 --- a/server/storage.go +++ b/server/storage.go @@ -4,6 +4,7 @@ import ( "crypto/ed25519" "fmt" "io" + "log" "net/http" "os" "strconv" @@ -13,6 +14,7 @@ import ( // Replicate a user's log to another server func replicate(id, s string) { + log.Printf("Starting replication for %s %s", id, s) for true { mu.Lock() // Make sure that this server is still the primary for this user @@ -42,7 +44,7 @@ func replicate(id, s string) { op := user.log[idx] mu.Unlock() file, _ := os.Open(op) - resp, err := http.Post(s + "/storage/" + id + "/" + op + "?idx=" + fmt.Sprint(idx), "application/octet-stream", file) + resp, err := http.Post(s+"/storage/"+id+"/"+op+"?idx="+fmt.Sprint(idx), "application/octet-stream", file) if err != nil { time.Sleep(50 * time.Millisecond) continue @@ -59,18 +61,18 @@ func replicate(id, s string) { } } - // Handle storage requests func storageHandler(w http.ResponseWriter, r *http.Request) { pathSplit := strings.Split(r.URL.Path, "/") - id := pathSplit[1] - filename := pathSplit[2] + id := pathSplit[2] + filename := pathSplit[3] r.ParseForm() if r.Method == "GET" { if r.Form.Has("direct") { // Directly read and respond with file file, err := os.ReadFile(dataDir + "/" + id + "/" + filename) if err != nil { + log.Print(err) w.WriteHeader(http.StatusNotFound) return } @@ -78,7 +80,9 @@ func storageHandler(w http.ResponseWriter, r *http.Request) { return } val := dhtGet(id, false) - if verify(id, val) != nil { + err := verify(id, val) + if err != nil { + verify(id, val) w.WriteHeader(http.StatusNotFound) return } @@ -89,6 +93,7 @@ func storageHandler(w http.ResponseWriter, r *http.Request) { if servers[0] == me { file, err := os.ReadFile(dataDir + "/" + id + "/" + filename) if err != nil { + log.Print(err) w.WriteHeader(http.StatusNotFound) return } @@ -119,10 +124,13 @@ func storageHandler(w http.ResponseWriter, r *http.Request) { b, err := io.ReadAll(r.Body) if err != nil { + log.Print(err) w.WriteHeader(http.StatusInternalServerError) return } - if verify(id, b) != nil { + err = verify(id, b) + if err != nil { + log.Print(err) w.WriteHeader(http.StatusUnauthorized) return } @@ -149,16 +157,19 @@ func storageHandler(w http.ResponseWriter, r *http.Request) { // Fetch older version of file resp, err := http.Get(user.servers[0] + "/storage/" + id + "/" + op) if err != nil { + log.Print(err) w.WriteHeader(http.StatusInternalServerError) return } b, err := io.ReadAll(resp.Body) if err != nil { + log.Print(err) w.WriteHeader(http.StatusInternalServerError) return } - err = os.WriteFile(dataDir + "/" + id + "/" + op, b, 0644) + err = os.WriteFile(dataDir+"/"+id+"/"+op, b, 0644) if err != nil { + log.Print(err) w.WriteHeader(http.StatusInternalServerError) return } @@ -166,7 +177,7 @@ func storageHandler(w http.ResponseWriter, r *http.Request) { } } - err = os.WriteFile(dataDir + "/" + id + "/" + filename, b, 0644) + err = os.WriteFile(dataDir+"/"+id+"/"+filename, b, 0644) if err != nil { w.WriteHeader(http.StatusInternalServerError) return diff --git a/server/user.go b/server/user.go index 6b55eb4..cd31ca1 100644 --- a/server/user.go +++ b/server/user.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/gob" "errors" + "log" "net/http" "os" "strings" @@ -18,7 +19,7 @@ type user struct { nextIndex map[string]int } -var users map[string]user +var users map[string]*user // Verify that a body was signed by this ID func verify(id string, body []byte) error { @@ -59,17 +60,20 @@ func reconfigure(id string, dhtVal []byte) { user.dhtVal = dhtVal servers := strings.Split(string(dhtVal[8:len(dhtVal)-ed25519.SignatureSize]), "\n") + log.Printf("Reconfiguring %s %s", id, servers) + user.servers = servers if servers[0] == me { if user.nextIndex == nil { user.nextIndex = make(map[string]int) } - for _, server := range servers { - if _, ok := user.nextIndex[server]; !ok { + for i, server := range servers { + if _, ok := user.nextIndex[server]; !ok && i > 0 { user.nextIndex[server] = len(user.log) go replicate(id, server) } } } + inServers := false for _, server := range servers { if server == me { @@ -95,7 +99,7 @@ func userHandler(w http.ResponseWriter, r *http.Request) { mu.Lock() if _, ok := users[id]; !ok { // Add user - users[id] = user{dhtVal: val} + users[id] = &user{dhtVal: val} os.Mkdir(dataDir+"/"+id, 0755) persist(id) }