diff --git a/Cargo.lock b/Cargo.lock index 1a76312..7f24178 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,7 +4,7 @@ version = 4 [[package]] name = "hteapot" -version = "0.6.2" +version = "0.6.5" dependencies = [ "libc", ] diff --git a/Cargo.toml b/Cargo.toml index f022889..4515c61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hteapot" -version = "0.6.2" +version = "0.6.5" edition = "2024" authors = ["Alb Ruiz G. "] description = "HTeaPot is a lightweight HTTP server library designed to be easy to use and extend." @@ -25,6 +25,7 @@ path = "src/hteapot/mod.rs" name = "hteapot" [dependencies] +[target.'cfg(unix)'.dependencies] libc = "0.2.172" diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..57c7aea --- /dev/null +++ b/Dockerfile @@ -0,0 +1,19 @@ +FROM rust AS builder + +WORKDIR /app +COPY Cargo.lock Cargo.lock +COPY Cargo.toml Cargo.toml +COPY src ./src + +RUN cargo build --release + +FROM ubuntu + +COPY --from=builder /app/target/release/hteapot /bin/hteapot + +EXPOSE 80 + +WORKDIR /config + +ENTRYPOINT ["/bin/hteapot"] +CMD ["config.toml"] diff --git a/examples/basic2.rs b/examples/basic2.rs new file mode 100644 index 0000000..be12e5d --- /dev/null +++ b/examples/basic2.rs @@ -0,0 +1,18 @@ +use hteapot::{Hteapot, HttpRequest, HttpResponse, HttpStatus}; + +fn main() { + let server = Hteapot::new("localhost", 8081); + server.listen(move |req: HttpRequest| { + // This will be executed for each request + let body = String::from_utf8(req.body).unwrap_or("NOPE".to_string()); + for header in req.headers { + println!("- {}: {}", header.0, header.1); + } + println!("{}", body); + HttpResponse::new( + HttpStatus::IAmATeapot, + format!("Hello, I am HTeaPot\n{}", body), + None, + ) + }); +} diff --git a/examples/proxy_con.rs b/examples/proxy_con.rs new file mode 100644 index 0000000..878b88a --- /dev/null +++ b/examples/proxy_con.rs @@ -0,0 +1,28 @@ +use hteapot::{Hteapot, HttpMethod, HttpRequest, HttpResponse, TunnelResponse}; + +fn main() { + let server = Hteapot::new_threaded("0.0.0.0", 8081, 3); + server.listen(move |req: HttpRequest| { + println!("New request to {} {}!", req.method.to_str(), &req.path); + if req.method == HttpMethod::CONNECT { + TunnelResponse::new(&req.path) + } else { + println!("{:?}", req); + let addr = req.headers.get("host"); + let addr = if let Some(addr) = addr { + addr + } else { + return HttpResponse::new( + hteapot::HttpStatus::InternalServerError, + "content", + None, + ); + }; + req.brew(addr).unwrap_or(HttpResponse::new( + hteapot::HttpStatus::InternalServerError, + "content", + None, + )) + } + }); +} diff --git a/src/cache.rs b/src/cache.rs index 16f1c02..5dd7d63 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,10 +1,11 @@ // Written by Alberto Ruiz, 2024-11-05 -// +// // Config module: handles application configuration setup and parsing. // This module defines structs and functions to load and validate // configuration settings from files, environment variables, or other sources. use std::collections::HashMap; +use std::hash::Hash; use std::time; use std::time::SystemTime; @@ -24,14 +25,18 @@ use std::time::SystemTime; /// let data = cache.get("hello".into()); /// assert!(data.is_some()); /// ``` -pub struct Cache { +pub struct Cache { // TODO: consider make it generic // The internal store: (data, expiration timestamp) - data: HashMap, u64)>, + data: HashMap, max_ttl: u64, } -impl Cache { +impl Cache +where + K: Eq + Hash, + V: Clone, +{ /// Creates a new `Cache` with the specified TTL in seconds. pub fn new(max_ttl: u64) -> Self { Cache { @@ -61,14 +66,14 @@ impl Cache { } /// Stores data in the cache with the given key and a TTL. - pub fn set(&mut self, key: String, data: Vec) { + pub fn set(&mut self, key: K, data: V) { self.data.insert(key, (data, self.get_ttl())); } /// Retrieves data from the cache if it exists and hasn't expired. /// /// Removes and returns `None` if the TTL has expired. - pub fn get(&mut self, key: String) -> Option> { + pub fn get(&mut self, key: &K) -> Option { let r = self.data.get(&key); if r.is_some() { let (data, ttl) = r.unwrap(); diff --git a/src/config.rs b/src/config.rs index 42e68c1..28673a7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,9 +1,9 @@ // Written by Alberto Ruiz 2024-04-07 (Happy 3th monthsary) -// +// // This is the config module: responsible for loading application configuration // from a file and providing structured access to settings. -use std::{any::Any, collections::HashMap, fs}; +use std::{any::Any, collections::HashMap, fs, path::Path}; /// Dynamic TOML value representation. /// @@ -62,7 +62,7 @@ pub fn toml_parser(content: &str) -> HashMap { let mut map = HashMap::new(); let mut submap = HashMap::new(); let mut title = "".to_string(); - + let lines = content.split("\n"); for line in lines { if line.starts_with("#") || line.is_empty() { @@ -88,13 +88,13 @@ pub fn toml_parser(content: &str) -> HashMap { submap = HashMap::new(); continue; } - + // Split key and value let parts = line.split("=").collect::>(); if parts.len() != 2 { continue; } - + // Remove leading and trailing whitespace let key = parts[0] .trim() @@ -103,7 +103,7 @@ pub fn toml_parser(content: &str) -> HashMap { if key.is_empty() { continue; } - + // Remove leading and trailing whitespace let value = parts[1].trim(); let value = if value.contains('\'') || value.contains('"') { @@ -152,14 +152,14 @@ pub fn toml_parser(content: &str) -> HashMap { /// such as host, port, caching behavior, and proxy rules. #[derive(Debug)] pub struct Config { - pub port: u16, // Port number to listen - pub host: String, // Host name or IP - pub root: String, // Root directory to serve files + pub port: u16, // Port number to listen + pub host: String, // Host name or IP + pub root: String, // Root directory to serve files pub cache: bool, pub cache_ttl: u16, pub threads: u16, pub log_file: Option, - pub index: String, // Index file to serve by default + pub index: String, // Index file to serve by default // pub error: String, // Error file to serve when a file is not found pub proxy_rules: HashMap, } @@ -192,6 +192,35 @@ impl Config { } } + pub fn new_serve(path: &str) -> Config { + let mut s_path = "./".to_string(); + s_path.push_str(path); + let serving_path = Path::new(&s_path); + let file_name: &str; + let root_dir: String; + if serving_path.is_file() { + let parent_path = serving_path.parent().unwrap(); + root_dir = parent_path.to_str().unwrap().to_string(); + file_name = serving_path.file_name().unwrap().to_str().unwrap(); + } else { + file_name = "index.html"; + root_dir = serving_path.to_str().unwrap().to_string(); + }; + + Config { + port: 8080, + host: "0.0.0.0".to_string(), + root: root_dir, + index: file_name.to_string(), + log_file: None, + + threads: 1, + cache: false, + cache_ttl: 0, + proxy_rules: HashMap::new(), + } + } + /// Loads configuration from a TOML file, returning defaults on failure. /// /// Expects the file to contain `[HTEAPOT]` and optionally `[proxy]` sections. @@ -224,13 +253,13 @@ impl Config { // Suggested alternative parsing logic // if let Some(proxy_map) = map.get("proxy") { - // for k in proxy_map.keys() { - // if let Some(url) = proxy_map.get2(k) { - // proxy_rules.insert(k.clone(), url); - // } else { - // println!("Missing or invalid proxy URL for key: {}", k); - // } - // } + // for k in proxy_map.keys() { + // if let Some(url) = proxy_map.get2(k) { + // proxy_rules.insert(k.clone(), url); + // } else { + // println!("Missing or invalid proxy URL for key: {}", k); + // } + // } // } // Extract main configuration @@ -239,7 +268,6 @@ impl Config { // Suggested alternative parsing logic (Not working) // let map = map.get("HTEAPOT").unwrap_or(&TOMLSchema::new()); - Config { port: map.get2("port").unwrap_or(8080), host: map.get2("host").unwrap_or("".to_string()), @@ -253,4 +281,20 @@ impl Config { proxy_rules, } } + + pub fn new_proxy() -> Config { + let mut proxy_rules = HashMap::new(); + proxy_rules.insert("/".to_string(), "".to_string()); + Config { + port: 8080, + host: "0.0.0.0".to_string(), + root: "./".to_string(), + cache: false, + cache_ttl: 0, + threads: 2, + log_file: None, + index: "index.html".to_string(), + proxy_rules, + } + } } diff --git a/src/handler/file.rs b/src/handler/file.rs new file mode 100644 index 0000000..8bcfbaf --- /dev/null +++ b/src/handler/file.rs @@ -0,0 +1,131 @@ +use std::{ + fs, + path::{Path, PathBuf}, +}; + +use crate::{ + handler::handler::{Handler, HandlerFactory}, + hteapot::{HttpHeaders, HttpResponse, HttpStatus}, + utils::{Context, get_mime_tipe}, +}; + +/// Safely joins a root directory with a requested relative path. +/// +/// Ensures that: +/// - Symbolic links and `..` segments are resolved (`canonicalize`) +/// - The resulting path stays within `root` +/// - The path exists on disk +/// +/// This prevents directory traversal attacks (e.g., accessing `/etc/passwd`). +/// +/// # Arguments +/// * `root` - Allowed root directory. +/// * `requested_path` - Path requested by the client. +/// +/// # Returns +/// `Some(PathBuf)` if the path is valid and exists, `None` otherwise. +/// +/// # Example +/// ``` +/// let safe_path = safe_join_paths("/var/www", "/index.html"); +/// assert!(safe_path.unwrap().ends_with("index.html")); +/// ``` +pub fn safe_join_paths(root: &str, requested_path: &str) -> Option { + let root_path = Path::new(root).canonicalize().ok()?; + let requested_full_path = root_path.join(requested_path.trim_start_matches("/")); + + if !requested_full_path.exists() { + return None; + } + + let canonical_path = requested_full_path.canonicalize().ok()?; + if canonical_path.starts_with(&root_path) { + Some(canonical_path) + } else { + None + } +} + +/// Handles serving static files from a root directory, including index files. +pub struct FileHandler { + root: String, + index: String, +} + +impl FileHandler {} + +impl Handler for FileHandler { + fn run(&self, ctx: &mut Context) -> Box { + let logger = ctx.log.with_component("HTTP"); + + // Resolve the requested path safely + let safe_path_result = if ctx.request.path == "/" { + // Special handling for the root path: serve the index file + Path::new(&self.root) + .canonicalize() + .ok() + .map(|root_path| root_path.join(&self.index)) + .filter(|index_path| index_path.exists()) + } else { + // Other paths: use safe join + safe_join_paths(&self.root, &ctx.request.path) + }; + + // Handle directories or invalid paths + let safe_path = match safe_path_result { + Some(path) => { + if path.is_dir() { + let index_path = path.join(&self.index); + if index_path.exists() { + index_path + } else { + logger.warn(format!( + "Index file not found in directory: {}", + ctx.request.path + )); + return HttpResponse::new(HttpStatus::NotFound, "Index not found", None); + } + } else { + path + } + } + None => { + logger.warn(format!( + "Path not found or access denied: {}", + ctx.request.path + )); + return HttpResponse::new(HttpStatus::NotFound, "Not found", None); + } + }; + + // Determine MIME type + let mimetype = get_mime_tipe(&safe_path.to_string_lossy().to_string()); + + // Read file content + match fs::read(&safe_path).ok() { + Some(content) => { + let mut headers = HttpHeaders::new(); + headers.insert("Content-Type", &mimetype); + headers.insert("X-Content-Type-Options", "nosniff"); + let response = HttpResponse::new(HttpStatus::OK, content, Some(headers)); + + // Cache the response if caching is enabled + if let Some(cache) = ctx.cache.as_deref_mut() { + cache.set(ctx.request.clone(), (*response).clone()); + } + + response + } + None => HttpResponse::new(HttpStatus::NotFound, "Not found", None), + } + } +} + +impl HandlerFactory for FileHandler { + fn is(ctx: &Context) -> Option> { + Some(Box::new(FileHandler { + root: ctx.config.root.to_string(), + index: ctx.config.index.to_string(), + })) + } +} diff --git a/src/handler/handler.rs b/src/handler/handler.rs new file mode 100644 index 0000000..8458f6a --- /dev/null +++ b/src/handler/handler.rs @@ -0,0 +1,9 @@ +use crate::{hteapot::HttpResponseCommon, utils::Context}; + +pub trait Handler { + fn run(&self, context: &mut Context) -> Box; +} + +pub trait HandlerFactory { + fn is(context: &Context) -> Option>; +} diff --git a/src/handler/mod.rs b/src/handler/mod.rs new file mode 100644 index 0000000..a0576a5 --- /dev/null +++ b/src/handler/mod.rs @@ -0,0 +1,63 @@ +use crate::{ + handler::handler::{Handler, HandlerFactory}, + utils::Context, +}; + +pub mod file; +mod handler; +pub mod proxy; + +/// Type alias for a handler factory function. +/// +/// A factory takes a reference to the current `Config` and `HttpRequest` +/// and returns an `Option>`. It returns `Some(handler)` +/// if it can handle the request, or `None` if it cannot. +type Factory = fn(&Context) -> Option>; + +/// List of all available handler factories. +/// +/// New handlers can be added to this array to make them available +/// for request processing. +static HANDLERS: &[Factory] = &[proxy::ProxyHandler::is, file::FileHandler::is]; + +/// Returns the first handler that can process the given request. +/// +/// Iterates over all registered handler factories in `HANDLERS`. +/// Calls each factory with the provided `config` and `request`. +/// Returns `Some(Box)` if a suitable handler is found, +/// or `None` if no handler can handle the request. +/// +/// # Examples +/// +/// ```rust +/// let handler = get_handler(&config, &request); +/// if let Some(h) = handler { +/// let response = h.run(&request); +/// // process the response +/// } +/// ``` + +pub struct HandlerEngine { + handlers: Vec, +} + +impl HandlerEngine { + pub fn new() -> HandlerEngine { + let mut handlers = Vec::new(); + handlers.extend_from_slice(HANDLERS); + HandlerEngine { handlers } + } + + pub fn add_handler(&mut self, handler: Factory) { + self.handlers.insert(0, handler); + } + + pub fn get_handler(&self, ctx: &Context) -> Option> { + for h in self.handlers.iter() { + if let Some(handler) = h(ctx) { + return Some(handler); + } + } + None + } +} diff --git a/src/handler/proxy.rs b/src/handler/proxy.rs new file mode 100644 index 0000000..d6a7457 --- /dev/null +++ b/src/handler/proxy.rs @@ -0,0 +1,87 @@ +use crate::handler::handler::{Handler, HandlerFactory}; +use crate::hteapot::{HttpMethod, HttpResponse, HttpResponseCommon, HttpStatus, TunnelResponse}; +use crate::utils::Context; + +/// Handles HTTP proxying based on server configuration. +/// +/// Determines whether a request matches any proxy rules and forwards it +/// to the corresponding upstream server, rewriting the path and `Host` header +/// as needed. +/// +/// # Fields +/// * `new_path` - Path to use for the proxied request. +/// * `url` - Target upstream URL. +pub struct ProxyHandler { + new_path: String, + url: String, +} + +impl Handler for ProxyHandler { + fn run(&self, ctx: &mut Context) -> Box { + let _proxy_logger = &ctx.log.with_component("proxy"); + + // Return a tunnel response immediately for OPTIONS requests + if ctx.request.method == HttpMethod::OPTIONS { + return TunnelResponse::new(&ctx.request.path); + } + + // Prepare a modified request for proxying + let mut proxy_req = ctx.request.clone(); + proxy_req.path = self.new_path.clone(); + proxy_req.headers.remove("Host"); + + // Determine the upstream host from the URL + let host_parts: Vec<&str> = self.url.split("://").collect(); + let host = if host_parts.len() == 1 { + host_parts.first().unwrap() + } else { + host_parts.last().unwrap() + }; + proxy_req.headers.insert("host", host); + + // Forward the request and handle errors + let response = proxy_req.brew(&self.url).unwrap_or(HttpResponse::new( + HttpStatus::NotAcceptable, + "", + None, + )); + + // Cache the response if caching is enabled + if let Some(cache) = ctx.cache.as_deref_mut() { + cache.set(ctx.request.clone(), (*response).clone()); + } + + response + } +} + +impl HandlerFactory for ProxyHandler { + fn is(ctx: &Context) -> Option> { + // OPTIONS requests are always handled + if ctx.request.method == HttpMethod::OPTIONS { + return Some(Box::new(ProxyHandler { + url: String::new(), + new_path: String::new(), + })); + } + + // Check if the request matches any configured proxy rules + for proxy_path in ctx.config.proxy_rules.keys() { + if let Some(path_match) = ctx.request.path.strip_prefix(proxy_path) { + let new_path = path_match.to_string(); + let url = ctx.config.proxy_rules.get(proxy_path).unwrap(); + let url = if url.is_empty() { + // If the rule URL is empty, fallback to Host header + let proxy_url = ctx.request.headers.get("host")?; + proxy_url.to_owned() + } else { + url.to_string() + }; + + return Some(Box::new(ProxyHandler { url, new_path })); + } + } + + None + } +} diff --git a/src/hteapot/brew.rs b/src/hteapot/brew.rs index cf8f7dc..d47dc88 100644 --- a/src/hteapot/brew.rs +++ b/src/hteapot/brew.rs @@ -14,17 +14,17 @@ use super::response::HttpResponse; // use std::net::{IpAddr, Ipv4Addr, SocketAddr}; impl HttpRequest { - /// Adds a query argument to the HTTP request. - pub fn arg(&mut self, key: &str, value: &str) -> &mut HttpRequest { - self.args.insert(key.to_string(), value.to_string()); - self - } + // /// Adds a query argument to the HTTP request. + // pub fn arg(&mut self, key: &str, value: &str) -> &mut HttpRequest { + // self.args.insert(key.to_string(), value.to_string()); + // self + // } - /// Adds a header to the HTTP request. - pub fn header(&mut self, key: &str, value: &str) -> &mut HttpRequest { - self.headers.insert(key.to_string(), value.to_string()); - self - } + // /// Adds a header to the HTTP request. + // pub fn header(&mut self, key: &str, value: &str) -> &mut HttpRequest { + // self.headers.insert(key, value); + // self + // } /// Converts the request into a raw HTTP/1.1-compliant string. /// @@ -120,10 +120,9 @@ impl HttpRequest { } } Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - println!("Read timeout"); - break; + return Err("Connection timeout"); } - Err(_) => return Err("Error reading"), + Err(_e) => return Err("Error reading"), } } @@ -144,7 +143,7 @@ pub fn brew(direction: &str, request: &mut HttpRequest) -> Result, + body_size: usize, + response_base: BaseResponse, + buffer: Vec, + state: State, +} + +impl HttpResponseBuilder { + pub fn new() -> HttpResponseBuilder { + HttpResponseBuilder { + response_base: BaseResponse { + status: HttpStatus::IAmATeapot, + headers: HttpHeaders::new(), + }, + body_size: 0, + body: Vec::new(), + buffer: Vec::new(), + state: State::Init, + } + } + + pub fn get(&self) -> Option { + match self.state { + State::Finish => { + let response = + HttpResponse::new_with_base(self.response_base.clone(), self.body.clone()); + Some(response) + } + _ => None, + } + } + + pub fn append(&mut self, chunk: &[u8]) -> Result { + self.buffer.extend_from_slice(chunk); + + while !self.buffer.is_empty() { + match self.state { + State::Init => { + if let Some(line) = get_line(&mut self.buffer) { + let parts: Vec<&str> = line.split(" ").collect(); + if parts.len() < 3 { + return Err("Invalid response"); + } + let status_str = parts.get(1).ok_or("Invalid status")?; + let status = status_str.parse::().map_err(|_| "Invalid status")?; + self.response_base.status = + HttpStatus::from_u16(status).map_err(|_| "Invalid status")?; + self.state = State::Headers; + } else { + return Ok(false); + } + } + State::Headers => { + if let Some(line) = get_line(&mut self.buffer) { + if line.is_empty() { + self.state = if self.body_size == 0 { + State::Finish + } else { + State::Body + }; + continue; + } + let (key, value) = line.split_once(":").ok_or("Invalid header")?; + let key = key.trim(); + let value = value.trim(); + if key.to_lowercase() == "content-length" { + self.body_size = value + .parse::() + .map_err(|_| "invalid content-length")?; + } + self.response_base.headers.insert(key, value); + } else { + return Ok(false); + } + } + State::Body => { + self.body.extend_from_slice(&mut self.buffer.as_slice()); + self.buffer.clear(); + if let Some(content_length) = self.response_base.headers.get("content-length") { + let content_length = content_length + .parse::() + .map_err(|_| "invalid content-length")?; + if self.body.len() >= content_length { + self.state = State::Finish; + return Ok(true); + } else { + return Ok(false); + } + } else { + //TODO: handle chunked + self.state = State::Finish; + return Ok(true); + } + } + State::Finish => { + return Ok(true); + } + } + } + + Ok(false) + } +} + +fn get_line(buffer: &mut Vec) -> Option { + if let Some(pos) = buffer.windows(2).position(|w| w == b"\r\n") { + let line = buffer.drain(..pos).collect::>(); + buffer.drain(..2); // remove CRLF + return match str::from_utf8(line.as_slice()) { + Ok(v) => Some(v.to_string()), + Err(_e) => None, + }; + } + None +} + +#[cfg(test)] +#[test] +fn basic_response() { + // Placeholder test — add real body/header parsing test here. + + let buffer = "HTTP/1.1 204 No Content\r\n\r\n".as_bytes().to_vec(); + let mut response_builder = HttpResponseBuilder::new(); + let done = response_builder.append(buffer.as_slice()); + assert!(done.is_ok()); + let response = response_builder.get(); + assert!(response.is_some()); + let mut response = response.unwrap(); + let response_base = response.base(); + assert!(response_base.status == HttpStatus::NoContent); + assert!(response_base.headers.len() == 0); + assert!(response.content.len() == 0); +} diff --git a/src/hteapot/engine.rs b/src/hteapot/engine.rs new file mode 100644 index 0000000..cb6bc7b --- /dev/null +++ b/src/hteapot/engine.rs @@ -0,0 +1,340 @@ +use std::collections::VecDeque; +use std::io::{self, Read, Write}; +use std::net::{Shutdown, TcpListener, TcpStream}; + +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Condvar, Mutex}; +use std::thread; +use std::time::Instant; + +use super::BUFFER_SIZE; +use super::KEEP_ALIVE_TTL; +use crate::{HttpRequest, HttpResponse, HttpStatus}; +// Internal types used for connection management +use super::request::HttpRequestBuilder; +use super::response::{EmptyHttpResponse, HttpResponseCommon, IterError}; + +/// Helper macro to construct header maps. +/// +/// # Example +/// ```rust +/// use hteapot::headers; +/// let headers = headers! { +/// "Content-Type" => "text/html", +/// "X-Custom" => "value" +/// }; +/// ``` + +pub struct Hteapot { + port: u16, + address: String, + threads: u16, + shutdown_signal: Option>, + shutdown_hooks: Vec>, +} + +#[derive(PartialEq)] +enum Status { + Read, + Write, + Finish, +} + +/// Represents the state of a connection's lifecycle. +struct SocketStatus { + ttl: Instant, + status: Status, + response: Box, + request: HttpRequestBuilder, + index_writed: usize, +} + +/// Wraps a TCP stream and its associated state. +struct SocketData { + stream: TcpStream, + status: Option, +} + +impl Hteapot { + pub fn set_shutdown_signal(&mut self, signal: Arc) { + self.shutdown_signal = Some(signal); + } + + pub fn get_shutdown_signal(&self) -> Option> { + self.shutdown_signal.clone() + } + + pub fn add_shutdown_hook(&mut self, hook: F) + where + F: Fn() + Send + Sync + 'static, + { + self.shutdown_hooks.push(Arc::new(hook)); + } + + pub fn get_addr(&self) -> (String, u16) { + return (self.address.clone(), self.port); + } + + // Constructor + pub fn new(address: &str, port: u16) -> Self { + Hteapot { + port, + address: address.to_string(), + threads: 1, + shutdown_signal: None, + shutdown_hooks: Vec::new(), + } + } + + pub fn new_threaded(address: &str, port: u16, threads: u16) -> Self { + Hteapot { + port, + address: address.to_string(), + threads: if threads == 0 { 1 } else { threads }, + shutdown_signal: None, + shutdown_hooks: Vec::new(), + } + } + + // Start the server + pub fn listen( + &self, + action: impl Fn(HttpRequest) -> Box + Send + Sync + 'static, + ) { + let addr = format!("{}:{}", self.address, self.port); + let listener = match TcpListener::bind(addr) { + Ok(listener) => listener, + Err(e) => { + eprintln!("Error binding to address: {}", e); + return; + } + }; + + let pool: Arc<(Mutex>, Condvar)> = + Arc::new((Mutex::new(VecDeque::new()), Condvar::new())); + let priority_list: Arc>> = + Arc::new(Mutex::new(vec![0; self.threads as usize])); + let arc_action = Arc::new(action); + + // Clone shutdown_signal and share the shutdown_hooks via Arc + let shutdown_signal = self.shutdown_signal.clone(); + let shutdown_hooks = Arc::new(self.shutdown_hooks.clone()); + + for thread_index in 0..self.threads { + let pool_clone = pool.clone(); + let action_clone = arc_action.clone(); + let priority_list_clone = priority_list.clone(); + let shutdown_signal_clone = shutdown_signal.clone(); + + thread::spawn(move || { + let mut streams_to_handle = Vec::new(); + loop { + { + let (lock, cvar) = &*pool_clone; + let mut pool = lock.lock().expect("Error locking pool"); + if streams_to_handle.is_empty() { + // Store the returned guard back into pool + pool = cvar + .wait_while(pool, |pool| pool.is_empty()) + .expect("Error waiting on cvar"); + } + //TODO: move this to allow process the last request + if let Some(signal) = &shutdown_signal_clone { + if !signal.load(Ordering::SeqCst) { + break; // Exit the server loop + } + } + + while let Some(stream) = pool.pop_back() { + let socket_status = SocketStatus { + ttl: Instant::now(), + status: Status::Read, + response: Box::new(EmptyHttpResponse {}), + request: HttpRequestBuilder::new(), + index_writed: 0, + }; + let socket_data = SocketData { + stream, + status: Some(socket_status), + }; + streams_to_handle.push(socket_data); + } + } + + // { + // let mut priority_list = priority_list_clone + // .lock() + // .expect("Error locking priority list"); + // priority_list[thread_index as usize] = streams_to_handle.len(); + // } + + streams_to_handle.retain_mut(|s| { + if s.status.is_none() { + return false; + } + Hteapot::handle_client(s, &action_clone).is_some() + }); + } + }); + } + + loop { + if let Some(signal) = &shutdown_signal { + if !signal.load(Ordering::SeqCst) { + let (lock, cvar) = &*pool; + let _guard = lock.lock().unwrap(); + cvar.notify_all(); + for hook in shutdown_hooks.iter() { + hook(); + } + break; + } + } + let stream = match listener.accept() { + Ok((stream, _)) => stream, + Err(_) => continue, + }; + + if stream.set_nonblocking(true).is_err() { + eprintln!("Error setting non-blocking mode on stream"); + continue; + } + if stream.set_nodelay(true).is_err() { + eprintln!("Error setting no delay on stream"); + continue; + } + + { + let (lock, cvar) = &*pool; + let mut pool = lock.lock().expect("Error locking pool"); + + // Add the connection to the pool for the least-loaded thread + pool.push_front(stream); + cvar.notify_one(); + } + } + } + + fn handle_client( + socket_data: &mut SocketData, + action: &Arc Box + Send + Sync + 'static>, + ) -> Option<()> { + let status = socket_data.status.as_mut()?; + + // Check if the TTL (time-to-live) for the connection has expired. + if Instant::now().duration_since(status.ttl) > KEEP_ALIVE_TTL + && status.status != Status::Write + { + let _ = socket_data.stream.shutdown(Shutdown::Both); + return None; + } + + match status.status { + Status::Read => { + while !status.request.done() { + let mut buffer = [0; BUFFER_SIZE]; + match socket_data.stream.read(&mut buffer) { + Err(e) => match e.kind() { + io::ErrorKind::WouldBlock => return Some(()), + io::ErrorKind::ConnectionReset => return None, + _ => { + eprintln!("Read error: {:?}", e); + return None; + } + }, + Ok(m) => { + if m == 0 { + return None; + } + status.ttl = Instant::now(); + let r = status.request.append(buffer[..m].to_vec()); + if r.is_err() { + // Early return response if not valid request is sended + let error_msg = r.err().unwrap(); + let response = + HttpResponse::new(HttpStatus::BadRequest, error_msg, None) + .to_bytes(); + let _ = socket_data.stream.write(&response); + let _ = socket_data.stream.flush(); + let _ = socket_data.stream.shutdown(Shutdown::Both); + return None; + } + } + } + } + let request = status.request.get()?; + let keep_alive = request + .headers + .get("connection") + .map(|v| v.to_lowercase() == "keep-alive") + .unwrap_or(false); + + let mut response = action(request); + if keep_alive { + response + .base() + .headers + .entry("connection") + .or_insert("keep-alive".to_string()); + response.base().headers.insert( + "Keep-Alive", + &format!("timeout={}", KEEP_ALIVE_TTL.as_secs()), + ); + } else { + response.base().headers.insert("Connection", "close"); + } + status.status = Status::Write; + status.response = response; + status.response.set_stream(&socket_data.stream); + } + Status::Write => { + loop { + match status.response.peek() { + Ok(n) => match socket_data.stream.write(&n) { + Ok(_) => { + status.ttl = Instant::now(); + let _ = status.response.next(); + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Some(()), + Err(e) => { + eprintln!("Write error: {:?}", e); + return None; + } + }, + Err(IterError::WouldBlock) => { + status.ttl = Instant::now(); + return Some(()); + } + Err(_) => break, + } + } + status.status = Status::Finish; + let request = status.request.get()?; + let keep_alive = request + .headers + .get("connection") + .map(|v| v.to_lowercase() == "keep-alive") + .unwrap_or(false); + if keep_alive { + status.status = Status::Read; + status.index_writed = 0; + status.request = HttpRequestBuilder::new(); + return Some(()); + } else { + let _ = socket_data.stream.shutdown(Shutdown::Both); + return None; + } + } + Status::Finish => { + return None; + } + }; + Some(()) + + // If the request is not yet complete, read data from the stream into a buffer. + // This ensures that the server can handle partial or chunked requests. + + // Seting the stream in case is needed for the response, (example: streaming) + // Write the response to the client in chunks + } +} diff --git a/src/hteapot/http/headers.rs b/src/hteapot/http/headers.rs new file mode 100644 index 0000000..b55d50f --- /dev/null +++ b/src/hteapot/http/headers.rs @@ -0,0 +1,143 @@ +use std::collections::hash_map::Entry; +use std::collections::{HashMap, hash_map}; +use std::fmt::Display; +use std::hash::{Hash, Hasher}; +use std::ops::Deref; + +#[derive(Debug, Clone)] +pub struct CaseInsensitiveString(String); + +impl PartialEq for CaseInsensitiveString { + fn eq(&self, other: &Self) -> bool { + self.0.eq_ignore_ascii_case(&other.0) + } +} + +impl Display for CaseInsensitiveString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} +impl Eq for CaseInsensitiveString {} + +impl Hash for CaseInsensitiveString { + fn hash(&self, state: &mut H) { + for b in self.0.bytes() { + state.write_u8(b.to_ascii_lowercase()); + } + } +} + +impl Deref for CaseInsensitiveString { + type Target = str; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug, Default, Clone)] +pub struct HttpHeaders(HashMap); + +impl HttpHeaders { + pub fn new() -> Self { + HttpHeaders(HashMap::new()) + } + + pub fn insert(&mut self, key: &str, value: &str) { + // Ejemplo: forzar keys a lowercase + self.0 + .insert(CaseInsensitiveString(key.to_string()), value.to_string()); + } + + pub fn get(&self, key: &str) -> Option<&String> { + self.0.get(&CaseInsensitiveString(key.to_string())) + } + + pub fn get_owned(&self, key: &str) -> Option { + self.0.get(&CaseInsensitiveString(key.to_string())).cloned() + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn entry(&mut self, key: &str) -> Entry<'_, CaseInsensitiveString, String> { + self.0.entry(CaseInsensitiveString(key.to_string())) + } + + pub fn remove(&mut self, key: &str) -> Option { + self.0.remove(&CaseInsensitiveString(key.to_string())) + } + + pub fn iter(&self) -> std::collections::hash_map::Iter<'_, CaseInsensitiveString, String> { + self.0.iter() + } +} + +impl IntoIterator for HttpHeaders { + type Item = (CaseInsensitiveString, String); + type IntoIter = hash_map::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a> IntoIterator for &'a HttpHeaders { + type Item = (&'a CaseInsensitiveString, &'a String); + type IntoIter = hash_map::Iter<'a, CaseInsensitiveString, String>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} + +impl<'a> IntoIterator for &'a mut HttpHeaders { + type Item = (&'a CaseInsensitiveString, &'a mut String); + type IntoIter = hash_map::IterMut<'a, CaseInsensitiveString, String>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter_mut() + } +} + +impl PartialEq for HttpHeaders { + fn eq(&self, other: &Self) -> bool { + other.0 == self.0 + } +} + +#[macro_export] +macro_rules! headers { + ( $($k:expr => $v:expr),* $(,)? ) => {{ + let mut headers = hteapot::HttpHeaders::new(); + $( headers.insert($k, $v); )* + Some(headers) + }}; +} + +#[cfg(test)] +#[test] +fn test_caseinsensitive() { + let mut headers = HttpHeaders::new(); + headers.insert("X-Test-Header", "Value"); + assert!(headers.get("x-test-header").is_some()); + assert!(headers.get("x-test-header").unwrap() == "Value"); + assert!(headers.get("x-test-header").unwrap() != "value"); +} + +#[cfg(test)] +#[test] +fn test_remove() { + let mut headers = HttpHeaders::new(); + headers.insert("X-Test-Header", "Value"); + assert!(headers.get("x-test-header").is_some()); + assert!(headers.get("x-test-header").unwrap() == "Value"); + assert!(headers.get("x-test-header").unwrap() != "value"); + assert!(headers.remove("x-test-header").is_some()); + assert!(headers.get("x-test-header").is_none()); +} diff --git a/src/hteapot/methods.rs b/src/hteapot/http/methods.rs similarity index 96% rename from src/hteapot/methods.rs rename to src/hteapot/http/methods.rs index 78592f6..ab374fb 100644 --- a/src/hteapot/methods.rs +++ b/src/hteapot/http/methods.rs @@ -32,7 +32,8 @@ impl HttpMethod { /// assert_eq!(custom, HttpMethod::Other("CUSTOM".into())); /// ``` pub fn from_str(method: &str) -> HttpMethod { - match method { + let method = method.to_uppercase(); + match method.as_str() { "GET" => HttpMethod::GET, "POST" => HttpMethod::POST, "PUT" => HttpMethod::PUT, diff --git a/src/hteapot/http/mod.rs b/src/hteapot/http/mod.rs new file mode 100644 index 0000000..32345b2 --- /dev/null +++ b/src/hteapot/http/mod.rs @@ -0,0 +1,7 @@ +mod headers; +mod methods; +mod status; + +pub use headers::HttpHeaders; +pub use methods::HttpMethod; +pub use status::HttpStatus; diff --git a/src/hteapot/status.rs b/src/hteapot/http/status.rs similarity index 99% rename from src/hteapot/status.rs rename to src/hteapot/http/status.rs index f066a47..1268474 100644 --- a/src/hteapot/status.rs +++ b/src/hteapot/http/status.rs @@ -11,7 +11,7 @@ /// /// Use [`HttpStatus::from_u16`] to convert from raw codes, /// and [`HttpStatus::to_string`] to get the standard reason phrase. -#[derive(Clone, Copy)] +#[derive(Clone, Copy, PartialEq)] pub enum HttpStatus { // 2xx Success OK = 200, @@ -98,7 +98,7 @@ impl HttpStatus { 305 => Ok(HttpStatus::UseProxy), 307 => Ok(HttpStatus::TemporaryRedirect), 308 => Ok(HttpStatus::PermanentRedirect), - + 400 => Ok(HttpStatus::BadRequest), 401 => Ok(HttpStatus::Unauthorized), 402 => Ok(HttpStatus::PaymentRequired), @@ -127,7 +127,7 @@ impl HttpStatus { 428 => Ok(HttpStatus::PreconditionRequired), 429 => Ok(HttpStatus::TooManyRequests), 431 => Ok(HttpStatus::RequestHeaderFieldsTooLarge), - + 500 => Ok(HttpStatus::InternalServerError), 501 => Ok(HttpStatus::NotImplemented), 502 => Ok(HttpStatus::BadGateway), diff --git a/src/hteapot/mod.rs b/src/hteapot/mod.rs index 95df09b..5c841a1 100644 --- a/src/hteapot/mod.rs +++ b/src/hteapot/mod.rs @@ -18,356 +18,38 @@ /// Submodules for HTTP functionality. pub mod brew; // HTTP client implementation -mod methods; // HTTP method and status enums +mod engine; +mod http; // HTTP method and status enums mod request; // Request parsing and builder mod response; // Response types and streaming -mod status; // Status code mapping +// Status code mapping -// Internal types used for connection management -use self::response::{EmptyHttpResponse, HttpResponseCommon, IterError}; // use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; + // Public API exposed by this module -pub use self::methods::HttpMethod; pub use self::request::HttpRequest; -use self::request::HttpRequestBuilder; -pub use self::response::{HttpResponse, StreamedResponse}; -pub use self::status::HttpStatus; +pub use engine::Hteapot; +pub use http::HttpHeaders; +pub use http::HttpMethod; +pub use http::HttpStatus; -use std::collections::VecDeque; -use std::io::{self, Read, Write}; -use std::net::{Shutdown, TcpListener, TcpStream}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, Condvar, Mutex}; -use std::thread; -use std::time::{Duration, Instant}; +pub use response::{HttpResponse, HttpResponseCommon, StreamedResponse, TunnelResponse}; /// Crate version as set by `Cargo.toml`. pub const VERSION: &str = env!("CARGO_PKG_VERSION"); - /// Size of the buffer used for reading from the TCP stream. const BUFFER_SIZE: usize = 1024 * 2; /// Time-to-live for keep-alive connections. const KEEP_ALIVE_TTL: Duration = Duration::from_secs(10); -/// Helper macro to construct header maps. -/// -/// # Example -/// ```rust -/// use hteapot::headers; -/// let headers = headers! { -/// "Content-Type" => "text/html", -/// "X-Custom" => "value" -/// }; -/// ``` -#[macro_export] -macro_rules! headers { - ( $($k:expr => $v:expr),*) => { - { - use std::collections::HashMap; - let mut headers: HashMap = HashMap::new(); - $( headers.insert($k.to_string(), $v.to_string()); )* - Some(headers) - } - }; -} - -pub struct Hteapot { - port: u16, - address: String, - threads: u16, - shutdown_signal: Option>, - shutdown_hooks: Vec>, -} - -/// Represents the state of a connection's lifecycle. -struct SocketStatus { - ttl: Instant, - reading: bool, - write: bool, - response: Box, - request: HttpRequestBuilder, - index_writed: usize, -} - -/// Wraps a TCP stream and its associated state. -struct SocketData { - stream: TcpStream, - status: Option, -} - -impl Hteapot { - pub fn set_shutdown_signal(&mut self, signal: Arc) { - self.shutdown_signal = Some(signal); - } - - pub fn get_shutdown_signal(&self) -> Option> { - self.shutdown_signal.clone() - } - - pub fn add_shutdown_hook(&mut self, hook: F) - where - F: Fn() + Send + Sync + 'static, - { - self.shutdown_hooks.push(Arc::new(hook)); - } - - pub fn get_addr(&self) -> (String, u16) { - return (self.address.clone(), self.port); - } - - // Constructor - pub fn new(address: &str, port: u16) -> Self { - Hteapot { - port, - address: address.to_string(), - threads: 1, - shutdown_signal: None, - shutdown_hooks: Vec::new(), - } - } - - pub fn new_threaded(address: &str, port: u16, threads: u16) -> Self { - Hteapot { - port, - address: address.to_string(), - threads: if threads == 0 { 1 } else { threads }, - shutdown_signal: None, - shutdown_hooks: Vec::new(), - } - } - - // Start the server - pub fn listen( - &self, - action: impl Fn(HttpRequest) -> Box + Send + Sync + 'static, - ) { - let addr = format!("{}:{}", self.address, self.port); - let listener = match TcpListener::bind(addr) { - Ok(listener) => listener, - Err(e) => { - eprintln!("Error binding to address: {}", e); - return; - } - }; - - let pool: Arc<(Mutex>, Condvar)> = - Arc::new((Mutex::new(VecDeque::new()), Condvar::new())); - let priority_list: Arc>> = - Arc::new(Mutex::new(vec![0; self.threads as usize])); - let arc_action = Arc::new(action); - - // Clone shutdown_signal and share the shutdown_hooks via Arc - let shutdown_signal = self.shutdown_signal.clone(); - let shutdown_hooks = Arc::new(self.shutdown_hooks.clone()); - - for thread_index in 0..self.threads { - let pool_clone = pool.clone(); - let action_clone = arc_action.clone(); - let priority_list_clone = priority_list.clone(); - let shutdown_signal_clone = shutdown_signal.clone(); - - thread::spawn(move || { - let mut streams_to_handle = Vec::new(); - loop { - { - let (lock, cvar) = &*pool_clone; - let mut pool = lock.lock().expect("Error locking pool"); - if streams_to_handle.is_empty() { - // Store the returned guard back into pool - pool = cvar - .wait_while(pool, |pool| pool.is_empty()) - .expect("Error waiting on cvar"); - } - //TODO: move this to allow process the last request - if let Some(signal) = &shutdown_signal_clone { - if !signal.load(Ordering::SeqCst) { - break; // Exit the server loop - } - } - - while let Some(stream) = pool.pop_back() { - let socket_status = SocketStatus { - ttl: Instant::now(), - reading: true, - write: false, - response: Box::new(EmptyHttpResponse {}), - request: HttpRequestBuilder::new(), - index_writed: 0, - }; - let socket_data = SocketData { - stream, - status: Some(socket_status), - }; - streams_to_handle.push(socket_data); - } - } - - { - let mut priority_list = priority_list_clone - .lock() - .expect("Error locking priority list"); - priority_list[thread_index as usize] = streams_to_handle.len(); - } - - streams_to_handle.retain_mut(|s| { - if s.status.is_none() { - return false; - } - Hteapot::handle_client(s, &action_clone).is_some() - }); - } - }); - } - - loop { - if let Some(signal) = &shutdown_signal { - if !signal.load(Ordering::SeqCst) { - let (lock, cvar) = &*pool; - let _guard = lock.lock().unwrap(); - cvar.notify_all(); - for hook in shutdown_hooks.iter() { - hook(); - } - break; - } - } - let stream = match listener.accept() { - Ok((stream, _)) => stream, - Err(_) => continue, - }; - - if stream.set_nonblocking(true).is_err() { - eprintln!("Error setting non-blocking mode on stream"); - continue; - } - if stream.set_nodelay(true).is_err() { - eprintln!("Error setting no delay on stream"); - continue; - } - - { - let (lock, cvar) = &*pool; - let mut pool = lock.lock().expect("Error locking pool"); - - // Add the connection to the pool for the least-loaded thread - pool.push_front(stream); - cvar.notify_one(); - } - } - } - - fn handle_client( - socket_data: &mut SocketData, - action: &Arc Box + Send + Sync + 'static>, - ) -> Option<()> { - let status = socket_data.status.as_mut()?; - - // Check if the TTL (time-to-live) for the connection has expired. - if Instant::now().duration_since(status.ttl) > KEEP_ALIVE_TTL && !status.write { - let _ = socket_data.stream.shutdown(Shutdown::Both); - return None; - } - - // If the request is not yet complete, read data from the stream into a buffer. - // This ensures that the server can handle partial or chunked requests. - if !status.request.done { - let mut buffer = [0; BUFFER_SIZE]; - match socket_data.stream.read(&mut buffer) { - Err(e) => match e.kind() { - io::ErrorKind::WouldBlock => return Some(()), - io::ErrorKind::ConnectionReset => return None, - _ => { - eprintln!("Read error: {:?}", e); - return None; - } - }, - Ok(m) => { - if m == 0 { - return None; - } - status.ttl = Instant::now(); - let r = status.request.append(buffer[..m].to_vec()); - if r.is_err() { - // Early return response if not valid request is sended - let error_msg = r.err().unwrap(); - let response = - HttpResponse::new(HttpStatus::BadRequest, error_msg, None).to_bytes(); - let _ = socket_data.stream.write(&response); - let _ = socket_data.stream.flush(); - let _ = socket_data.stream.shutdown(Shutdown::Both); - return None; - } - } - } - } - - let request = status.request.get()?; - let keep_alive = request - .headers - .get("connection") //all headers are turn lowercase in the builder - .map(|v| v.to_lowercase() == "keep-alive") - .unwrap_or(false); - if !status.write { - let mut response = action(request); - if keep_alive { - response - .base() - .headers - .entry("Connection".to_string()) - .or_insert("keep-alive".to_string()); - response.base().headers.insert( - "Keep-Alive".to_string(), - format!("timeout={}", KEEP_ALIVE_TTL.as_secs()), - ); - } else { - response - .base() - .headers - .insert("Connection".to_string(), "close".to_string()); - } - status.write = true; - status.response = response; - } - - // Write the response to the client in chunks - loop { - match status.response.peek() { - Ok(n) => match socket_data.stream.write(&n) { - Ok(_) => { - status.ttl = Instant::now(); - let _ = status.response.next(); - } - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Some(()), - Err(e) => { - eprintln!("Write error: {:?}", e); - return None; - } - }, - Err(IterError::WouldBlock) => { - status.ttl = Instant::now(); - return Some(()); - } - Err(_) => break, - } - } - - if keep_alive { - status.reading = true; - status.write = false; - status.index_writed = 0; - status.request = HttpRequestBuilder::new(); - return Some(()); - } else { - let _ = socket_data.stream.shutdown(Shutdown::Both); - None - } - } -} - #[cfg(test)] mod tests { + use crate::{HttpResponse, HttpStatus}; + use http::HttpHeaders; + const VERSION: &str = env!("CARGO_PKG_VERSION"); use super::*; #[test] @@ -376,7 +58,7 @@ mod tests { let response = String::from_utf8(response.to_bytes()).unwrap(); let expected_response = format!( "HTTP/1.1 418 I'm a teapot\r\nContent-Length: 13\r\nServer: HTeaPot/{}\r\n\r\nHello, World!\r\n", - VERSION + VERSION //TODO: fix ); let expected_response_list = expected_response.split("\r\n"); for item in expected_response_list { @@ -386,19 +68,15 @@ mod tests { #[test] fn test_keep_alive_connection() { - let mut response = HttpResponse::new( - HttpStatus::OK, - "Keep-Alive Test", - headers! { - "Connection" => "keep-alive", - "Content-Length" => "15" - }, + let mut headers = HttpHeaders::new(); + headers.insert("Connection", "keep-alive"); + headers.insert("Content-Length", "15"); + headers.insert( + "Keep-Alive", + &format!("timeout={}", KEEP_ALIVE_TTL.as_secs()), ); - response.base().headers.insert( - "Keep-Alive".to_string(), - format!("timeout={}", KEEP_ALIVE_TTL.as_secs()), - ); + let mut response = HttpResponse::new(HttpStatus::OK, "Keep-Alive Test", Some(headers)); let response_bytes = response.to_bytes(); let response_str = String::from_utf8(response_bytes.clone()).unwrap(); @@ -409,20 +87,16 @@ mod tests { assert!(response_str.contains("Keep-Alive: timeout=10")); assert!(response_str.contains("Server: HTeaPot/")); assert!(response_str.contains("Keep-Alive Test")); - - let mut second_response = HttpResponse::new( - HttpStatus::OK, - "Second Request", - headers! { - "Connection" => "keep-alive", - "Content-Length" => "14" // Length for "Second Request" - }, + let mut headers = HttpHeaders::new(); + headers.insert("Connection", "keep-alive"); + headers.insert("Content-Length", "14"); + headers.insert( + "Keep-Alive", + &format!("timeout={}", KEEP_ALIVE_TTL.as_secs()), ); - second_response.base().headers.insert( - "Keep-Alive".to_string(), - format!("timeout={}", KEEP_ALIVE_TTL.as_secs()), - ); + let mut second_response = + HttpResponse::new(HttpStatus::OK, "Second Request", Some(headers)); let second_response_bytes = second_response.to_bytes(); let second_response_str = String::from_utf8(second_response_bytes.clone()).unwrap(); diff --git a/src/hteapot/request.rs b/src/hteapot/request.rs index 4dfea67..25775c8 100644 --- a/src/hteapot/request.rs +++ b/src/hteapot/request.rs @@ -6,9 +6,10 @@ // - Partial header validation // - No URI normalization or encoding // -// ⚠️ A full refactor is recommended before production use. +use super::HttpHeaders; use super::HttpMethod; +use std::hash::Hash; use std::{cmp::min, collections::HashMap, net::TcpStream, str}; const MAX_HEADER_SIZE: usize = 1024 * 16; @@ -22,11 +23,34 @@ pub struct HttpRequest { pub method: HttpMethod, pub path: String, pub args: HashMap, - pub headers: HashMap, + pub headers: HttpHeaders, pub body: Vec, stream: Option, } +impl Hash for HttpRequest { + fn hash(&self, state: &mut H) { + self.method.hash(state); + self.path.hash(state); + // self.args.hash(state); + // self.headers.hash(state); + self.body.hash(state); + } +} + +impl PartialEq for HttpRequest { + fn eq(&self, other: &Self) -> bool { + let same_method = self.method == other.method; + let same_path = self.path == other.path; + let same_body = self.body == other.body; + let same_args = other.args == self.args; + let same_headers = self.headers == other.headers; + return same_method && same_path && same_body && same_args && same_headers; + } +} + +impl Eq for HttpRequest {} + impl HttpRequest { /// Creates a new HTTP request with the given method and path. pub fn new(method: HttpMethod, path: &str) -> Self { @@ -34,7 +58,7 @@ impl HttpRequest { method, path: path.to_string(), args: HashMap::new(), - headers: HashMap::new(), + headers: HttpHeaders::new(), body: Vec::new(), stream: None, }; @@ -46,7 +70,7 @@ impl HttpRequest { method: HttpMethod::Other(String::new()), path: String::new(), args: HashMap::new(), - headers: HashMap::new(), + headers: HttpHeaders::new(), body: Vec::new(), stream: None, } @@ -88,10 +112,17 @@ impl HttpRequest { pub struct HttpRequestBuilder { request: HttpRequest, buffer: Vec, - header_done: bool, - header_size: usize, + state: State, body_size: usize, - pub done: bool, + chunked: bool, +} + +#[derive(PartialEq)] +enum State { + Init, + Headers, + Body, + Finish, } impl HttpRequestBuilder { @@ -102,40 +133,34 @@ impl HttpRequestBuilder { method: HttpMethod::GET, path: String::new(), args: HashMap::new(), - headers: HashMap::new(), + headers: HttpHeaders::new(), body: Vec::new(), stream: None, }, - header_size: 0, - header_done: false, + chunked: false, + state: State::Init, body_size: 0, buffer: Vec::new(), - done: false, }; } /// Returns the built request if parsing is complete. pub fn get(&self) -> Option { - if self.done { - return Some(self.request.clone()); - } else { - None + match self.state { + State::Finish => Some(self.request.clone()), + _ => None, } } /// Reads bytes into the request body based on `Content-Length`. fn read_body_len(&mut self) -> Option<()> { let body_left = self.body_size.saturating_sub(self.request.body.len()); - let to_take = min(body_left, self.buffer.len()); - let to_append = self.buffer.drain(..to_take); - let to_append = to_append.as_slice(); - self.request.body.extend_from_slice(to_append); + let body_left = self.body_size.saturating_sub(self.request.body.len()); if body_left > 0 { return None; } else { - self.done = true; return Some(()); } } @@ -151,105 +176,191 @@ impl HttpRequestBuilder { return self.read_body_len(); } + pub fn done(&self) -> bool { + self.state == State::Finish + } + /// Feeds a chunk of bytes into the builder. /// /// This function may return an error if the header is too large or malformed. pub fn append(&mut self, chunk: Vec) -> Result<(), &'static str> { - if !self.header_done && self.buffer.len() > MAX_HEADER_SIZE { - return Err("Entity Too large"); - } - - let chunk_size = chunk.len(); self.buffer.extend(chunk); - if self.header_done { - self.read_body(); - return Ok(()); - } else { - self.header_size += chunk_size; - if self.header_size > MAX_HEADER_SIZE { - return Err("Entity Too large"); - } - } - - while let Some(pos) = self.buffer.windows(2).position(|w| w == b"\r\n") { - let line = self.buffer.drain(..pos).collect::>(); - self.buffer.drain(..2); // remove CRLF - - let line_str = match str::from_utf8(line.as_slice()) { - Ok(v) => v.to_string(), - Err(_e) => return Err("No utf-8"), - }; - - if self.request.path.is_empty() { - // This is the request line - let parts: Vec<&str> = line_str.split_whitespace().collect(); - if parts.len() < 2 { - return Ok(()); - } - - if parts.len() != 3 { - return Err("Invalid method + path + version request"); - } - self.request.method = HttpMethod::from_str(parts[0]); - let path_parts: Vec<&str> = parts[1].split('?').collect(); - self.request.path = path_parts[0].to_string(); - - if path_parts.len() > 1 { - self.request.args = path_parts[1] - .split('&') - .filter_map(|pair| { - let kv: Vec<&str> = pair.split('=').collect(); - if kv.len() == 2 { - Some((kv[0].to_string(), kv[1].to_string())) - } else { - None - } - }) - .collect(); + while !self.buffer.is_empty() { + match self.state { + State::Init => { + let line = get_line(&mut self.buffer); + if line.is_none() { + if self.buffer.len() >= MAX_HEADER_SIZE { + return Err("Entity Too Large"); + } + return Ok(()); + } + let line = line.unwrap(); + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() != 3 { + return Err("Invalid method + path + version request"); + } + self.request.method = HttpMethod::from_str(parts[0]); + let path_parts: Vec<&str> = parts[1].split('?').collect(); + self.request.path = path_parts[0].to_string(); + if path_parts.len() > 1 { + self.request.args = path_parts[1] + .split('&') + .filter_map(|pair| { + let kv: Vec<&str> = pair.split('=').collect(); + if kv.len() == 2 { + Some((kv[0].to_string(), kv[1].to_string())) + } else { + None + } + }) + .collect(); + } + self.state = State::Headers; } - } else if !line_str.is_empty() { - // Header line - if let Some((key, value)) = line_str.split_once(":") { - //Check the number of headers, if the actual headers exceed that number - //drop the connection - if self.request.headers.len() > MAX_HEADER_COUNT { + State::Headers => { + let line = get_line(&mut self.buffer); + if line.is_none() { + return Ok(()); + } + let line = line.unwrap(); + if line.is_empty() { + self.state = if self.body_size == 0 && !self.chunked { + State::Finish + } else { + State::Body + }; + continue; + } + if self.request.headers.len() > MAX_HEADER_COUNT || line.len() > MAX_HEADER_SIZE + { return Err("Header number exceed allowed"); } - let key = key.trim().to_lowercase(); + let (key, value) = line.split_once(':').ok_or("Invalid Header")?; + let key = key.trim(); let value = value.trim(); - - if key == "content-length" { + if key.to_lowercase() == "content-length" { if self.request.headers.get("content-length").is_some() - || self - .request - .headers - .get("transfer-encoding") - .map(|te| te == "chunked") - .unwrap_or(false) + || self.request.headers.get("Transfer-Encoding") + == Some(&"chunked".to_string()) { - return Err("Duplicated content-length"); + continue; + } + self.body_size = value + .parse::() + .map_err(|_| "invalid content-length")?; + } + if key.to_lowercase() == "transfer-encoding" + && value.to_lowercase() == "chunked" + { + if self.request.headers.get("content-length").is_some() { + continue; } - self.body_size = value.parse().unwrap_or(0); + self.chunked = true; } - self.request - .headers - .insert(key.to_string(), value.to_string()); + self.request.headers.insert(key, value); } - } else { - // Empty line = end of headers - self.header_done = true; - self.read_body(); - return Ok(()); + State::Body => { + let body_left = self.body_size - self.request.body.len(); + if body_left > 0 { + let to_take = min(body_left, self.buffer.len()); + let to_append = self.buffer.drain(..to_take); + let to_append = to_append.as_slice(); + self.request.body.extend_from_slice(to_append); + } + if self.chunked { + if self.body_size != 0 { + let empty = get_line(&mut self.buffer); + if empty.is_none() { + return Ok(()); + } + } + let size = get_line(&mut self.buffer); + if size.is_none() { + return Ok(()); + } + let size = size.unwrap(); + let size = size.strip_prefix("0x").unwrap_or(&size); + let size = + i64::from_str_radix(size, 16).map_err(|_| "Invalud chunk size")?; + if size == 0 { + self.state = State::Finish; + return Ok(()); + } + self.body_size += size as usize; + } else { + self.state = State::Finish; + return Ok(()); + } + } + State::Finish => return Ok(()), } } + Ok(()) } } +fn get_line(buffer: &mut Vec) -> Option { + if let Some(pos) = buffer.windows(2).position(|w| w == b"\r\n") { + let line = buffer.drain(..pos).collect::>(); + buffer.drain(..2); // remove CRLF + return match str::from_utf8(line.as_slice()) { + Ok(v) => Some(v.to_string()), + Err(_e) => None, + }; + } + None +} + #[cfg(test)] #[test] fn basic_request() { // Placeholder test — add real body/header parsing test here. + let buffer = "GET / HTTP/1.1\r\n\r\n".as_bytes().to_vec(); + let mut request_builder = HttpRequestBuilder::new(); + let done = request_builder.append(buffer); + assert!(done.is_ok()); + let request = request_builder.get(); + assert!(request.is_some()); + let request = request.unwrap(); + assert!(request.path == "/"); + assert!(request.method == HttpMethod::GET); + assert!(request.headers.len() == 0); +} + +#[cfg(test)] +#[test] +fn basic_request_headers() { + // Placeholder test — add real body/header parsing test here. + let buffer = "GET / HTTP/1.1\r\nHost: test\r\n\r\n".as_bytes().to_vec(); + let mut request_builder = HttpRequestBuilder::new(); + let done = request_builder.append(buffer); + assert!(done.is_ok()); + let request = request_builder.get(); + assert!(request.is_some()); + let request = request.unwrap(); + assert!(request.path == "/"); + assert!(request.method == HttpMethod::GET); + assert!(request.headers.len() == 1); +} + +#[cfg(test)] +#[test] +fn post_request() { + // Placeholder test — add real body/header parsing test here. + let buffer = "POST / HTTP/1.1\r\ncontent-length: 4\r\n\r\nhello\r\n" + .as_bytes() + .to_vec(); + let mut request_builder = HttpRequestBuilder::new(); + let done = request_builder.append(buffer); + assert!(done.is_ok()); + let request = request_builder.get(); + assert!(request.is_some()); + let request = request.unwrap(); + assert!(request.path == "/"); + assert!(request.method == HttpMethod::POST); + assert!(request.headers.len() == 1); } diff --git a/src/hteapot/response.rs b/src/hteapot/response.rs index e2d35a0..e3a326f 100644 --- a/src/hteapot/response.rs +++ b/src/hteapot/response.rs @@ -7,26 +7,33 @@ //! //! All response types implement the [`HttpResponseCommon`] trait. +use super::HttpHeaders; + use super::HttpStatus; use super::{BUFFER_SIZE, VERSION}; -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; +use std::io::Write; +use std::net::TcpStream; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::{self, Receiver, SendError, Sender, TryRecvError}; -use std::thread; use std::thread::JoinHandle; +use std::time::Duration; +use std::vec; +use std::{io, thread}; /// Basic HTTP status line + headers. +#[derive(Clone)] pub struct BaseResponse { pub status: HttpStatus, - pub headers: HashMap, + pub headers: HttpHeaders, } impl BaseResponse { /// Converts the status + headers into a properly formatted HTTP header block. pub fn to_bytes(&mut self) -> Vec { let mut headers_text = String::new(); - for (key, value) in self.headers.iter() { + for (key, value) in &self.headers { headers_text.push_str(&format!("{}: {}\r\n", key, value)); } @@ -44,6 +51,7 @@ impl BaseResponse { } /// Represents a full HTTP response (headers + body). +#[derive(Clone)] pub struct HttpResponse { base: BaseResponse, pub content: Vec, @@ -62,6 +70,10 @@ pub trait HttpResponseCommon { /// Advances and returns the next chunk of the response body. fn peek(&mut self) -> Result, IterError>; + + fn set_stream(&mut self, _stream: &TcpStream) { + () + } } /// Error returned during response iteration. @@ -78,16 +90,13 @@ impl HttpResponse { pub fn new>( status: HttpStatus, content: B, - headers: Option>, + headers: Option, ) -> Box { - let mut headers = headers.unwrap_or(HashMap::new()); + let mut headers = headers.unwrap_or(HttpHeaders::new()); let content = content.as_ref(); - headers.insert("Content-Length".to_string(), content.len().to_string()); - headers.insert( - "Server".to_string(), - format!("HTeaPot/{}", VERSION).to_string(), - ); + headers.insert("Content-Length", &content.len().to_string()); + headers.insert("Server", &format!("HTeaPot/{}", VERSION).to_string()); Box::new(HttpResponse { base: BaseResponse { status, headers }, @@ -103,7 +112,7 @@ impl HttpResponse { HttpResponse { base: BaseResponse { status: HttpStatus::IAmATeapot, - headers: HashMap::new(), + headers: HttpHeaders::new(), }, content: vec![], raw: Some(raw), @@ -124,7 +133,7 @@ impl HttpResponse { } let mut headers_text = String::new(); - for (key, value) in self.base.headers.iter() { + for (key, value) in self.base.headers.clone() { headers_text.push_str(&format!("{}: {}\r\n", key, value)); } @@ -138,8 +147,6 @@ impl HttpResponse { let mut response = Vec::new(); response.extend_from_slice(response_header.as_bytes()); response.append(&mut self.content); - response.push(0x0D); // Carriage Return - response.push(0x0A); // Line Feed response } } @@ -150,9 +157,9 @@ impl HttpResponseCommon for HttpResponse { } fn next(&mut self) -> Result, IterError> { - let byte_chunk = self.peek()?; + //let byte_chunk = self.peek()?; self.index += 1; - return Ok(byte_chunk); + return Ok(Vec::new()); } fn peek(&mut self) -> Result, IterError> { @@ -229,15 +236,12 @@ impl StreamedResponse { let mut base = BaseResponse { status: HttpStatus::OK, - headers: HashMap::new(), + headers: HttpHeaders::new(), }; + base.headers.insert("Transfer-Encoding", "chunked"); base.headers - .insert("Transfer-Encoding".to_string(), "chunked".to_string()); - base.headers.insert( - "Server".to_string(), - format!("HTeaPot/{}", VERSION).to_string(), - ); + .insert("Server", &format!("HTeaPot/{}", VERSION)); let _ = tx.send(base.to_bytes()); let has_end = Arc::new(AtomicBool::new(false)); @@ -293,3 +297,83 @@ impl HttpResponseCommon for StreamedResponse { } } } + +pub struct TunnelResponse { + base: BaseResponse, + addr: String, + has_end: Arc, + stream_in: Option, // In as Stream from the client *in* this server + stream_out: Option, // Out as Stream from the server *to* this server +} + +impl TunnelResponse { + pub fn new(addr: &str) -> Box { + return Box::new(TunnelResponse { + base: BaseResponse { + status: HttpStatus::OK, + headers: HttpHeaders::new(), + // headers: headers! {"connection" => "keep-alive"}.unwrap(), + }, + addr: addr.to_string(), + has_end: Arc::new(AtomicBool::new(false)), + stream_in: None, + stream_out: None, + }); + } +} + +impl HttpResponseCommon for TunnelResponse { + fn base(&mut self) -> &mut BaseResponse { + &mut self.base + } + + fn next(&mut self) -> Result, IterError> { + self.peek() + } + + fn peek(&mut self) -> Result, IterError> { + if self.has_end.load(Ordering::SeqCst) { + if let Some(sock_in) = &self.stream_in { + let _ = sock_in.shutdown(std::net::Shutdown::Both); + } + if let Some(sock_out) = &self.stream_out { + let _ = sock_out.shutdown(std::net::Shutdown::Both); + } + return Err(IterError::Finished); + } + let mut buf = [0; 1]; + let _ = self.stream_in.as_ref().unwrap().peek(&mut buf); + + return Err(IterError::WouldBlock); + } + + fn set_stream(&mut self, stream: &TcpStream) { + let mut client_stream = stream.try_clone().expect("clone failed..."); + self.stream_in = Some(client_stream.try_clone().expect("clone failed...")); + let server_stream = TcpStream::connect(&self.addr); + if server_stream.is_err() { + println!("Error connecting"); + return; + } + let mut server_stream = server_stream.unwrap(); + self.stream_out = Some(server_stream.try_clone().expect("clone failed...")); + let _ = client_stream.set_nonblocking(false); + let _ = client_stream.set_nodelay(true); + let _ = client_stream.set_read_timeout(Some(Duration::from_millis(500))); + let _ = client_stream.set_write_timeout(Some(Duration::from_millis(500))); + let _ = client_stream.write_all(&self.base.to_bytes()); + let mut server_stream_1 = server_stream.try_clone().expect("Error cloning"); + let mut client_stream_1 = client_stream.try_clone().expect("clone failed..."); + let has_ended = self.has_end.clone(); + thread::spawn(move || { + let _ = io::copy(&mut client_stream_1, &mut server_stream_1); + has_ended.store(true, Ordering::SeqCst); + }); + + let has_ended = self.has_end.clone(); + thread::spawn(move || { + let _ = io::copy(&mut server_stream, &mut client_stream); + has_ended.store(true, Ordering::SeqCst); + }); + } +} diff --git a/src/main.rs b/src/main.rs index 69799e6..c101fc3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,120 +36,25 @@ //! See the [`config`](crate::config) module for configuration options and structure. mod cache; mod config; +mod handler; pub mod hteapot; mod logger; mod shutdown; mod utils; -use std::path::Path; +use std::fs; +use std::io; use std::sync::Mutex; -use std::{fs, io, path::PathBuf}; use cache::Cache; -use config::Config; + use hteapot::{Hteapot, HttpRequest, HttpResponse, HttpStatus}; -use utils::get_mime_tipe; use logger::{LogLevel, Logger}; use std::time::Instant; -/// Attempts to safely join a root directory and a requested relative path. -/// -/// Ensures that the resulting path: -/// - Resolves symbolic links and `..` segments via `canonicalize` -/// - Remains within the bounds of the specified root directory -/// - Actually exists on disk -/// -/// This protects against directory traversal vulnerabilities, such as accessing -/// files outside of the intended root (e.g., `/etc/passwd`). -/// -/// # Arguments -/// * `root` - The root directory from which serving is allowed. -/// * `requested_path` - The path requested by the client (usually from the URL). -/// -/// # Returns -/// `Some(PathBuf)` if the resolved path exists and is within the root. `None` otherwise. -/// -/// # Example -/// ``` -/// let safe_path = safe_join_paths("/var/www", "/index.html"); -/// assert!(safe_path.unwrap().ends_with("index.html")); -/// ``` -fn safe_join_paths(root: &str, requested_path: &str) -> Option { - let root_path = Path::new(root).canonicalize().ok()?; - let requested_full_path = root_path.join(requested_path.trim_start_matches("/")); - - if !requested_full_path.exists() { - return None; - } - - let canonical_path = requested_full_path.canonicalize().ok()?; - - if canonical_path.starts_with(&root_path) { - Some(canonical_path) - } else { - None - } -} - -/// Determines whether a given HTTP request should be proxied based on the configuration. -/// -/// If a matching proxy rule is found in `config.proxy_rules`, the function rewrites the -/// request path and updates the `Host` header accordingly. -/// -/// # Arguments -/// * `config` - Server configuration containing proxy rules. -/// * `req` - The original HTTP request. -/// -/// # Returns -/// `Some((proxy_url, modified_request))` if the request should be proxied, otherwise `None`. -fn is_proxy(config: &Config, req: HttpRequest) -> Option<(String, HttpRequest)> { - for proxy_path in config.proxy_rules.keys() { - let path_match = req.path.strip_prefix(proxy_path); - if path_match.is_some() { - let new_path = path_match.unwrap(); - let url = config.proxy_rules.get(proxy_path).unwrap().clone(); - let mut proxy_req = req.clone(); - proxy_req.path = new_path.to_string(); - proxy_req.headers.remove("host"); - proxy_req.headers.remove("Host"); - let host_parts: Vec<_> = url.split("://").collect(); - let host = if host_parts.len() == 1 { - host_parts.first().unwrap() - } else { - host_parts.last().clone().unwrap() - }; - proxy_req.header("host", host); - return Some((url, proxy_req)); - } - } - None -} - -/// Reads the content of a file from the filesystem. -/// -/// # Arguments -/// * `path` - A reference to a `PathBuf` representing the target file. -/// -/// # Returns -/// `Some(Vec)` if the file is read successfully, or `None` if an error occurs. -/// -/// # Notes -/// Uses `PathBuf` instead of `&str` to clearly express intent and reduce path handling bugs. -/// -/// # See Also -/// [`std::fs::read`](https://doc.rust-lang.org/std/fs/fn.read.html) -fn serve_file(path: &PathBuf) -> Option> { - let r = fs::read(path); - if r.is_ok() { Some(r.unwrap()) } else { None } -} -// -// Suggest to use .ok()? instead of manual unwrap/if is_ok for more idiomatic error handling: -// fn serve_file(path: &PathBuf) -> Option> { -// fs::read(path).ok() -// } -// -// +use crate::handler::HandlerEngine; +use crate::utils::Context; /// Main entry point of the Hteapot server. /// @@ -178,7 +83,7 @@ fn main() { } // Initialize logger based on config or default to stdout - let config = match args[1].as_str() { + let mut config = match args[1].as_str() { "--help" | "-h" => { println!("Hteapot {}", hteapot::VERSION); println!("usage: {} ", args[0]); @@ -189,33 +94,28 @@ fn main() { return; } "--serve" | "-s" => { - let mut c = config::Config::new_default(); - let serving_path = Some(args.get(2).unwrap().clone()); - let serving_path_str = serving_path.unwrap(); - let serving_path_str = serving_path_str.as_str(); - let serving_path = Path::new(serving_path_str); - if serving_path.is_dir() { - c.root = serving_path.to_str().unwrap_or_default().to_string(); - } else { - c.index = serving_path - .file_name() - .unwrap() - .to_str() - .unwrap_or_default() - .to_string(); - c.root = serving_path - .parent() - .unwrap_or(Path::new("./")) - .to_str() - .unwrap_or_default() - .to_string(); - } - c.host = "0.0.0.0".to_string(); + let path = args.get(2).unwrap().clone(); + config::Config::new_serve(&path) + } + "--proxy" => { + let c = config::Config::new_proxy(); c } _ => config::Config::load_config(&args[1]), }; + if args.contains(&"-p".to_string()) { + let i = args.iter().position(|e| *e == "-p".to_string()).unwrap(); + let port = args[i + 1].clone(); + let port = port.parse::(); + if port.is_err() { + println!("Invalid port provided"); + return; + } + let port = port.unwrap(); + config.port = port; + } + // Determine if the server should proxy all requests let proxy_only = config.proxy_rules.get("/").is_some(); @@ -243,7 +143,8 @@ fn main() { // Set up the cache with thread-safe locking // The Mutex ensures that only one thread can access the cache at a time, // preventing race conditions when reading and writing to the cache. - let cache: Mutex = Mutex::new(Cache::new(config.cache_ttl as u64)); // Initialize the cache with TTL + let cache: Mutex> = + Mutex::new(Cache::new(config.cache_ttl as u64)); // Initialize the cache with TTL // Create a new threaded HTTP server with the provided host, port, and number of threads let mut server = Hteapot::new_threaded(config.host.as_str(), config.port, config.threads); @@ -269,154 +170,65 @@ fn main() { // Create separate loggers for each component (proxy, cache, and HTTP) // This allows for more granular control over logging and better separation of concerns - let proxy_logger = logger.with_component("proxy"); + let cache_logger = logger.with_component("cache"); let http_logger = logger.with_component("http"); + let handlers = HandlerEngine::new(); // Start listening for HTTP requests - server.listen(move |req| { + server.listen(move |req: HttpRequest| { // SERVER CORE: For each incoming request, we handle it in this closure let start_time = Instant::now(); // Track request processing time let req_method = req.method.to_str(); // Get the HTTP method (e.g., GET, POST) - let req_path = req.path.clone(); // Get the requested path // Log the incoming request method and path http_logger.info(format!("Request {} {}", req_method, req.path)); - // Check if the request should be proxied (either because proxy-only mode is on, or it matches a rule) - let is_proxy = is_proxy(&config, req.clone()); - if proxy_only || is_proxy.is_some() { - // If proxying is enabled or this request matches a proxy rule, handle it - let (host, proxy_req) = is_proxy.unwrap(); // Get the target host and modified request - proxy_logger.info(format!( - "Proxying request {} {} to {}", - req_method, req_path, host - )); - - // Perform the proxy request (forward the request to the target server) - let res = proxy_req.brew(host.as_str()); - let elapsed = start_time.elapsed(); // Measure the time taken to process the proxy request - if res.is_ok() { - // If the proxy request is successful, log the time taken and return the response - let response = res.unwrap(); - proxy_logger.info(format!( - "Proxy request processed in {:.6}ms", + if config.cache { + let cache_start = Instant::now(); // Track cache operation time + let mut cache_lock = cache.lock().expect("Error locking cache"); + if let Some(response) = cache_lock.get(&req) { + cache_logger.debug(format!("cache hit for {}", &req.path)); + let elapsed = start_time.elapsed(); + http_logger.debug(format!( + "Request processed in {:.6}ms", elapsed.as_secs_f64() * 1000.0 // Log the time taken in milliseconds )); - return response; - } else { - // If the proxy request fails, log the error and return a 500 Internal Server Error - proxy_logger.error(format!("Proxy request failed: {:?}", res.err())); - return HttpResponse::new( - HttpStatus::InternalServerError, - "Internal Server Error", - None, - ); - } - } - - // If the request is not a proxy request, resolve the requested path safely - let safe_path_result = if req.path == "/" { - // Special handling for the root "/" path - let root_path = Path::new(&config.root).canonicalize(); - if root_path.is_ok() { - // If the root path exists and is valid, try to join the index file - let index_path = root_path.unwrap().join(&config.index); - if index_path.exists() { - Some(index_path) // If index exists, return its path - } else { - None // If no index exists, return None - } - } else { - None // If the root path is invalid, return None - } - } else { - // For any other path, resolve it safely using the `safe_join_paths` function - safe_join_paths(&config.root, &req.path) - }; - - // Handle the case where the resolved path is a directory - let safe_path = match safe_path_result { - Some(path) => { - if path.is_dir() { - // If it's a directory, check for the index file in that directory - let index_path = path.join(&config.index); - if index_path.exists() { - index_path // If index exists, return its path - } else { - // If no index file exists, log a warning and return a 404 response - http_logger - .warn(format!("Index file not found in directory: {}", req.path)); - return HttpResponse::new(HttpStatus::NotFound, "Index not found", None); - } - } else { - path // If it's not a directory, just return the path - } - } - None => { - // If the path is invalid or access is denied, return a 404 response - http_logger.warn(format!("Path not found or access denied: {}", req.path)); - return HttpResponse::new(HttpStatus::NotFound, "Not found", None); - } - }; - - // Determine the MIME type for the file based on its extension - let mimetype = get_mime_tipe(&safe_path.to_string_lossy().to_string()); - - // Try to serve the file from the cache, or read it from disk if not cached - let content: Option> = if config.cache { - // Lock the cache to ensure thread-safe access - let mut cachee = cache.lock().expect("Error locking cache"); - let cache_start = Instant::now(); // Track cache operation time - let cache_key = req.path.clone(); // Use the request path as the cache key - let mut r = cachee.get(cache_key.clone()); // Try to get the content from cache - if r.is_none() { - // If cache miss, read the file from disk and store it in cache - cache_logger.debug(format!("cache miss for {}", cache_key)); - r = serve_file(&safe_path); - if r.is_some() { - // If the file is read successfully, add it to the cache - cache_logger.debug(format!("Adding {} to cache", cache_key)); - cachee.set(cache_key, r.clone().unwrap()); - } + return Box::new(response); } else { - // If cache hit, log it - cache_logger.debug(format!("cache hit for {}", cache_key)); + cache_logger.debug(format!("cache miss for {}", &req.path)); } - - // Log how long the cache operation took let cache_elapsed = cache_start.elapsed(); cache_logger.debug(format!( "Cache operation completed in {:.6}µs", cache_elapsed.as_micros() )); - r // Return the cached content (or None if not found) - } else { - // If cache is disabled, read the file from disk - serve_file(&safe_path) + } + + let mut ctx = Context { + request: &req, + log: &logger, + config: &config, + cache: if config.cache { + Some(&mut cache.lock().unwrap()) + } else { + None + }, }; + let response = handlers.get_handler(&ctx); + if response.is_none() { + return HttpResponse::new(HttpStatus::InternalServerError, "content", None); + } + let response = response.unwrap().run(&mut ctx); + // Log how long the request took to process let elapsed = start_time.elapsed(); http_logger.debug(format!( "Request processed in {:.6}ms", elapsed.as_secs_f64() * 1000.0 // Log the time taken in milliseconds )); - + response // If content was found, return it with the appropriate headers, otherwise return a 404 - match content { - Some(c) => { - // If content is found, create response with proper headers and a 200 OK status - let headers = headers!( - "Content-Type" => mimetype, - "X-Content-Type-Options" => "nosniff" - ); - HttpResponse::new(HttpStatus::OK, c, headers) - } - None => { - // If no content is found, return a 404 Not Found response - HttpResponse::new(HttpStatus::NotFound, "Not found", None) - } - } }); } diff --git a/src/utils.rs b/src/utils.rs index 8e56ca4..98b8a8f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,12 @@ use std::path::Path; +use crate::{ + cache::Cache, + config::Config, + hteapot::{HttpRequest, HttpResponse}, + logger::Logger, +}; + /// Returns the MIME type based on the file extension of a given path. /// /// This function maps common file extensions to their appropriate @@ -27,7 +34,7 @@ pub fn get_mime_tipe(path: &String) -> String { // Suggest using `to_str()` directly on the extension // Alternative way to get the extension // .and_then(|ext| ext.to_str()) - + let mimetipe = match extension { // Text "html" | "htm" => "text/html; charset=utf-8", @@ -39,7 +46,7 @@ pub fn get_mime_tipe(path: &String) -> String { "txt" => "text/plain", "md" => "text/markdown", "csv" => "text/csv", - + // Images "ico" => "image/x-icon", "png" => "image/png", @@ -49,19 +56,19 @@ pub fn get_mime_tipe(path: &String) -> String { "webp" => "image/webp", "bmp" => "image/bmp", "tiff" | "tif" => "image/tiff", - + // Audio "mp3" => "audio/mpeg", "wav" => "audio/wav", "ogg" => "audio/ogg", "flac" => "audio/flac", - + // Video "mp4" => "video/mp4", "webm" => "video/webm", "avi" => "video/x-msvideo", "mkv" => "video/x-matroska", - + // Documents "pdf" => "application/pdf", "doc" => "application/msword", @@ -70,20 +77,20 @@ pub fn get_mime_tipe(path: &String) -> String { "xlsx" => "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", "ppt" => "application/vnd.ms-powerpoint", "pptx" => "application/vnd.openxmlformats-officedocument.presentationml.presentation", - + // Archives "zip" => "application/zip", "tar" => "application/x-tar", "gz" => "application/gzip", "7z" => "application/x-7z-compressed", "rar" => "application/vnd.rar", - + // Fonts "ttf" => "font/ttf", "otf" => "font/otf", "woff" => "font/woff", "woff2" => "font/woff2", - + // For unknown types, use a safe default _ => "application/octet-stream", }; @@ -92,4 +99,11 @@ pub fn get_mime_tipe(path: &String) -> String { } //TODO: make a parser args to config -//pub fn args_to_dict(list: Vec) -> HashMap {} \ No newline at end of file +//pub fn args_to_dict(list: Vec) -> HashMap {} + +pub struct Context<'a> { + pub request: &'a HttpRequest, + pub log: &'a Logger, + pub config: &'a Config, + pub cache: Option<&'a mut Cache>, +}