tlib.rs - sraft - simple raft implementation
HTML git clone https://git.parazyd.org/sraft
DIR Log
DIR Files
DIR Refs
DIR README
---
tlib.rs (16370B)
---
1 use std::{collections::HashMap, io, net::SocketAddr, time::Duration};
2
3 use async_channel::{Receiver, Sender};
4 use async_std::{
5 io::{ReadExt, WriteExt},
6 net::{TcpListener, TcpStream},
7 stream::StreamExt,
8 sync::Mutex,
9 task,
10 };
11 use borsh::{BorshDeserialize, BorshSerialize};
12 use futures::{select, FutureExt};
13 use lazy_static::lazy_static;
14 use log::{debug, error};
15 use rand::Rng;
16
17 mod method;
18 use crate::method::{HeartbeatArgs, HeartbeatReply, RaftMethod, VoteArgs, VoteReply};
19
20 #[derive(BorshSerialize, BorshDeserialize, Clone, Debug)]
21 pub struct LogEntry {
22 log_term: u64,
23 log_index: u64,
24 log_data: Vec<u8>,
25 }
26
27 pub struct LogStore(pub Vec<LogEntry>);
28
29 impl LogStore {
30 fn get_last_index(&self) -> u64 {
31 let rlen = self.0.len();
32 if rlen == 0 {
33 return 0
34 }
35
36 self.0[rlen - 1].log_index
37 }
38 }
39
40 lazy_static! {
41 pub static ref LOG_STORE: Mutex<LogStore> = Mutex::new(LogStore(vec![]));
42 // This is used for heartbeats
43 pub static ref HEARTBEAT_CHAN: (Sender<bool>, Receiver<bool>) = async_channel::unbounded();
44 // This is used to let our node know when it has become a leader
45 pub static ref TOLEADER_CHAN: (Sender<bool>, Receiver<bool>) = async_channel::unbounded();
46
47 pub static ref STATE: Mutex<State> = Mutex::new(State::new());
48 }
49
50 #[derive(Default)]
51 pub struct State {
52 pub current_term: u64,
53 pub voted_for: u64,
54 pub vote_count: u64,
55
56 pub commit_index: u64,
57 pub _last_applied: u64,
58
59 pub next_index: Vec<u64>,
60 pub match_index: Vec<u64>,
61 }
62
63 impl State {
64 pub fn new() -> Self {
65 Self {
66 current_term: 0,
67 voted_for: 0,
68 vote_count: 0,
69 commit_index: 0,
70 _last_applied: 0,
71 next_index: vec![],
72 match_index: vec![],
73 }
74 }
75 }
76
77 pub enum Role {
78 Follower,
79 Candidate,
80 Leader,
81 }
82
83 pub struct Raft {
84 pub peers: HashMap<u64, SocketAddr>,
85 node_id: u64,
86 role: Role,
87 }
88
89 impl Raft {
90 pub fn new(node_id: u64) -> Self {
91 Self { peers: Default::default(), node_id, role: Role::Follower }
92 }
93
94 pub async fn start(&mut self) {
95 debug!("Raft::start()");
96 self.role = Role::Follower;
97
98 let mut state = STATE.lock().await;
99 state.current_term = 0;
100 state.voted_for = 0;
101 drop(state);
102
103 let mut rng = rand::thread_rng();
104
105 loop {
106 let delay = Duration::from_millis(rng.gen_range(0..200) + 300);
107
108 match self.role {
109 Role::Follower => {
110 select! {
111 _ = HEARTBEAT_CHAN.1.recv().fuse() => {
112 debug!("[FOLLOWER] Raft::start(): follower_{} got heartbeat", self.node_id);
113 }
114 _ = task::sleep(delay).fuse() => {
115 debug!("[FOLLOWER] Raft::start(): follower_{} timeout", self.node_id);
116 self.role = Role::Candidate;
117 }
118 }
119 }
120
121 Role::Candidate => {
122 debug!("[CANDIDATE] Raft::start(): peer_{} is now a candidate", self.node_id);
123 let mut state = STATE.lock().await;
124 state.current_term += 1;
125 state.voted_for = self.node_id;
126 state.vote_count = 1;
127 drop(state);
128
129 // TODO: In background
130 debug!("[CANDIDATE] Raft::start(): broadcasting request_vote");
131 self.broadcast_request_vote().await;
132
133 select! {
134 _ = task::sleep(delay).fuse() => {
135 debug!("[CANDIDATE] Raft::start(): Timeout as candidate, becoming a follower");
136 self.role = Role::Follower;
137 }
138 _ = TOLEADER_CHAN.1.recv().fuse() => {
139 debug!("[CANDIDATE] Raft::start(): We are now the leader");
140 self.role = Role::Leader;
141
142 let mut state = STATE.lock().await;
143 state.next_index = vec![1_u64; self.peers.len()];
144 state.match_index = vec![0_u64; self.peers.len()];
145 drop(state);
146
147 // TODO: In background
148 let t = task::spawn(async {
149 let mut i = 0;
150 loop {
151 debug!("[CANDIDATE] Raft::start(): Appending data in bg loop");
152 i += 1;
153 let state = STATE.lock().await;
154 let logentry = LogEntry {
155 log_term: state.current_term,
156 log_index: i,
157 log_data: format!("user send: {}", i).as_bytes().to_vec(),
158 };
159 drop(state);
160
161 debug!("[CANDIDATE] Raft::start(): Acquiring logstore lock in bg loop");
162 let mut logstore = LOG_STORE.lock().await;
163 logstore.0.push(logentry);
164 drop(logstore);
165 debug!("[CANDIDATE] Raft::start(): Dropped logstore lock in bg loop");
166 task::sleep(Duration::from_secs(3)).await;
167 }
168 });
169 }
170 }
171 }
172
173 Role::Leader => {
174 debug!("[LEADER] Raft::start(): Broadcasting heartbeat as leader");
175 self.broadcast_heartbeat().await;
176 task::sleep(Duration::from_millis(100)).await;
177 }
178 }
179 }
180 }
181
182 async fn broadcast_request_vote(&mut self) {
183 debug!("Raft::broadcast_request_vote()");
184 let state = STATE.lock().await;
185 let args = VoteArgs { term: state.current_term, candidate_id: self.node_id };
186 drop(state);
187
188 // TODO: Do this concurrently
189 for i in self.peers.clone() {
190 debug!("Raft::broadcast_request_vote(): Sending req to peer {}", i.1);
191 match self.send_request_vote(i.0, args.clone()).await {
192 Ok(v) => debug!("Raft::broadcast_request_vote(): Got reply: {:?}", v),
193 Err(e) => {
194 error!("Raft::broadcast_request_vote(): Failed vote to peer {}, ({})", i.1, e);
195 continue
196 }
197 };
198 }
199 }
200
201 async fn send_request_vote(
202 &mut self,
203 node_id: u64,
204 args: VoteArgs,
205 ) -> Result<VoteReply, io::Error> {
206 debug!("Raft::send_request_vote()");
207 let addr = self.peers[&node_id];
208
209 let method = RaftMethod::Vote(args);
210 let payload = method.try_to_vec().unwrap();
211
212 debug!("Raft::send_request_vote(): Connecting to peer_{}", node_id);
213 let mut stream = TcpStream::connect(addr).await?;
214 debug!("Raft::send_request_vote(): Writing to stream");
215 stream.write_all(&payload).await?;
216 debug!("Raft::send_request_vote(): Wrote to stream");
217
218 debug!("Raft::send_request_vote(): Reading from stream");
219 let mut buf = vec![0_u8; 4096];
220 stream.read(&mut buf).await?;
221 debug!("Raft::send_request_vote(): Read from stream");
222
223 let reply = try_from_slice_unchecked::<VoteReply>(&buf)?;
224 let mut state = STATE.lock().await;
225 if reply.term > state.current_term {
226 debug!("Raft::send_request_vote(): reply.term > state.current_term");
227 state.current_term = reply.term;
228 state.voted_for = 0;
229 drop(state);
230 self.role = Role::Follower;
231 return Ok(reply)
232 }
233 drop(state);
234
235 if reply.vote_granted {
236 debug!("Raft::send_request_vote(): reply.vote_granted == true");
237 let mut state = STATE.lock().await;
238 state.vote_count += 1;
239 drop(state);
240 }
241
242 let state = STATE.lock().await;
243 if state.vote_count >= (self.peers.len() / 2 + 1).try_into().unwrap() {
244 debug!("Raft::send_request_vote(): Elected for leader");
245 TOLEADER_CHAN.0.send(true).await.unwrap();
246 }
247 drop(state);
248
249 Ok(reply)
250 }
251
252 async fn broadcast_heartbeat(&mut self) {
253 debug!("[LEADER] Raft::broadcast_heartbeat()");
254
255 for i in self.peers.clone() {
256 let state = STATE.lock().await;
257 let mut args = HeartbeatArgs {
258 term: state.current_term,
259 leader_id: self.node_id,
260 prev_log_index: 0,
261 prev_log_term: 0,
262 entries: vec![],
263 leader_commit: state.commit_index,
264 };
265
266 let prev_log_index = state.next_index[i.0 as usize] - 1;
267 drop(state);
268
269 debug!("[LEADER] Raft::broadcast_heartbeat(): Acquiring lock on LOG_STORE");
270 let logstore = LOG_STORE.lock().await;
271 if logstore.get_last_index() > prev_log_index {
272 args.prev_log_index = prev_log_index;
273 args.prev_log_term = logstore.0[prev_log_index as usize].log_term;
274 args.entries = logstore.0[prev_log_index as usize..].to_vec();
275 drop(logstore);
276 debug!("[LEADER] Raft::broadcast_heartbeat(): Dropped lock on LOG_STORE");
277 debug!("[LEADER] Raft::broadcast_heartbeat(): Send entries: {:?}", args.entries);
278 }
279
280 // TODO: Run in background
281 match self.send_heartbeat(i.0, args).await {
282 Ok(v) => debug!("[LEADER] Raft::broadcast_heartbeat(): Got reply: {:?}", v),
283 Err(e) => {
284 error!(
285 "[LEADER] Raft::broadcast_heartbeat(): Failed heartbeat to peer_{} ({})",
286 i.0, e
287 );
288 continue
289 }
290 };
291 }
292 }
293
294 async fn send_heartbeat(
295 &mut self,
296 node_id: u64,
297 args: HeartbeatArgs,
298 ) -> Result<HeartbeatReply, io::Error> {
299 debug!("Raft::send_heartbeat({}, {:?}", node_id, args);
300 let addr = self.peers[&node_id];
301
302 let method = RaftMethod::Heartbeat(args);
303 let payload = method.try_to_vec()?;
304
305 debug!("Raft::send_heartbeat(): Connecting to peer_{}", node_id);
306 let mut stream = TcpStream::connect(addr).await?;
307 debug!("Raft::send_heartbeat(): Writing to stream");
308 stream.write_all(&payload).await?;
309 debug!("Raft::send_heartbeat(): Wrote to stream");
310
311 debug!("Raft::send_heartbeat(): Reading from stream");
312 let mut buf = vec![0_u8; 4096];
313 stream.read(&mut buf).await?;
314 debug!("Raft::send_heartbeat(): Read from stream");
315
316 let reply = try_from_slice_unchecked::<HeartbeatReply>(&buf)?;
317
318 let mut state = STATE.lock().await;
319 if reply.success {
320 debug!("Raft::send_heartbeat(): Got success reply");
321 if reply.next_index > 0 {
322 state.next_index[node_id as usize] = reply.next_index;
323 state.match_index[node_id as usize] = reply.next_index - 1;
324 }
325 } else if reply.term > state.current_term {
326 debug!("Raft::send_heartbeat(): reply.term > state.current_term");
327 state.current_term = reply.term;
328 state.voted_for = 0;
329 self.role = Role::Follower;
330 }
331 drop(state);
332
333 Ok(reply)
334 }
335 }
336
337 pub struct RaftRpc(pub SocketAddr);
338
339 impl RaftRpc {
340 pub async fn start(&self) {
341 debug!("RaftRpc::start()");
342
343 debug!("RaftRpc::start(): Binding to {}", self.0);
344 let listener = TcpListener::bind(self.0).await.unwrap();
345 let mut incoming = listener.incoming();
346
347 while let Some(stream) = incoming.next().await {
348 debug!("RaftRpc::start(): Got RPC request");
349 let stream = stream.unwrap();
350 let (reader, writer) = &mut (&stream, &stream);
351
352 debug!("RaftRpc::start(): Reading from reader...");
353 let mut buf = vec![0_u8; 4096];
354 reader.read(&mut buf).await.unwrap();
355 debug!("RaftRpc::start(): Read from reader");
356
357 match try_from_slice_unchecked::<RaftMethod>(&buf).unwrap() {
358 RaftMethod::Vote(args) => {
359 debug!("RaftRpc::start(): Got RaftMethod::Vote");
360 let reply = self.request_vote(args).await;
361 let payload = reply.try_to_vec().unwrap();
362
363 debug!("RaftRpc::start(): Vote: Writing to writer...");
364 writer.write_all(&payload).await.unwrap();
365 debug!("RaftRpc::start(): Vote: Wrote to writer");
366 }
367
368 RaftMethod::Heartbeat(args) => {
369 debug!("RaftRpc::start(): Got RaftMethod::Heartbeat");
370 let reply = self.heartbeat(args).await;
371 let payload = reply.try_to_vec().unwrap();
372
373 debug!("RaftRpc::start(): Heartbeat: Writing to writer...");
374 writer.write_all(&payload).await.unwrap();
375 debug!("RaftRpc::start(): Heartbeat: Wrote to writer");
376 }
377 }
378 }
379 }
380
381 async fn request_vote(&self, args: VoteArgs) -> VoteReply {
382 debug!("RaftRpc::request_vote()");
383 let mut reply = VoteReply { term: 0, vote_granted: false };
384
385 debug!("RaftRpc::request_vote(): Acquiring state lock");
386 let mut state = STATE.lock().await;
387 debug!("RaftRpc::request_vote(): Got lock");
388
389 if args.term < state.current_term {
390 reply.term = state.current_term;
391 drop(state);
392 reply.vote_granted = false;
393 return reply
394 }
395
396 if state.voted_for == 0 {
397 state.current_term = args.term;
398 state.voted_for = args.candidate_id;
399 drop(state);
400 reply.term = args.term;
401 reply.vote_granted = true;
402 return reply
403 }
404
405 drop(state);
406 reply
407 }
408
409 async fn heartbeat(&self, args: HeartbeatArgs) -> HeartbeatReply {
410 debug!("RaftRpc::heartbeat()");
411 let mut reply = HeartbeatReply { success: false, term: 0, next_index: 0 };
412
413 debug!("RaftRpc::heartbeat(): Acquiring state lock");
414 let state = STATE.lock().await;
415 debug!("RaftRpc::heartbeat(): Got state lock");
416 let current_term = state.current_term;
417 drop(state);
418 debug!("RaftRpc::heartbeat(): Dropped state lock");
419
420 if args.term < current_term {
421 reply.success = false;
422 reply.term = current_term;
423 return reply
424 }
425
426 debug!("RaftRpc::heartbeat(): Sending to channel");
427 HEARTBEAT_CHAN.0.send(true).await.unwrap();
428 debug!("RaftRpc::heartbeat(): Sent to channel");
429
430 if args.entries.is_empty() {
431 reply.success = true;
432 reply.term = current_term;
433 return reply
434 }
435
436 debug!("RaftRpc::heartbeat(): Acquiring logstore lock");
437 let mut logstore = LOG_STORE.lock().await;
438 debug!("RaftRpc::heartbeat(): Got logstore lock");
439 if args.prev_log_index > logstore.get_last_index() {
440 reply.success = false;
441 reply.term = current_term;
442 reply.next_index = logstore.get_last_index() + 1;
443 drop(logstore);
444 return reply
445 }
446
447 logstore.0.extend_from_slice(&args.entries);
448 reply.next_index = logstore.get_last_index() + 1;
449 drop(logstore);
450 debug!("RaftRpc::heartbeat(): Dropped logstore lock");
451
452 reply.success = true;
453 reply.term = current_term;
454
455 reply
456 }
457 }
458
459 fn try_from_slice_unchecked<T: BorshDeserialize>(data: &[u8]) -> Result<T, io::Error> {
460 let mut data_mut = data;
461 let result = T::deserialize(&mut data_mut)?;
462 Ok(result)
463 }