From 7e09621a7dbc0d174044be4860c568d18366c9fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Omar=20Mu=C3=B1oz?= Date: Wed, 21 Jan 2026 10:07:13 -0600 Subject: [PATCH] JWT works --- pom.xml | 17 +++++ .../java/com/example/demo/AuthController.java | 65 ++++++++++++++++++ .../java/com/example/demo/FilterConfig.java | 62 +++++++++++++++++ src/main/java/com/example/demo/JwtFilter.java | 55 +++++++++++++++ src/main/java/com/example/demo/JwtUtil.java | 68 +++++++++++++++++++ .../example/demo/RLSConnectionManager.java | 25 +------ .../com/example/demo/RLSTestController.java | 32 +++++---- src/main/resources/application.properties | 4 ++ 8 files changed, 294 insertions(+), 34 deletions(-) create mode 100644 src/main/java/com/example/demo/AuthController.java create mode 100644 src/main/java/com/example/demo/FilterConfig.java create mode 100644 src/main/java/com/example/demo/JwtFilter.java create mode 100644 src/main/java/com/example/demo/JwtUtil.java diff --git a/pom.xml b/pom.xml index f7d5ca4..b430fe5 100644 --- a/pom.xml +++ b/pom.xml @@ -44,6 +44,23 @@ postgresql runtime + + io.jsonwebtoken + jjwt-api + 0.12.6 + + + io.jsonwebtoken + jjwt-impl + 0.12.6 + runtime + + + io.jsonwebtoken + jjwt-jackson + 0.12.6 + runtime + org.springframework.boot spring-boot-starter-data-jdbc-test diff --git a/src/main/java/com/example/demo/AuthController.java b/src/main/java/com/example/demo/AuthController.java new file mode 100644 index 0000000..4e2de3a --- /dev/null +++ b/src/main/java/com/example/demo/AuthController.java @@ -0,0 +1,65 @@ +package com.example.demo; + +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServletResponse; +import org.springframework.web.bind.annotation.*; + +import java.util.HashMap; +import java.util.Map; + +@RestController +@RequestMapping("/api/auth") +public class AuthController { + + private final JwtUtil jwtUtil; + + public AuthController(JwtUtil jwtUtil) { + this.jwtUtil = jwtUtil; + } + + /** + * Login endpoint - generates JWT and sets it as HTTP-only cookie + * In production, you'd validate credentials against a database + */ + @PostMapping("/login") + public Map login(@RequestParam Long userId, HttpServletResponse response) { + // Generate JWT token + String token = jwtUtil.generateToken(userId); + + // Set JWT as HTTP-only cookie + Cookie cookie = new Cookie("jwt", token); + cookie.setHttpOnly(true); + cookie.setSecure(false); // Set to true in production with HTTPS + cookie.setPath("/"); + cookie.setMaxAge(24 * 60 * 60); // 24 hours + + response.addCookie(cookie); + + Map result = new HashMap<>(); + result.put("success", true); + result.put("userId", userId); + result.put("message", "Login successful"); + + return result; + } + + /** + * Logout endpoint - clears the JWT cookie + */ + @PostMapping("/logout") + public Map logout(HttpServletResponse response) { + Cookie cookie = new Cookie("jwt", null); + cookie.setHttpOnly(true); + cookie.setSecure(false); + cookie.setPath("/"); + cookie.setMaxAge(0); // Expire immediately + + response.addCookie(cookie); + + Map result = new HashMap<>(); + result.put("success", true); + result.put("message", "Logout successful"); + + return result; + } +} diff --git a/src/main/java/com/example/demo/FilterConfig.java b/src/main/java/com/example/demo/FilterConfig.java new file mode 100644 index 0000000..a9c2cb5 --- /dev/null +++ b/src/main/java/com/example/demo/FilterConfig.java @@ -0,0 +1,62 @@ +package com.example.demo; + +import org.springframework.boot.web.servlet.FilterRegistrationBean; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class FilterConfig { + + private final JwtFilter jwtFilter; + + public FilterConfig(JwtFilter jwtFilter) { + this.jwtFilter = jwtFilter; + } + + @Bean + public FilterRegistrationBean jwtFilterRegistration() { + FilterRegistrationBean registrationBean = new FilterRegistrationBean<>(); + registrationBean.setFilter(jwtFilter); + + // Add protected URL patterns + registrationBean.addUrlPatterns("/api/rls-test/setup"); + registrationBean.addUrlPatterns("/api/rls-test/documents/*"); + registrationBean.addUrlPatterns("/api/rls-test/context/*"); + + registrationBean.setOrder(1); + + return registrationBean; + } +} +/* + * ALTERNATIVE APPROACH: Exclude specific paths instead of including + * + * To apply JWT filter to ALL paths EXCEPT certain ones, replace the + * FilterRegistrationBean configuration with this approach: + * + * @Bean + * public FilterRegistrationBean jwtFilterRegistration() { + * FilterRegistrationBean registrationBean = new FilterRegistrationBean<>(); + * registrationBean.setFilter(jwtFilter); + * + * // Apply to all paths + * registrationBean.addUrlPatterns("/*"); + * + * registrationBean.setOrder(1); + * + * return registrationBean; + * } + * + * Then modify JwtFilter.shouldNotFilter() method to exclude paths: + * + * @Override + * protected boolean shouldNotFilter(HttpServletRequest request) { + * String path = request.getRequestURI(); + * return path.startsWith("/api/auth/") || // Skip auth endpoints + * path.equals("/health") || // Skip health checks + * path.startsWith("/public/"); // Skip public resources + * } + * + * This approach is better when you have MORE protected paths than open paths. + * Current approach (explicit addUrlPatterns) is better when you have FEWER protected paths. + */ diff --git a/src/main/java/com/example/demo/JwtFilter.java b/src/main/java/com/example/demo/JwtFilter.java new file mode 100644 index 0000000..75271e1 --- /dev/null +++ b/src/main/java/com/example/demo/JwtFilter.java @@ -0,0 +1,55 @@ +package com.example.demo; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.springframework.stereotype.Component; +import org.springframework.web.filter.OncePerRequestFilter; + +import java.io.IOException; + +@Component +public class JwtFilter extends OncePerRequestFilter { + + private final JwtUtil jwtUtil; + + public JwtFilter(JwtUtil jwtUtil) { + this.jwtUtil = jwtUtil; + } + + @Override + protected void doFilterInternal(HttpServletRequest request, + HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + + // Extract JWT from cookie + String token = extractTokenFromCookie(request); + + if (token != null && jwtUtil.validateToken(token)) { + // Extract user ID and add to request attribute + Long userId = jwtUtil.getUserIdFromToken(token); + request.setAttribute("userId", userId); + } else { + // No valid JWT - return 401 Unauthorized + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED); + response.getWriter().write("{\"error\": \"Unauthorized - Invalid or missing JWT\"}"); + return; + } + + filterChain.doFilter(request, response); + } + + private String extractTokenFromCookie(HttpServletRequest request) { + Cookie[] cookies = request.getCookies(); + if (cookies != null) { + for (Cookie cookie : cookies) { + if ("jwt".equals(cookie.getName())) { + return cookie.getValue(); + } + } + } + return null; + } +} diff --git a/src/main/java/com/example/demo/JwtUtil.java b/src/main/java/com/example/demo/JwtUtil.java new file mode 100644 index 0000000..34485ab --- /dev/null +++ b/src/main/java/com/example/demo/JwtUtil.java @@ -0,0 +1,68 @@ +package com.example.demo; + +import io.jsonwebtoken.Claims; +import io.jsonwebtoken.Jwts; +import io.jsonwebtoken.security.Keys; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; + +import javax.crypto.SecretKey; +import java.nio.charset.StandardCharsets; +import java.util.Date; + +@Component +public class JwtUtil { + + @Value("${jwt.secret}") + private String secret; + + @Value("${jwt.expiration}") + private long expiration; + + private SecretKey getSigningKey() { + return Keys.hmacShaKeyFor(secret.getBytes(StandardCharsets.UTF_8)); + } + + /** + * Generate JWT token with user ID + */ + public String generateToken(Long userId) { + Date now = new Date(); + Date expiryDate = new Date(now.getTime() + expiration); + + return Jwts.builder() + .subject(userId.toString()) + .issuedAt(now) + .expiration(expiryDate) + .signWith(getSigningKey()) + .compact(); + } + + /** + * Extract user ID from JWT token + */ + public Long getUserIdFromToken(String token) { + Claims claims = Jwts.parser() + .verifyWith(getSigningKey()) + .build() + .parseSignedClaims(token) + .getPayload(); + + return Long.parseLong(claims.getSubject()); + } + + /** + * Validate JWT token + */ + public boolean validateToken(String token) { + try { + Jwts.parser() + .verifyWith(getSigningKey()) + .build() + .parseSignedClaims(token); + return true; + } catch (Exception e) { + return false; + } + } +} diff --git a/src/main/java/com/example/demo/RLSConnectionManager.java b/src/main/java/com/example/demo/RLSConnectionManager.java index cb31a25..7db2942 100644 --- a/src/main/java/com/example/demo/RLSConnectionManager.java +++ b/src/main/java/com/example/demo/RLSConnectionManager.java @@ -44,23 +44,12 @@ public class RLSConnectionManager { connection.setAutoCommit(false); try { - // Set the RLS context variable on THIS connection using raw Statement + // Set the RLS context variable on THIS connection try (Statement stmt = connection.createStatement()) { stmt.execute("SET LOCAL app.current_user_id = '" + userId + "'"); - System.out.println("RLS context set for user: " + userId); - } - - // Verify it's set (for debugging) - try (Statement stmt = connection.createStatement()) { - var rs = stmt.executeQuery("SELECT current_setting('app.current_user_id', true)"); - if (rs.next()) { - String value = rs.getString(1); - System.out.println("Verified context value: " + value); - } } // Create a JdbcTemplate bound to THIS specific connection - // Use SingleConnectionDataSource to ensure the template uses this exact connection SingleConnectionDataSource singleConnectionDataSource = new SingleConnectionDataSource(connection, true); JdbcTemplate scopedTemplate = new JdbcTemplate(singleConnectionDataSource); @@ -81,8 +70,8 @@ public class RLSConnectionManager { // CRITICAL: Reset the context variable before returning connection to pool try (Statement stmt = connection.createStatement()) { stmt.execute("RESET app.current_user_id"); - System.out.println("RLS context reset"); } catch (Exception e) { + // Log but don't throw - connection will still be returned to pool System.err.println("Warning: Failed to reset RLS context: " + e.getMessage()); } } @@ -102,13 +91,5 @@ public class RLSConnectionManager { } } - /** - * Alternative approach: Execute raw SQL with RLS context. - * Useful when you need full control over the SQL. - */ - public T executeRawSqlWithRLS(Long userId, String sql, Function resultExtractor) { - return executeWithRLSContext(userId, template -> { - return resultExtractor.apply(template); - }); - } + } diff --git a/src/main/java/com/example/demo/RLSTestController.java b/src/main/java/com/example/demo/RLSTestController.java index 50f21eb..f9b6e67 100644 --- a/src/main/java/com/example/demo/RLSTestController.java +++ b/src/main/java/com/example/demo/RLSTestController.java @@ -1,5 +1,6 @@ package com.example.demo; +import jakarta.servlet.http.HttpServletRequest; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.web.bind.annotation.*; @@ -12,22 +13,22 @@ import java.util.Map; public class RLSTestController { private final RLSConnectionManager rlsManager; - private final DocumentRepository documentRepository; private final JdbcTemplate jdbcTemplate; public RLSTestController(RLSConnectionManager rlsManager, - DocumentRepository documentRepository, JdbcTemplate jdbcTemplate) { this.rlsManager = rlsManager; - this.documentRepository = documentRepository; this.jdbcTemplate = jdbcTemplate; } /** - * Test 1: Execute raw SQL with RLS context + * Test 1: Get documents using JWT user ID from filter */ - @GetMapping("/documents/user/{userId}") - public List> getDocumentsWithRawSQL(@PathVariable Long userId) { + @GetMapping("/documents") + public List> getDocuments(HttpServletRequest request) { + // Get user ID from JWT (set by JwtFilter) + Long userId = (Long) request.getAttribute("userId"); + return rlsManager.executeWithRLSContext(userId, scopedTemplate -> { // This query will only return documents the user has access to (via RLS policy) String sql = "SELECT id, title, content, user_id FROM documents"; @@ -38,8 +39,10 @@ public class RLSTestController { /** * Test 2: Verify context variable is set correctly */ - @GetMapping("/context/verify/{userId}") - public Map verifyContextVariable(@PathVariable Long userId) { + @GetMapping("/context/verify") + public Map verifyContextVariable(HttpServletRequest request) { + Long userId = (Long) request.getAttribute("userId"); + return rlsManager.executeWithRLSContext(userId, scopedTemplate -> { // Query the context variable to verify it's set String currentUserId = scopedTemplate.queryForObject( @@ -48,7 +51,7 @@ public class RLSTestController { ); Map result = new HashMap<>(); - result.put("requestedUserId", userId); + result.put("jwtUserId", userId); result.put("contextUserId", currentUserId); result.put("match", userId.toString().equals(currentUserId)); @@ -60,7 +63,8 @@ public class RLSTestController { * Test 3: Verify context is reset after request (simulate concurrent requests) */ @GetMapping("/context/isolation-test") - public Map testContextIsolation() throws InterruptedException { + public Map testContextIsolation(HttpServletRequest request) throws InterruptedException { + Long currentUserId = (Long) request.getAttribute("userId"); Map result = new HashMap<>(); // Set context for user 1 @@ -123,9 +127,12 @@ public class RLSTestController { * Test 4: Insert with RLS context (useful for audit trails) */ @PostMapping("/documents") - public Map createDocument(@RequestParam Long userId, + public Map createDocument(HttpServletRequest request, @RequestParam String title, @RequestParam String content) { + // Get user ID from JWT + Long userId = (Long) request.getAttribute("userId"); + return rlsManager.executeWithRLSContext(userId, scopedTemplate -> { // Insert with the user context set scopedTemplate.update( @@ -146,7 +153,8 @@ public class RLSTestController { * Setup endpoint - creates the table and test data */ @PostMapping("/setup") - public String setupDatabase() { + public String setupDatabase(HttpServletRequest request) { + Long currentUserId = (Long) request.getAttribute("userId"); // Drop existing table jdbcTemplate.execute("DROP TABLE IF EXISTS documents CASCADE"); diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 60ff06e..dc206f5 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -5,3 +5,7 @@ spring.datasource.url=jdbc:postgresql://localhost:5432/rls_test spring.datasource.username=rls_test spring.datasource.password=rls_test spring.datasource.hikari.maximum-pool-size=10 + +# JWT Configuration +jwt.secret=your-256-bit-secret-key-change-this-in-production-minimum-32-characters +jwt.expiration=86400000