From e49973e4853b0fecc01acf65ced307f513003ab6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=93=D1=80=D0=B8=D0=B3=D0=BE=D1=80=D0=B8=D0=B9=20=D0=A1?= =?UTF-8?q?=D0=B0=D1=84=D1=80=D0=BE=D0=BD=D0=BE=D0=B2?= Date: Sun, 25 May 2025 22:04:59 +0000 Subject: [PATCH] Update src/main.rs --- src/main.rs | 266 ++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 227 insertions(+), 39 deletions(-) diff --git a/src/main.rs b/src/main.rs index af7891c..76bf540 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,57 +1,209 @@ -use std::io::{self, Write}; -use std::net::TcpStream; -use std::time::Duration; +use std::collections::HashMap; use std::env; -use colored::Colorize; -use regex::Regex; +use std::io::{self, Write}; +use std::net::{TcpStream, SocketAddr}; use std::process; +use std::time::Duration; +use colored::Colorize; mod resp; -const DEFAULT_PORT: u16 = 9880; +const DEFAULT_HOST: &str = "127.0.0.1"; +const DEFAULT_PORT: u16 = 6379; const PROMPT_NAME: &str = "futriix"; const CONNECTION_TIMEOUT_SECS: u64 = 2; -fn main() { - let (host, port) = parse_args(); - let addr = format!("{}:{}", host, port); - - let stream = match TcpStream::connect_timeout( - &addr.parse().unwrap(), - Duration::from_secs(CONNECTION_TIMEOUT_SECS) - ) { - Ok(stream) => stream, - Err(e) if e.kind() == io::ErrorKind::ConnectionRefused => { - eprintln!("{}", "Connection refused".red()); - process::exit(1); - }, - Err(e) => { - eprintln!("Connection error: {}", e); - process::exit(1); - } - }; - - println!("Connected to {}", addr.green()); - run_repl_loop(stream, &host, port); +#[derive(Debug)] +struct ClusterNode { + stream: TcpStream, + slots: (u16, u16), + node_id: String, } -fn parse_args() -> (String, u16) { - let args: Vec = env::args().collect(); - let mut host = "127.0.0.1".to_string(); - let mut port = DEFAULT_PORT; +#[derive(Debug)] +struct RedisCluster { + nodes: HashMap, + slot_cache: HashMap, +} - if args.len() > 1 { - let re = Regex::new(r"^(?:([^:]+):)?([^:]+)(?::(\d+))?$").unwrap(); - if let Some(caps) = re.captures(&args[1]) { - if let Some(h) = caps.get(2) { - host = h.as_str().to_string(); +impl RedisCluster { + fn new(initial_nodes: Vec<&str>) -> io::Result { + let mut cluster = RedisCluster { + nodes: HashMap::new(), + slot_cache: HashMap::new(), + }; + cluster.update_cluster_info(initial_nodes)?; + Ok(cluster) + } + + fn update_cluster_info(&mut self, nodes: Vec<&str>) -> io::Result<()> { + for node in nodes { + let addr: SocketAddr = node.parse().map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid node address: {}", e), + ) + })?; + + let stream = TcpStream::connect_timeout( + &addr, + Duration::from_secs(CONNECTION_TIMEOUT_SECS), + )?; + + self.nodes.insert( + node.to_string(), + ClusterNode { + stream, + slots: (0, 16383), + node_id: "mock_node_id".to_string(), + }, + ); + + for slot in 0..16384 { + self.slot_cache.insert(slot, node.to_string()); } - if let Some(p) = caps.get(3) { - port = p.as_str().parse().unwrap_or(DEFAULT_PORT); + } + Ok(()) + } + + fn get_connection(&mut self, key: &str) -> io::Result<&mut TcpStream> { + let slot = self.calculate_slot(key); + if let Some(node_addr) = self.slot_cache.get(&slot) { + if let Some(node) = self.nodes.get_mut(node_addr) { + return Ok(&mut node.stream); + } + } + Err(io::Error::new( + io::ErrorKind::NotFound, + "Failed to find node for key", + )) + } + + fn calculate_slot(&self, key: &str) -> u16 { + let key = if let Some(start) = key.find('{') { + if let Some(end) = key.find('}') { + &key[start+1..end] + } else { + key + } + } else { + key + }; + + crc16::State::::calculate(key.as_bytes()) % 16384 + } +} + +fn main() { + let args: Vec = env::args().collect(); + + let mut host = DEFAULT_HOST.to_string(); + let mut port = DEFAULT_PORT; + let mut cluster_mode = false; + + let mut i = 1; + while i < args.len() { + match args[i].as_str() { + "-h" | "--host" => { + if i + 1 < args.len() { + host = args[i + 1].clone(); + i += 1; + } + }, + "-p" | "--port" => { + if i + 1 < args.len() { + port = args[i + 1].parse().unwrap_or(DEFAULT_PORT); + i += 1; + } + }, + "-c" | "--cluster" => { + cluster_mode = true; + }, + _ => {} + } + i += 1; + } + + let addr = format!("{}:{}", host, port); + + if cluster_mode { + match RedisCluster::new(vec![&addr]) { + Ok(mut cluster) => { + println!("Connected to cluster at {}", addr.green()); + run_cluster_repl(&mut cluster, &host, port); + }, + Err(e) => { + eprintln!("Cluster connection error: {}", e); + process::exit(1); + } + } + } else { + match TcpStream::connect_timeout( + &addr.parse().unwrap(), + Duration::from_secs(CONNECTION_TIMEOUT_SECS) + ) { + Ok(stream) => { + println!("Connected to {}:{}", host.green(), port.to_string().green()); + run_repl_loop(stream, &host, port); + }, + Err(e) if e.kind() == io::ErrorKind::ConnectionRefused => { + eprintln!("{}", format!("Connection refused to {}:{}", host, port).red()); + process::exit(1); + }, + Err(e) => { + eprintln!("Connection error: {}", e); + process::exit(1); + } + } + } +} + +fn run_cluster_repl(cluster: &mut RedisCluster, host: &str, port: u16) { + let mut input = String::new(); + loop { + print_prompt(host, port); + input.clear(); + + if io::stdin().read_line(&mut input).is_err() { + eprintln!("{}", "Failed to read input".red()); + continue; + } + + let input = input.trim(); + if input.is_empty() { + continue; + } + + if input.eq_ignore_ascii_case("quit") || input.eq_ignore_ascii_case("exit") { + break; + } + + if !is_valid_command(input) { + eprintln!("{}", "Error: Invalid command format".red()); + continue; + } + + let parts: Vec<&str> = input.split_whitespace().collect(); + let key = if parts.len() > 1 { parts[1] } else { "" }; + + match cluster.get_connection(key) { + Ok(stream) => { + match send_command(stream, input) { + Ok(response) => print_response(&response), + Err(e) => { + if is_connection_error(&e) { + eprintln!("{}", "Connection error".red()); + break; + } + eprintln!("{}", format!("Error: {}", e.to_string().replace("KeyDB", "Futriix")).red()); + } + } + }, + Err(e) => { + eprintln!("Cluster error: {}", e); } } } - (host, port) } fn run_repl_loop(stream: TcpStream, host: &str, port: u16) { @@ -145,4 +297,40 @@ fn is_connection_error(error: &io::Error) -> bool { io::ErrorKind::ConnectionReset | io::ErrorKind::BrokenPipe ) +} + +mod crc16 { + pub struct State { + crc: u16, + _hasher: std::marker::PhantomData, + } + + pub trait Hasher { + const INIT: u16; + fn update(crc: u16, data: &[u8]) -> u16; + } + + pub struct XMODEM; + impl Hasher for XMODEM { + const INIT: u16 = 0; + fn update(crc: u16, data: &[u8]) -> u16 { + data.iter().fold(crc, |crc, &byte| { + let mut c = crc ^ (u16::from(byte) << 8); + for _ in 0..8 { + if c & 0x8000 != 0 { + c = (c << 1) ^ 0x1021; + } else { + c <<= 1; + } + } + c + }) + } + } + + impl State { + pub fn calculate(data: &[u8]) -> u16 { + H::update(H::INIT, data) + } + } } \ No newline at end of file