diff --git a/src/uu/head/src/head.rs b/src/uu/head/src/head.rs index 896e0c923e..0f5b4fc1e3 100644 --- a/src/uu/head/src/head.rs +++ b/src/uu/head/src/head.rs @@ -166,34 +166,73 @@ fn wrap_in_stdout_error(err: io::Error) -> io::Error { ) } +enum PrintError { + ReadError(io::Error), + WriteError(io::Error), +} + +impl PrintError { + /// Wrap a read `io::Error` into a `PrintError::ReadError`, stripping the errno suffix. + fn read(e: io::Error) -> Self { + Self::ReadError(io::Error::new(e.kind(), uucore::error::strip_errno(&e))) + } + + /// Wrap a write `io::Error` into a `PrintError::WriteError`, adding the stdout context. + fn write(e: io::Error) -> Self { + Self::WriteError(wrap_in_stdout_error(e)) + } + + fn into_uu_error(self, name: PathBuf) -> Box { + match self { + Self::ReadError(err) => HeadError::Io { name, err }.into(), + Self::WriteError(err) => err.into(), + } + } +} + // zero-copy fast-path #[cfg(any(target_os = "linux", target_os = "android"))] -fn print_n_bytes(input: impl Read + AsFd, n: u64) -> io::Result { +fn print_n_bytes(input: impl Read + AsFd, n: u64) -> Result { let mut out = io::stdout(); - let res = uucore::pipes::send_n_bytes(input, &out, n).map_err(wrap_in_stdout_error); + let res = uucore::pipes::send_n_bytes(input, &out, n).map_err(PrintError::write); // flush prevents ignoring I/O error - out.flush().map_err(wrap_in_stdout_error)?; + out.flush().map_err(PrintError::write)?; res } #[cfg(not(any(target_os = "linux", target_os = "android")))] -fn print_n_bytes(input: impl Read, n: u64) -> io::Result { +fn print_n_bytes(input: impl Read, n: u64) -> Result { // Read the first `n` bytes from the `input` reader. let mut reader = input.take(n); // Write those bytes to `stdout`. let stdout = io::stdout(); - let mut stdout = stdout.lock(); + let stdout = stdout.lock(); + let mut writer = BufWriter::with_capacity(BUF_SIZE, stdout); + + let mut buf = [0u8; BUF_SIZE]; + let mut bytes_written = 0u64; - let bytes_written = io::copy(&mut reader, &mut stdout).map_err(wrap_in_stdout_error)?; + loop { + let bytes_read = reader.read(&mut buf).map_err(PrintError::read)?; + + if bytes_read == 0 { + break; + } + + writer + .write_all(&buf[..bytes_read]) + .map_err(PrintError::write)?; + bytes_written += bytes_read as u64; + } // flush prevents ignoring I/O error - stdout.flush().map_err(wrap_in_stdout_error)?; + writer.flush().map_err(PrintError::write)?; Ok(bytes_written) } -fn print_n_lines(input: &mut impl io::BufRead, n: u64, separator: u8) -> io::Result { +fn print_n_lines(input: &mut impl io::BufRead, n: u64, separator: u8) -> Result { // Read the first `n` lines from the `input` reader. let mut reader = take_lines(input, n, separator); @@ -202,12 +241,26 @@ fn print_n_lines(input: &mut impl io::BufRead, n: u64, separator: u8) -> io::Res let stdout = stdout.lock(); let mut writer = BufWriter::with_capacity(BUF_SIZE, stdout); - let bytes_written = io::copy(&mut reader, &mut writer).map_err(wrap_in_stdout_error)?; + let mut buf = [0u8; BUF_SIZE]; + let mut bytes_written = 0; + + loop { + let bytes_read = reader.read(&mut buf).map_err(PrintError::read)?; + + if bytes_read == 0 { + break; + } + + writer + .write_all(&buf[..bytes_read]) + .map_err(PrintError::write)?; + bytes_written += bytes_read as u64; + } // Make sure we finish writing everything to the target before // exiting. Otherwise, when Rust is implicitly flushing, any // error will be silently ignored. - writer.flush().map_err(wrap_in_stdout_error)?; + writer.flush().map_err(PrintError::write)?; Ok(bytes_written) } @@ -216,41 +269,41 @@ fn catch_too_large_numbers_in_backwards_bytes_or_lines(n: u64) -> Option usize::try_from(n).ok() } -fn print_but_last_n_bytes(mut input: impl Read, n: u64) -> io::Result { +fn print_but_last_n_bytes(mut input: impl Read, n: u64) -> Result { let mut bytes_written: u64 = 0; if let Some(n) = catch_too_large_numbers_in_backwards_bytes_or_lines(n) { let stdout = io::stdout(); let mut stdout = stdout.lock(); bytes_written = copy_all_but_n_bytes(&mut input, &mut stdout, n) - .map_err(wrap_in_stdout_error)? + .map_err(PrintError::write)? .try_into() .unwrap(); // Make sure we finish writing everything to the target before // exiting. Otherwise, when Rust is implicitly flushing, any // error will be silently ignored. - stdout.flush().map_err(wrap_in_stdout_error)?; + stdout.flush().map_err(PrintError::write)?; } Ok(bytes_written) } -fn print_but_last_n_lines(mut input: impl Read, n: u64, separator: u8) -> io::Result { +fn print_but_last_n_lines(mut input: impl Read, n: u64, separator: u8) -> Result { let stdout = io::stdout(); let mut stdout = stdout.lock(); if n == 0 { - return io::copy(&mut input, &mut stdout).map_err(wrap_in_stdout_error); + return io::copy(&mut input, &mut stdout).map_err(PrintError::write); } let mut bytes_written: u64 = 0; if let Some(n) = catch_too_large_numbers_in_backwards_bytes_or_lines(n) { bytes_written = copy_all_but_n_lines(input, &mut stdout, n, separator) - .map_err(wrap_in_stdout_error)? + .map_err(PrintError::write)? .try_into() .unwrap(); // Make sure we finish writing everything to the target before // exiting. Otherwise, when Rust is implicitly flushing, any // error will be silently ignored. - stdout.flush().map_err(wrap_in_stdout_error)?; + stdout.flush().map_err(PrintError::write)?; } Ok(bytes_written) } @@ -351,8 +404,8 @@ fn is_seekable(input: &mut File) -> bool { && input.seek(SeekFrom::Start(current_pos.unwrap())).is_ok() } -fn head_backwards_file(input: &mut File, options: &HeadOptions) -> io::Result { - let st = input.metadata()?; +fn head_backwards_file(input: &mut File, options: &HeadOptions) -> Result { + let st = input.metadata().map_err(PrintError::read)?; let seekable = is_seekable(input); let blksize_limit = uucore::fs::sane_blksize::sane_blksize_from_metadata(&st); if !seekable || st.len() <= blksize_limit || options.presume_input_pipe { @@ -362,7 +415,10 @@ fn head_backwards_file(input: &mut File, options: &HeadOptions) -> io::Result io::Result { +fn head_backwards_without_seek_file( + input: &mut File, + options: &HeadOptions, +) -> Result { match options.mode { Mode::AllButLastBytes(n) => print_but_last_n_bytes(input, n), Mode::AllButLastLines(n) => print_but_last_n_lines(input, n, options.line_ending.into()), @@ -370,10 +426,13 @@ fn head_backwards_without_seek_file(input: &mut File, options: &HeadOptions) -> } } -fn head_backwards_on_seekable_file(input: &mut File, options: &HeadOptions) -> io::Result { +fn head_backwards_on_seekable_file( + input: &mut File, + options: &HeadOptions, +) -> Result { match options.mode { Mode::AllButLastBytes(n) => { - let size = input.metadata()?.len(); + let size = input.metadata().map_err(PrintError::read)?.len(); if n >= size { Ok(0) } else { @@ -381,14 +440,15 @@ fn head_backwards_on_seekable_file(input: &mut File, options: &HeadOptions) -> i } } Mode::AllButLastLines(n) => { - let found = find_nth_line_from_end(input, n, options.line_ending.into())?; + let found = find_nth_line_from_end(input, n, options.line_ending.into()) + .map_err(PrintError::read)?; print_n_bytes(input, found) } _ => unreachable!(), } } -fn head_file(input: &mut File, options: &HeadOptions) -> io::Result { +fn head_file(input: &mut File, options: &HeadOptions) -> Result { match options.mode { Mode::FirstBytes(n) => print_n_bytes(input, n), Mode::FirstLines(n) => print_n_lines( @@ -424,10 +484,16 @@ fn uu_head(options: &HeadOptions) -> UResult<()> { // last byte read so that any tools that parse the remainder of // the stdin stream read from the correct place. - let bytes_read = head_file(&mut stdin_file, options)?; + let bytes_read = match head_file(&mut stdin_file, options) { + Ok(n) => n, + Err(e) => return Err(e.into_uu_error("standard input".into())), + }; stdin_file.seek(SeekFrom::Start(current_pos + bytes_read))?; } else { - let _bytes_read = head_file(&mut stdin_file, options)?; + match head_file(&mut stdin_file, options) { + Ok(_) => {} + Err(e) => return Err(e.into_uu_error("standard input".into())), + } } } @@ -435,14 +501,17 @@ fn uu_head(options: &HeadOptions) -> UResult<()> { { let mut stdin = stdin.lock(); - match options.mode { + let res = match options.mode { Mode::FirstBytes(n) => print_n_bytes(&mut stdin, n), Mode::AllButLastBytes(n) => print_but_last_n_bytes(&mut stdin, n), Mode::FirstLines(n) => print_n_lines(&mut stdin, n, options.line_ending.into()), Mode::AllButLastLines(n) => { print_but_last_n_lines(&mut stdin, n, options.line_ending.into()) } - }?; + }; + if let Err(e) = res { + return Err(e.into_uu_error("standard input".into())); + } } Ok(()) @@ -493,8 +562,7 @@ fn uu_head(options: &HeadOptions) -> UResult<()> { continue; } }; - head_file(&mut file_handle, options)?; - Ok(()) + head_file(&mut file_handle, options).map(|_| ()) }; if let Err(err) = res { let name = if file == "-" { @@ -502,7 +570,8 @@ fn uu_head(options: &HeadOptions) -> UResult<()> { } else { file.into() }; - return Err(HeadError::Io { name, err }.into()); + + return Err(err.into_uu_error(name)); } first = false; } diff --git a/tests/by-util/test_head.rs b/tests/by-util/test_head.rs index 8406c3d908..9aedd28823 100644 --- a/tests/by-util/test_head.rs +++ b/tests/by-util/test_head.rs @@ -934,6 +934,22 @@ fn test_do_not_attempt_to_read_a_directory() { .stderr_contains("error reading '.'"); } +/// Regression test: reading an unreadable special file such as /proc/self/mem +/// must report a read error on that file, not a "error writing +/// 'standard output'" message. +#[test] +#[cfg(target_os = "linux")] +#[cfg_attr(wasi_runner, ignore = "WASI sandbox: host paths (/proc) not visible")] +fn test_proc_self_mem_reports_read_error() { + new_ucmd!() + .arg("/proc/self/mem") + .fails_with_code(1) + // Must mention the file, not stdout + .stderr_contains("error reading '/proc/self/mem'") + // Must NOT blame standard output + .stderr_does_not_contain("error writing 'standard output'"); +} + /// Regression test for https://github.com/uutils/coreutils/issues/12215 /// `head -c0 ` should succeed (nothing to read), matching GNU. #[test]