use std::io::{self, Write}; use std::net::{TcpStream, ToSocketAddrs}; use std::time::Duration; use std::env; use colored::*; use regex::Regex; mod resp; const DEFAULT_PORT: u16 = 9880; const PROMPT_NAME: &str = "futriix"; const CONNECTION_TIMEOUT: u64 = 5; // seconds fn main() { // Parse command line arguments let args: Vec = env::args().collect(); let mut host = "127.0.0.1".to_string(); let mut port = DEFAULT_PORT; if args.len() > 1 { let re = Regex::new(r"^(?:([^:]+):)?([^:]+)(?::(\d+))?$").unwrap(); if let Some(caps) = re.captures(&args[1]) { if let Some(h) = caps.get(2) { host = h.as_str().to_string(); } if let Some(p) = caps.get(3) { port = p.as_str().parse().unwrap_or(DEFAULT_PORT); } } } // Try to connect to server let addr = format!("{}:{}", host, port); let socket_addr = match format!("{}:{}", host, port).to_socket_addrs() { Ok(mut addrs) => addrs.next().expect("No address found"), Err(_e) => { eprintln!("Failed to resolve address {}", addr); return; } }; let stream = match TcpStream::connect_timeout(&socket_addr, Duration::from_secs(CONNECTION_TIMEOUT)) { Ok(s) => s, Err(_e) => { eprintln!("Connection to port {} refused", port); return; } }; println!("Connected to {}", addr); // Main REPL loop let mut input = String::new(); loop { // Print prompt print_prompt(&host, port); // Read user input input.clear(); io::stdin().read_line(&mut input).expect("Failed to read input"); let input = input.trim(); if input.is_empty() { continue; } // Handle special commands if input.eq_ignore_ascii_case("quit") || input.eq_ignore_ascii_case("exit") { break; } // Validate command before sending if !is_valid_command(input) { eprintln!("Error: Invalid command format"); continue; } // Send command to server match send_command(&stream, input) { Ok(response) => { print_response(&response); } Err(e) => { eprintln!("Error: {}", e); // Check if connection was lost if e.kind() == io::ErrorKind::ConnectionAborted || e.kind() == io::ErrorKind::ConnectionReset { eprintln!("Connection lost. Please restart the client."); break; } } } } } fn is_valid_command(cmd: &str) -> bool { // Basic validation - command should not be empty and should contain only printable characters if cmd.is_empty() { return false; } // Check for control characters if cmd.chars().any(|c| c.is_control()) { return false; } // Check for valid command structure (at least one non-whitespace character) cmd.split_whitespace().next().is_some() } fn print_prompt(host: &str, port: u16) { let prompt = format!("{}:{}:{}:~>", PROMPT_NAME, host, port); print!("{} ", prompt.green()); io::stdout().flush().unwrap(); } fn send_command(stream: &TcpStream, command: &str) -> io::Result { // Parse command into RESP format let parts: Vec<&str> = command.split_whitespace().collect(); let mut resp_command = String::new(); // RESP protocol: *\r\n$\r\n\r\n... resp_command.push_str(&format!("*{}\r\n", parts.len())); for part in parts { resp_command.push_str(&format!("${}\r\n{}\r\n", part.len(), part)); } // Send command let mut stream = stream.try_clone()?; stream.write_all(resp_command.as_bytes())?; // Read response let mut decoder = resp::Decoder::new(&stream); decoder.decode() } fn print_response(value: &resp::Value) { match value { resp::Value::SimpleString(s) | resp::Value::BulkString(s) => println!("{}", s), resp::Value::Error(e) => println!("(error) {}", e), resp::Value::Integer(i) => println!("(integer) {}", i), resp::Value::Array(arr) => { for (i, item) in arr.iter().enumerate() { print!("{}) ", i + 1); print_response(item); } } resp::Value::Null => println!("(nil)"), } } #[cfg(test)] mod tests { use super::*; #[test] fn test_print_prompt() { print_prompt("127.0.0.1", 9880); } }