SpringMVC 实现自定义的session 共享(同步)机制

原创
2018/06/29 14:56
阅读数 4.8K

SpringMVC 实现自定义的session 共享(同步)机制

思路

这个问题是针对线上多台服务器(例如多个tomcat,集群)负载均衡而言,如果只有一个服务器运行(提供服务),则不存在这个问题,请直接略过.

为什么多个tomcat服务会存在session 不同步的问题

使用nginx 负载均衡 假设我们使用服务器端的session记录登录鉴权信息,(没有使用redis), 比如用户登录时,登录接口命中的是服务A,那么服务A中就会记录用户的登录信息, 接着用户修改资料(比如上传图片,修改昵称等),保存资料接口命中的是服务器B, 服务B中并没有记录登录信息,所以直接报错:未登录,会跳转到登录页面.

用户明明已经登录过了,可是莫名其妙地又让用户去登录. 这就是问题, 用户的登录信息 只存储到了一台服务器上, 而用户的各种操作(接口访问)可能负载到任意一台服务上. 而http session是内存级别的,各tomcat服务是不会共享的.

流程

如何让所有服务都能读取到用户的登录信息呢? 我们需要把登录信息存储到一个所有服务器都能访问的地方,这里我们使用redis, 使用其他分布式的缓存,Memcached ,zookeeper 也可以.

方案

实现 HttpServletRequest , 重写它的 getSession(boolean),getSession()方法.

具体方案

  1. 实现 javax.servlet.http.HttpServletRequestWrapper ,重写它的 getSession(boolean),getSession()
  2. 实现HttpSession ,重写HttpSession的三个核心方法: a. getAttribute; b. setAttribute; c. removeAttribute
  3. 在这三个方法中,除了对原始的HttpSession 操作外,还会同时对redis进行操作.
    看下 setAttribute 的重写实现:
 /**
     * 需要重写
     *
     * @param s
     * @param o
     */
    @Override
    public void setAttribute(String s, Object o) {
        String sessionId = null;
        if (null == this.httpSession) {
            sessionId = this.JSESSIONID;
        } else {
            this.httpSession.setAttribute(s, o);
            sessionId = this.httpSession.getId();
        }
        RedisCacheUtil.setSessionAttribute(sessionId + s, o);
    }

注意

  1. 存储到redis 中的时候,redis 的key一定要有原始sessionId,这样才能区分是哪个会话;
  2. redis 中的value 实际都是String,所以在setAttribute 中存储到redis 时,要对存储的值进行序列化, 同理 getAttribute中,对从redis中获取的value,要反序列化

代码

CustomSharedHttpSession 实现HttpSession

package oa.web.responsibility.impl.custom;

import com.common.util.RedisCacheUtil;

import javax.servlet.ServletContext;
import javax.servlet.http.HttpSession;
import javax.servlet.http.HttpSessionContext;
import java.util.Enumeration;

/***
 * http session 同步和共享<br />
 * see oa/web/responsibility/impl/custom/HttpSessionSyncShareFilter.java
 */
public class CustomSharedHttpSession implements HttpSession {
    protected HttpSession httpSession;
    protected String JSESSIONID;

    public CustomSharedHttpSession() {
        super();
    }

    public CustomSharedHttpSession(HttpSession httpSession, String JSESSIONID) {
        this.httpSession = httpSession;
        this.JSESSIONID = JSESSIONID;
    }

    @Override
    public long getCreationTime() {
        return this.httpSession.getCreationTime();
    }

    @Override
    public String getId() {
        return this.httpSession.getId();
    }

    @Override
    public long getLastAccessedTime() {
        return this.httpSession.getLastAccessedTime();
    }

    @Override
    public ServletContext getServletContext() {
        return this.httpSession.getServletContext();
    }

    @Override
    public void setMaxInactiveInterval(int i) {
        this.httpSession.setMaxInactiveInterval(i);
    }

    @Override
    public int getMaxInactiveInterval() {
        return this.httpSession.getMaxInactiveInterval();
    }

    @Override
    public HttpSessionContext getSessionContext() {
        return this.httpSession.getSessionContext();
    }

    /***
     * 需要重写 TODO
     * @param s
     * @return
     */
    @Override
    public Object getAttribute(String s) {
        Object o1 = null;
        if (null == this.getHttpSession()) {
            o1 = RedisCacheUtil.getSessionAttribute(this.JSESSIONID + s);
            /*if (null != o1) {
                this.setAttribute(s,o1);
            }*/
            return o1;
        }

        Object o = this.httpSession.getAttribute(s);
        if (o == null) {
            String currentSessionId = this.httpSession.getId();
            o = RedisCacheUtil.getSessionAttribute(currentSessionId + s);
            if (null == o) {
                if ((!currentSessionId.equals(this.JSESSIONID))) {
                    Object o2 = RedisCacheUtil.getSessionAttribute(this.JSESSIONID + s);
                    if (null != o2) {
                        this.httpSession.setAttribute(s, o2);
                        o = o2;
//                        RedisCacheUtil.setSessionAttribute(currentSessionId + s, o);
                    }
                }
            }
            this.setAttribute(s, o);
        }
        return o;
    }

    @Override
    public Object getValue(String s) {
        return this.httpSession.getValue(s);
    }

    @Override
    public Enumeration<String> getAttributeNames() {
        return this.httpSession.getAttributeNames();
    }

    @Override
    public String[] getValueNames() {
        return this.httpSession.getValueNames();
    }

    /**
     * 需要重写
     *
     * @param s
     * @param o
     */
    @Override
    public void setAttribute(String s, Object o) {
        String sessionId = null;
        if (null == this.httpSession) {
            sessionId = this.JSESSIONID;
        } else {
            this.httpSession.setAttribute(s, o);
            sessionId = this.httpSession.getId();
        }
        RedisCacheUtil.setSessionAttribute(sessionId + s, o);
    }

    @Override
    public void putValue(String s, Object o) {
        this.httpSession.putValue(s, o);
    }

    @Override
    public void removeAttribute(String s) {
        if (null != this.httpSession) {
            this.httpSession.removeAttribute(s);
            String sessionId = this.httpSession.getId();
            RedisCacheUtil.setSessionAttribute(sessionId + s, null);
        }
        RedisCacheUtil.setSessionAttribute(this.JSESSIONID + s, null);
    }

    @Override
    public void removeValue(String s) {
        this.httpSession.removeValue(s);
    }

    @Override
    public void invalidate() {
        this.httpSession.invalidate();
    }

    @Override
    public boolean isNew() {
        return this.httpSession.isNew();
    }

    /***
     * 自定义方法
     * @return
     */
    public HttpSession getHttpSession() {
        return httpSession;
    }

    /***
     * 自定义方法
     * @param httpSession
     */
    public void setHttpSession(HttpSession httpSession) {
        this.httpSession = httpSession;
    }
}

HttpPutFormContentRequestWrapper重写HttpServletRequest

package oa.web.request;

import com.common.util.RedisCacheUtil;
import com.common.util.RequestUtil;
import com.common.util.SystemHWUtil;
import com.common.web.filter.CustomFormHttpMessageConverter;
import com.file.hw.props.GenericReadPropsUtil;
import com.io.hw.json.HWJacksonUtils;
import com.string.widget.util.RegexUtil;
import com.string.widget.util.ValueWidget;
import oa.util.SpringMVCUtil;
import org.apache.log4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

import javax.servlet.FilterChain;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

public class HttpPutFormContentRequestWrapper extends HttpServletRequestWrapper {
    protected final static Logger logger = Logger.getLogger(HttpPutFormContentRequestWrapper.class);
    public final static org.slf4j.Logger httpClientRestLogger = LoggerFactory.getLogger("rest_log");
    protected MultiValueMap<String, String> formParameters;
    protected String requestBody;
    private ResettableServletInputStream servletStream;
    /***
     * <实际不存在的接口路径A,真实的接口路径B> <br>
     *     A映射到B<br>
     *     有两个来源:(1)config/pathMapping.json;(2)redis,通过方法 RedisCacheUtil.getServletPathMap()
     */
    private static Map<String, String> handlerMethodPathMap;
    /***
     * 缓存应答题,<servletPath,responseText>
     */
    private Map<String, Object> responseReturnResultMap = new ConcurrentHashMap<>();
    private FilterChain chain;
    /***
     * 是否需要改为成员变量
     */
    protected static final CustomFormHttpMessageConverter formConverter = new CustomFormHttpMessageConverter();
    private ThreadLocal<Boolean> is404NotFound = new ThreadLocal<Boolean>() {
        @Override
        protected Boolean initialValue() {
            return Boolean.FALSE;
        }
    };
    /***
     * 解决 SpringMVC 进入接口慢的问题 <br />
     * added 2018-06-28   中国标准时间 下午8:55:41 <br />
     * see http://i.yhskyc.com/test/1384?testcase=SpringMVC%E8%BF%9B%E5%85%A5%E8%AF%B7%E6%B1%82%E5%B7%A8%E6%85%A2
     */
    protected Map<String, String> servletPathOriginAndTargetMap;

    public void set404NotFound(boolean bool) {
        this.is404NotFound.set(bool);
    }

    public boolean is404NotFound() {
        return this.is404NotFound.get();
    }

    static {//因为每次请求都会new 一个HttpPutFormContentRequestWrapper,所以把initMapping 防止静态代码中,全局执行一次
        initMapping();
    }

    public void put(String servletPath, Object response) {
        if (null == servletPath) {
            servletPath = "";
        }
        responseReturnResultMap.put(servletPath, response);
    }

    public String getResponseBodyBackup(String servletPath) {
        return (String) this.responseReturnResultMap.get(servletPath);
    }

    public String getResponseBodyBackup() {
        return this.getResponseBodyBackup(getServletPath());
    }

    public boolean hasContains(String servletPath) {
        if (null == servletPath) {
            return false;
        }
        return this.responseReturnResultMap.containsKey(servletPath);
    }

    /***
     * servlet 路径映射,类似于nginx的转发功能<br />
     * see https://my.oschina.net/huangweiindex/blog/1789164
     */
    public static void initMapping() {
        handlerMethodPathMap = new ConcurrentHashMap<>();
        //从本地文件"/config/pathMapping.json"中读取
        handlerMethodPathMap.put("/agent/afterbuy/list/json", "/agent/afterbuy/listfilter/json");
        ClassLoader classLoader = SpringMVCUtil.getApplication().getClassLoader();
        String resourcePath = "/config/pathMapping.json";
        String json = GenericReadPropsUtil.getConfigTxt(classLoader, resourcePath);
        System.out.println("config/pathMapping.json :" + json);
        if (!ValueWidget.isNullOrEmpty(json)) {
            json = RegexUtil.sedDeleteComment(json);//删除第一行的注释
            if (ValueWidget.isNullOrEmpty(json)) {
                return;
            }
            handlerMethodPathMap.putAll(HWJacksonUtils.deSerializeMap(json, String.class));
        }
    }


    public HttpPutFormContentRequestWrapper(HttpServletRequest request/*, MultiValueMap<String, String> parameters, String requestBody*/) {
        super(request);
        servletStream = new ResettableServletInputStream();
        MultiValueMap<String, String> parameters = RequestUtil.readFormParameters(request, formConverter);
        this.formParameters = (MultiValueMap) (parameters != null ? parameters : new LinkedMultiValueMap());
        this.requestBody = formConverter.getRequestBody();
    }

    /***
     * see https://my.oschina.net/huangweiindex/blog/1789164<br >
     *     里面有接口路径的映射:handlerMethodPathMap
     * @return
     */
    @Override
    public String getServletPath() {
        if (null != this.servletPathOriginAndTargetMap) {
//            System.out.println("servletPath :" + servletPath);
            String servletPath = super.getServletPath();
            System.out.println("servletPath 2 :" + servletPath);
            String targetPath = this.servletPathOriginAndTargetMap.get(servletPath);
            if (null == targetPath) {
                targetPath = servletPath;
            }
            return targetPath;
        }
        String servletPath = super.getServletPath();
        //映射
        String lookupPath = null;
        if (!ValueWidget.isNullOrEmpty(handlerMethodPathMap)) {
            //<实际不存在的接口路径A,真实的接口路径B>
            if (handlerMethodPathMap.containsKey(servletPath)) {
                lookupPath = handlerMethodPathMap.get(servletPath);
            } else {//从 redis 获取,see PreServletPathMapController
                Map servletPathMap = RedisCacheUtil.getServletPathMap();
                if (!ValueWidget.isNullOrEmpty(servletPathMap)) {
                    lookupPath = (String) servletPathMap.get(servletPath);
                    handlerMethodPathMap.putAll(servletPathMap);
                    RedisCacheUtil.clearServletPathMap();
                }
            }
        }
        if (ValueWidget.isNullOrEmpty(lookupPath)) {
            lookupPath = servletPath;
        } else {
            String msg = "SpringMVC 层实现 Path Mapping,old:" + servletPath + "\tnew:" + lookupPath + " 将被真正调用";
            logger.warn(msg);
            System.out.println(msg);
            httpClientRestLogger.error(msg);
        }
        //解决 SpringMVC 进入接口慢的问题
        servletPathOriginAndTargetMap = new HashMap<>();
        servletPathOriginAndTargetMap.put(super.getServletPath(), lookupPath);
        return lookupPath;
    }

    @Override
    public String getParameter(String name) {
        String queryStringValue = super.getParameter(name);
        String formValue = (String) this.formParameters.getFirst(name);
        return queryStringValue != null ? queryStringValue : formValue;
    }

    @Override
    public Map<String, String[]> getParameterMap() {
        Map<String, String[]> result = new LinkedHashMap();
        Enumeration names = this.getParameterNames();

        while (names.hasMoreElements()) {
            String name = (String) names.nextElement();
            result.put(name, this.getParameterValues(name));
        }

        return result;
    }

    @Override
    public Enumeration<String> getParameterNames() {
        Set<String> names = new LinkedHashSet();
        names.addAll(Collections.list(super.getParameterNames()));
        names.addAll(this.formParameters.keySet());
        return Collections.enumeration(names);
    }

    @Override
    public String[] getParameterValues(String name) {
        String[] queryStringValues = super.getParameterValues(name);
        List<String> formValues = (List) this.formParameters.get(name);
        if (formValues == null) {
            return queryStringValues;
        } else if (queryStringValues == null) {
            return (String[]) formValues.toArray(new String[formValues.size()]);
        } else {
            List<String> result = new ArrayList();
            result.addAll(Arrays.asList(queryStringValues));
            result.addAll(formValues);
            return (String[]) result.toArray(new String[result.size()]);
        }
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        if (super.getInputStream().available() > 0) {
            return super.getInputStream();
        }
        String requestCharset = getRequest().getCharacterEncoding();
        if (ValueWidget.isNullOrEmpty(requestCharset)) {
            requestCharset = SystemHWUtil.CHARSET_ISO88591;
        }
        servletStream.stream = new ByteArrayInputStream(this.requestBody.getBytes(requestCharset));
        return servletStream;
    }

    public MultiValueMap<String, String> getFormParameters() {
        return formParameters;
    }

    private static class ResettableServletInputStream extends ServletInputStream {

        private InputStream stream;

        @Override
        public int read() throws IOException {
            return stream.read();
        }
    }

    public FilterChain getChain() {
        return chain;
    }

    public void setChain(FilterChain chain) {
        this.chain = chain;
    }

    public static CustomFormHttpMessageConverter getFormConverter() {
        return formConverter;
    }


    public void resetCustom() {
        this.servletPathOriginAndTargetMap = null;
    }
}

推荐

我的其他开源项目 用于服务器端API 的stub 测试
zookeeper的一个java 图形客户端

展开阅读全文
打赏
0
7 收藏
分享
加载中
更多评论
打赏
0 评论
7 收藏
0
分享
返回顶部
顶部