Part 3: Gradual Magnitude Pruning (GMP) Hyperparameters
TL;DR: To facilitate the GMP process when pruning a network, several hyperparameters must be defined. These include general hyperparameters such as learning rate, pruning update frequency, and pruning schedule function in addition to the sparsity per layer. All hyperparameters affect end level recovery, loss, and performance.
Reading time: 5 minutes, 5 seconds
Welcome to Part 3 in Neural Magic’s five-part blog series on pruning in machine learning. In case you missed it, Part 1 gave a pruning overview, detailed the difference between structured vs. unstructured pruning, and described commonly used algorithms, including General Magnitude Pruning (GMP). In Part 2 we argued that GMP is one of the best pruning approaches to use due to its simplicity, ease of use, and performance on a wide variety of models. And we discussed three general stages to GMP: stabilization, pruning, and fine-tuning.
To facilitate the General Magnitude Pruning (GMP) process when pruning a network, several hyperparameters must be defined. These include general hyperparameters such as learning rate, pruning update frequency, and pruning schedule function in addition to the sparsity per layer, which we’ll describe in full detail in Part 4. All hyperparameters affect end level recovery, loss, and performance.
An important parameter to select is the learning rate to use during the stabilization and pruning phases. Picking a learning rate that is too high can quickly lead to the model diverging or failing to train while pruning. Selecting a learning rate that is too low will fail to regularize the weights properly and will not allow the pruned model to generalize.
If you are using adaptive techniques in your optimizer, such as with Adam, you should generally keep the same configuration and learning rate as initially used to train the network. Adaptive methods can give suboptimal generalization for both training and pruning, though, so we do recommend using a properly tuned SGD schedule.
For SGD, a learning rate roughly in the middle of your start and end learning rates for the training process works well. For example, a standard ImageNet training schedule goes through three steps: 0.1, 0.01, and 0.001. Therefore a typical pruning schedule will run at 0.01 for ImageNet. This is not a hard rule, though; you will need to adjust your learning rate accordingly. Specifically, if you see a wide generalization gap after finishing pruning (where training loss is much lower than validation), increasing the learning rate can help. If both the training and validation loss are higher than the baseline, then decreasing it can help.
After pruning is complete, it is vital to step the learning rate multiple times in the fine-tuning stage (typically by one-tenth of the previous value for each step). This allows the network to continue to lower the loss function and generalize as usual. Given that the network is now much smaller, though, the loss function should converge more quickly than it did in the baseline model. Additionally, since the model is smaller, it can continue to learn at lower learning rates than did the dense baseline model. Adding a step or two past where it would typically stop for the baseline training helps significantly.
Learning Rate Analysis
A learning rate analysis can give a rough approximation for the best learning rate to prune at as well. To run an analysis, start by using the trained model and a newly constructed optimizer with a small learning rate, such as 1e-9. Continue the training process to run batches through the optimizer and gradually increase the learning rate until it reaches or exceeds 1.0. There will be an inflection point where the model begins to diverge from its trained solution significantly. This point is generally an ideal learning rate with which to prune the model. More information on the learning rate analysis can be found in the cyclic learning rate paper. Additionally, for PyTorch users, an API is available in the neuralmagicML package:
from neuralmagicML.pytorch.recal import lr_loss_sensitivity help(lr_loss_sensitivity)
Pruning Update Frequency
The pruning update frequency defines how often a pruning step is taken between the start and end of the pruning stage. Updating only once or twice during the whole pruning process makes it very close to one-shot pruning with retraining (that is, we will cut out too many essential weights). On the other end, updating after every batch step is both expensive (the weights are sorted each time to figure out which to cut) and may not allow the network time to see enough examples to renormalize since the previous cut properly.
Generally, a safe number that works well is stepping once per epoch or a few times per epoch. For the ImageNet example that prunes over 35 epochs, once per epoch works well. However, if once per epoch ends with only ten pruning steps, steps should be taken more frequently. A minimum of 30 steps is a good rule to have in place, and more frequent steps will not hurt, provided it is not after every batch.
Pruning Schedule Function
The pruning schedule function is an easy choice and included here for completeness. The Neural Magic ML Tooling defaults to a cubic function where early steps are much larger than the final pruning steps in the function. Experimentally, this works better than a linear function where each step removes the same amount of weights. In general, this is because as we get closer to our target sparsity, most of the weights are relatively large. It is important to allow for more regularization as compared to the number of weights being cut.
Next up: Sparsity per Layer Hyperparameter
In addition to the general hyperparameters described above, the sparsity to target per layer is arguably the most critical hyperparameter you can set. It is the one that controls the amount of performance speedup for a network and most strongly correlates with the likelihood of accuracy recovery after pruning. Stay on the lookout next week for a deep down on this critical hyperparameter.
Learn More About Pruning
In our forth post in this series, we’ll elaborate on the most critical hyperparamater: sparsity per layer. If you have questions or would like us to help you to speed up your neural networks with pruning, reach out to us!