GPU Memory Management

OOM Troubleshooting

(ex: Photo by

Gemini

on

(ex: Photo by

Gemini

on

(ex: Photo by

Gemini

on

How to Prevent OOM Errors in PyTorch Training

How to Prevent OOM Errors in PyTorch Training

Jan 15, 2026

6 min read

Maximilian Niroomand

CTO & Co-Founder at Lyceum Technologies

Nothing halts a training run faster than the dreaded CUDA Out of Memory error. As models grow and datasets expand, managing VRAM becomes a critical engineering discipline rather than a trial and error exercise.

The content in short

Implement Automatic Mixed Precision (AMP) and BF16 to reduce VRAM usage by up to 50% on modern GPUs.

Use Gradient Checkpointing to trade a small amount of compute time for significant memory savings in deep architectures.

Use distributed sharding techniques like FSDP to scale models across multiple GPUs by distributing optimizer states and parameters.

The RuntimeError: CUDA out of memory (OOM) is the most common bottleneck in modern deep learning. It often occurs not because your GPU lacks the total capacity, but because of how PyTorch allocates and fragments memory during the forward and backward passes. Efficient memory management is the difference between a successful deployment and a wasted compute budget. Lyceum provides high-performance orchestration that respects hardware limits. These engineering tactics optimize memory usage and keep training scripts stable on sovereign European infrastructure.

The Foundation: Precision and Batch Dynamics

The most immediate lever for memory management is the precision of your tensors. By default, PyTorch uses 32-bit floating-point (FP32) numbers. While precise, FP32 is memory-intensive. 16-bit precision (FP16 or BF16) reduces the memory footprint of weights and gradients by nearly 50 percent. NVIDIA's 2025 performance guidelines suggest Automatic Mixed Precision (AMP) to use lower precision where possible while maintaining FP32 for critical calculations, preventing numerical instability.

Batch size is the second primary factor. A larger batch size doesn't always lead to better convergence. In reality, the memory required scales linearly with the batch size. If you hit an OOM error, the first step is often to halve the batch size. However, if a large effective batch size is required for your specific optimizer (like LAMB or large-scale Adam), you should set up gradient accumulation. This technique allows you to process smaller micro-batches and only update the weights after several steps, effectively simulating a larger batch without the VRAM overhead.

  • Use BF16 on modern GPUs: If you are running on NVIDIA A100 or H100 GPUs, prefer BF16 over FP16. It offers a wider dynamic range and eliminates the need for loss scaling.

  • Gradient Accumulation: Instead of a batch size of 64, use a micro-batch of 8 and accumulate gradients over 8 steps.

  • Pin Memory: Always set pin_memory=True in your DataLoader to speed up the transfer from CPU to GPU.

Trading Compute for Capacity: Gradient Checkpointing

When your model architecture itself is too large for the GPU memory, even with a batch size of one, you must look at gradient checkpointing. During a standard forward pass, PyTorch stores all intermediate activations to calculate gradients during the backward pass. For deep networks, these activations consume the majority of your VRAM.

Gradient checkpointing works by discarding these intermediate activations during the forward pass and re-computing them when needed during the backward pass. This is a classic engineering trade-off: you save a massive amount of memory at the cost of roughly 20 to 30 percent additional computation time. In the context of 2025 research workflows, this is often the only way to fine-tune 70B+ parameter models on single or dual GPU nodes.

To set up this in PyTorch, you can use the torch.utils.checkpoint module. It is most effective when applied to the most memory-intensive layers, such as the transformer blocks in an LLM. By checkpointing every other block, you can often double your feasible sequence length or model depth without upgrading your hardware.

Managing the PyTorch Caching Allocator

PyTorch uses a caching allocator to speed up memory allocations. When you delete a tensor, the memory isn't immediately returned to the system; it stays in a pool for future allocations. This can lead to memory fragmentation, where the total free memory is sufficient, but there is no single contiguous block large enough for a new tensor. This is a frequent cause of OOM errors in long-running training loops.

While it is tempting to call torch.cuda.empty_cache() frequently, this can actually slow down your training because it forces the GPU to synchronize with the CPU. Instead, use it strategically. A common scenario is calling it after a validation loop or when switching between different phases of a training pipeline. For deeper insights, the torch.cuda.memory_summary() function provides a detailed breakdown of active vs. cached memory, helping you identify if fragmentation is your primary enemy.

Another often overlooked area is the .item() or .detach() methods. If you are logging loss values or metrics, ensure you are not inadvertently keeping the entire computation graph in memory by storing the tensor itself. Always use loss.item() to extract the scalar value for logging purposes.

Scaling Out: Distributed Memory Strategies

When a single GPU is no longer enough, you must move to distributed training. Traditional Data Parallel (DP) is often inefficient because it replicates the entire model on every GPU. For modern AI infrastructure, Fully Sharded Data Parallel (FSDP) or DeepSpeed's ZeRO (Zero Redundancy Optimizer) are the gold standards. These techniques shard the model states (parameters, gradients, and optimizer states) across all available GPUs in your cluster.

A 2025 Hugging Face report shows that sharding optimizer states alone can reduce memory usage eightfold when using 8 GPUs. This allows for the training of models that are significantly larger than the memory of any single card. Lyceum handles these sharded workloads natively, keeping data within European borders while compute scales horizontally.

  1. ZeRO-1: Shards optimizer states.

  2. ZeRO-2: Shards optimizer states and gradients.

  3. ZeRO-3: Shards everything, including model parameters.

Setting these up requires moving beyond simple training scripts to frameworks like PyTorch Lightning or Accelerate, which abstract the complexity of sharding while providing the performance benefits of a distributed backend.

Data Loading and CPU Offloading

Memory management isn't just about the GPU. If your CPU RAM is exhausted, it can lead to system-wide instability or slow data transfers that starve the GPU. Using num_workers in your DataLoader is essential for performance, but each worker consumes memory. If you are working with high-resolution images or massive text corpora, monitor your host memory closely.

CPU offloading works for extremely large models. Tools like DeepSpeed allow you to offload the optimizer states or even certain layers to the CPU RAM or NVMe storage. While this introduces latency, it enables the training of trillion-parameter models on hardware that would otherwise be insufficient. This flexibility maximizes available resources without relying on proprietary, closed-source hardware optimizations.

Common Mistakes and Decision Frameworks

Failing to clear the gradient buffer is a common mistake. If you forget optimizer.zero_grad(), gradients will accumulate indefinitely, leading to an OOM error within a few iterations. Another mistake is keeping large tensors in the global scope or within lists that are never cleared. Python's garbage collector might not trigger as often as you expect, especially when dealing with large CUDA objects.

When deciding which optimization to apply, consider the following framework:

  • Is the model small but the batch size large? Use Mixed Precision and Gradient Accumulation.

  • Is the model too large for one GPU? Use FSDP or ZeRO-3.

  • Is the sequence length the bottleneck? Use Gradient Checkpointing and Flash Attention.

  • Is the error intermittent? Check for memory fragmentation and use empty_cache() after validation steps.


FAQ

How does gradient accumulation help with OOM?

Gradient accumulation allows you to process smaller batches (micro-batches) that fit in memory and sum their gradients over multiple steps. You only update the model weights after reaching the desired effective batch size, effectively bypassing VRAM limits for large batches.

When should I use torch.utils.checkpoint?

Use it when your model is too deep to fit the activations of a single forward pass in memory. It is particularly effective for transformer-based models where you can checkpoint individual layers to save space at the cost of re-computing them during the backward pass.

Is it better to use DataParallel or DistributedDataParallel?

Always prefer DistributedDataParallel (DDP). DataParallel (DP) is single-process and suffers from significant overhead and unbalanced memory usage on the primary GPU. DDP is multi-process, faster, and more memory-efficient.

How can I monitor GPU memory usage in real-time?

You can use the 'nvidia-smi' command in the terminal or integrate 'GPUtil' in your Python script. For PyTorch-specific insights, 'torch.cuda.memory_allocated()' and 'torch.cuda.max_memory_allocated()' are useful for tracking peaks during training.

Does the optimizer affect memory usage?

Yes, significantly. Optimizers like Adam maintain two states for every model parameter (momentum and variance), effectively tripling the memory required for parameters. Using 8-bit optimizers (like those from the bitsandbytes library) can reduce this overhead by 75%.

Get to know more articles