From 5dc7f8b2ddc981f2cb463a2157198309ef676f94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Omar=20Mu=C3=B1oz?= Date: Wed, 21 Jan 2026 10:44:12 -0600 Subject: [PATCH] before fixing latency --- .../com/example/demo/DemoApplication.java | 2 + .../com/example/demo/DocumentController.java | 131 ++++++++++++++++++ .../java/com/example/demo/FilterConfig.java | 1 + .../example/demo/RLSConnectionManager.java | 98 +++++-------- .../com/example/demo/RLSTestController.java | 30 ++-- 5 files changed, 181 insertions(+), 81 deletions(-) create mode 100644 src/main/java/com/example/demo/DocumentController.java diff --git a/src/main/java/com/example/demo/DemoApplication.java b/src/main/java/com/example/demo/DemoApplication.java index 64b538a..d0e41f8 100644 --- a/src/main/java/com/example/demo/DemoApplication.java +++ b/src/main/java/com/example/demo/DemoApplication.java @@ -2,8 +2,10 @@ package com.example.demo; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.transaction.annotation.EnableTransactionManagement; @SpringBootApplication +@EnableTransactionManagement public class DemoApplication { public static void main(String[] args) { diff --git a/src/main/java/com/example/demo/DocumentController.java b/src/main/java/com/example/demo/DocumentController.java new file mode 100644 index 0000000..25a0be9 --- /dev/null +++ b/src/main/java/com/example/demo/DocumentController.java @@ -0,0 +1,131 @@ +package com.example.demo; + +import jakarta.servlet.http.HttpServletRequest; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.*; + +import java.util.HashMap; +import java.util.Map; + +@RestController +@RequestMapping("/api/documents") +public class DocumentController { + + private final RLSConnectionManager rlsManager; + private final DocumentRepository documentRepository; + + public DocumentController(RLSConnectionManager rlsManager, DocumentRepository documentRepository) { + this.rlsManager = rlsManager; + this.documentRepository = documentRepository; + } + + /** + * CREATE - Insert a new document using repository + */ + @PostMapping + public ResponseEntity createDocument(HttpServletRequest request, + @RequestBody Document document) { + Long userId = (Long) request.getAttribute("userId"); + + // Force the user_id to match JWT user (prevent privilege escalation) + document.setUserId(userId); + + return rlsManager.executeWithRLSContext(userId, () -> { + Document saved = documentRepository.save(document); + return ResponseEntity.status(HttpStatus.CREATED).body(saved); + }); + } + + /** + * READ - Get all documents (filtered by RLS to current user) + */ + @GetMapping + public ResponseEntity> getAllDocuments(HttpServletRequest request) { + Long userId = (Long) request.getAttribute("userId"); + + return rlsManager.executeWithRLSContext(userId, () -> { + Iterable documents = documentRepository.findAll(); + return ResponseEntity.ok(documents); + }); + } + + /** + * READ - Get single document by ID (RLS ensures user owns it) + */ + @GetMapping("/{id}") + public ResponseEntity> getDocumentById(HttpServletRequest request, + @PathVariable Long id) { + Long userId = (Long) request.getAttribute("userId"); + + return rlsManager.executeWithRLSContext(userId, () -> { + return documentRepository.findById(id) + .>>map(doc -> { + Map response = new HashMap<>(); + response.put("document", doc); + return ResponseEntity.ok(response); + }) + .orElseGet(() -> { + Map error = new HashMap<>(); + error.put("error", "Document not found or access denied"); + return ResponseEntity.status(HttpStatus.NOT_FOUND).body(error); + }); + }); + } + + /** + * UPDATE - Update existing document (RLS ensures user owns it) + */ + @PutMapping("/{id}") + public ResponseEntity> updateDocument(HttpServletRequest request, + @PathVariable Long id, + @RequestBody Document updatedDocument) { + Long userId = (Long) request.getAttribute("userId"); + + return rlsManager.executeWithRLSContext(userId, () -> { + return documentRepository.findById(id) + .>>map(existingDoc -> { + // Update fields + existingDoc.setTitle(updatedDocument.getTitle()); + existingDoc.setContent(updatedDocument.getContent()); + // Don't allow changing user_id + + Document saved = documentRepository.save(existingDoc); + Map response = new HashMap<>(); + response.put("document", saved); + return ResponseEntity.ok(response); + }) + .orElseGet(() -> { + Map error = new HashMap<>(); + error.put("error", "Document not found or access denied"); + return ResponseEntity.status(HttpStatus.NOT_FOUND).body(error); + }); + }); + } + + /** + * DELETE - Delete document (RLS ensures user owns it) + */ + @DeleteMapping("/{id}") + public ResponseEntity> deleteDocument(HttpServletRequest request, + @PathVariable Long id) { + Long userId = (Long) request.getAttribute("userId"); + + return rlsManager.executeWithRLSContext(userId, () -> { + return documentRepository.findById(id) + .>>map(doc -> { + documentRepository.deleteById(id); + Map response = new HashMap<>(); + response.put("success", true); + response.put("message", "Document deleted"); + response.put("id", id); + return ResponseEntity.ok(response); + }) + .orElseGet(() -> { + Map error = new HashMap<>(); + error.put("error", "Document not found or access denied"); + return ResponseEntity.status(HttpStatus.NOT_FOUND).body(error); + }); + }); + } +} diff --git a/src/main/java/com/example/demo/FilterConfig.java b/src/main/java/com/example/demo/FilterConfig.java index a9c2cb5..2b81efa 100644 --- a/src/main/java/com/example/demo/FilterConfig.java +++ b/src/main/java/com/example/demo/FilterConfig.java @@ -22,6 +22,7 @@ public class FilterConfig { registrationBean.addUrlPatterns("/api/rls-test/setup"); registrationBean.addUrlPatterns("/api/rls-test/documents/*"); registrationBean.addUrlPatterns("/api/rls-test/context/*"); + registrationBean.addUrlPatterns("/api/documents/*"); registrationBean.setOrder(1); diff --git a/src/main/java/com/example/demo/RLSConnectionManager.java b/src/main/java/com/example/demo/RLSConnectionManager.java index 7db2942..bf7f233 100644 --- a/src/main/java/com/example/demo/RLSConnectionManager.java +++ b/src/main/java/com/example/demo/RLSConnectionManager.java @@ -1,95 +1,61 @@ package com.example.demo; import org.springframework.jdbc.core.JdbcTemplate; -import org.springframework.jdbc.datasource.SingleConnectionDataSource; import org.springframework.stereotype.Component; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.TransactionStatus; +import org.springframework.transaction.support.DefaultTransactionDefinition; -import javax.sql.DataSource; -import java.sql.Connection; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.function.Function; +import java.util.function.Supplier; /** * Manages PostgreSQL RLS context variables with proper connection scoping. * - * CRITICAL: This ensures context variables are set and reset on the SAME connection - * to prevent leaking between concurrent requests when connections are pooled. + * Uses programmatic transaction management for better performance. */ @Component public class RLSConnectionManager { - private final DataSource dataSource; + private final JdbcTemplate jdbcTemplate; + private final PlatformTransactionManager transactionManager; - public RLSConnectionManager(DataSource dataSource) { - this.dataSource = dataSource; + public RLSConnectionManager(JdbcTemplate jdbcTemplate, PlatformTransactionManager transactionManager) { + this.jdbcTemplate = jdbcTemplate; + this.transactionManager = transactionManager; } /** - * Executes an operation with RLS context variables set on a dedicated connection. - * The connection is obtained, configured, used, reset, and returned to pool - all atomically. + * Executes an operation with RLS context variables set. + * Uses programmatic transactions for better performance. * * @param userId The user ID to set in the context - * @param operation The operation to execute (receives JdbcTemplate bound to the connection) + * @param operation The operation to execute * @param Return type * @return Result of the operation */ - public T executeWithRLSContext(Long userId, Function operation) { - Connection connection = null; + public T executeWithRLSContext(Long userId, Supplier operation) { + // Start transaction + DefaultTransactionDefinition def = new DefaultTransactionDefinition(); + def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRED); + TransactionStatus status = transactionManager.getTransaction(def); + try { - // Get a connection from the pool - connection = dataSource.getConnection(); + // Set the RLS context variable (uses transaction connection) + jdbcTemplate.execute("SET LOCAL app.current_user_id = '" + userId + "'"); - // CRITICAL: Set connection to manual commit to ensure atomicity - connection.setAutoCommit(false); + // Execute the operation + T result = operation.get(); - try { - // Set the RLS context variable on THIS connection - try (Statement stmt = connection.createStatement()) { - stmt.execute("SET LOCAL app.current_user_id = '" + userId + "'"); - } - - // Create a JdbcTemplate bound to THIS specific connection - SingleConnectionDataSource singleConnectionDataSource = - new SingleConnectionDataSource(connection, true); - JdbcTemplate scopedTemplate = new JdbcTemplate(singleConnectionDataSource); - - // Execute the operation with the scoped template - T result = operation.apply(scopedTemplate); - - // Commit the transaction - connection.commit(); - - return result; - - } catch (Exception e) { - // Rollback on error - connection.rollback(); - throw new RuntimeException("Error executing RLS operation", e); - } finally { - // CRITICAL: Reset the context variable before returning connection to pool - try (Statement stmt = connection.createStatement()) { - stmt.execute("RESET app.current_user_id"); - } catch (Exception e) { - // Log but don't throw - connection will still be returned to pool - System.err.println("Warning: Failed to reset RLS context: " + e.getMessage()); - } - } + // Commit transaction (SET LOCAL auto-reverts after commit) + transactionManager.commit(status); - } catch (SQLException e) { - throw new RuntimeException("Failed to obtain database connection", e); - } finally { - // Return connection to pool - if (connection != null) { - try { - connection.setAutoCommit(true); // Restore default behavior - connection.close(); // Returns to pool - } catch (SQLException e) { - System.err.println("Error closing connection: " + e.getMessage()); - } - } + return result; + + } catch (Exception e) { + // Rollback on error + transactionManager.rollback(status); + throw new RuntimeException("Error executing RLS operation", e); } } - - } diff --git a/src/main/java/com/example/demo/RLSTestController.java b/src/main/java/com/example/demo/RLSTestController.java index f9b6e67..1cfe382 100644 --- a/src/main/java/com/example/demo/RLSTestController.java +++ b/src/main/java/com/example/demo/RLSTestController.java @@ -29,10 +29,10 @@ public class RLSTestController { // Get user ID from JWT (set by JwtFilter) Long userId = (Long) request.getAttribute("userId"); - return rlsManager.executeWithRLSContext(userId, scopedTemplate -> { + return rlsManager.executeWithRLSContext(userId, () -> { // This query will only return documents the user has access to (via RLS policy) String sql = "SELECT id, title, content, user_id FROM documents"; - return scopedTemplate.queryForList(sql); + return jdbcTemplate.queryForList(sql); }); } @@ -43,9 +43,9 @@ public class RLSTestController { public Map verifyContextVariable(HttpServletRequest request) { Long userId = (Long) request.getAttribute("userId"); - return rlsManager.executeWithRLSContext(userId, scopedTemplate -> { + return rlsManager.executeWithRLSContext(userId, () -> { // Query the context variable to verify it's set - String currentUserId = scopedTemplate.queryForObject( + String currentUserId = jdbcTemplate.queryForObject( "SELECT current_setting('app.current_user_id', true)", String.class ); @@ -68,8 +68,8 @@ public class RLSTestController { Map result = new HashMap<>(); // Set context for user 1 - rlsManager.executeWithRLSContext(1L, scopedTemplate -> { - String ctx = scopedTemplate.queryForObject( + rlsManager.executeWithRLSContext(1L, () -> { + String ctx = jdbcTemplate.queryForObject( "SELECT current_setting('app.current_user_id', true)", String.class ); @@ -94,8 +94,8 @@ public class RLSTestController { result.put("afterUser1", leakedContext); // Set context for user 2 - rlsManager.executeWithRLSContext(2L, scopedTemplate -> { - String ctx = scopedTemplate.queryForObject( + rlsManager.executeWithRLSContext(2L, () -> { + String ctx = jdbcTemplate.queryForObject( "SELECT current_setting('app.current_user_id', true)", String.class ); @@ -133,9 +133,9 @@ public class RLSTestController { // Get user ID from JWT Long userId = (Long) request.getAttribute("userId"); - return rlsManager.executeWithRLSContext(userId, scopedTemplate -> { + return rlsManager.executeWithRLSContext(userId, () -> { // Insert with the user context set - scopedTemplate.update( + jdbcTemplate.update( "INSERT INTO documents (title, content, user_id) VALUES (?, ?, ?)", title, content, userId ); @@ -188,20 +188,20 @@ public class RLSTestController { // Insert test data WITH RLS context set // Now that FORCE RLS is enabled, even our inserts must respect the policy - rlsManager.executeWithRLSContext(1L, scopedTemplate -> { - scopedTemplate.update( + rlsManager.executeWithRLSContext(1L, () -> { + jdbcTemplate.update( "INSERT INTO documents (title, content, user_id) VALUES (?, ?, ?)", "User 1 Document", "Private content for user 1", 1L ); - scopedTemplate.update( + jdbcTemplate.update( "INSERT INTO documents (title, content, user_id) VALUES (?, ?, ?)", "Another User 1 Doc", "More private content for user 1", 1L ); return null; }); - rlsManager.executeWithRLSContext(2L, scopedTemplate -> { - scopedTemplate.update( + rlsManager.executeWithRLSContext(2L, () -> { + jdbcTemplate.update( "INSERT INTO documents (title, content, user_id) VALUES (?, ?, ?)", "User 2 Document", "Private content for user 2", 2L );