JWT works

This commit is contained in:
Omar Muñoz
2026-01-21 10:07:13 -06:00
parent 289de1c7e6
commit 7e09621a7d
8 changed files with 294 additions and 34 deletions

17
pom.xml
View File

@@ -44,6 +44,23 @@
<artifactId>postgresql</artifactId> <artifactId>postgresql</artifactId>
<scope>runtime</scope> <scope>runtime</scope>
</dependency> </dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-api</artifactId>
<version>0.12.6</version>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-impl</artifactId>
<version>0.12.6</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-jackson</artifactId>
<version>0.12.6</version>
<scope>runtime</scope>
</dependency>
<dependency> <dependency>
<groupId>org.springframework.boot</groupId> <groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-jdbc-test</artifactId> <artifactId>spring-boot-starter-data-jdbc-test</artifactId>

View File

@@ -0,0 +1,65 @@
package com.example.demo;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.web.bind.annotation.*;
import java.util.HashMap;
import java.util.Map;
@RestController
@RequestMapping("/api/auth")
public class AuthController {
private final JwtUtil jwtUtil;
public AuthController(JwtUtil jwtUtil) {
this.jwtUtil = jwtUtil;
}
/**
* Login endpoint - generates JWT and sets it as HTTP-only cookie
* In production, you'd validate credentials against a database
*/
@PostMapping("/login")
public Map<String, Object> login(@RequestParam Long userId, HttpServletResponse response) {
// Generate JWT token
String token = jwtUtil.generateToken(userId);
// Set JWT as HTTP-only cookie
Cookie cookie = new Cookie("jwt", token);
cookie.setHttpOnly(true);
cookie.setSecure(false); // Set to true in production with HTTPS
cookie.setPath("/");
cookie.setMaxAge(24 * 60 * 60); // 24 hours
response.addCookie(cookie);
Map<String, Object> result = new HashMap<>();
result.put("success", true);
result.put("userId", userId);
result.put("message", "Login successful");
return result;
}
/**
* Logout endpoint - clears the JWT cookie
*/
@PostMapping("/logout")
public Map<String, Object> logout(HttpServletResponse response) {
Cookie cookie = new Cookie("jwt", null);
cookie.setHttpOnly(true);
cookie.setSecure(false);
cookie.setPath("/");
cookie.setMaxAge(0); // Expire immediately
response.addCookie(cookie);
Map<String, Object> result = new HashMap<>();
result.put("success", true);
result.put("message", "Logout successful");
return result;
}
}

View File

@@ -0,0 +1,62 @@
package com.example.demo;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class FilterConfig {
private final JwtFilter jwtFilter;
public FilterConfig(JwtFilter jwtFilter) {
this.jwtFilter = jwtFilter;
}
@Bean
public FilterRegistrationBean<JwtFilter> jwtFilterRegistration() {
FilterRegistrationBean<JwtFilter> registrationBean = new FilterRegistrationBean<>();
registrationBean.setFilter(jwtFilter);
// Add protected URL patterns
registrationBean.addUrlPatterns("/api/rls-test/setup");
registrationBean.addUrlPatterns("/api/rls-test/documents/*");
registrationBean.addUrlPatterns("/api/rls-test/context/*");
registrationBean.setOrder(1);
return registrationBean;
}
}
/*
* ALTERNATIVE APPROACH: Exclude specific paths instead of including
*
* To apply JWT filter to ALL paths EXCEPT certain ones, replace the
* FilterRegistrationBean configuration with this approach:
*
* @Bean
* public FilterRegistrationBean<JwtFilter> jwtFilterRegistration() {
* FilterRegistrationBean<JwtFilter> registrationBean = new FilterRegistrationBean<>();
* registrationBean.setFilter(jwtFilter);
*
* // Apply to all paths
* registrationBean.addUrlPatterns("/*");
*
* registrationBean.setOrder(1);
*
* return registrationBean;
* }
*
* Then modify JwtFilter.shouldNotFilter() method to exclude paths:
*
* @Override
* protected boolean shouldNotFilter(HttpServletRequest request) {
* String path = request.getRequestURI();
* return path.startsWith("/api/auth/") || // Skip auth endpoints
* path.equals("/health") || // Skip health checks
* path.startsWith("/public/"); // Skip public resources
* }
*
* This approach is better when you have MORE protected paths than open paths.
* Current approach (explicit addUrlPatterns) is better when you have FEWER protected paths.
*/

View File

@@ -0,0 +1,55 @@
package com.example.demo;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;
import java.io.IOException;
@Component
public class JwtFilter extends OncePerRequestFilter {
private final JwtUtil jwtUtil;
public JwtFilter(JwtUtil jwtUtil) {
this.jwtUtil = jwtUtil;
}
@Override
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
// Extract JWT from cookie
String token = extractTokenFromCookie(request);
if (token != null && jwtUtil.validateToken(token)) {
// Extract user ID and add to request attribute
Long userId = jwtUtil.getUserIdFromToken(token);
request.setAttribute("userId", userId);
} else {
// No valid JWT - return 401 Unauthorized
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
response.getWriter().write("{\"error\": \"Unauthorized - Invalid or missing JWT\"}");
return;
}
filterChain.doFilter(request, response);
}
private String extractTokenFromCookie(HttpServletRequest request) {
Cookie[] cookies = request.getCookies();
if (cookies != null) {
for (Cookie cookie : cookies) {
if ("jwt".equals(cookie.getName())) {
return cookie.getValue();
}
}
}
return null;
}
}

View File

@@ -0,0 +1,68 @@
package com.example.demo;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.security.Keys;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import javax.crypto.SecretKey;
import java.nio.charset.StandardCharsets;
import java.util.Date;
@Component
public class JwtUtil {
@Value("${jwt.secret}")
private String secret;
@Value("${jwt.expiration}")
private long expiration;
private SecretKey getSigningKey() {
return Keys.hmacShaKeyFor(secret.getBytes(StandardCharsets.UTF_8));
}
/**
* Generate JWT token with user ID
*/
public String generateToken(Long userId) {
Date now = new Date();
Date expiryDate = new Date(now.getTime() + expiration);
return Jwts.builder()
.subject(userId.toString())
.issuedAt(now)
.expiration(expiryDate)
.signWith(getSigningKey())
.compact();
}
/**
* Extract user ID from JWT token
*/
public Long getUserIdFromToken(String token) {
Claims claims = Jwts.parser()
.verifyWith(getSigningKey())
.build()
.parseSignedClaims(token)
.getPayload();
return Long.parseLong(claims.getSubject());
}
/**
* Validate JWT token
*/
public boolean validateToken(String token) {
try {
Jwts.parser()
.verifyWith(getSigningKey())
.build()
.parseSignedClaims(token);
return true;
} catch (Exception e) {
return false;
}
}
}

View File

@@ -44,23 +44,12 @@ public class RLSConnectionManager {
connection.setAutoCommit(false); connection.setAutoCommit(false);
try { try {
// Set the RLS context variable on THIS connection using raw Statement // Set the RLS context variable on THIS connection
try (Statement stmt = connection.createStatement()) { try (Statement stmt = connection.createStatement()) {
stmt.execute("SET LOCAL app.current_user_id = '" + userId + "'"); stmt.execute("SET LOCAL app.current_user_id = '" + userId + "'");
System.out.println("RLS context set for user: " + userId);
}
// Verify it's set (for debugging)
try (Statement stmt = connection.createStatement()) {
var rs = stmt.executeQuery("SELECT current_setting('app.current_user_id', true)");
if (rs.next()) {
String value = rs.getString(1);
System.out.println("Verified context value: " + value);
}
} }
// Create a JdbcTemplate bound to THIS specific connection // Create a JdbcTemplate bound to THIS specific connection
// Use SingleConnectionDataSource to ensure the template uses this exact connection
SingleConnectionDataSource singleConnectionDataSource = SingleConnectionDataSource singleConnectionDataSource =
new SingleConnectionDataSource(connection, true); new SingleConnectionDataSource(connection, true);
JdbcTemplate scopedTemplate = new JdbcTemplate(singleConnectionDataSource); JdbcTemplate scopedTemplate = new JdbcTemplate(singleConnectionDataSource);
@@ -81,8 +70,8 @@ public class RLSConnectionManager {
// CRITICAL: Reset the context variable before returning connection to pool // CRITICAL: Reset the context variable before returning connection to pool
try (Statement stmt = connection.createStatement()) { try (Statement stmt = connection.createStatement()) {
stmt.execute("RESET app.current_user_id"); stmt.execute("RESET app.current_user_id");
System.out.println("RLS context reset");
} catch (Exception e) { } catch (Exception e) {
// Log but don't throw - connection will still be returned to pool
System.err.println("Warning: Failed to reset RLS context: " + e.getMessage()); System.err.println("Warning: Failed to reset RLS context: " + e.getMessage());
} }
} }
@@ -102,13 +91,5 @@ public class RLSConnectionManager {
} }
} }
/**
* Alternative approach: Execute raw SQL with RLS context.
* Useful when you need full control over the SQL.
*/
public <T> T executeRawSqlWithRLS(Long userId, String sql, Function<JdbcTemplate, T> resultExtractor) {
return executeWithRLSContext(userId, template -> {
return resultExtractor.apply(template);
});
}
} }

View File

@@ -1,5 +1,6 @@
package com.example.demo; package com.example.demo;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
@@ -12,22 +13,22 @@ import java.util.Map;
public class RLSTestController { public class RLSTestController {
private final RLSConnectionManager rlsManager; private final RLSConnectionManager rlsManager;
private final DocumentRepository documentRepository;
private final JdbcTemplate jdbcTemplate; private final JdbcTemplate jdbcTemplate;
public RLSTestController(RLSConnectionManager rlsManager, public RLSTestController(RLSConnectionManager rlsManager,
DocumentRepository documentRepository,
JdbcTemplate jdbcTemplate) { JdbcTemplate jdbcTemplate) {
this.rlsManager = rlsManager; this.rlsManager = rlsManager;
this.documentRepository = documentRepository;
this.jdbcTemplate = jdbcTemplate; this.jdbcTemplate = jdbcTemplate;
} }
/** /**
* Test 1: Execute raw SQL with RLS context * Test 1: Get documents using JWT user ID from filter
*/ */
@GetMapping("/documents/user/{userId}") @GetMapping("/documents")
public List<Map<String, Object>> getDocumentsWithRawSQL(@PathVariable Long userId) { public List<Map<String, Object>> getDocuments(HttpServletRequest request) {
// Get user ID from JWT (set by JwtFilter)
Long userId = (Long) request.getAttribute("userId");
return rlsManager.executeWithRLSContext(userId, scopedTemplate -> { return rlsManager.executeWithRLSContext(userId, scopedTemplate -> {
// This query will only return documents the user has access to (via RLS policy) // This query will only return documents the user has access to (via RLS policy)
String sql = "SELECT id, title, content, user_id FROM documents"; String sql = "SELECT id, title, content, user_id FROM documents";
@@ -38,8 +39,10 @@ public class RLSTestController {
/** /**
* Test 2: Verify context variable is set correctly * Test 2: Verify context variable is set correctly
*/ */
@GetMapping("/context/verify/{userId}") @GetMapping("/context/verify")
public Map<String, Object> verifyContextVariable(@PathVariable Long userId) { public Map<String, Object> verifyContextVariable(HttpServletRequest request) {
Long userId = (Long) request.getAttribute("userId");
return rlsManager.executeWithRLSContext(userId, scopedTemplate -> { return rlsManager.executeWithRLSContext(userId, scopedTemplate -> {
// Query the context variable to verify it's set // Query the context variable to verify it's set
String currentUserId = scopedTemplate.queryForObject( String currentUserId = scopedTemplate.queryForObject(
@@ -48,7 +51,7 @@ public class RLSTestController {
); );
Map<String, Object> result = new HashMap<>(); Map<String, Object> result = new HashMap<>();
result.put("requestedUserId", userId); result.put("jwtUserId", userId);
result.put("contextUserId", currentUserId); result.put("contextUserId", currentUserId);
result.put("match", userId.toString().equals(currentUserId)); result.put("match", userId.toString().equals(currentUserId));
@@ -60,7 +63,8 @@ public class RLSTestController {
* Test 3: Verify context is reset after request (simulate concurrent requests) * Test 3: Verify context is reset after request (simulate concurrent requests)
*/ */
@GetMapping("/context/isolation-test") @GetMapping("/context/isolation-test")
public Map<String, Object> testContextIsolation() throws InterruptedException { public Map<String, Object> testContextIsolation(HttpServletRequest request) throws InterruptedException {
Long currentUserId = (Long) request.getAttribute("userId");
Map<String, Object> result = new HashMap<>(); Map<String, Object> result = new HashMap<>();
// Set context for user 1 // Set context for user 1
@@ -123,9 +127,12 @@ public class RLSTestController {
* Test 4: Insert with RLS context (useful for audit trails) * Test 4: Insert with RLS context (useful for audit trails)
*/ */
@PostMapping("/documents") @PostMapping("/documents")
public Map<String, Object> createDocument(@RequestParam Long userId, public Map<String, Object> createDocument(HttpServletRequest request,
@RequestParam String title, @RequestParam String title,
@RequestParam String content) { @RequestParam String content) {
// Get user ID from JWT
Long userId = (Long) request.getAttribute("userId");
return rlsManager.executeWithRLSContext(userId, scopedTemplate -> { return rlsManager.executeWithRLSContext(userId, scopedTemplate -> {
// Insert with the user context set // Insert with the user context set
scopedTemplate.update( scopedTemplate.update(
@@ -146,7 +153,8 @@ public class RLSTestController {
* Setup endpoint - creates the table and test data * Setup endpoint - creates the table and test data
*/ */
@PostMapping("/setup") @PostMapping("/setup")
public String setupDatabase() { public String setupDatabase(HttpServletRequest request) {
Long currentUserId = (Long) request.getAttribute("userId");
// Drop existing table // Drop existing table
jdbcTemplate.execute("DROP TABLE IF EXISTS documents CASCADE"); jdbcTemplate.execute("DROP TABLE IF EXISTS documents CASCADE");

View File

@@ -5,3 +5,7 @@ spring.datasource.url=jdbc:postgresql://localhost:5432/rls_test
spring.datasource.username=rls_test spring.datasource.username=rls_test
spring.datasource.password=rls_test spring.datasource.password=rls_test
spring.datasource.hikari.maximum-pool-size=10 spring.datasource.hikari.maximum-pool-size=10
# JWT Configuration
jwt.secret=your-256-bit-secret-key-change-this-in-production-minimum-32-characters
jwt.expiration=86400000