package com.efuture.ocp.common.annotation;

import com.efuture.ocp.common.exception.ServiceException;
import com.efuture.ocp.common.language.MessageSourceHelper;
import com.efuture.ocp.common.language.ResponseCode;
import com.efuture.ocp.common.rest.ServiceRestReflect;
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.RateLimiter;
import org.apache.log4j.Logger;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.PrintWriter;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.Objects;


@Aspect
@Component
public class RateLimitAspect
{
    private Map<String, RateLimiter> limitMap = Maps.newConcurrentMap();
    //private static final Logger logger = Logger.getLogger(RateLimitAspect.class);

    @Pointcut("@annotation(RateLimit)")
    public void rateLimit() {
    }

    private String getLimitKey(ProceedingJoinPoint joinPoint)
    {
        if (ServiceRestReflect.getLocale() == null || ServiceRestReflect.getLocale().get() == null)
        {
            return null;
        }
        String className = joinPoint.getTarget().getClass().getSimpleName();
        String methodName = joinPoint.getSignature().getName();
        return String.valueOf(ServiceRestReflect.getLocale().get().getEnt_id()).concat( "-" ).concat(className).concat( "." ).concat( methodName );
    }

    @Around("rateLimit()")
    public Object around(ProceedingJoinPoint joinPoint) throws Throwable
    {
        String limitKey = getLimitKey( joinPoint );
        // 获取自定义注解
        RateLimit rateLimiter = getRateLimit(joinPoint);
        if (limitKey != null && rateLimiter != null)
        {
            RateLimiter limiter = null;
            // 判断map集合中是否有创建有创建好的令牌桶，否则，先锁，再取一次
            if (!limitMap.containsKey(limitKey))
            {
                synchronized (this)
                {
                    if (!limitMap.containsKey(limitKey))
                    {
                        // 创建令牌桶
                        limiter = RateLimiter.create( rateLimiter.permitsPerSecond() );
                        limitMap.put( limitKey, limiter );
                    }
                }
            }
            limiter = limitMap.get(limitKey);
            boolean acquire = limiter.tryAcquire(rateLimiter.timeout(), rateLimiter.timeunit());

            if (!acquire)
            {
                throw new ServiceException( ResponseCode.LIMIT, rateLimiter.msg());
            }
        }
        return joinPoint.proceed();
    }

    private RateLimit getRateLimit(final JoinPoint joinPoint) {
        Method[] methods = joinPoint.getTarget().getClass().getDeclaredMethods();
        String name = joinPoint.getSignature().getName();
        if (!StringUtils.isEmpty(name)) {
            for (Method method : methods) {
                RateLimit annotation = method.getAnnotation(RateLimit.class);
                if (!Objects.isNull(annotation) && name.equals(method.getName())) {
                    return annotation;
                }
            }
        }
        return null;
    }

}
