Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ libc = "0.2"

[dev-dependencies]
anyhow = "1.0"
proptest = "1.11.0"
rand = "0.10"
uuid = "1.10"

Expand Down
281 changes: 280 additions & 1 deletion src/cmdext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,229 @@
//!
//! - File descriptor passing
//! - Changing to a file-descriptor relative directory
//! - Systemd socket activation fd passing

use cap_std::fs::Dir;
use cap_std::io_lifetimes;
use cap_tempfile::cap_std;
use io_lifetimes::OwnedFd;
use rustix::fd::{AsFd, FromRawFd, IntoRawFd};
use rustix::io::FdFlags;
use std::collections::BTreeSet;
use std::ffi::CString;
use std::os::fd::AsRawFd;
use std::os::unix::process::CommandExt;
use std::sync::Arc;

/// The file descriptor number at which systemd passes the first socket.
/// See `sd_listen_fds(3)`.
const SD_LISTEN_FDS_START: i32 = 3;

/// A validated name for a systemd socket-activation file descriptor.
///
/// Names appear in the `LISTEN_FDNAMES` environment variable as
/// colon-separated values. The constructor validates that the name
/// conforms to systemd's `fdname_is_valid()` rules: at most 255
/// printable ASCII characters, excluding `:`.
///
/// ```
/// use cap_std_ext::cmdext::SystemdFdName;
/// let name = SystemdFdName::new("varlink");
/// ```
#[derive(Debug, Clone, Copy)]
pub struct SystemdFdName<'a>(&'a str);

impl<'a> SystemdFdName<'a> {
/// Create a new `SystemdFdName`, panicking if `name` is invalid.
///
/// # Panics
///
/// Panics if `name` is longer than 255 bytes or contains any
/// character that is not printable ASCII (i.e. control characters,
/// DEL, non-ASCII bytes, or `:`).
pub const fn new(name: &'a str) -> Self {
assert!(
name.len() <= 255,
"systemd fd name must be at most 255 characters"
);
let bytes = name.as_bytes();
let mut i = 0;
while i < bytes.len() {
let b = bytes[i];
assert!(
b >= b' ' && b < 127 && b != b':',
"systemd fd name must only contain printable ASCII characters except ':'"
);
i += 1;
}
Self(name)
}

/// Return the name as a string slice.
pub fn as_str(&self) -> &'a str {
self.0
}
}

/// File descriptor allocator for child processes.
///
/// Collects fd assignments and optional systemd socket-activation
/// configuration, then applies them all at once via
/// [`CapStdExtCommandExt::take_fds`].
///
/// - [`new_systemd_fds`](Self::new_systemd_fds) creates an allocator
/// with systemd socket-activation fds at 3, 4, … (`SD_LISTEN_FDS_START`).
/// - [`take_fd`](Self::take_fd) auto-assigns the next fd above all
/// previously assigned ones (minimum 3).
/// - [`take_fd_n`](Self::take_fd_n) places an fd at an explicit number,
/// panicking on overlap.
///
/// ```no_run
/// # use std::sync::Arc;
/// # use cap_std_ext::cmdext::{CmdFds, CapStdExtCommandExt, SystemdFdName};
/// # let varlink_fd: Arc<rustix::fd::OwnedFd> = todo!();
/// # let extra_fd: Arc<rustix::fd::OwnedFd> = todo!();
/// let mut cmd = std::process::Command::new("myservice");
/// let mut fds = CmdFds::new_systemd_fds([(varlink_fd, SystemdFdName::new("varlink"))]);
/// let extra_n = fds.take_fd(extra_fd);
/// cmd.take_fds(fds);
/// ```
#[derive(Debug)]
pub struct CmdFds {
taken: BTreeSet<i32>,
fds: Vec<(i32, Arc<OwnedFd>)>,
/// Pre-built CStrings for the systemd env vars, set by new_systemd_fds.
systemd_env: Option<(CString, CString)>,
}

impl Default for CmdFds {
fn default() -> Self {
Self::new()
}
}

impl CmdFds {
/// Create a new fd allocator.
pub fn new() -> Self {
Self {
taken: BTreeSet::new(),
fds: Vec::new(),
systemd_env: None,
}
}

/// Create a new fd allocator with systemd socket-activation fds.
///
/// Each `(fd, name)` pair is assigned a consecutive fd number starting
/// at `SD_LISTEN_FDS_START` (3). The `LISTEN_PID`, `LISTEN_FDS`, and
/// `LISTEN_FDNAMES` environment variables will be set in the child
/// when [`CapStdExtCommandExt::take_fds`] is called.
///
/// Additional (non-systemd) fds can be registered afterwards via
/// [`take_fd`](Self::take_fd) or [`take_fd_n`](Self::take_fd_n).
///
/// [sd_listen_fds]: https://www.freedesktop.org/software/systemd/man/latest/sd_listen_fds.html
pub fn new_systemd_fds<'a>(
fds: impl IntoIterator<Item = (Arc<OwnedFd>, SystemdFdName<'a>)>,
) -> Self {
let mut this = Self::new();
this.register_systemd_fds(fds);
this
}

/// Compute the next fd number above everything already taken
/// (minimum `SD_LISTEN_FDS_START`).
fn next_fd(&self) -> i32 {
self.taken
.last()
.map(|n| n.checked_add(1).expect("fd number overflow"))
.unwrap_or(SD_LISTEN_FDS_START)
}

fn insert_fd(&mut self, n: i32) {
let inserted = self.taken.insert(n);
assert!(inserted, "fd {n} is already assigned");
}

/// Register a file descriptor at the next available fd number.
///
/// Returns the fd number that will be assigned in the child.
/// Call [`CapStdExtCommandExt::take_fds`] to apply.
pub fn take_fd(&mut self, fd: Arc<OwnedFd>) -> i32 {
let n = self.next_fd();
self.insert_fd(n);
self.fds.push((n, fd));
n
}

/// Register a file descriptor at a specific fd number.
///
/// Call [`CapStdExtCommandExt::take_fds`] to apply.
///
/// # Panics
///
/// Panics if `target` has already been assigned.
pub fn take_fd_n(&mut self, fd: Arc<OwnedFd>, target: i32) -> &mut Self {
self.insert_fd(target);
self.fds.push((target, fd));
self
}

fn register_systemd_fds<'a>(
&mut self,
fds: impl IntoIterator<Item = (Arc<OwnedFd>, SystemdFdName<'a>)>,
) {
let mut n_fds: i32 = 0;
let mut names = Vec::new();
for (fd, name) in fds {
let target = SD_LISTEN_FDS_START
.checked_add(n_fds)
.expect("too many fds");
self.insert_fd(target);
self.fds.push((target, fd));
names.push(name.as_str());
n_fds = n_fds.checked_add(1).expect("too many fds");
}

let fd_count = CString::new(n_fds.to_string()).unwrap();
// SAFETY: SystemdFdName guarantees no NUL bytes.
let fd_names = CString::new(names.join(":")).unwrap();
self.systemd_env = Some((fd_count, fd_names));
}
}

/// Extension trait for [`std::process::Command`].
///
/// [`cap_std::fs::Dir`]: https://docs.rs/cap-std/latest/cap_std/fs/struct.Dir.html
pub trait CapStdExtCommandExt {
/// Pass a file descriptor into the target process.
/// Pass a file descriptor into the target process at a specific fd number.
///
/// # Deprecated
///
/// Use [`CmdFds`] with [`take_fds`](Self::take_fds) instead. This method
/// registers an independent `pre_exec` hook per call, which means
/// multiple `take_fd_n` calls (or mixing with `take_fds`) can clobber
/// each other when a source fd's raw number equals another mapping's
/// target. `take_fds` handles this correctly with atomic fd shuffling.
#[deprecated = "Use CmdFds with take_fds() instead"]
fn take_fd_n(&mut self, fd: Arc<OwnedFd>, target: i32) -> &mut Self;

/// Apply a [`CmdFds`] to this command, passing all registered file
/// descriptors and (if configured) setting up the systemd
/// socket-activation environment.
///
/// # Important: Do not use `Command::env()` with systemd fds
///
/// When systemd socket-activation environment variables are configured
/// (via [`CmdFds::new_systemd_fds`]), they are set using `setenv(3)` in
/// a `pre_exec` hook. If `Command::env()` is also called, Rust will
/// build an `envp` array that replaces the process environment, causing
/// the `LISTEN_*` variables set by the hook to be lost. `Command::envs()`
/// is equally problematic. If you need to set additional environment
/// variables alongside systemd fds, set them via `pre_exec` + `setenv`
/// as well.
fn take_fds(&mut self, fds: CmdFds) -> &mut Self;

/// Use the given directory as the current working directory for the process.
fn cwd_dir(&mut self, dir: Dir) -> &mut Self;

Expand All @@ -39,7 +244,27 @@ pub trait CapStdExtCommandExt {
fn lifecycle_bind_to_parent_thread(&mut self) -> &mut Self;
}

/// Wrapper around `libc::setenv` that checks the return value.
///
/// # Safety
///
/// Must only be called in a single-threaded context (e.g. after `fork()`
/// and before `exec()`).
#[allow(unsafe_code)]
unsafe fn check_setenv(
key: *const std::ffi::c_char,
val: *const std::ffi::c_char,
) -> std::io::Result<()> {
// SAFETY: Caller guarantees we are in a single-threaded context
// with valid nul-terminated C strings.
if unsafe { libc::setenv(key, val, 1) } != 0 {
return Err(std::io::Error::last_os_error());
}
Ok(())
}

#[allow(unsafe_code)]
#[allow(deprecated)]
impl CapStdExtCommandExt for std::process::Command {
fn take_fd_n(&mut self, fd: Arc<OwnedFd>, target: i32) -> &mut Self {
unsafe {
Expand All @@ -62,6 +287,59 @@ impl CapStdExtCommandExt for std::process::Command {
self
}

fn take_fds(&mut self, fds: CmdFds) -> &mut Self {
// Use a single pre_exec hook that handles all fd shuffling atomically.
// This avoids the problem where separate hooks clobber each other when
// a source fd number equals a target fd number from a different mapping.
unsafe {
self.pre_exec(move || {
// Dup each source fd to a temporary location above all
// targets, so that no dup2() in step 2 can clobber a source.
let safe_min = fds
.fds
.iter()
.map(|(t, _)| *t)
.max()
.unwrap_or(0)
.checked_add(1)
.expect("fd number overflow");
let mut safe_copies: Vec<(i32, OwnedFd)> = Vec::new();
for (target, fd) in &fds.fds {
let copy = rustix::io::fcntl_dupfd_cloexec(fd, safe_min)?;
safe_copies.push((*target, copy));
}

// Place each fd at its target via dup2.
// We use raw dup2 to avoid fabricating an OwnedFd for a
// target number we don't yet own (which would be unsound
// if dup2 failed — the OwnedFd drop would close a wrong fd).
for (target, copy) in safe_copies {
// SAFETY: target is a non-negative fd number that dup2
// will atomically (re)open; we don't own it beforehand.
let r = libc::dup2(copy.as_raw_fd(), target);
if r < 0 {
return Err(std::io::Error::last_os_error());
}
// `copy` drops here, closing the temporary fd.
}

// Handle systemd env vars, if configured
if let Some((ref fd_count, ref fd_names)) = fds.systemd_env {
let pid = rustix::process::getpid();
let pid_dec = rustix::path::DecInt::new(pid.as_raw_nonzero().get());
// SAFETY: After fork() and before exec(), the child is
// single-threaded, so setenv (which is not thread-safe)
// is safe to call here.
check_setenv(c"LISTEN_PID".as_ptr(), pid_dec.as_c_str().as_ptr())?;
check_setenv(c"LISTEN_FDS".as_ptr(), fd_count.as_ptr())?;
check_setenv(c"LISTEN_FDNAMES".as_ptr(), fd_names.as_ptr())?;
}
Ok(())
});
}
self
}

fn cwd_dir(&mut self, dir: Dir) -> &mut Self {
unsafe {
self.pre_exec(move || {
Expand Down Expand Up @@ -92,6 +370,7 @@ mod tests {
use super::*;
use std::sync::Arc;

#[allow(deprecated)]
#[test]
fn test_take_fdn() -> anyhow::Result<()> {
// Pass srcfd == destfd and srcfd != destfd
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub(crate) fn escape_attempt() -> io::Error {
/// Prelude, intended for glob import.
pub mod prelude {
#[cfg(not(windows))]
pub use super::cmdext::CapStdExtCommandExt;
pub use super::cmdext::{CapStdExtCommandExt, CmdFds, SystemdFdName};
pub use super::dirext::CapStdExtDirExt;
#[cfg(feature = "fs_utf8")]
pub use super::dirext::CapStdExtDirExtUtf8;
Expand Down
Loading