Today's article comes from the IOP journal of Machine Learning Science and Technology. The authors are Bae et al., from Seoul National University, in South Korea. In this paper they show us a new technique for overcoming latent gradient bias. Let's see if it works.
"Have you tried turning it off and on again?". If you ever watched that British sitcom “The IT Crowd” from the mid-2000’s you’ll recognize this running joke. The characters staff an IT help desk, and the first thing they do when they answer the phone is ask that question. “Hello, IT. Have you tried turning it off and on again?”. This joke is funny because it’s real. There’s just no two-ways around it, when your phone, or your laptop, or your server is misbehaving there’s a decent chance that simply turning it off and on again will fix it. The question we’re asking today is “Does this apply to A.I.?” Specifically, when you’re training a Machine Learning model does turning the learning process off and on again improve the end result?
The authors call this idea “stochastic resetting”, and in order to understand how it works we first need to understand one of the pillar-concepts in ML: SGD. Stochastic Gradient Descent. So today we’re going to deep-dive on that, and get you fully familiar with that concept. Once that’s under our belts, we’ll talk about something called latent gradient bias, which is the problem the authors were trying to solve, and then we’ll turn to their idea: stochastic resetting. We’ll walk through the logic of it, see how they implemented it, and then find out if it really worked…or not. Let’s jump in:
Here’s everything you need to know about Stochastic Gradient Descent. Buckle up. In machine learning and optimization, many problems reduce to minimizing a function (typically a loss function) that measures how far a model’s predictions are from the correct answers. While some functions have exact solutions, many loss landscapes are too complex for that, especially in high-dimensional spaces. In these cases, we don’t shoot for the exact target, we just try to get closer and closer to the neighborhood it lives in. We use iterative methods that can gradually improve our solution by moving toward lower values of the loss. Gradient descent is one such method, and it is built around the geometric intuition that if we’re standing on a hilly surface and want to get to the lowest point, we should take steps in the direction that most steeply decreases elevation. Mathematically, this is the negative of the gradient, a vector that points in the direction of the steepest increase in a function's value.
The gradient itself is computed from the partial derivatives of the function with respect to each parameter we wish to optimize. It tells us how much the function will increase or decrease if we make a small change in each parameter, and thus guides how to adjust those parameters to reduce the loss. In basic gradient descent, we compute the gradient of the entire loss function using all the training data and then update the model's parameters in the opposite direction. This method is called batch gradient descent because it processes the full batch of data at every step. It produces stable and accurate updates, but it becomes inefficient and slow when the dataset is large, since calculating the full gradient is computationally expensive.
Stochastic gradient descent, or SGD, addresses this problem by approximating the full gradient with a much cheaper estimate. Rather than using the entire dataset, SGD selects a single data point at random, computes the gradient of the loss just for that point, and updates the model accordingly. This makes each step very fast, allowing the algorithm to take many small steps more quickly. However, because the gradient is based on just one data point, it is noisy: the direction may not always point exactly downhill, and individual steps may even briefly increase the loss. Over time, though, the randomness averages out, and the algorithm tends to make progress toward the minimum.
This noisy behavior is both a weakness and a strength. The erratic updates can slow convergence and make the path toward the minimum less direct. This is especially true when the learning rate (the factor that scales how big a step is taken) has not been carefully chosen. If the learning rate is too high, the algorithm might overshoot the minimum and diverge; if it's too low, convergence becomes painfully slow. But the randomness introduced by SGD can help avoid certain pitfalls of deterministic methods. In non-convex loss landscapes, which are common in deep learning, SGD can sometimes escape local minima or saddle points that would trap traditional gradient descent.
In practice though, rather than using a single data point, it is common to use a small subset of the data at each step. This approach, known as mini-batch SGD, balances the stability of batch methods with the speed of stochastic updates. Mini-batches provide a better estimate of the gradient than a single point, reducing the noise while still offering computational efficiency. This method also allows better use of parallel hardware like GPUs, which can process batches efficiently.
Although the core idea of SGD is simple, its performance is highly sensitive to the learning rate. A fixed learning rate might work for small problems but often needs to be adjusted during training for larger models. It’s common to start with a higher learning rate and gradually reduce it (a process known as learning rate scheduling) to allow faster initial progress and finer adjustments later. This helps SGD stabilize near a minimum, since the noise becomes harder to control if the step size remains large.
Because of its simplicity and efficiency, SGD is used as the foundation for more advanced optimizers. These include:
These methods retain the core of SGD but improve convergence in challenging settings by automatically tuning step sizes or directions.
Now, with that all under our belt, we need to wrap our heads around the problem that the authors were trying to solve in this paper. It’s called latent gradient bias. It arises when the gradient estimates used in SGD are systematically distorted due to correlations between the data sampling procedure and the structure of the loss landscape. Although SGD is typically assumed to produce unbiased gradient estimates (meaning that the average of many noisy updates points in the true gradient direction) this assumption breaks down when certain data points consistently exert more influence on the updates than others. The result is that the optimizer may drift toward regions of parameter space not because they minimize the true expected loss, but because of persistent asymmetries in the gradients sampled. This effect is often subtle and not due to implementation error, but rather to structural properties of the data, model, or loss function.
For example, consider training a language model on a corpus where shorter sentences are overrepresented. If the model updates its parameters using mini-batches drawn randomly from the full dataset, shorter and syntactically simpler examples will dominate the early gradient estimates. These examples might favor certain token predictions or structural patterns that are not representative of the broader language distribution. As a result, the model parameters will be nudged in directions that overfit these simpler forms. This doesn’t just introduce noise, it creates a directional bias in the gradient field, pulling the optimization process toward a region that minimizes loss on an unbalanced subset of the data. Even as training continues and more varied examples appear, the model may be stuck in a basin shaped by these early biases. This phenomenon, while hard to detect in raw training curves, has been observed in practice in domains like reinforcement learning and large-scale language modeling, where certain trajectories or token sequences can dominate gradient flow despite their relatively low importance in the target distribution.
In this paper, the author’s hypothesis was that they could use something called stochastic resetting to overcome this. This is the thing I’m referring to as “turning it off and on again”. Here’s how it works. The basic idea is deceptively simple: every so often, you throw away your current model parameters and start over. But instead of restarting training completely, you keep the same dataset, the same optimizer, and the same hyperparameters, the only thing that resets is the position in parameter space, back to a checkpoint. In practice, this means running SGD for a while, stopping at some predefined or randomly chosen interval, reinitializing the model weights either to fresh random values or more likely to a previously recorded checkpoint (typically the one with minimum validation loss), and then beginning SGD again. You repeat this cycle many times.
Each restart forces the optimizer to explore a new trajectory through the loss landscape, sampling different sequences of gradients. Over many restarts, the hope is that these independent trajectories will average out the biases that any single run might accumulate. Critically, resetting doesn't "undo" learning in the traditional sense, it doesn't remember previous weights or blend them together. Instead, it's a statistical trick: by resetting over and over again you're repeatedly sampling new paths through the same optimization problem, and can expose the optimizer to regions of parameter space that a single biased trajectory would never reach.
In the author's implementation, a reset cycle looks like this:
Between cycles, nothing is carried over (beyond the checkpoint) except the fact that you're still trying to minimize the same loss on the same data. So at this point you might be saying:
“Wait, why would this help? What’s the point if we don’t carry anything forward between the cycles?
Think of it just like resetting a video game to a previous checkpoint. Trying again gives you another opportunity to get it right. Maybe you went to the left last time, and that was a dead end. Your character died, you reset to a checkpoint, and now you get another chance. So you go the right (or in this case just some random other direction). And if you do this enough times, eventually you'll find the best way to win that level.
At least, that was their hypothesis. But did it work?
To test it they conducted a series of controlled experiments across synthetic and real-world datasets with varying levels of label noise. They began by establishing a clean baseline using standard SGD. By progressively increasing the noise rate, they recreated conditions under which gradient bias typically leads to memorization (where the model starts fitting noisy labels instead of learning generalizable features). For each setup, they compared performance between regular SGD and SGD augmented with stochastic resetting. In their implementation, resetting was applied probabilistically during training: at each iteration, the optimizer would either continue its update step or reset the model parameters back to a checkpoint from an earlier training phase, typically chosen from a point where validation loss was minimized.
The results were clear and consistent across the experiments. Models trained with stochastic resetting showed significantly improved generalization performance, particularly at higher noise levels. Notably, the benefits were most pronounced when the batch size was small and noise rate was high (conditions under which the gradient estimates were more stochastic and more vulnerable to bias). The authors also explored partial resetting, where only later layers of the network were reset. This approach often outperformed full resets, especially in convolutional architectures where earlier layers learn general features and later layers are more prone to memorizing noise.
The key takeaway is that stochastic resetting acts as a form of dynamic regularization. By interrupting trajectories that might otherwise drift too far toward memorizing corrupted data, resetting gives the model multiple chances to reorient toward more generalizable solutions. Crucially, this strategy doesn't require any new hyperparameters beyond the reset probability, and it works across architectures, from CNNs to vision transformers. The method is simple to implement, incurs negligible computational overhead, and is compatible with existing techniques for robust learning. In effect, stochastic resetting transforms SGD from a single, potentially biased optimization path into an ensemble of independent searches, each less likely to get stuck in the same flawed direction. It’s a simple modest intervention with outsized impact on learning.
If you want to go deeper, download the PDF. The authors offer a much more rigorous explanation and theoretical analysis of stochastic resetting than we can do here. The PDF also includes benchmark comparisons across multiple datasets and optimizers if you want to get deeper into their numbers.