diff --git a/examples/list.rs b/examples/list.rs index e618183..1f4501b 100644 --- a/examples/list.rs +++ b/examples/list.rs @@ -6,6 +6,8 @@ extern crate env_logger; fn main() { env_logger::init(); + // We allow unused_unsafe here because the call is unsafe only on some targets. + #[allow(unused_unsafe)] let mut users: Vec = unsafe { all_users() }.collect(); users.sort_by(|a, b| a.uid().cmp(&b.uid())); diff --git a/src/base.rs b/src/base.rs index 5f7e274..f578bb6 100644 --- a/src/base.rs +++ b/src/base.rs @@ -29,6 +29,7 @@ //! best bet is to check for them yourself before passing strings into any //! functions. +use std::convert::TryFrom; use std::ffi::{CStr, CString, OsStr, OsString}; use std::fmt; use std::io; @@ -829,7 +830,10 @@ pub fn get_user_groups + ?Sized>(username: &S, gid: gid_t) -> Op } /// An iterator over every user present on the system. -struct AllUsers; +struct AllUsers { + #[cfg(target_os = "linux")] + file: *mut libc::FILE, +} /// Creates a new iterator over every user present on the system. /// @@ -868,13 +872,68 @@ struct AllUsers; /// println!("User #{} ({:?})", user.uid(), user.name()); /// } /// ``` +#[cfg(not(target_os = "linux"))] pub unsafe fn all_users() -> impl Iterator { #[cfg(feature = "logging")] trace!("Running setpwent"); #[cfg(not(target_os = "android"))] libc::setpwent(); - AllUsers + + AllUsers {} +} + +/// Creates a new iterator over every user present on the system. +/// +/// # libc functions used +/// +/// - [`fopen`](https://docs.rs/libc/*/libc/fn.fopen.html) +/// - [`fgetpwent_r`](https://docs.rs/libc/*/libc/fn.fgetpwent_r.html) +/// - [`fclose`](https://docs.rs/libc/*/libc/fn.fclose.html) +/// +/// # Examples +/// +/// ``` +/// use uzers::all_users; +/// +/// let iter = all_users(); +/// for user in iter { +/// println!("User #{} ({:?})", user.uid(), user.name()); +/// } +/// ``` +#[cfg(target_os = "linux")] +pub fn all_users() -> impl Iterator { + all_users_from_file("/etc/passwd") +} + +/// Creates a new iterator over every user defined in the given passwd file. +/// +/// # libc functions used +/// +/// - [`fopen`](https://docs.rs/libc/*/libc/fn.fopen.html) +/// - [`fgetpwent_r`](https://docs.rs/libc/*/libc/fn.fgetpwent_r.html) +/// - [`fclose`](https://docs.rs/libc/*/libc/fn.fclose.html) +/// +/// # Examples +/// +/// ``` +/// use uzers::all_users_from_file; +/// +/// let iter = all_users_from_file("mypasswd"); +/// for user in iter { +/// println!("User #{} ({:?})", user.uid(), user.name()); +/// } +/// ``` +#[cfg(target_os = "linux")] +pub fn all_users_from_file>(passwd_file_path: S) -> impl Iterator { + let mut result = AllUsers { + file: std::ptr::null_mut(), + }; + + let file_path = CString::new(passwd_file_path.as_ref()).unwrap(); + result.file = unsafe { libc::fopen(file_path.as_ptr(), CString::new("r").unwrap().as_ptr()) }; + + result } impl Drop for AllUsers { @@ -883,13 +942,20 @@ impl Drop for AllUsers { // nothing to do here } - #[cfg(not(target_os = "android"))] + #[cfg(not(any(target_os = "android", target_os = "linux")))] fn drop(&mut self) { #[cfg(feature = "logging")] trace!("Running endpwent"); unsafe { libc::endpwent() }; } + + #[cfg(target_os = "linux")] + fn drop(&mut self) { + if !self.file.is_null() { + unsafe { libc::fclose(self.file) }; + } + } } impl Iterator for AllUsers { @@ -900,7 +966,7 @@ impl Iterator for AllUsers { None } - #[cfg(not(target_os = "android"))] + #[cfg(not(any(target_os = "android", target_os = "linux")))] fn next(&mut self) -> Option { #[cfg(feature = "logging")] trace!("Running getpwent"); @@ -914,10 +980,59 @@ impl Iterator for AllUsers { Some(user) } } + + #[cfg(target_os = "linux")] + fn next(&mut self) -> Option { + if self.file.is_null() { + // We weren't able to open the backing file, so we have no users to return. + return None; + } + + // Compute the maximum size of the buffer. + let buffer_len = unsafe { libc::sysconf(libc::_SC_GETPW_R_SIZE_MAX) }; + if buffer_len == -1 { + // We couldn't compute the required buffer size, so we can't read the user. + return None; + } + + let buffer_len = usize::try_from(buffer_len).unwrap_or(0); + if buffer_len == 0 { + // The buffer size was invalid, we can't proceed. + return None; + } + + // Create the buffers we need. + let mut buffer = vec![0; buffer_len]; + let buffer_ptr = buffer.as_mut_ptr() as *mut c_char; + let mut pwd = c_passwd { + pw_name: std::ptr::null_mut(), + pw_passwd: std::ptr::null_mut(), + pw_uid: 0, + pw_gid: 0, + pw_gecos: std::ptr::null_mut(), + pw_dir: std::ptr::null_mut(), + pw_shell: std::ptr::null_mut(), + }; + + // Call fgetpwent_r to read the next entry. + let mut result = ptr::null_mut(); + let ret = + unsafe { libc::fgetpwent_r(self.file, &mut pwd, buffer_ptr, buffer_len, &mut result) }; + if ret != 0 || result.is_null() || result != &mut pwd { + // We expect to get a pointer to `pwd` back; in any other case, we can't safely proceed. + return None; + } + + // Parse the struct and return it. + Some(unsafe { passwd_to_user(result.read()) }) + } } /// An iterator over every group present on the system. -struct AllGroups; +struct AllGroups { + #[cfg(target_os = "linux")] + file: *mut libc::FILE, +} /// Creates a new iterator over every group present on the system. /// @@ -956,13 +1071,67 @@ struct AllGroups; /// println!("Group #{} ({:?})", group.gid(), group.name()); /// } /// ``` +#[cfg(not(target_os = "linux"))] pub unsafe fn all_groups() -> impl Iterator { #[cfg(feature = "logging")] trace!("Running setgrent"); #[cfg(not(target_os = "android"))] libc::setgrent(); - AllGroups + AllGroups {} +} + +/// Creates a new iterator over every group present on the system. +/// +/// # libc functions used +/// +/// - [`fopen`](https://docs.rs/libc/*/libc/fn.fopen.html) +/// - [`fgetgrent_r`](https://docs.rs/libc/*/libc/fn.fgetgrent_r.html) +/// - [`fclose`](https://docs.rs/libc/*/libc/fn.fclose.html) +/// +/// # Examples +/// +/// ``` +/// use uzers::all_groups; +/// +/// let iter = all_groups(); +/// for group in iter { +/// println!("Group #{} ({:?})", group.gid(), group.name()); +/// } +/// ``` +#[cfg(target_os = "linux")] +pub fn all_groups() -> impl Iterator { + all_groups_from_file("/etc/group") +} + +/// Creates a new iterator over every group present in the provided group file. +/// +/// # libc functions used +/// +/// - [`fopen`](https://docs.rs/libc/*/libc/fn.fopen.html) +/// - [`fgetgrent_r`](https://docs.rs/libc/*/libc/fn.fgetgrent_r.html) +/// - [`fclose`](https://docs.rs/libc/*/libc/fn.fclose.html) +/// +/// # Examples +/// +/// ``` +/// use uzers::all_groups_from_file; +/// +/// let iter = all_groups_from_file("mygroup"); +/// for group in iter { +/// println!("Group #{} ({:?})", group.gid(), group.name()); +/// } +/// ``` +#[cfg(target_os = "linux")] +pub fn all_groups_from_file>(file_path: S) -> impl Iterator { + let mut result = AllGroups { + file: std::ptr::null_mut(), + }; + + let file_path = CString::new(file_path.as_ref()).unwrap(); + result.file = unsafe { libc::fopen(file_path.as_ptr(), CString::new("r").unwrap().as_ptr()) }; + + result } impl Drop for AllGroups { @@ -971,13 +1140,20 @@ impl Drop for AllGroups { // nothing to do here } - #[cfg(not(target_os = "android"))] + #[cfg(not(any(target_os = "android", target_os = "linux")))] fn drop(&mut self) { #[cfg(feature = "logging")] trace!("Running endgrent"); unsafe { libc::endgrent() }; } + + #[cfg(target_os = "linux")] + fn drop(&mut self) { + if !self.file.is_null() { + unsafe { libc::fclose(self.file) }; + } + } } impl Iterator for AllGroups { @@ -988,7 +1164,7 @@ impl Iterator for AllGroups { None } - #[cfg(not(target_os = "android"))] + #[cfg(not(any(target_os = "android", target_os = "linux")))] fn next(&mut self) -> Option { #[cfg(feature = "logging")] trace!("Running getgrent"); @@ -1002,6 +1178,50 @@ impl Iterator for AllGroups { Some(group) } } + + #[cfg(target_os = "linux")] + fn next(&mut self) -> Option { + if self.file.is_null() { + // We weren't able to open the backing file, so we have no groups to return. + return None; + } + + // Compute the maximum size of the buffer. + let buffer_len = unsafe { libc::sysconf(libc::_SC_GETGR_R_SIZE_MAX) }; + if buffer_len == -1 { + // We couldn't compute the required buffer size, so we can't read the group. + return None; + } + + let buffer_len = usize::try_from(buffer_len).unwrap_or(0); + if buffer_len == 0 { + // The buffer size was invalid, we can't proceed. + return None; + } + + // Create the buffers we need. + let mut buffer = vec![0; buffer_len]; + let buffer_ptr = buffer.as_mut_ptr() as *mut c_char; + let mut group = c_group { + gr_name: std::ptr::null_mut(), + gr_passwd: std::ptr::null_mut(), + gr_gid: 0, + gr_mem: std::ptr::null_mut(), + }; + + // Call fgetgrent_r to read the next entry. + let mut result = ptr::null_mut(); + let ret = unsafe { + libc::fgetgrent_r(self.file, &mut group, buffer_ptr, buffer_len, &mut result) + }; + if ret != 0 || result.is_null() || result != &mut group { + // We expect to get a pointer to `group` back; in any other case, we can't safely proceed. + return None; + } + + // Parse the struct and return it. + Some(unsafe { struct_to_group(result.read()) }) + } } /// OS-specific extensions to users and groups. diff --git a/src/lib.rs b/src/lib.rs index 7c5b43c..4e42bd1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -143,6 +143,9 @@ pub use base::{get_user_by_name, get_user_by_uid}; pub use base::{get_user_groups, group_access_list}; pub use base::{os, Group, User}; +#[cfg(target_os = "linux")] +pub use base::{all_groups_from_file, all_users_from_file}; + #[cfg(feature = "cache")] pub mod cache; diff --git a/tests/groups.rs b/tests/groups.rs index 618d33b..3d6472c 100644 --- a/tests/groups.rs +++ b/tests/groups.rs @@ -15,6 +15,25 @@ mod integration { assert_eq!(group.name(), "bosses"); } + #[cfg(target_os = "linux")] + #[test] + #[serial_test::serial] + fn test_all_groups_from_file() { + let test_group_file_path = std::env::var("NSS_WRAPPER_GROUP").unwrap(); + + let groups: Vec<_> = uzers::all_groups_from_file(test_group_file_path).collect(); + assert_eq!(groups.len(), 2); + + let group = &groups[0]; + assert_eq!(group.gid(), 42); + assert_eq!(group.name(), "bosses"); + + let group = &groups[1]; + assert_eq!(group.gid(), 43); + assert_eq!(group.name(), "contributors"); + } + + #[cfg(not(target_os = "linux"))] #[test] #[serial_test::serial] fn test_all_groups() { diff --git a/tests/users.rs b/tests/users.rs index c1f5793..aece110 100644 --- a/tests/users.rs +++ b/tests/users.rs @@ -21,6 +21,23 @@ mod integration { assert_eq!(user.home_dir(), PathBuf::from("/home/fred")); } + #[cfg(target_os = "linux")] + #[test] + #[serial_test::serial] + fn test_all_users_from_file() { + let test_passwd_file_path = std::env::var("NSS_WRAPPER_PASSWD").unwrap(); + + let users: Vec<_> = uzers::all_users_from_file(test_passwd_file_path).collect(); + assert_eq!(users.len(), 1); + + let user = users.first().unwrap(); + assert_eq!(user.uid(), 1337); + assert_eq!(user.name(), "fred"); + assert_eq!(user.primary_group_id(), 42); + assert_eq!(user.home_dir(), PathBuf::from("/home/fred")); + } + + #[cfg(not(target_os = "linux"))] #[test] #[serial_test::serial] fn test_all_users() {