/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.allocator.impl;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MemoryTracker {
    private static final Logger log = LoggerFactory.getLogger(MemoryTracker.class);
    private List<AtomicLong> allocatedPerDevice = new ArrayList<AtomicLong>();
    private List<AtomicLong> cachedPerDevice = new ArrayList<AtomicLong>();
    private List<AtomicLong> totalPerDevice = new ArrayList<AtomicLong>();
    private List<AtomicLong> freePerDevice = new ArrayList<AtomicLong>();
    private List<AtomicLong> workspacesPerDevice = new ArrayList<AtomicLong>();
    private AtomicLong cachedHost = new AtomicLong(0L);
    private AtomicLong allocatedHost = new AtomicLong(0L);
    private static final MemoryTracker INSTANCE = new MemoryTracker();

    public MemoryTracker() {
        for (int i = 0; i < Nd4j.getAffinityManager().getNumberOfDevices(); ++i) {
            this.allocatedPerDevice.add(i, new AtomicLong(0L));
            this.cachedPerDevice.add(i, new AtomicLong(0L));
            this.workspacesPerDevice.add(i, new AtomicLong(0L));
            this.totalPerDevice.add(i, new AtomicLong(NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceTotalMemory(i)));
            AtomicLong f = new AtomicLong(NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceFreeMemory(i));
            this.freePerDevice.add(i, f);
        }
    }

    public static MemoryTracker getInstance() {
        return INSTANCE;
    }

    public long getAllocatedAmount(int deviceId) {
        return this.allocatedPerDevice.get(deviceId).get();
    }

    public long getCachedAmount(int deviceId) {
        return this.cachedPerDevice.get(deviceId).get();
    }

    public long getCachedHostAmount() {
        return this.cachedHost.get();
    }

    public long getAllocatedHostAmount() {
        return this.allocatedHost.get();
    }

    public long getActiveHostAmount() {
        return this.getAllocatedHostAmount() + this.getCachedHostAmount();
    }

    public void incrementCachedHostAmount(long numBytes) {
        this.cachedHost.addAndGet(numBytes);
    }

    public void incrementAllocatedHostAmount(long numBytes) {
        this.allocatedHost.addAndGet(numBytes);
    }

    public void decrementCachedHostAmount(long numBytes) {
        this.cachedHost.addAndGet(-numBytes);
    }

    public void decrementAllocatedHostAmount(long numBytes) {
        this.allocatedHost.addAndGet(-numBytes);
    }

    public long getWorkspaceAllocatedAmount(int deviceId) {
        return this.workspacesPerDevice.get(deviceId).get();
    }

    public long getTotalMemory(int deviceId) {
        return this.totalPerDevice.get(deviceId).get();
    }

    public long getFreeMemory(int deviceId) {
        return this.freePerDevice.get(deviceId).get();
    }

    public long getApproximateFreeMemory(int deviceId) {
        long externalAllocations = this.getTotalMemory(deviceId) - this.getFreeMemory(deviceId);
        long active = this.getActiveMemory(deviceId);
        long free = this.getTotalMemory(deviceId) - (active + externalAllocations);
        return free;
    }

    public long getPreciseFreeMemory(int deviceId) {
        long extFree = NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceFreeMemory(deviceId);
        return extFree;
    }

    public long getUsableMemory(int deviceId) {
        return this.getTotalMemory(deviceId) - this.getFreeMemory(deviceId);
    }

    public long getActiveMemory(int deviceId) {
        return this.getWorkspaceAllocatedAmount(deviceId) + this.getAllocatedAmount(deviceId) + this.getCachedAmount(deviceId);
    }

    public long getManagedMemory(int deviceId) {
        return this.getAllocatedAmount(deviceId) + this.getCachedAmount(deviceId);
    }

    public void incrementAllocatedAmount(int deviceId, long memoryAdded) {
        this.allocatedPerDevice.get(deviceId).getAndAdd(this.matchBlock(memoryAdded));
    }

    public void incrementCachedAmount(int deviceId, long memoryAdded) {
        this.cachedPerDevice.get(deviceId).getAndAdd(this.matchBlock(memoryAdded));
    }

    public void decrementAllocatedAmount(int deviceId, long memorySubtracted) {
        this.allocatedPerDevice.get(deviceId).getAndAdd(-this.matchBlock(memorySubtracted));
    }

    public void decrementCachedAmount(int deviceId, long memorySubtracted) {
        this.cachedPerDevice.get(deviceId).getAndAdd(-this.matchBlock(memorySubtracted));
    }

    public void incrementWorkspaceAllocatedAmount(int deviceId, long memoryAdded) {
        this.workspacesPerDevice.get(deviceId).getAndAdd(this.matchBlock(memoryAdded));
    }

    public void decrementWorkspaceAmount(int deviceId, long memorySubtracted) {
        this.workspacesPerDevice.get(deviceId).getAndAdd(-this.matchBlock(memorySubtracted));
    }

    private void setTotalPerDevice(int device, long memoryAvailable) {
        this.totalPerDevice.add(device, new AtomicLong(memoryAvailable));
    }

    private long matchBlock(long numBytes) {
        return numBytes;
    }
}

