目 录CONTENT

文章目录

Java拦截请求参数并设置请求头

阿豪
2022-09-09 / 0 评论 / 2 点赞 / 52 阅读 / 811 字 / 正在检测是否收录...
温馨提示:
本文最后更新于 2022-09-09,若内容或图片失效,请留言反馈。部分素材来自网络,若不小心影响到您的利益,请联系我们删除。

业务中遇到了一些场景,需要对请求做统一拦截,用请求参数计算新的变量设置到请求头中。
以下分别用Filter和Interceptor两种方式实现,(建议使用Filter的方法,因为Interceptor的方法仅仅对Post和GET方法有效,并不支持PUT等其他方法,主要原因是因为HttpServletRequest接口的实现类不同,以下仅支持了POST和GET方法)

一、场景

我们有两类用户,一类用户的请求中头部header中有用户名字参数(userName),另一类用户的请求并无请求header,但是请求参数requestParam中有userId参数,我们可以通过userId查库等计算出userName。

二、需求

我们想统一两类用户,使请求达到我们的controller层的时候头部统一都有userName。

三、解决方案

以下的两种方式实现了这个需求,controller层获取userName,将优先使用从requestParam中的userId计算出来的userName。

1、Filter实现

import org.springframework.context.annotation.Configuration;

import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.IOException;
import java.util.*;

@Configuration
@WebFilter(value = "/*")
public class ParamFilter implements Filter {
    
    private static final String USER_PARAM = "userId";
    private static final String USER_HEADER = "userName";

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        HttpServletRequest req = (HttpServletRequest) request;
        HeaderMapRequestWrapper requestWrapper = new HeaderMapRequestWrapper(req);
        String userId = request.getParameter(USER_PARAM);
        //依据你自己的业务通过userId获取userName
        String userName = getUserNameById(userId);
        requestWrapper.addHeader(USER_HEADER, userName);
        chain.doFilter(requestWrapper, response);
    }

    @Override
    public void destroy() {
    }

    public class HeaderMapRequestWrapper extends HttpServletRequestWrapper {
        public HeaderMapRequestWrapper(HttpServletRequest request) {
            super(request);
        }

        private Map<String, String> headerMap = new HashMap<String, String>();

        public void addHeader(String name, String value) {
            headerMap.put(name, value);
        }

        @Override
        public String getHeader(String name) {
            String headerValue = super.getHeader(name);
            if (headerMap.containsKey(name)) {
                headerValue = headerMap.get(name);
            }
            return headerValue;
        }

        @Override
        public Enumeration<String> getHeaderNames() {
            List<String> names = Collections.list(super.getHeaderNames());
            for (String name : headerMap.keySet()) {
                names.add(name);
            }
            return Collections.enumeration(names);
        }

        @Override
        public Enumeration<String> getHeaders(String name) {
            List<String> values = Collections.list(super.getHeaders(name));
            if (headerMap.containsKey(name)) {
                values.add(headerMap.get(name));
            }
            return Collections.enumeration(values);
        }
    }
}

2、Interceptor实现

import org.apache.commons.lang3.StringUtils;
import org.apache.tomcat.util.http.MimeHeaders;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.messaging.handler.HandlerMethod;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Field;

@Component
public class ParamInterceptor extends HandlerInterceptorAdapter {
    
    private static final Logger logger = LoggerFactory.getLogger(ParamInterceptor.class);
    private static final String USER_PARAM = "userId";
    private static final String USER_HEADER = "userName";

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        if (handler instanceof HandlerMethod) {
            String userId = request.getParameter(USER_PARAM);
            if (StringUtils.isNotBlank(userId)) {
                //依据你自己的业务通过userId获取userName
                String userName = getUserNameById(userId);
                reflectSetHeader(request, USER_HEADER, userName);
            }
        }
        return true;
    }

    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
        super.postHandle(request, response, handler, modelAndView);
    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
        super.afterCompletion(request, response, handler, ex);
    }

    private void reflectSetHeader(HttpServletRequest request, String key, String value) {
        Class<? extends HttpServletRequest> requestClass = request.getClass();
        logger.info("request实现类={}", requestClass.getName());
        try {
            Field request1 = requestClass.getDeclaredField("request");
            request1.setAccessible(true);
            Object o = request1.get(request);
            Field coyoteRequest = o.getClass().getDeclaredField("coyoteRequest");
            coyoteRequest.setAccessible(true);
            Object o1 = coyoteRequest.get(o);
            Field headers = o1.getClass().getDeclaredField("headers");
            headers.setAccessible(true);
            MimeHeaders o2 = (MimeHeaders) headers.get(o1);
            o2.removeHeader(key);
            o2.addValue(key).setString(value);
        } catch (Exception e) {
            logger.info("reflect set header error {}", e);
        }
    }
}
2
广告 广告

评论区