Android On-device AI Memory Management: Model Loading Peaks, Tensor Lifetimes, and KV Cache Reclaim

When deploying larger on-device models, the first bottleneck is usually not inference speed. It is memory. A quantized 3B-parameter model can push peak memory to 4-5 GB during loading. On an Android device with 8 GB RAM, being killed by the system OOM path is common. This article records the memory-management strategies I accumulated while engineering on-device inference: model loading, tensor lifetimes, and KV cache reclamation, assembled into a practical low-memory-pressure approach.

Model loading: why mmap is standard on device

The conventional approach is to allocate heap memory with malloc and read all model weights into it. A 2 GB model file occupies 2 GB on the heap, and I/O buffers can add hundreds of megabytes. Peak memory can reach 1.5x the model size.

With memory mapping, or mmap, the kernel maps the file directly into virtual address space and loads physical pages on demand.

// Traditional loading: peak = file size + I/O buffer
char* buffer = new char[file_size];
fread(buffer, 1, file_size, fp);  // Physical memory +2 GB

// mmap loading: virtual address mapping, physical pages allocated on demand
int fd = open("model.gguf", O_RDONLY);
void* mapped = mmap(nullptr, file_size, PROT_READ,
                    MAP_SHARED, fd, 0);
// Physical memory is only used for pages that have actually been touched

mmap brings two benefits. First, only pages that are actually read are loaded from disk into physical memory; untouched weights do not consume RAM. Second, when the system is under memory pressure, clean mapped pages can be reclaimed directly and paged back in later.

But mmap is not a silver bullet. If the model structure causes inference to randomly access all weights, it degenerates toward full loading. In my measurements, dense matrix operations in MLP layers had poor page locality, so on-demand behavior helped less there. Embedding and attention layers only touched the weight regions needed by the current sequence, and the effect was much better.

One trap I hit was file sharding and mmap alignment. If exported model weights are not aligned to page boundaries, usually 4 KB, cross-page access after mapping can cost up to 30% performance. During export, add page-aligned padding for weight tensors and combine it with madvise(MADV_SEQUENTIAL) to give the kernel a prefetch hint.

// Set prefetch strategy for the mapped region
madvise(mapped, file_size, MADV_SEQUENTIAL);  // Sequential access, prefetch early
// Hint after inference is complete
madvise(mapped, file_size, MADV_DONTNEED);    // Actively reclaim physical pages

Tensor lifetimes: from “keep everything” to “free after last use”

Most inference frameworks roughly follow this flow: read input, compute layer by layer, keep all intermediate tensors, then output the result. Tensors generated by every layer pile up in memory until the whole inference finishes.

That is not always necessary. Attention KV tensors must be preserved because later tokens need them for context. Intermediate FFN tensors become dead data after the current token is computed.

My approach is to inject tensor lifecycle markers so the allocator can understand, at operator-scheduling time, which tensors can be released early:

struct TensorLifecycle {
    int last_consumer_op;  // Index of the last operator that consumes this tensor
    bool is_persistent;    // Whether it survives across forward passes, such as KV cache
};

// Inject memory reclamation into the operator scheduler
void on_op_complete(int op_index, TensorPool& pool) {
    for (auto& tensor : pool.tensors) {
        if (tensor.last_consumer_op == op_index 
            && !tensor.is_persistent) {
            pool.free(tensor);  // Reclaim immediately, do not wait for frame end
        }
    }
}

After this optimization, peak intermediate-tensor memory for a Llama-style 3B model dropped from 1.2 GB to 360 MB.

Two implementation details still need attention.

Operator fusion: combine kernels such as LayerNorm plus MatMul so intermediate results stay in registers instead of landing in memory. This is especially effective around normalization operations before and after attention.

Memory-pool fragmentation: frequent allocate/free cycles create fragmentation, and actual usage can end up 20-30% higher than the theoretical value. I used BFC, or Best-Fit with Coalescing, so adjacent free blocks merge automatically. Fragmentation stayed under 5% for a single inference.

KV cache: sliding windows are the survival line on low-memory devices

In long-context scenarios, KV cache is the real memory killer. For a 3B model with 32 layers and head dimension 128, a 4096-token input uses about 512 MB of KV cache. At 32K tokens, it heads toward 4 GB.

The problem with static truncation

Limiting sequence length is the most direct fix, such as hard-truncating to 2048 tokens. For question answering and code completion, however, truncating context means the model loses key information, and answer quality can fall off a cliff.

Sliding-window reclamation

A sliding window is more practical: keep only the latest N tokens of KV data and reclaim the rest.

void sliding_window_reclaim(KVCache& cache, int window_size) {
    if (cache.seq_len > window_size) {
        int evict_count = cache.seq_len - window_size;
        // Mark KV slots in [0, evict_count) as reusable.
        // In a real implementation this is a ring-buffer pointer move,
        // so it does not trigger a memory copy.
        cache.head = (cache.head + evict_count) % cache.capacity;
        cache.seq_len = window_size;
    }
}

The key is a ring buffer. If every reclaim does a memmove to shift data forward, CPU overhead becomes unacceptable. Use a head pointer to mark the start of the current window. Reclamation becomes pointer arithmetic and finishes in O(1).

Window size depends on the scenario. Code completion is often fine with 512-1024 tokens because it relies mostly on local context. Document QA needs 2048-4096 tokens because dependencies cross paragraph boundaries. I usually use a dynamic window: default to 1024 and temporarily expand to 2048 when attention weights show stronger long-distance dependency.

Layer-wise decay reclamation

Uniform sliding-window reclamation has a weakness: shallow attention layers focus more on local information, while deeper layers rely more on global context. Using the same window for every layer wastes memory.

Layer-wise decay is more precise: use smaller windows for shallow layers and larger windows for deeper layers. For evicted KV, instead of hard-dropping it, decay it and merge it into neighboring slots:

void layer_wise_evict(KVCache& cache, int layer_id, int max_tokens) {
    int layer_window = max_tokens * (0.5 + 0.5 * layer_id / total_layers);
    // Shallow layers keep fewer tokens; deeper layers keep more
    
    if (cache.seq_len > layer_window) {
        // Apply weighted decay instead of direct discard
        float decay = exp(-(cache.seq_len - layer_window) * 0.01);
        cache.evict_with_decay(cache.seq_len - layer_window, decay);
    }
}

On the LongBench benchmark, compared with equal-length sliding windows, this improved ROUGE-L for long-document QA by about 4.2% while reducing peak memory by 35%.

Three trade-offs in production

mmap loading versus warmup strategy. Page faults during first access can slow down first-token latency, especially on cold start. The compromise is to call madvise(MADV_WILLNEED) for the first three layers at load time and leave the rest on demand. First-token latency rises by only 100-200 ms, while peak memory stays controlled.

Quantized KV cache storage. Quantizing KV cache from FP16 to INT8 cuts memory in half, with an accuracy cost. For code generation, INT8 KV cache reduced pass@1 by less than 1%, which was acceptable. For complex reasoning, degradation reached 3-5%. I switch dynamically by task: code and summarization use INT8; translation and reasoning stay on FP16.

Process priority and LMK. Even after these optimizations, Android’s Low Memory Killer can still kill the inference process in extreme cases. The mitigation is to request android:largeHeap="true", raise the inference service to IMPORTANCE_FOREGROUND, and listen for onTrimMemory(TRIM_MEMORY_RUNNING_CRITICAL) to actively shrink the KV cache to a 256-token emergency window.

With this combination, a 3B model can run 8192-token long-text inference stably on a 6 GB RAM device, and a 4 GB RAM device can support 4096-token chat scenarios. Moving on-device AI from “it can run” to “it is usable” requires serious memory engineering.