/*
 * Decompiled with CFR 0.152.
 */
package org.thingsboard.server.common.transport.limits;

import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.BiConsumer;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.thingsboard.server.common.data.EntityType;
import org.thingsboard.server.common.data.StringUtils;
import org.thingsboard.server.common.data.TenantProfile;
import org.thingsboard.server.common.data.id.DeviceId;
import org.thingsboard.server.common.data.id.EntityId;
import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.common.data.tenant.profile.DefaultTenantProfileConfiguration;
import org.thingsboard.server.common.data.tenant.profile.TenantProfileData;
import org.thingsboard.server.common.data.util.TbPair;
import org.thingsboard.server.common.transport.TransportTenantProfileCache;
import org.thingsboard.server.common.transport.limits.DummyTransportRateLimit;
import org.thingsboard.server.common.transport.limits.EntityTransportRateLimits;
import org.thingsboard.server.common.transport.limits.InetAddressRateLimitStats;
import org.thingsboard.server.common.transport.limits.SimpleTransportRateLimit;
import org.thingsboard.server.common.transport.limits.TransportLimitsType;
import org.thingsboard.server.common.transport.limits.TransportRateLimit;
import org.thingsboard.server.common.transport.limits.TransportRateLimitService;
import org.thingsboard.server.common.transport.profile.TenantProfileUpdateResult;
import org.thingsboard.server.queue.util.TbTransportComponent;

@Service
@TbTransportComponent
public class DefaultTransportRateLimitService
implements TransportRateLimitService {
    private static final Logger log = LoggerFactory.getLogger(DefaultTransportRateLimitService.class);
    private static final DummyTransportRateLimit ALLOW = new DummyTransportRateLimit();
    private final ConcurrentMap<TenantId, Boolean> tenantAllowed = new ConcurrentHashMap<TenantId, Boolean>();
    private final ConcurrentMap<TenantId, Set<DeviceId>> tenantDevices = new ConcurrentHashMap<TenantId, Set<DeviceId>>();
    private final ConcurrentMap<TenantId, Set<DeviceId>> tenantGateways = new ConcurrentHashMap<TenantId, Set<DeviceId>>();
    private final ConcurrentMap<TenantId, Set<DeviceId>> tenantGatewayDevices = new ConcurrentHashMap<TenantId, Set<DeviceId>>();
    private final ConcurrentMap<TenantId, EntityTransportRateLimits> perTenantLimits = new ConcurrentHashMap<TenantId, EntityTransportRateLimits>();
    private final ConcurrentMap<DeviceId, EntityTransportRateLimits> perDeviceLimits = new ConcurrentHashMap<DeviceId, EntityTransportRateLimits>();
    private final ConcurrentMap<DeviceId, EntityTransportRateLimits> perGatewayLimits = new ConcurrentHashMap<DeviceId, EntityTransportRateLimits>();
    private final ConcurrentMap<DeviceId, EntityTransportRateLimits> perGatewayDeviceLimits = new ConcurrentHashMap<DeviceId, EntityTransportRateLimits>();
    private final Map<InetAddress, InetAddressRateLimitStats> ipMap = new ConcurrentHashMap<InetAddress, InetAddressRateLimitStats>();
    private final TransportTenantProfileCache tenantProfileCache;
    @Value(value="${transport.rate_limits.ip_limits_enabled:false}")
    private boolean ipRateLimitsEnabled;
    @Value(value="${transport.rate_limits.max_wrong_credentials_per_ip:10}")
    private int maxWrongCredentialsPerIp;
    @Value(value="${transport.rate_limits.ip_block_timeout:60000}")
    private long ipBlockTimeout;

    public DefaultTransportRateLimitService(TransportTenantProfileCache tenantProfileCache) {
        this.tenantProfileCache = tenantProfileCache;
    }

    @Override
    public TbPair<EntityType, Boolean> checkLimits(TenantId tenantId, DeviceId gatewayId, DeviceId deviceId, int dataPoints, boolean isGateway) {
        if (!this.tenantAllowed.getOrDefault(tenantId, Boolean.TRUE).booleanValue()) {
            return TbPair.of((Object)EntityType.API_USAGE_STATE, (Object)false);
        }
        if (!this.checkEntityRateLimit(dataPoints, this.getTenantRateLimits(tenantId))) {
            return TbPair.of((Object)EntityType.TENANT, (Object)false);
        }
        if (isGateway && !this.checkEntityRateLimit(dataPoints, this.getGatewayDeviceRateLimits(tenantId, deviceId))) {
            return TbPair.of((Object)EntityType.DEVICE, (Object)true);
        }
        if (gatewayId != null && !this.checkEntityRateLimit(dataPoints, this.getGatewayRateLimits(tenantId, gatewayId))) {
            return TbPair.of((Object)EntityType.DEVICE, (Object)true);
        }
        if (!isGateway && deviceId != null && !this.checkEntityRateLimit(dataPoints, this.getDeviceRateLimits(tenantId, deviceId))) {
            return TbPair.of((Object)EntityType.DEVICE, (Object)false);
        }
        return null;
    }

    private boolean checkEntityRateLimit(int dataPoints, EntityTransportRateLimits limits) {
        if (dataPoints > 0) {
            return limits.getTelemetryMsgRateLimit().tryConsume() && limits.getTelemetryDataPointsRateLimit().tryConsume(dataPoints);
        }
        return limits.getRegularMsgRateLimit().tryConsume();
    }

    @Override
    public void update(TenantProfileUpdateResult update) {
        log.info("Received tenant profile update: {}", (Object)update.getProfile());
        EntityTransportRateLimits tenantRateLimitPrototype = this.createRateLimits(update.getProfile(), TransportLimitsType.TENANT_LIMITS);
        EntityTransportRateLimits deviceRateLimitPrototype = this.createRateLimits(update.getProfile(), TransportLimitsType.DEVICE_LIMITS);
        EntityTransportRateLimits gatewayRateLimitPrototype = this.createRateLimits(update.getProfile(), TransportLimitsType.GATEWAY_LIMITS);
        EntityTransportRateLimits gatewayDeviceRateLimitPrototype = this.createRateLimits(update.getProfile(), TransportLimitsType.GATEWAY_DEVICE_LIMITS);
        for (TenantId tenantId : update.getAffectedTenants()) {
            this.update(tenantId, tenantRateLimitPrototype, deviceRateLimitPrototype, gatewayRateLimitPrototype, gatewayDeviceRateLimitPrototype);
        }
    }

    @Override
    public void update(TenantId tenantId) {
        EntityTransportRateLimits tenantRateLimitPrototype = this.createRateLimits(this.tenantProfileCache.get(tenantId), TransportLimitsType.TENANT_LIMITS);
        EntityTransportRateLimits deviceRateLimitPrototype = this.createRateLimits(this.tenantProfileCache.get(tenantId), TransportLimitsType.DEVICE_LIMITS);
        EntityTransportRateLimits gatewayRateLimitPrototype = this.createRateLimits(this.tenantProfileCache.get(tenantId), TransportLimitsType.GATEWAY_LIMITS);
        EntityTransportRateLimits gatewayDeviceRateLimitPrototype = this.createRateLimits(this.tenantProfileCache.get(tenantId), TransportLimitsType.GATEWAY_DEVICE_LIMITS);
        this.update(tenantId, tenantRateLimitPrototype, deviceRateLimitPrototype, gatewayRateLimitPrototype, gatewayDeviceRateLimitPrototype);
    }

    private void update(TenantId tenantId, EntityTransportRateLimits tenantRateLimitPrototype, EntityTransportRateLimits deviceRateLimitPrototype, EntityTransportRateLimits gatewayRateLimitPrototype, EntityTransportRateLimits gatewayDeviceRateLimitPrototype) {
        this.mergeLimits(tenantId, tenantRateLimitPrototype, this.perTenantLimits::get, this.perTenantLimits::put);
        this.getTenantDevices(tenantId).forEach(deviceId -> this.mergeLimits((EntityId)deviceId, deviceRateLimitPrototype, this.perDeviceLimits::get, this.perDeviceLimits::put));
        this.getTenantGateways(tenantId).forEach(gatewayId -> this.mergeLimits((EntityId)gatewayId, gatewayRateLimitPrototype, this.perGatewayLimits::get, this.perGatewayLimits::put));
        this.getTenantGatewayDevices(tenantId).forEach(gatewayId -> this.mergeLimits((EntityId)gatewayId, gatewayDeviceRateLimitPrototype, this.perGatewayDeviceLimits::get, this.perGatewayDeviceLimits::put));
    }

    @Override
    public void remove(TenantId tenantId) {
        this.perTenantLimits.remove(tenantId);
        this.tenantDevices.remove(tenantId);
        this.tenantGateways.remove(tenantId);
        this.tenantGatewayDevices.remove(tenantId);
    }

    @Override
    public void remove(DeviceId deviceId) {
        this.perDeviceLimits.remove(deviceId);
        this.perGatewayLimits.remove(deviceId);
        this.perGatewayDeviceLimits.remove(deviceId);
        this.tenantDevices.values().forEach(set -> set.remove(deviceId));
        this.tenantGateways.values().forEach(set -> set.remove(deviceId));
        this.tenantGatewayDevices.values().forEach(set -> set.remove(deviceId));
    }

    @Override
    public void update(TenantId tenantId, boolean allowed) {
        this.tenantAllowed.put(tenantId, allowed);
    }

    @Override
    public boolean checkAddress(InetSocketAddress address) {
        if (!this.ipRateLimitsEnabled) {
            return true;
        }
        InetAddressRateLimitStats stats = this.ipMap.computeIfAbsent(address.getAddress(), a -> new InetAddressRateLimitStats());
        return !stats.isBlocked() || stats.getLastActivityTs() + this.ipBlockTimeout < System.currentTimeMillis();
    }

    @Override
    public void onAuthSuccess(InetSocketAddress address) {
        if (!this.ipRateLimitsEnabled) {
            return;
        }
        InetAddressRateLimitStats stats = this.ipMap.computeIfAbsent(address.getAddress(), a -> new InetAddressRateLimitStats());
        stats.getLock().lock();
        try {
            stats.setLastActivityTs(System.currentTimeMillis());
            stats.setFailureCount(0);
            if (stats.isBlocked()) {
                stats.setBlocked(false);
                log.info("[{}] IP address un-blocked due to correct credentials.", (Object)address.getAddress());
            }
        }
        finally {
            stats.getLock().unlock();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void onAuthFailure(InetSocketAddress address) {
        if (!this.ipRateLimitsEnabled) {
            return;
        }
        InetAddressRateLimitStats stats = this.ipMap.computeIfAbsent(address.getAddress(), a -> new InetAddressRateLimitStats());
        stats.getLock().lock();
        try {
            stats.setLastActivityTs(System.currentTimeMillis());
            int failureCount = stats.getFailureCount() + 1;
            stats.setFailureCount(failureCount);
            if (failureCount >= this.maxWrongCredentialsPerIp) {
                log.info("[{}] IP address blocked due to constantly wrong credentials.", (Object)address.getAddress());
                stats.setBlocked(true);
            }
        }
        finally {
            stats.getLock().unlock();
        }
    }

    @Override
    public void invalidateRateLimitsIpTable(long sessionInactivityTimeout) {
        if (!this.ipRateLimitsEnabled) {
            return;
        }
        long currentTime = System.currentTimeMillis();
        long expTime = currentTime - Math.max(sessionInactivityTimeout, this.ipBlockTimeout);
        for (Map.Entry<InetAddress, InetAddressRateLimitStats> entry : this.ipMap.entrySet()) {
            InetAddressRateLimitStats stats = entry.getValue();
            if (stats.getLastActivityTs() < expTime) {
                log.debug("[{}] IP address removed due to session inactivity timeout.", (Object)entry.getKey());
                this.ipMap.remove(entry.getKey());
                continue;
            }
            if (!stats.isBlocked() || stats.getLastActivityTs() + this.ipBlockTimeout >= currentTime) continue;
            log.info("[{}] IP address unblocked due ip block timeout.", (Object)entry.getKey());
            stats.setBlocked(false);
        }
    }

    private <T extends EntityId> void mergeLimits(T entityId, EntityTransportRateLimits newRateLimits, Function<T, EntityTransportRateLimits> getFunction, BiConsumer<T, EntityTransportRateLimits> putFunction) {
        EntityTransportRateLimits oldRateLimits = getFunction.apply(entityId);
        if (oldRateLimits == null) {
            if (EntityType.TENANT.equals((Object)entityId.getEntityType())) {
                log.info("[{}] New rate limits: {}", entityId, (Object)newRateLimits);
            } else {
                log.debug("[{}] New rate limits: {}", entityId, (Object)newRateLimits);
            }
            putFunction.accept(entityId, newRateLimits);
        } else {
            EntityTransportRateLimits updated = this.merge(oldRateLimits, newRateLimits);
            if (updated != null) {
                if (EntityType.TENANT.equals((Object)entityId.getEntityType())) {
                    log.info("[{}] Updated rate limits: {}", entityId, (Object)updated);
                } else {
                    log.debug("[{}] Updated rate limits: {}", entityId, (Object)updated);
                }
                putFunction.accept(entityId, updated);
            }
        }
    }

    private EntityTransportRateLimits merge(EntityTransportRateLimits oldRateLimits, EntityTransportRateLimits newRateLimits) {
        boolean telemetryDataPointUpdate;
        boolean regularUpdate = !oldRateLimits.getRegularMsgRateLimit().getConfiguration().equals(newRateLimits.getRegularMsgRateLimit().getConfiguration());
        boolean telemetryMsgRateUpdate = !oldRateLimits.getTelemetryMsgRateLimit().getConfiguration().equals(newRateLimits.getTelemetryMsgRateLimit().getConfiguration());
        boolean bl = telemetryDataPointUpdate = !oldRateLimits.getTelemetryDataPointsRateLimit().getConfiguration().equals(newRateLimits.getTelemetryDataPointsRateLimit().getConfiguration());
        if (regularUpdate || telemetryMsgRateUpdate || telemetryDataPointUpdate) {
            return new EntityTransportRateLimits(regularUpdate ? DefaultTransportRateLimitService.newLimit(newRateLimits.getRegularMsgRateLimit().getConfiguration()) : oldRateLimits.getRegularMsgRateLimit(), telemetryMsgRateUpdate ? DefaultTransportRateLimitService.newLimit(newRateLimits.getTelemetryMsgRateLimit().getConfiguration()) : oldRateLimits.getTelemetryMsgRateLimit(), telemetryDataPointUpdate ? DefaultTransportRateLimitService.newLimit(newRateLimits.getTelemetryDataPointsRateLimit().getConfiguration()) : oldRateLimits.getTelemetryDataPointsRateLimit());
        }
        return null;
    }

    private EntityTransportRateLimits createRateLimits(TenantProfile tenantProfile, TransportLimitsType limitsType) {
        TransportRateLimit telemetryMsgRateLimit;
        TransportRateLimit regularMsgRateLimit;
        TenantProfileData profileData = tenantProfile.getProfileData();
        DefaultTenantProfileConfiguration profile = (DefaultTenantProfileConfiguration)profileData.getConfiguration();
        if (profile == null) {
            return new EntityTransportRateLimits(ALLOW, ALLOW, ALLOW);
        }
        return new EntityTransportRateLimits(regularMsgRateLimit, telemetryMsgRateLimit, switch (limitsType) {
            case TransportLimitsType.TENANT_LIMITS -> {
                regularMsgRateLimit = DefaultTransportRateLimitService.newLimit(profile.getTransportTenantMsgRateLimit());
                telemetryMsgRateLimit = DefaultTransportRateLimitService.newLimit(profile.getTransportTenantTelemetryMsgRateLimit());
                yield DefaultTransportRateLimitService.newLimit(profile.getTransportTenantTelemetryDataPointsRateLimit());
            }
            case TransportLimitsType.DEVICE_LIMITS -> {
                regularMsgRateLimit = DefaultTransportRateLimitService.newLimit(profile.getTransportDeviceMsgRateLimit());
                telemetryMsgRateLimit = DefaultTransportRateLimitService.newLimit(profile.getTransportDeviceTelemetryMsgRateLimit());
                yield DefaultTransportRateLimitService.newLimit(profile.getTransportDeviceTelemetryDataPointsRateLimit());
            }
            case TransportLimitsType.GATEWAY_LIMITS -> {
                regularMsgRateLimit = DefaultTransportRateLimitService.newLimit(profile.getTransportGatewayMsgRateLimit());
                telemetryMsgRateLimit = DefaultTransportRateLimitService.newLimit(profile.getTransportGatewayTelemetryMsgRateLimit());
                yield DefaultTransportRateLimitService.newLimit(profile.getTransportGatewayTelemetryDataPointsRateLimit());
            }
            case TransportLimitsType.GATEWAY_DEVICE_LIMITS -> {
                regularMsgRateLimit = DefaultTransportRateLimitService.newLimit(profile.getTransportGatewayDeviceMsgRateLimit());
                telemetryMsgRateLimit = DefaultTransportRateLimitService.newLimit(profile.getTransportGatewayDeviceTelemetryMsgRateLimit());
                yield DefaultTransportRateLimitService.newLimit(profile.getTransportGatewayDeviceTelemetryDataPointsRateLimit());
            }
            default -> throw new IllegalStateException("Unknown limits type: " + limitsType);
        });
    }

    private static TransportRateLimit newLimit(String config) {
        return StringUtils.isEmpty((String)config) ? ALLOW : new SimpleTransportRateLimit(config);
    }

    private EntityTransportRateLimits getTenantRateLimits(TenantId tenantId) {
        return this.perTenantLimits.computeIfAbsent(tenantId, k -> this.createRateLimits(this.tenantProfileCache.get(tenantId), TransportLimitsType.TENANT_LIMITS));
    }

    private EntityTransportRateLimits getDeviceRateLimits(TenantId tenantId, DeviceId deviceId) {
        return this.perDeviceLimits.computeIfAbsent(deviceId, k -> {
            EntityTransportRateLimits limits = this.createRateLimits(this.tenantProfileCache.get(tenantId), TransportLimitsType.DEVICE_LIMITS);
            this.getTenantDevices(tenantId).add(deviceId);
            return limits;
        });
    }

    private EntityTransportRateLimits getGatewayRateLimits(TenantId tenantId, DeviceId gatewayId) {
        return this.perGatewayLimits.computeIfAbsent(gatewayId, k -> {
            EntityTransportRateLimits limits = this.createRateLimits(this.tenantProfileCache.get(tenantId), TransportLimitsType.GATEWAY_LIMITS);
            this.getTenantGateways(tenantId).add(gatewayId);
            return limits;
        });
    }

    private EntityTransportRateLimits getGatewayDeviceRateLimits(TenantId tenantId, DeviceId gatewayId) {
        return this.perGatewayDeviceLimits.computeIfAbsent(gatewayId, k -> {
            EntityTransportRateLimits limits = this.createRateLimits(this.tenantProfileCache.get(tenantId), TransportLimitsType.GATEWAY_DEVICE_LIMITS);
            this.getTenantGatewayDevices(tenantId).add(gatewayId);
            return limits;
        });
    }

    private Set<DeviceId> getTenantDevices(TenantId tenantId) {
        return this.tenantDevices.computeIfAbsent(tenantId, id -> ConcurrentHashMap.newKeySet());
    }

    private Set<DeviceId> getTenantGateways(TenantId tenantId) {
        return this.tenantGateways.computeIfAbsent(tenantId, id -> ConcurrentHashMap.newKeySet());
    }

    private Set<DeviceId> getTenantGatewayDevices(TenantId tenantId) {
        return this.tenantGatewayDevices.computeIfAbsent(tenantId, id -> ConcurrentHashMap.newKeySet());
    }
}

