Learning Low-Rank Tensor Cores with Probabilistic ℓ0-Regularized Rank Selection for Model Compression
Learning Low-Rank Tensor Cores with Probabilistic ℓ0-Regularized Rank Selection for Model Compression
Tianxiao Cao, Lu Sun, Canh Hao Nguyen, Hiroshi Mamitsuka
Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence
Main Track. Pages 3780-3788.
https://doi.org/10.24963/ijcai.2024/418
Compressing deep neural networks is of great importance for real-world applications on resource-constrained devices. Tensor decomposition is one promising answer that retains the functionality and most of the expressive power of the original deep models by replacing the weights with their decomposed cores. Decomposition with optimal ranks can achieve a good compression-accuracy trade-off, but it is expensive to optimize due to its discrete and combinatorial nature. A common practice is to set all ranks equal and tune one hyperparameter, but it may significantly harm the flexibility and generalization. In this paper, we propose a novel automatic rank selection method for deep model compression that allows learning model weights and decomposition ranks simultaneously. We propose to penalize the ℓ0 (quasi-)norm of the slices of decomposed tensor cores during model training. To avoid combinatorial optimization, we develop a probabilistic formulation and apply an approximate Bernoulli gate to each of the slices of tensor cores, which can be implemented in an end-to-end and scalable framework via gradient descent. It enables the automatic rank selection to be incorporated with arbitrary tensor decompositions and neural network layers such as linear layers, convolutional layers, and embedding layers. Comprehensive experiments on various tasks, including image classification, text sentiment classification, and neural machine translation, demonstrate the superior effectiveness of the proposed method over baselines.
Keywords:
Machine Learning: ML: Matrix/tensor methods
Machine Learning: ML: Learning sparse models