疯狂的狮子li
2021-06-16 639816369a797d967cc6804a6da684080c3cccb6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package com.ruoyi.framework.aspectj;
 
 
import com.ruoyi.common.annotation.RedisLock;
import com.ruoyi.common.constant.Constants;
import com.ruoyi.common.core.redis.RedisLockManager;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
 
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
 
/**
 * 分布式锁(注解实现版本)
 *
 * @author shenxinquan
 */
 
@Slf4j
@Aspect
@Order(9)
@Component
public class RedisLockAspect {
 
    @Autowired
    private RedisLockManager redisLockManager;
 
    @Pointcut("@annotation(com.ruoyi.common.annotation.RedisLock)")
    public void annotationPointcut() {
    }
 
    @Around("annotationPointcut()")
    public Object doAround(ProceedingJoinPoint joinPoint) throws Throwable {
        // 获得当前访问的class
        Class<?> className = joinPoint.getTarget().getClass();
        // 获得访问的方法名
        String methodName = joinPoint.getSignature().getName();
        // 得到方法的参数的类型
        Class<?>[] argClass = ((MethodSignature) joinPoint.getSignature()).getParameterTypes();
        Object[] args = joinPoint.getArgs();
        String key = "";
        // 默认30秒过期时间
        int expireTime = 30;
 
        try {
            // 得到访问的方法对象
            Method method = className.getMethod(methodName, argClass);
            method.setAccessible(true);
            // 判断是否存在@RedisLock注解
            if (method.isAnnotationPresent(RedisLock.class)) {
                RedisLock annotation = method.getAnnotation(RedisLock.class);
                key = getRedisKey(args, annotation.key());
                expireTime = getExpireTime(annotation);
            }
        } catch (Exception e) {
            throw new RuntimeException("redis分布式锁注解参数异常", e);
        }
 
        // 声明锁名称
        key = Constants.REDIS_LOCK_KEY + key;
        Object res;
        try {
            if (redisLockManager.getLock(key, expireTime, TimeUnit.SECONDS)) {
                log.info("lock => key : " + key + " , ThreadName : " + Thread.currentThread().getName());
                try {
                    res = joinPoint.proceed();
                    return res;
                } catch (Exception e) {
                    throw new RuntimeException(e);
                } finally {
                    redisLockManager.unLock(key);
                    log.info("unlock => key : " + key + " , ThreadName : " + Thread.currentThread().getName());
                }
            } else {
                throw new RuntimeException("redis分布式锁注解参数异常");
            }
        } catch (IllegalMonitorStateException e) {
            log.error("lock timeout => key : " + key + " , ThreadName : " + Thread.currentThread().getName());
            throw new RuntimeException("lock timeout => key : " + key);
        } catch (Exception e) {
            throw new Exception("redis分布式未知异常", e);
        }
    }
 
    private int getExpireTime(RedisLock annotation) {
        return annotation.expireTime();
    }
 
    private String getRedisKey(Object[] args, String primalKey) {
        if (args.length == 0) {
            return primalKey;
        }
        // 获取#p0...集合
        List<String> keyList = getKeyParsList(primalKey);
        for (String keyName : keyList) {
            int keyIndex = Integer.parseInt(keyName.toLowerCase().replace("#p", ""));
            Object parValue = args[keyIndex];
            primalKey = primalKey.replace(keyName, String.valueOf(parValue));
        }
        return primalKey.replace("+", "").replace("'", "");
    }
 
    /**
     * 获取key中#p0中的参数名称
     */
    private static List<String> getKeyParsList(String key) {
        List<String> listPar = new ArrayList<>();
        if (key.contains("#")) {
            int plusIndex = key.substring(key.indexOf("#")).indexOf("+");
            int indexNext = 0;
            String parName;
            int indexPre = key.indexOf("#");
            if (plusIndex > 0) {
                indexNext = key.indexOf("#") + plusIndex;
                parName = key.substring(indexPre, indexNext);
            } else {
                parName = key.substring(indexPre);
            }
            listPar.add(parName.trim());
            key = key.substring(indexNext + 1);
            if (key.contains("#")) {
                listPar.addAll(getKeyParsList(key));
            }
        }
        return listPar;
    }
 
}