JWT works
This commit is contained in:
17
pom.xml
17
pom.xml
@@ -44,6 +44,23 @@
|
||||
<artifactId>postgresql</artifactId>
|
||||
<scope>runtime</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.jsonwebtoken</groupId>
|
||||
<artifactId>jjwt-api</artifactId>
|
||||
<version>0.12.6</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.jsonwebtoken</groupId>
|
||||
<artifactId>jjwt-impl</artifactId>
|
||||
<version>0.12.6</version>
|
||||
<scope>runtime</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.jsonwebtoken</groupId>
|
||||
<artifactId>jjwt-jackson</artifactId>
|
||||
<version>0.12.6</version>
|
||||
<scope>runtime</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-data-jdbc-test</artifactId>
|
||||
|
||||
65
src/main/java/com/example/demo/AuthController.java
Normal file
65
src/main/java/com/example/demo/AuthController.java
Normal file
@@ -0,0 +1,65 @@
|
||||
package com.example.demo;
|
||||
|
||||
import jakarta.servlet.http.Cookie;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/api/auth")
|
||||
public class AuthController {
|
||||
|
||||
private final JwtUtil jwtUtil;
|
||||
|
||||
public AuthController(JwtUtil jwtUtil) {
|
||||
this.jwtUtil = jwtUtil;
|
||||
}
|
||||
|
||||
/**
|
||||
* Login endpoint - generates JWT and sets it as HTTP-only cookie
|
||||
* In production, you'd validate credentials against a database
|
||||
*/
|
||||
@PostMapping("/login")
|
||||
public Map<String, Object> login(@RequestParam Long userId, HttpServletResponse response) {
|
||||
// Generate JWT token
|
||||
String token = jwtUtil.generateToken(userId);
|
||||
|
||||
// Set JWT as HTTP-only cookie
|
||||
Cookie cookie = new Cookie("jwt", token);
|
||||
cookie.setHttpOnly(true);
|
||||
cookie.setSecure(false); // Set to true in production with HTTPS
|
||||
cookie.setPath("/");
|
||||
cookie.setMaxAge(24 * 60 * 60); // 24 hours
|
||||
|
||||
response.addCookie(cookie);
|
||||
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("success", true);
|
||||
result.put("userId", userId);
|
||||
result.put("message", "Login successful");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Logout endpoint - clears the JWT cookie
|
||||
*/
|
||||
@PostMapping("/logout")
|
||||
public Map<String, Object> logout(HttpServletResponse response) {
|
||||
Cookie cookie = new Cookie("jwt", null);
|
||||
cookie.setHttpOnly(true);
|
||||
cookie.setSecure(false);
|
||||
cookie.setPath("/");
|
||||
cookie.setMaxAge(0); // Expire immediately
|
||||
|
||||
response.addCookie(cookie);
|
||||
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("success", true);
|
||||
result.put("message", "Logout successful");
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
62
src/main/java/com/example/demo/FilterConfig.java
Normal file
62
src/main/java/com/example/demo/FilterConfig.java
Normal file
@@ -0,0 +1,62 @@
|
||||
package com.example.demo;
|
||||
|
||||
import org.springframework.boot.web.servlet.FilterRegistrationBean;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
public class FilterConfig {
|
||||
|
||||
private final JwtFilter jwtFilter;
|
||||
|
||||
public FilterConfig(JwtFilter jwtFilter) {
|
||||
this.jwtFilter = jwtFilter;
|
||||
}
|
||||
|
||||
@Bean
|
||||
public FilterRegistrationBean<JwtFilter> jwtFilterRegistration() {
|
||||
FilterRegistrationBean<JwtFilter> registrationBean = new FilterRegistrationBean<>();
|
||||
registrationBean.setFilter(jwtFilter);
|
||||
|
||||
// Add protected URL patterns
|
||||
registrationBean.addUrlPatterns("/api/rls-test/setup");
|
||||
registrationBean.addUrlPatterns("/api/rls-test/documents/*");
|
||||
registrationBean.addUrlPatterns("/api/rls-test/context/*");
|
||||
|
||||
registrationBean.setOrder(1);
|
||||
|
||||
return registrationBean;
|
||||
}
|
||||
}
|
||||
/*
|
||||
* ALTERNATIVE APPROACH: Exclude specific paths instead of including
|
||||
*
|
||||
* To apply JWT filter to ALL paths EXCEPT certain ones, replace the
|
||||
* FilterRegistrationBean configuration with this approach:
|
||||
*
|
||||
* @Bean
|
||||
* public FilterRegistrationBean<JwtFilter> jwtFilterRegistration() {
|
||||
* FilterRegistrationBean<JwtFilter> registrationBean = new FilterRegistrationBean<>();
|
||||
* registrationBean.setFilter(jwtFilter);
|
||||
*
|
||||
* // Apply to all paths
|
||||
* registrationBean.addUrlPatterns("/*");
|
||||
*
|
||||
* registrationBean.setOrder(1);
|
||||
*
|
||||
* return registrationBean;
|
||||
* }
|
||||
*
|
||||
* Then modify JwtFilter.shouldNotFilter() method to exclude paths:
|
||||
*
|
||||
* @Override
|
||||
* protected boolean shouldNotFilter(HttpServletRequest request) {
|
||||
* String path = request.getRequestURI();
|
||||
* return path.startsWith("/api/auth/") || // Skip auth endpoints
|
||||
* path.equals("/health") || // Skip health checks
|
||||
* path.startsWith("/public/"); // Skip public resources
|
||||
* }
|
||||
*
|
||||
* This approach is better when you have MORE protected paths than open paths.
|
||||
* Current approach (explicit addUrlPatterns) is better when you have FEWER protected paths.
|
||||
*/
|
||||
55
src/main/java/com/example/demo/JwtFilter.java
Normal file
55
src/main/java/com/example/demo/JwtFilter.java
Normal file
@@ -0,0 +1,55 @@
|
||||
package com.example.demo;
|
||||
|
||||
import jakarta.servlet.FilterChain;
|
||||
import jakarta.servlet.ServletException;
|
||||
import jakarta.servlet.http.Cookie;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.filter.OncePerRequestFilter;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
@Component
|
||||
public class JwtFilter extends OncePerRequestFilter {
|
||||
|
||||
private final JwtUtil jwtUtil;
|
||||
|
||||
public JwtFilter(JwtUtil jwtUtil) {
|
||||
this.jwtUtil = jwtUtil;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doFilterInternal(HttpServletRequest request,
|
||||
HttpServletResponse response,
|
||||
FilterChain filterChain) throws ServletException, IOException {
|
||||
|
||||
// Extract JWT from cookie
|
||||
String token = extractTokenFromCookie(request);
|
||||
|
||||
if (token != null && jwtUtil.validateToken(token)) {
|
||||
// Extract user ID and add to request attribute
|
||||
Long userId = jwtUtil.getUserIdFromToken(token);
|
||||
request.setAttribute("userId", userId);
|
||||
} else {
|
||||
// No valid JWT - return 401 Unauthorized
|
||||
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
|
||||
response.getWriter().write("{\"error\": \"Unauthorized - Invalid or missing JWT\"}");
|
||||
return;
|
||||
}
|
||||
|
||||
filterChain.doFilter(request, response);
|
||||
}
|
||||
|
||||
private String extractTokenFromCookie(HttpServletRequest request) {
|
||||
Cookie[] cookies = request.getCookies();
|
||||
if (cookies != null) {
|
||||
for (Cookie cookie : cookies) {
|
||||
if ("jwt".equals(cookie.getName())) {
|
||||
return cookie.getValue();
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
68
src/main/java/com/example/demo/JwtUtil.java
Normal file
68
src/main/java/com/example/demo/JwtUtil.java
Normal file
@@ -0,0 +1,68 @@
|
||||
package com.example.demo;
|
||||
|
||||
import io.jsonwebtoken.Claims;
|
||||
import io.jsonwebtoken.Jwts;
|
||||
import io.jsonwebtoken.security.Keys;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import javax.crypto.SecretKey;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Date;
|
||||
|
||||
@Component
|
||||
public class JwtUtil {
|
||||
|
||||
@Value("${jwt.secret}")
|
||||
private String secret;
|
||||
|
||||
@Value("${jwt.expiration}")
|
||||
private long expiration;
|
||||
|
||||
private SecretKey getSigningKey() {
|
||||
return Keys.hmacShaKeyFor(secret.getBytes(StandardCharsets.UTF_8));
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate JWT token with user ID
|
||||
*/
|
||||
public String generateToken(Long userId) {
|
||||
Date now = new Date();
|
||||
Date expiryDate = new Date(now.getTime() + expiration);
|
||||
|
||||
return Jwts.builder()
|
||||
.subject(userId.toString())
|
||||
.issuedAt(now)
|
||||
.expiration(expiryDate)
|
||||
.signWith(getSigningKey())
|
||||
.compact();
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract user ID from JWT token
|
||||
*/
|
||||
public Long getUserIdFromToken(String token) {
|
||||
Claims claims = Jwts.parser()
|
||||
.verifyWith(getSigningKey())
|
||||
.build()
|
||||
.parseSignedClaims(token)
|
||||
.getPayload();
|
||||
|
||||
return Long.parseLong(claims.getSubject());
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate JWT token
|
||||
*/
|
||||
public boolean validateToken(String token) {
|
||||
try {
|
||||
Jwts.parser()
|
||||
.verifyWith(getSigningKey())
|
||||
.build()
|
||||
.parseSignedClaims(token);
|
||||
return true;
|
||||
} catch (Exception e) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -44,23 +44,12 @@ public class RLSConnectionManager {
|
||||
connection.setAutoCommit(false);
|
||||
|
||||
try {
|
||||
// Set the RLS context variable on THIS connection using raw Statement
|
||||
// Set the RLS context variable on THIS connection
|
||||
try (Statement stmt = connection.createStatement()) {
|
||||
stmt.execute("SET LOCAL app.current_user_id = '" + userId + "'");
|
||||
System.out.println("RLS context set for user: " + userId);
|
||||
}
|
||||
|
||||
// Verify it's set (for debugging)
|
||||
try (Statement stmt = connection.createStatement()) {
|
||||
var rs = stmt.executeQuery("SELECT current_setting('app.current_user_id', true)");
|
||||
if (rs.next()) {
|
||||
String value = rs.getString(1);
|
||||
System.out.println("Verified context value: " + value);
|
||||
}
|
||||
}
|
||||
|
||||
// Create a JdbcTemplate bound to THIS specific connection
|
||||
// Use SingleConnectionDataSource to ensure the template uses this exact connection
|
||||
SingleConnectionDataSource singleConnectionDataSource =
|
||||
new SingleConnectionDataSource(connection, true);
|
||||
JdbcTemplate scopedTemplate = new JdbcTemplate(singleConnectionDataSource);
|
||||
@@ -81,8 +70,8 @@ public class RLSConnectionManager {
|
||||
// CRITICAL: Reset the context variable before returning connection to pool
|
||||
try (Statement stmt = connection.createStatement()) {
|
||||
stmt.execute("RESET app.current_user_id");
|
||||
System.out.println("RLS context reset");
|
||||
} catch (Exception e) {
|
||||
// Log but don't throw - connection will still be returned to pool
|
||||
System.err.println("Warning: Failed to reset RLS context: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
@@ -102,13 +91,5 @@ public class RLSConnectionManager {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Alternative approach: Execute raw SQL with RLS context.
|
||||
* Useful when you need full control over the SQL.
|
||||
*/
|
||||
public <T> T executeRawSqlWithRLS(Long userId, String sql, Function<JdbcTemplate, T> resultExtractor) {
|
||||
return executeWithRLSContext(userId, template -> {
|
||||
return resultExtractor.apply(template);
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.example.demo;
|
||||
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
@@ -12,22 +13,22 @@ import java.util.Map;
|
||||
public class RLSTestController {
|
||||
|
||||
private final RLSConnectionManager rlsManager;
|
||||
private final DocumentRepository documentRepository;
|
||||
private final JdbcTemplate jdbcTemplate;
|
||||
|
||||
public RLSTestController(RLSConnectionManager rlsManager,
|
||||
DocumentRepository documentRepository,
|
||||
JdbcTemplate jdbcTemplate) {
|
||||
this.rlsManager = rlsManager;
|
||||
this.documentRepository = documentRepository;
|
||||
this.jdbcTemplate = jdbcTemplate;
|
||||
}
|
||||
|
||||
/**
|
||||
* Test 1: Execute raw SQL with RLS context
|
||||
* Test 1: Get documents using JWT user ID from filter
|
||||
*/
|
||||
@GetMapping("/documents/user/{userId}")
|
||||
public List<Map<String, Object>> getDocumentsWithRawSQL(@PathVariable Long userId) {
|
||||
@GetMapping("/documents")
|
||||
public List<Map<String, Object>> getDocuments(HttpServletRequest request) {
|
||||
// Get user ID from JWT (set by JwtFilter)
|
||||
Long userId = (Long) request.getAttribute("userId");
|
||||
|
||||
return rlsManager.executeWithRLSContext(userId, scopedTemplate -> {
|
||||
// This query will only return documents the user has access to (via RLS policy)
|
||||
String sql = "SELECT id, title, content, user_id FROM documents";
|
||||
@@ -38,8 +39,10 @@ public class RLSTestController {
|
||||
/**
|
||||
* Test 2: Verify context variable is set correctly
|
||||
*/
|
||||
@GetMapping("/context/verify/{userId}")
|
||||
public Map<String, Object> verifyContextVariable(@PathVariable Long userId) {
|
||||
@GetMapping("/context/verify")
|
||||
public Map<String, Object> verifyContextVariable(HttpServletRequest request) {
|
||||
Long userId = (Long) request.getAttribute("userId");
|
||||
|
||||
return rlsManager.executeWithRLSContext(userId, scopedTemplate -> {
|
||||
// Query the context variable to verify it's set
|
||||
String currentUserId = scopedTemplate.queryForObject(
|
||||
@@ -48,7 +51,7 @@ public class RLSTestController {
|
||||
);
|
||||
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("requestedUserId", userId);
|
||||
result.put("jwtUserId", userId);
|
||||
result.put("contextUserId", currentUserId);
|
||||
result.put("match", userId.toString().equals(currentUserId));
|
||||
|
||||
@@ -60,7 +63,8 @@ public class RLSTestController {
|
||||
* Test 3: Verify context is reset after request (simulate concurrent requests)
|
||||
*/
|
||||
@GetMapping("/context/isolation-test")
|
||||
public Map<String, Object> testContextIsolation() throws InterruptedException {
|
||||
public Map<String, Object> testContextIsolation(HttpServletRequest request) throws InterruptedException {
|
||||
Long currentUserId = (Long) request.getAttribute("userId");
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
|
||||
// Set context for user 1
|
||||
@@ -123,9 +127,12 @@ public class RLSTestController {
|
||||
* Test 4: Insert with RLS context (useful for audit trails)
|
||||
*/
|
||||
@PostMapping("/documents")
|
||||
public Map<String, Object> createDocument(@RequestParam Long userId,
|
||||
public Map<String, Object> createDocument(HttpServletRequest request,
|
||||
@RequestParam String title,
|
||||
@RequestParam String content) {
|
||||
// Get user ID from JWT
|
||||
Long userId = (Long) request.getAttribute("userId");
|
||||
|
||||
return rlsManager.executeWithRLSContext(userId, scopedTemplate -> {
|
||||
// Insert with the user context set
|
||||
scopedTemplate.update(
|
||||
@@ -146,7 +153,8 @@ public class RLSTestController {
|
||||
* Setup endpoint - creates the table and test data
|
||||
*/
|
||||
@PostMapping("/setup")
|
||||
public String setupDatabase() {
|
||||
public String setupDatabase(HttpServletRequest request) {
|
||||
Long currentUserId = (Long) request.getAttribute("userId");
|
||||
// Drop existing table
|
||||
jdbcTemplate.execute("DROP TABLE IF EXISTS documents CASCADE");
|
||||
|
||||
|
||||
@@ -5,3 +5,7 @@ spring.datasource.url=jdbc:postgresql://localhost:5432/rls_test
|
||||
spring.datasource.username=rls_test
|
||||
spring.datasource.password=rls_test
|
||||
spring.datasource.hikari.maximum-pool-size=10
|
||||
|
||||
# JWT Configuration
|
||||
jwt.secret=your-256-bit-secret-key-change-this-in-production-minimum-32-characters
|
||||
jwt.expiration=86400000
|
||||
|
||||
Reference in New Issue
Block a user