diff --git a/client/main.go b/client/main.go index 22e5230..5397384 100644 --- a/client/main.go +++ b/client/main.go @@ -1,14 +1,30 @@ package main import ( + "bytes" "crypto/ed25519" "encoding/base64" "flag" "fmt" + "net/http" "os" "strings" + "time" ) +var pubKey ed25519.PublicKey +var privKey ed25519.PrivateKey +var id string +var servers []string + +// Post new server list to DHT +func dhtPost(s string) { + message := []byte(fmt.Sprint(time.Now().Unix()) + "\n" + strings.Join(servers, "\n")) + fmt.Print(message) + message = append(message, ed25519.Sign(privKey, message)...) + http.Post(s+"/dht/"+id, "application/octet-stream", bytes.NewBuffer(message)) +} + func main() { flag.Parse() @@ -18,11 +34,11 @@ func main() { if err != nil { panic(err) } - err = os.WriteFile("pubkey", pubKey, 644) + err = os.WriteFile("pubkey", pubKey, 0644) if err != nil { panic(err) } - err = os.WriteFile("privkey", privKey, 600) + err = os.WriteFile("privkey", privKey, 0600) if err != nil { panic(err) } @@ -31,7 +47,7 @@ func main() { if err != nil { panic(err) } - fmt.Printf("Success! Your user ID: %s", base64.RawURLEncoding.EncodeToString(pubKey)) + fmt.Printf("Success! Your user ID: %s\n", base64.RawURLEncoding.EncodeToString(pubKey)) return } @@ -49,11 +65,56 @@ func main() { if err != nil { panic(err) } + id = base64.RawURLEncoding.EncodeToString(pubKey) servers := strings.Split(string(serversBytes), "\n") + fmt.Println(pubKey, privKey, servers) - - if flag.Arg(0) == "associate" { - - + if flag.Arg(0) == "add" { + // Add server + servers = append(servers, flag.Arg(1)) + dhtPost(flag.Arg(1)) + http.Get(flag.Arg(1) + "/user/" + id) + if servers[0] != flag.Arg(1) { + http.Get(servers[0] + "/user/" + id) + } + err := os.WriteFile("servers", []byte(strings.Join(servers, "\n")), 0644) + if err != nil { + panic(err) + } + } else if flag.Arg(0) == "remove" { + // Remove server + for i := range servers { + if servers[i] == flag.Arg(1) { + servers = append(servers[:i], servers[i+1:]...) + break + } + } + if len(servers) > 0 { + dhtPost(servers[0]) + http.Get(servers[0] + "/user/" + id) + } else { + dhtPost(flag.Arg(1)) + } + http.Get(flag.Arg(1) + "/user/" + id) + err := os.WriteFile("servers", []byte(strings.Join(servers, "\n")), 0644) + if err != nil { + panic(err) + } + } else if flag.Arg(0) == "primary" { + // Make server a primary + for i := range servers { + if i > 0 && servers[i] == flag.Arg(1) { + servers[i] = servers[0] + servers[0] = flag.Arg(i) + dhtPost(flag.Arg(i)) + http.Get(servers[0] + "/user/" + id) + http.Get(servers[i] + "/user/" + id) + break + } + } + err := os.WriteFile("servers", []byte(strings.Join(servers, "\n")), 0644) + if err != nil { + panic(err) + } } } diff --git a/server/dht.go b/server/dht.go index c3e984b..6590199 100644 --- a/server/dht.go +++ b/server/dht.go @@ -5,7 +5,6 @@ import ( "crypto/ed25519" "crypto/sha256" "encoding/base64" - "errors" "fmt" "io" "log" @@ -63,7 +62,16 @@ func addPeer(peer string) error { peerHashes[i] = peerHash myPos = sort.SearchStrings(peerHashes, me) - // TODO: redistribute keys + // Distribute keys to new server + for id, user := range users { + phase := time.Now().Unix() / 600 + if keyPos(id + "\n" + fmt.Sprint(phase))-myPos < 5 { + go http.Post(peer+"/dht/"+id+"?phase="+fmt.Sprint(phase)+"&direct=true", "application/octet-stream", bytes.NewBuffer(user.dhtVal)) + } + if keyPos(id + "\n" + fmt.Sprint(phase+1))-myPos < 5 { + go http.Post(peer+"/dht/"+id+"?phase="+fmt.Sprint(phase+1)+"&direct=true", "application/octet-stream", bytes.NewBuffer(user.dhtVal)) + } + } mu.Unlock() // Read response body @@ -116,8 +124,22 @@ func timestamp(val []byte) int { } // Get the value for a key from the DHT -func dhtGet(key string) ([]byte, error) { - keyPos := keyPos(key) +func dhtGet(key, direct string) []byte { + phase := fmt.Sprint(time.Now().Unix() / 600) + keyPos := keyPos(key + "\n" + phase) + if direct != "" && keyPos-myPos < 5 { + // Directly read from kvstore + mu.Lock() + val, ok := kvstore[key+"\n"+phase] + mu.Unlock() + if !ok || verify(key, val) != nil { + return nil + } + return val + } + + // Contact 5 servers that store this key-value pair + var mu sync.Mutex var wg sync.WaitGroup var latest []byte for i := 0; i < 5 && i < len(peerHashes); i++ { @@ -137,30 +159,56 @@ func dhtGet(key string) ([]byte, error) { if err != nil { return } + mu.Lock() if latest == nil || timestamp(val) > timestamp(latest) { latest = val } + mu.Unlock() }() } + // Wait for all to finish or time out wg.Wait() - if latest == nil { - return nil, errors.New("key not found in kvstore") - } - return latest, nil + return latest } // Post a key-value pair into the DHT -func dhtPost(key string, val []byte) error { +func dhtPost(key, phase, direct string, val []byte) error { err := verify(key, val) if err != nil { return err } - keyPos := keyPos(key) + if phase == "" { + phase = fmt.Sprint(time.Now().Unix() / 600) + } + user, ok := users[key] + if ok { + curPhase, err := strconv.Atoi(phase) + if err != nil { + return err + } + nextPhase := time.Now().Unix()/600 + 1 + if int64(curPhase) < nextPhase && user.phase < nextPhase { + user.phase = nextPhase + go dhtPost(key, fmt.Sprint(nextPhase), "", val) + } + } + + keyPos := keyPos(key + "\n" + phase) + if direct != "" && keyPos-myPos < 5 { + // Directly write to kvstore + mu.Lock() + curVal, ok := kvstore[key+"\n"+phase] + if !ok || timestamp(val) > timestamp(curVal) { + kvstore[key+"\n"+phase] = val + } + mu.Unlock() + return nil + } + + // Contact 5 servers that store this key-value pair for i := 0; i < 5 && i < len(peerHashes); i++ { j := hashToDomain[peerHashes[(keyPos+i)%len(peerHashes)]] - go func() { - http.Post(j+"/dht/"+key+"?direct", "application/octet-stream", bytes.NewBuffer(val)) - }() + go http.Post(j+"/dht/"+key+"?phase="+phase+"&direct=true", "application/octet-stream", bytes.NewBuffer(val)) } return nil } @@ -169,52 +217,16 @@ func dhtPost(key string, val []byte) error { func dhtHandler(w http.ResponseWriter, r *http.Request) { key := r.URL.Path[5:] r.ParseForm() - if r.Form.Get("direct") != "" { - // Directly modify kvstore - if keyPos(key)-myPos >= 5 { - w.WriteHeader(http.StatusNotFound) - return - } - phase := time.Now().Unix()/600 - if r.Method == "GET" { - mu.Lock() - val, ok := kvstore[key+"\n"+fmt.Sprint(phase)] - mu.Unlock() - if !ok || verify(key, val) != nil { - w.WriteHeader(http.StatusNotFound) - return - } - w.Write(val) - } else if r.Method == "POST" { - val, err := io.ReadAll(r.Body) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - mu.Lock() - // Update key for this phase and next one - kvstore[key+"\n"+fmt.Sprint(phase)] = val - kvstore[key+"\n"+fmt.Sprint(phase+1)] = val - mu.Unlock() - } - return - } - if r.Method == "GET" { - val, err := dhtGet(key) - if err != nil { + val := dhtGet(key, r.Form.Get("direct")) + if val == nil { w.WriteHeader(http.StatusNotFound) return } - w.Write([]byte(val)) + w.Write(val) } else if r.Method == "POST" { val, err := io.ReadAll(r.Body) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - err = dhtPost(key, val) - if err != nil { + if err != nil || dhtPost(key, r.Form.Get("phase"), r.Form.Get("direct"), val) != nil { w.WriteHeader(http.StatusInternalServerError) return } @@ -235,7 +247,23 @@ func cleanPeers() { i := sort.SearchStrings(peerHashes, sha256sum(peer)) peerHashes = append(peerHashes[:i], peerHashes[i+1:]...) myPos = sort.SearchStrings(peerHashes, me) - // TODO: redistribute keys + + // Distribute keys on this server to other servers + if len(peerHashes) >= 5 { + for id, user := range users { + phase := time.Now().Unix() / 600 + kpos := keyPos(id + "\n" + fmt.Sprint(phase)) + if kpos-i < 5 { + server := hashToDomain[peerHashes[(kpos+4)%len(peerHashes)]] + go http.Post(server+"/dht/"+id+"?phase="+fmt.Sprint(phase)+"&direct=true", "application/octet-stream", bytes.NewBuffer(user.dhtVal)) + } + kpos = keyPos(id + "\n" + fmt.Sprint(phase+1)) + if kpos-i < 5 { + server := hashToDomain[peerHashes[(kpos+4)%len(peerHashes)]] + go http.Post(server+"/dht/"+id+"?phase="+fmt.Sprint(phase+1)+"&direct=true", "application/octet-stream", bytes.NewBuffer(user.dhtVal)) + } + } + } mu.Unlock() } time.Sleep(5 * time.Second) @@ -261,9 +289,14 @@ func cleanKVStore() { // Redistribute key-value pairs periodically func redistributeKeys() { for true { + mu.Lock() for id, user := range users { - dhtPost(id, user.dhtVal) + nextPhase := time.Now().Unix()/600 + 1 + if user.phase < nextPhase { + go dhtPost(id, fmt.Sprint(nextPhase), "", user.dhtVal) + } } + mu.Unlock() time.Sleep(time.Duration(rand.Intn(300)) * time.Second) } } diff --git a/server/main.go b/server/main.go index b4fc654..a342ba5 100644 --- a/server/main.go +++ b/server/main.go @@ -30,14 +30,23 @@ func main() { peerHashes = append(peerHashes, sha256sum(me)) hashToDomain = map[string]string{peerHashes[0]: me} + // Start background functions if initialPeer != "" { go addPeer(initialPeer) } go cleanPeers() go cleanKVStore() + go redistributeKeys() // Load user data from disk - entries, _ := os.ReadDir(dataDir) + err := os.Mkdir(dataDir, 0755) + if err != nil { + log.Fatal(err) + } + entries, err := os.ReadDir(dataDir) + if err != nil { + log.Fatal(err) + } for _, entry := range entries { id := entry.Name() reader, err := os.Open(dataDir + "/" + id + "/gob") diff --git a/server/test.sh b/server/test.sh index 8ca1d42..592ab6d 100755 --- a/server/test.sh +++ b/server/test.sh @@ -1,10 +1,10 @@ #!/bin/bash trap "kill 0" EXIT go build -./server -b :4200 -d http://localhost:4200 & +./server -d 0 & for i in $(seq 1 9) do sleep 0.1 - ./server -b :420$i -d http://localhost:420$i -i http://localhost:420$((i-1)) & + ./server -d $i -b :420$i -u http://localhost:420$i -i http://localhost:420$((i-1)) & done wait diff --git a/server/user.go b/server/user.go index 5865c93..aa78239 100644 --- a/server/user.go +++ b/server/user.go @@ -11,7 +11,8 @@ import ( ) type user struct { - dhtVal []byte + dhtVal []byte + phase int64 } var users map[string]user @@ -47,8 +48,8 @@ func persist(id string) { func userHandler(w http.ResponseWriter, r *http.Request) { id := r.URL.Fragment[6:] // Resolve ID to server list - val, err := dhtGet(id) - if err != nil || verify(id, val) != nil { + val := dhtGet(id, "") + if verify(id, val) != nil { w.WriteHeader(http.StatusNotFound) return } @@ -57,7 +58,7 @@ func userHandler(w http.ResponseWriter, r *http.Request) { if !strings.Contains(message, me) { // Delete user if they are no longer associated with this server delete(users, id) - err = os.RemoveAll(id) + err := os.RemoveAll(id) if err != nil { w.WriteHeader(http.StatusNotFound) return @@ -73,7 +74,7 @@ func userHandler(w http.ResponseWriter, r *http.Request) { users[id] = user{ dhtVal: val, } - os.Mkdir(id, 755) + os.Mkdir(id, 0755) persist(id) } }