WebSocketUtil.java 8.2 KB
package com.subsidy.util.websocket;

import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.subsidy.common.ResponseData;
import com.subsidy.mapper.ClassDictMapper;
import com.subsidy.mapper.OprMemDictMapper;
import com.subsidy.model.OprMemDictDO;
import com.subsidy.util.DateFormatUtil;
import com.subsidy.vo.classdict.ClassSettingsVO;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import java.io.IOException;
import java.util.Calendar;
import java.util.Date;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;

/**
 * <p>
 *  WebSocket
 * </p>
 *
 * @author DengMin
 * @since 2022/7/13
 */
@Slf4j
@Component
public class WebSocketUtil implements WebSocketHandler {

    @Autowired
    private OprMemDictMapper oprMemDictMapper;

    @Autowired
    private ClassDictMapper classDictMapper;

    private int heartbeatMin = 1; // 断连最小心跳次数
    private int heartbeatMax = 3; // 断连最大心跳次数
    private int reconnectionSeconds = 60; //每次断连间隔重新连接秒

    /**
     * 存放建立连接webSocket对象
     */
    private ConcurrentHashMap<Long, WebSocketSession> webSocketMap = new ConcurrentHashMap();

    ConcurrentHashMap<String, ScheduledFuture> taskMap = new ConcurrentHashMap<>(1);

    /**
     * 处理成功连接WebSocket
     * @param session
     */
    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws IOException {
        if(null != session) {
            String params = session.getUri().getQuery();
            Long id = Long.valueOf(params.split("=")[1]);
            if(null != webSocketMap && webSocketMap.get(id) != null) {
                if(webSocketMap.get(id).isOpen()) {
                    /* 相同账户进行挤号,发送消息给WebSocket通知账户已在其他地方登录 */
                    webSocketMap.get(id).sendMessage(new TextMessage(JSONObject.toJSONString(ResponseData.generateCreatedResponse(1011))));
                    webSocketMap.get(id).close();
                } else {
                    /*
                     * 如果上次连接的WebSocket状态是关闭,
                     * 并且上一次记录时间大于 { heartbeatMax * reconnectionSeconds } 秒(心跳检测机制),则判断为这次登陆是免密码登陆的重新记录上线时间
                     */
                    List<OprMemDictDO> list = oprMemDictMapper.selectList(new QueryWrapper<OprMemDictDO>()
                            .lambda()
                            .eq(OprMemDictDO::getUserId, id)
                            .orderByDesc(OprMemDictDO::getCreateDate));

                    Calendar calendar = Calendar.getInstance();
                    calendar.setTime(DateFormatUtil.localDateTimeToDate(list.get(0).getCreateDate()));
                    calendar.add(Calendar.SECOND,heartbeatMax * reconnectionSeconds);
                    if(calendar.getTime().after(DateFormatUtil.localDateTimeToDate(list.get(0).getCreateDate()))) {
                        if(list.get(0).getOprType().equals("登出")) {
                            OprMemDictDO oprMemDictDO = new OprMemDictDO();
                            oprMemDictDO.setUserId(id);
                            oprMemDictDO.setResult(1);
                            oprMemDictDO.setOprType("登录");
                            oprMemDictDO.setIpAddress(session.getRemoteAddress().getHostName());
                            oprMemDictMapper.insert(oprMemDictDO);
                        }
                    }
                }
                webSocketMap.remove(id);
            }
            webSocketMap.put(id, session);

            List<ClassSettingsVO> list = classDictMapper.getClassSettings(id);
            if(CollectionUtils.isNotEmpty(list)) {
                webSocketMap.get(id).sendMessage(new TextMessage(JSONObject.toJSONString(ResponseData.generateCreatedResponse(0, list))));
            }

            String httpSessionId = session.getId();
            String host = session.getUri().getHost();
            String query = session.getUri().getQuery();
            log.info("----> webSocket connection success");
            log.info("parameter:[ httpSessionId: {}, host: {}, {} ]", httpSessionId, host, query);
            log.info("connection time: {}", DateFormatUtil.format(new Date(), DateFormatUtil.FMT_sdf14_L));
        }
    }

    /**
     * 处理WebSocket transport error
     * @param session
     * @param throwable
     * @throws Exception
     */
    @Override
    public void handleTransportError(WebSocketSession session, Throwable throwable) throws Exception {
        if(session.isOpen()) {
            session.close();
        }

        String params = session.getUri().getQuery();
        Long id = Long.valueOf(params.split("=")[1]);
        webSocketMap.remove(id);
        log.error("<---- webSocket transport error");
        log.error("error message: {}", throwable.getMessage());
    }

    /**
     * 在两端WebSocket connection都关闭或transport error发生后执行
     * @param session
     * @param closeStatus
     * @throws Exception
     */
    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception{
        if(null != session) {
            if(null != webSocketMap) {
                String params = session.getUri().getQuery();
                Long id = Long.valueOf(params.split("=")[1]);
                heartbeat(webSocketMap.get(id));
            }
            log.info("<---- webSocket is close");
            log.info("session {} close, status: {}", session.getId(), closeStatus);
        }
    }

    /**
     * 断开连接后进行三次心跳验证判断是否重新连接了,如果没有连接成功则判断为下线
     *
     * @param session
     */
    public void heartbeat(WebSocketSession session) {
        ScheduledExecutorService service = Executors.newScheduledThreadPool(1);
        ScheduledFuture scheduledFuture = service.scheduleAtFixedRate(new Runnable() {
            int beatsNum = heartbeatMin;
            @SneakyThrows
            @Override
            public void run() {
                if(null != session && !session.isOpen()) {
                    while (beatsNum > heartbeatMax) {
                        String params = session.getUri().getQuery();
                        Long id = Long.valueOf(params.split("=")[1]);
                        OprMemDictDO oprMemDictDO = new OprMemDictDO();
                        oprMemDictDO.setUserId(id);
                        oprMemDictDO.setResult(1);
                        oprMemDictDO.setOprType("登出");
                        oprMemDictDO.setIpAddress(session.getRemoteAddress().getHostName());
                        oprMemDictMapper.insert(oprMemDictDO);
                        taskMap.get(session.getId()).cancel(true);
                    }
                    beatsNum++;
                } else if (null != session && session.isOpen()) {
                    /* 时间段内重新连接了结束验证 */
                    taskMap.get(session.getId()).cancel(true);
                }
            }
        }, 0, reconnectionSeconds, TimeUnit.SECONDS);
        taskMap.put(session.getId(), scheduledFuture);
    }

    /**
     * 接收WebSocket客户端Message
     * @param session
     * @param message
     * @throws Exception
     */
    @Override
    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
        if(message instanceof TextMessage) {
            System.out.println(message.getPayload());
        }
    }

    @Override
    public boolean supportsPartialMessages() {
        return false;
    }
}