/*
 * Decompiled with CFR 0.152.
 */
package org.thingsboard.server.common.transport.limits;

import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.BiConsumer;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.thingsboard.server.common.data.EntityType;
import org.thingsboard.server.common.data.StringUtils;
import org.thingsboard.server.common.data.TenantProfile;
import org.thingsboard.server.common.data.id.DeviceId;
import org.thingsboard.server.common.data.id.EntityId;
import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.common.data.tenant.profile.DefaultTenantProfileConfiguration;
import org.thingsboard.server.common.data.tenant.profile.TenantProfileData;
import org.thingsboard.server.common.transport.TransportTenantProfileCache;
import org.thingsboard.server.common.transport.limits.DummyTransportRateLimit;
import org.thingsboard.server.common.transport.limits.EntityTransportRateLimits;
import org.thingsboard.server.common.transport.limits.InetAddressRateLimitStats;
import org.thingsboard.server.common.transport.limits.SimpleTransportRateLimit;
import org.thingsboard.server.common.transport.limits.TransportRateLimit;
import org.thingsboard.server.common.transport.limits.TransportRateLimitService;
import org.thingsboard.server.common.transport.profile.TenantProfileUpdateResult;
import org.thingsboard.server.queue.util.TbTransportComponent;

@Service
@TbTransportComponent
public class DefaultTransportRateLimitService
implements TransportRateLimitService {
    private static final Logger log = LoggerFactory.getLogger(DefaultTransportRateLimitService.class);
    private static final DummyTransportRateLimit ALLOW = new DummyTransportRateLimit();
    private final ConcurrentMap<TenantId, Boolean> tenantAllowed = new ConcurrentHashMap<TenantId, Boolean>();
    private final ConcurrentMap<TenantId, Set<DeviceId>> tenantDevices = new ConcurrentHashMap<TenantId, Set<DeviceId>>();
    private final ConcurrentMap<TenantId, EntityTransportRateLimits> perTenantLimits = new ConcurrentHashMap<TenantId, EntityTransportRateLimits>();
    private final ConcurrentMap<DeviceId, EntityTransportRateLimits> perDeviceLimits = new ConcurrentHashMap<DeviceId, EntityTransportRateLimits>();
    private final Map<InetAddress, InetAddressRateLimitStats> ipMap = new ConcurrentHashMap<InetAddress, InetAddressRateLimitStats>();
    private final TransportTenantProfileCache tenantProfileCache;
    @Value(value="${transport.rate_limits.ip_limits_enabled:false}")
    private boolean ipRateLimitsEnabled;
    @Value(value="${transport.rate_limits.max_wrong_credentials_per_ip:10}")
    private int maxWrongCredentialsPerIp;
    @Value(value="${transport.rate_limits.ip_block_timeout:60000}")
    private long ipBlockTimeout;

    public DefaultTransportRateLimitService(TransportTenantProfileCache tenantProfileCache) {
        this.tenantProfileCache = tenantProfileCache;
    }

    @Override
    public EntityType checkLimits(TenantId tenantId, DeviceId deviceId, int dataPoints) {
        if (!this.tenantAllowed.getOrDefault(tenantId, Boolean.TRUE).booleanValue()) {
            return EntityType.API_USAGE_STATE;
        }
        if (!this.checkEntityRateLimit(dataPoints, this.getTenantRateLimits(tenantId))) {
            return EntityType.TENANT;
        }
        if (!this.checkEntityRateLimit(dataPoints, this.getDeviceRateLimits(tenantId, deviceId))) {
            return EntityType.DEVICE;
        }
        return null;
    }

    private boolean checkEntityRateLimit(int dataPoints, EntityTransportRateLimits tenantLimits) {
        if (dataPoints > 0) {
            return tenantLimits.getTelemetryMsgRateLimit().tryConsume() && tenantLimits.getTelemetryDataPointsRateLimit().tryConsume(dataPoints);
        }
        return tenantLimits.getRegularMsgRateLimit().tryConsume();
    }

    @Override
    public void update(TenantProfileUpdateResult update) {
        log.info("Received tenant profile update: {}", (Object)update.getProfile());
        EntityTransportRateLimits tenantRateLimitPrototype = this.createRateLimits(update.getProfile(), true);
        EntityTransportRateLimits deviceRateLimitPrototype = this.createRateLimits(update.getProfile(), false);
        for (TenantId tenantId : update.getAffectedTenants()) {
            this.mergeLimits(tenantId, tenantRateLimitPrototype, this.perTenantLimits::get, this.perTenantLimits::put);
            ((Set)this.tenantDevices.get(tenantId)).forEach(deviceId -> this.mergeLimits((EntityId)deviceId, deviceRateLimitPrototype, this.perDeviceLimits::get, this.perDeviceLimits::put));
        }
    }

    @Override
    public void update(TenantId tenantId) {
        EntityTransportRateLimits tenantRateLimitPrototype = this.createRateLimits(this.tenantProfileCache.get(tenantId), true);
        EntityTransportRateLimits deviceRateLimitPrototype = this.createRateLimits(this.tenantProfileCache.get(tenantId), false);
        this.mergeLimits(tenantId, tenantRateLimitPrototype, this.perTenantLimits::get, this.perTenantLimits::put);
        ((Set)this.tenantDevices.get(tenantId)).forEach(deviceId -> this.mergeLimits((EntityId)deviceId, deviceRateLimitPrototype, this.perDeviceLimits::get, this.perDeviceLimits::put));
    }

    @Override
    public void remove(TenantId tenantId) {
        this.perTenantLimits.remove(tenantId);
        this.tenantDevices.remove(tenantId);
    }

    @Override
    public void remove(DeviceId deviceId) {
        this.perDeviceLimits.remove(deviceId);
        this.tenantDevices.values().forEach(set -> set.remove(deviceId));
    }

    @Override
    public void update(TenantId tenantId, boolean allowed) {
        this.tenantAllowed.put(tenantId, allowed);
    }

    @Override
    public boolean checkAddress(InetSocketAddress address) {
        if (!this.ipRateLimitsEnabled) {
            return true;
        }
        InetAddressRateLimitStats stats = this.ipMap.computeIfAbsent(address.getAddress(), a -> new InetAddressRateLimitStats());
        return !stats.isBlocked() || stats.getLastActivityTs() + this.ipBlockTimeout < System.currentTimeMillis();
    }

    @Override
    public void onAuthSuccess(InetSocketAddress address) {
        if (!this.ipRateLimitsEnabled) {
            return;
        }
        InetAddressRateLimitStats stats = this.ipMap.computeIfAbsent(address.getAddress(), a -> new InetAddressRateLimitStats());
        stats.getLock().lock();
        try {
            stats.setLastActivityTs(System.currentTimeMillis());
            stats.setFailureCount(0);
            if (stats.isBlocked()) {
                stats.setBlocked(false);
                log.info("[{}] IP address un-blocked due to correct credentials.", (Object)address.getAddress());
            }
        }
        finally {
            stats.getLock().unlock();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void onAuthFailure(InetSocketAddress address) {
        if (!this.ipRateLimitsEnabled) {
            return;
        }
        InetAddressRateLimitStats stats = this.ipMap.computeIfAbsent(address.getAddress(), a -> new InetAddressRateLimitStats());
        stats.getLock().lock();
        try {
            stats.setLastActivityTs(System.currentTimeMillis());
            int failureCount = stats.getFailureCount() + 1;
            stats.setFailureCount(failureCount);
            if (failureCount >= this.maxWrongCredentialsPerIp) {
                log.info("[{}] IP address blocked due to constantly wrong credentials.", (Object)address.getAddress());
                stats.setBlocked(true);
            }
        }
        finally {
            stats.getLock().unlock();
        }
    }

    @Override
    public void invalidateRateLimitsIpTable(long sessionInactivityTimeout) {
        if (!this.ipRateLimitsEnabled) {
            return;
        }
        long currentTime = System.currentTimeMillis();
        long expTime = currentTime - Math.max(sessionInactivityTimeout, this.ipBlockTimeout);
        for (Map.Entry<InetAddress, InetAddressRateLimitStats> entry : this.ipMap.entrySet()) {
            InetAddressRateLimitStats stats = entry.getValue();
            if (stats.getLastActivityTs() < expTime) {
                log.debug("[{}] IP address removed due to session inactivity timeout.", (Object)entry.getKey());
                this.ipMap.remove(entry.getKey());
                continue;
            }
            if (!stats.isBlocked() || stats.getLastActivityTs() + this.ipBlockTimeout >= currentTime) continue;
            log.info("[{}] IP address unblocked due ip block timeout.", (Object)entry.getKey());
            stats.setBlocked(false);
        }
    }

    private <T extends EntityId> void mergeLimits(T entityId, EntityTransportRateLimits newRateLimits, Function<T, EntityTransportRateLimits> getFunction, BiConsumer<T, EntityTransportRateLimits> putFunction) {
        EntityTransportRateLimits oldRateLimits = getFunction.apply(entityId);
        if (oldRateLimits == null) {
            if (EntityType.TENANT.equals((Object)entityId.getEntityType())) {
                log.info("[{}] New rate limits: {}", entityId, (Object)newRateLimits);
            } else {
                log.debug("[{}] New rate limits: {}", entityId, (Object)newRateLimits);
            }
            putFunction.accept(entityId, newRateLimits);
        } else {
            EntityTransportRateLimits updated = this.merge(oldRateLimits, newRateLimits);
            if (updated != null) {
                if (EntityType.TENANT.equals((Object)entityId.getEntityType())) {
                    log.info("[{}] Updated rate limits: {}", entityId, (Object)updated);
                } else {
                    log.debug("[{}] Updated rate limits: {}", entityId, (Object)updated);
                }
                putFunction.accept(entityId, updated);
            }
        }
    }

    private EntityTransportRateLimits merge(EntityTransportRateLimits oldRateLimits, EntityTransportRateLimits newRateLimits) {
        boolean telemetryDataPointUpdate;
        boolean regularUpdate = !oldRateLimits.getRegularMsgRateLimit().getConfiguration().equals(newRateLimits.getRegularMsgRateLimit().getConfiguration());
        boolean telemetryMsgRateUpdate = !oldRateLimits.getTelemetryMsgRateLimit().getConfiguration().equals(newRateLimits.getTelemetryMsgRateLimit().getConfiguration());
        boolean bl = telemetryDataPointUpdate = !oldRateLimits.getTelemetryDataPointsRateLimit().getConfiguration().equals(newRateLimits.getTelemetryDataPointsRateLimit().getConfiguration());
        if (regularUpdate || telemetryMsgRateUpdate || telemetryDataPointUpdate) {
            return new EntityTransportRateLimits(regularUpdate ? DefaultTransportRateLimitService.newLimit(newRateLimits.getRegularMsgRateLimit().getConfiguration()) : oldRateLimits.getRegularMsgRateLimit(), telemetryMsgRateUpdate ? DefaultTransportRateLimitService.newLimit(newRateLimits.getTelemetryMsgRateLimit().getConfiguration()) : oldRateLimits.getTelemetryMsgRateLimit(), telemetryDataPointUpdate ? DefaultTransportRateLimitService.newLimit(newRateLimits.getTelemetryDataPointsRateLimit().getConfiguration()) : oldRateLimits.getTelemetryDataPointsRateLimit());
        }
        return null;
    }

    private EntityTransportRateLimits createRateLimits(TenantProfile tenantProfile, boolean tenant) {
        TenantProfileData profileData = tenantProfile.getProfileData();
        DefaultTenantProfileConfiguration profile = (DefaultTenantProfileConfiguration)profileData.getConfiguration();
        if (profile == null) {
            return new EntityTransportRateLimits(ALLOW, ALLOW, ALLOW);
        }
        TransportRateLimit regularMsgRateLimit = DefaultTransportRateLimitService.newLimit(tenant ? profile.getTransportTenantMsgRateLimit() : profile.getTransportDeviceMsgRateLimit());
        TransportRateLimit telemetryMsgRateLimit = DefaultTransportRateLimitService.newLimit(tenant ? profile.getTransportTenantTelemetryMsgRateLimit() : profile.getTransportDeviceTelemetryMsgRateLimit());
        TransportRateLimit telemetryDpRateLimit = DefaultTransportRateLimitService.newLimit(tenant ? profile.getTransportTenantTelemetryDataPointsRateLimit() : profile.getTransportTenantTelemetryDataPointsRateLimit());
        return new EntityTransportRateLimits(regularMsgRateLimit, telemetryMsgRateLimit, telemetryDpRateLimit);
    }

    private static TransportRateLimit newLimit(String config) {
        return StringUtils.isEmpty((String)config) ? ALLOW : new SimpleTransportRateLimit(config);
    }

    private EntityTransportRateLimits getTenantRateLimits(TenantId tenantId) {
        EntityTransportRateLimits limits = (EntityTransportRateLimits)this.perTenantLimits.get(tenantId);
        if (limits == null) {
            limits = this.createRateLimits(this.tenantProfileCache.get(tenantId), true);
            this.perTenantLimits.put(tenantId, limits);
        }
        return limits;
    }

    private EntityTransportRateLimits getDeviceRateLimits(TenantId tenantId, DeviceId deviceId) {
        EntityTransportRateLimits limits = (EntityTransportRateLimits)this.perDeviceLimits.get(deviceId);
        if (limits == null) {
            limits = this.createRateLimits(this.tenantProfileCache.get(tenantId), false);
            this.perDeviceLimits.put(deviceId, limits);
            this.tenantDevices.computeIfAbsent(tenantId, id -> ConcurrentHashMap.newKeySet()).add(deviceId);
        }
        return limits;
    }
}

