diff --git a/rust/docker-compose.yml b/rust/docker-compose.yml new file mode 100644 index 0000000..cc3f89c --- /dev/null +++ b/rust/docker-compose.yml @@ -0,0 +1,20 @@ +services: + postgres: + image: postgres:15 + environment: + POSTGRES_DB: postgres + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + volumes: + - postgres-data:/var/lib/postgresql/data + ports: + - "5432:5432" + networks: + - app-network + +volumes: + postgres-data: + +networks: + app-network: + driver: bridge diff --git a/rust/impls/src/postgres_store.rs b/rust/impls/src/postgres_store.rs index 3bc176c..bef073b 100644 --- a/rust/impls/src/postgres_store.rs +++ b/rust/impls/src/postgres_store.rs @@ -660,10 +660,15 @@ where } #[cfg(test)] + mod tests { use super::{drop_database, DUMMY_MIGRATION, MIGRATIONS}; use crate::postgres_store::PostgresPlaintextBackend; use api::define_kv_store_tests; + use api::kv_store::KvStore; + use api::types::{DeleteObjectRequest, GetObjectRequest, KeyValue, PutObjectRequest}; + + use bytes::Bytes; use tokio::sync::OnceCell; use tokio_postgres::NoTls; @@ -779,4 +784,70 @@ mod tests { drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await.unwrap(); } + + #[tokio::test] + async fn supports_objects_up_to_non_large_object_threshold() { + let vss_db = "supports_objects_up_to_non_large_object_threshold"; + let _ = drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await; + + const MAXIMUM_SUPPORTED_VALUE_SIZE: usize = 1024 * 1024 * 1024; + const PROTOCOL_OVERHEAD_MARGIN: usize = 150; + + // Construct entry that's for a field that's the maximum size of a non-"large_object" object + let large_value = vec![0u8; MAXIMUM_SUPPORTED_VALUE_SIZE - PROTOCOL_OVERHEAD_MARGIN]; + let kv = KeyValue { key: "k1".into(), version: 0, value: Bytes::from(large_value) }; + + { + let store = + PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap(); + let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap(); + assert_eq!(start, MIGRATIONS_START); + assert_eq!(end, MIGRATIONS_END); + assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]); + assert_eq!(store.get_schema_version().await, MIGRATIONS_END); + + // Round trip with non-large_object of threshold size + + store + .put( + "token".to_string(), + PutObjectRequest { + store_id: "store_id".to_string(), + global_version: None, + transaction_items: vec![kv], + delete_items: vec![], + }, + ) + .await + .unwrap(); + + let resp_kv = store + .get( + "token".to_string(), + GetObjectRequest { store_id: "store_id".to_string(), key: "k1".to_string() }, + ) + .await + .unwrap() + .value + .unwrap(); + assert_eq!( + resp_kv.value.len(), + MAXIMUM_SUPPORTED_VALUE_SIZE - PROTOCOL_OVERHEAD_MARGIN + ); + assert!(resp_kv.value.iter().all(|&b| b == 0)); + + store + .delete( + "token".to_string(), + DeleteObjectRequest { + store_id: "store_id".to_string(), + key_value: Some(resp_kv), + }, + ) + .await + .unwrap(); + }; + + drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await.unwrap(); + } } diff --git a/rust/server/src/main.rs b/rust/server/src/main.rs index 9092ff2..321f185 100644 --- a/rust/server/src/main.rs +++ b/rust/server/src/main.rs @@ -24,7 +24,7 @@ use auth_impls::jwt::JWTAuthorizer; #[cfg(feature = "sigs")] use auth_impls::signature::SignatureValidatingAuthorizer; use impls::postgres_store::{PostgresPlaintextBackend, PostgresTlsBackend}; -use vss_service::VssService; +use vss_service::{VssService, VssServiceConfig}; mod util; mod vss_service; @@ -37,6 +37,16 @@ fn main() { eprintln!("Failed to load configuration: {}", e); std::process::exit(-1); }); + let vss_service_config = match &config.maximum_request_body_size { + Some(size) => match VssServiceConfig::new(*size) { + Ok(config) => Arc::new(config), + Err(e) => { + eprintln!("Configuration validation error: {}", e); + return; + }, + }, + None => Arc::new(VssServiceConfig::default()), + }; let runtime = match tokio::runtime::Builder::new_multi_thread().enable_all().build() { Ok(runtime) => Arc::new(runtime), @@ -132,7 +142,7 @@ fn main() { match res { Ok((stream, _)) => { let io_stream = TokioIo::new(stream); - let vss_service = VssService::new(Arc::clone(&store), Arc::clone(&authorizer)); + let vss_service = VssService::new(Arc::clone(&store), Arc::clone(&authorizer), Arc::clone(&vss_service_config)); runtime.spawn(async move { if let Err(err) = http1::Builder::new().serve_connection(io_stream, vss_service).await { eprintln!("Failed to serve connection: {}", err); diff --git a/rust/server/src/util/config.rs b/rust/server/src/util/config.rs index 3236941..bbe9d94 100644 --- a/rust/server/src/util/config.rs +++ b/rust/server/src/util/config.rs @@ -2,6 +2,7 @@ use serde::Deserialize; use std::net::SocketAddr; const BIND_ADDR_VAR: &str = "VSS_BIND_ADDRESS"; +const MAX_REQUEST_BODY_SIZE: &str = "VSS_MAX_REQUEST_BODY_SIZE"; const JWT_RSA_PEM_VAR: &str = "VSS_JWT_RSA_PEM"; const PSQL_USER_VAR: &str = "VSS_PSQL_USERNAME"; const PSQL_PASS_VAR: &str = "VSS_PSQL_PASSWORD"; @@ -23,6 +24,7 @@ struct TomlConfig { #[derive(Deserialize)] struct ServerConfig { bind_address: Option, + maximum_request_body_size: Option, } #[derive(Deserialize)] @@ -48,6 +50,7 @@ struct TlsConfig { // Encapsulates the result of reading both the environment variables and the config file. pub(crate) struct Configuration { pub(crate) bind_address: SocketAddr, + pub(crate) maximum_request_body_size: Option, pub(crate) rsa_pem: Option, pub(crate) postgresql_prefix: String, pub(crate) default_db: String, @@ -85,6 +88,11 @@ pub(crate) fn load_configuration(config_file_path: Option<&str>) -> Result TomlConfig::default(), // All fields are set to `None` }; + let (bind_address_config, max_request_body_size_config) = match server_config { + Some(c) => (c.bind_address, c.maximum_request_body_size), + None => (None, None), + }; + let bind_address_env = read_env(BIND_ADDR_VAR)? .map(|addr| { addr.parse().map_err(|e| { @@ -94,11 +102,25 @@ pub(crate) fn load_configuration(config_file_path: Option<&str>) -> Result().map_err(|e| { + format!("Unable to parse the maximum request body size environment variable: {}", e) + }) + }) + .transpose()?; + let maximum_request_body_size = read_config( + maximum_request_body_size_env, + max_request_body_size_config, + "VSS server maximum request body size", + MAX_REQUEST_BODY_SIZE, + )?; + let rsa_pem_env = read_env(JWT_RSA_PEM_VAR)?; let rsa_pem = rsa_pem_env.or(jwt_auth_config.and_then(|config| config.rsa_pem)); @@ -155,5 +177,13 @@ pub(crate) fn load_configuration(config_file_path: Option<&str>) -> Result Result { + if maximum_request_body_size > MAXIMUM_REQUEST_BODY_SIZE { + return Err(format!( + "Request body size {} exceeds maximum {}", + maximum_request_body_size, MAXIMUM_REQUEST_BODY_SIZE + )); + } + + Ok(Self { maximum_request_body_size }) + } +} + +impl Default for VssServiceConfig { + fn default() -> Self { + Self { maximum_request_body_size: MAXIMUM_REQUEST_BODY_SIZE } + } +} + #[derive(Clone)] pub struct VssService { store: Arc, authorizer: Arc, + config: Arc, } impl VssService { - pub(crate) fn new(store: Arc, authorizer: Arc) -> Self { - Self { store, authorizer } + pub(crate) fn new( + store: Arc, authorizer: Arc, config: Arc, + ) -> Self { + Self { store, authorizer, config } } } @@ -41,22 +70,51 @@ impl Service> for VssService { let store = Arc::clone(&self.store); let authorizer = Arc::clone(&self.authorizer); let path = req.uri().path().to_owned(); + let maximum_request_body_size = self.config.maximum_request_body_size; Box::pin(async move { let prefix_stripped_path = path.strip_prefix(BASE_PATH_PREFIX).unwrap_or_default(); match prefix_stripped_path { "/getObject" => { - handle_request(store, authorizer, req, handle_get_object_request).await + handle_request( + store, + authorizer, + req, + maximum_request_body_size, + handle_get_object_request, + ) + .await }, "/putObjects" => { - handle_request(store, authorizer, req, handle_put_object_request).await + handle_request( + store, + authorizer, + req, + maximum_request_body_size, + handle_put_object_request, + ) + .await }, "/deleteObject" => { - handle_request(store, authorizer, req, handle_delete_object_request).await + handle_request( + store, + authorizer, + req, + maximum_request_body_size, + handle_delete_object_request, + ) + .await }, "/listKeyVersions" => { - handle_request(store, authorizer, req, handle_list_object_request).await + handle_request( + store, + authorizer, + req, + maximum_request_body_size, + handle_list_object_request, + ) + .await }, _ => { let error_msg = "Invalid request path.".as_bytes(); @@ -97,7 +155,7 @@ async fn handle_request< Fut: Future> + Send, >( store: Arc, authorizer: Arc, request: Request, - handler: F, + maximum_request_body_size: usize, handler: F, ) -> Result<>>::Response, hyper::Error> { let (parts, body) = request.into_parts(); let headers_map = parts @@ -110,8 +168,17 @@ async fn handle_request< Ok(auth_response) => auth_response.user_token, Err(e) => return Ok(build_error_response(e)), }; - // TODO: we should bound the amount of data we read to avoid allocating too much memory. - let bytes = body.collect().await?.to_bytes(); + + let limited_body = Limited::new(body, maximum_request_body_size); + let bytes = match limited_body.collect().await { + Ok(body) => body.to_bytes(), + Err(_) => { + return Ok(Response::builder() + .status(StatusCode::PAYLOAD_TOO_LARGE) + .body(Full::new(Bytes::from("Request body too large"))) + .unwrap()); + }, + }; match T::decode(bytes) { Ok(request) => match handler(store.clone(), user_token, request).await { Ok(response) => Ok(Response::builder() diff --git a/rust/server/vss-server-config.toml b/rust/server/vss-server-config.toml index 7a80dff..7412f8e 100644 --- a/rust/server/vss-server-config.toml +++ b/rust/server/vss-server-config.toml @@ -1,5 +1,6 @@ [server_config] bind_address = "127.0.0.1:8080" # Optional in TOML, can be overridden by env var `VSS_BIND_ADDRESS` +maximum_request_body_size = 1073741824 # Optional in TOML: maximum request body size in bytes capped at 1 GB, can be overriden by env var 'VSS_MAX_REQUEST_BODY_SIZE' # Uncomment the table below to verify JWT tokens in the HTTP Authorization header against the given RSA public key, # can be overridden by env var `VSS_JWT_RSA_PEM`