Differentiable Mask for Pruning Convolutional and Recurrent Networks Ramchalam Kinattinkara Ramakrishnan*, Eyyüb Sari*, Vahid Partovi Nia Accelerated Neural Technology (Ant), Huawei Noah’s Ark Lab
Abstract Proposed Method Input Input Sub-network Sub-network Input Conv2D Conv2D Sub-network BatchNorm Conv2D Deep networks require massive computation and Let {(xi, yi) | i ∈ N} be a dataset of N samples BatchNorm Conv2D ReLU BatchNorm
such models need to be compressed to bring them with x representing the input vector and y the out- ReLU i i ReLU Filter Mask on edge devices. Most existing pruning tech- put vector. We consider a model M with L layers, Conv2D BatchNorm Conv2D niques are focused on vision-based models like Conv2D Conv2D BatchNorm Conv2D Figure 3:The training process of 3 scaling factors for subnetwork where a layer l ∈ L represents a prunable entity and BatchNorm BatchNorm BatchNorm pruning. As one scaling factor gets pruned, another scaling convolutional networks, while text-based models BatchNorm its parameters denoted by θl. We define a prunable Filter Mask Sub-network Mask factor compensates for the instability caused and increases in are still evolving. The emergence of multi-modal entity as a node in the computational graph that + + magnitude. The rightmost figure also shows the scaling factor multi-task learning calls for a general method does not invalidate the graph upon its parameters + ReLU rejuvenating throughout the training process and finally pruned that works on vision and text architectures simul- being removed (i.e. the forward pass can still be ReLU ReLU at epoch 80. taneously. We introduce a differentiable mask, performed). Let f : Rn → Rn be any element- that induces sparsity on various granularity to wise transformation mapping on a node’s output Figure 2:A ResNet style subnetwork (left panel), DMP filter fill this gap. We apply our method successfully (eg. ReLU, Batch Normalization, Identity etc.). Let pruning (middle panel), DMP block pruning (right panel). Conclusion to prune weights, filters, subnetwork of a convolu- J(Θ) be the objective to be minimized, The intuition behind our approach is to to replace We introduced DMP, a new technique that ex- tional architecture, as well as nodes of a recurrent J ∗ = min J(θ), network. θ∈Θ f(X) with g(X) = f(α I(α) X) and apply tends pruning on two directions: structured and where Θ is a set of all learnable parameters and J ∗ `1-regularization on α to introduce sparsity on its unstructured. DMP induces sparsity that can is the optimized loss. corresponding prunable entity. Formally we propose be easily extended to prune weights, nodes, vec- Introduction We introduce Differentiable Mask Pruning (DMP) to replace J(θ) with tors, filters and sub-networks. The main short- for gradual pruning while training a network. Our J(θ, α) = C(θ, α) + R(θ) + λ X ||αl||1, (2) coming of pruning is to train the network prop- Recent models on machine translation, self-driving method can be generalized to unstructured (i.e. l∈L erly with fewer parameters. We proposed to im- in which R(θ) is regularizer, often an ` norm on cars, Alpha Go have shown game-changing break- weights) or structured (i.e. vector of parameters, 2 prove the training procedure by approximating weights. throughs. However, most of these models are highly filter, subnetwork) sparsity. the hard threshold gradient, and updating back- d over-parametrised for variety of reasons, ranging Let α ∈ R+ be a strictly positive scaling factor of Results propagation accordingly. We demonstrated the from the increase of computational power to the dimension d for a given prunable entity, f be a scale versatility of DMP by easily integrating it into lack of domain expertise. Subsequently, deploying sensitive differentiable function (i.e. f(α X) =6 Method λ1 Test Pruned subnets Pruned filters Ratio CV and NLP architectures. If pruning entity is a −3 these models on constrained edge devices is counter- f(X)), I(α) be a mask function where ×10 Error out of 27 out of 2032 ×100 sub-network, DMP can be regarded as a differen- Unpruned - 6.53 0 0 0 intuitive. For instance, real time updates to mo- DMP 1 7.10 5 320 15.70 tiable architecture search method, while spanning 1 if |α| > t, bile phones could be hampered by the model size. I(α) = (1) DMP 5 7.78 12 928 45.60 always on architectures with lower complexity. 0 otherwise. DMP 10 8.34 17 1152 56.60 Consequently, the training and inference time are Table 1:ResNet-56 run on CIFAR-10 to prune subnetworks impacted. One alternative is to store deep models 1 4 with different regularization constant λ1. on the cloud rather than edge devices to overcome 0.8 2 References 0.6 Method λ1 Test Pruned filters Ratio many of the edge implementation drawbacks, and 0.4 0 ×10−4 Error out of 2032 ×100 perform computation on the cloud server. However, 0.2 −2 Unpruned - 6.53 0 0 0 −4 DMP 1 6.81 264 12 [1] Mouloud Belbahri, Eyyüb Sari, Sajad Darabi, and Vahid Partovi Nia. the cons far outweigh the pros, especially in terms −4−2 0 2 4 −4−2 0 2 4 DMP 5 8.48 1227 60 Foothill: A quasiconvex regularization for edge computing of deep 4 DMP 10 9.50 1599 78 neural networks. of security, and the latency in transferring the data 1 pages 3–14, 2019. 0.8 2 Table 2:ResNet-56 architecture run on CIFAR-10, to prune to and from the cloud. Most of the models are pre- [2] Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan, 0.6 0 filters with different regularization constant λ1. ferred to be stored and computed on the edge in 0.4 −2 and Changshui Zhang. 0.2 Method λ1 Test Pruned nodes Ratio Learning efficient convolutional networks through network slimming. real applications. This goal can be achieved only by −4 −7 −4−2 0 2 4 −4−2 0 2 4 ×10 perplexity out of 5200 ×100 2017 IEEE International Conference on Computer Vision simplifying neural networks computations. Various Unpruned - 84.9 0 0 (ICCV), pages 2755–2763, 2017. DMP 1 85.1 524 10 techniques have been used for pruning such as Net- Figure 1:Top: Original mask function (left), derivative (right). [3] Jiecao Yu, Andrew Lukefahr, David Palframan, Ganesh Dasika, DMP 5 85.02 672 12 Reetuparna Das, and Scott Mahlke. work Slimming [2], similar to our approach as well Bottom: Approximated mask function (left), derivative (right) DMP 10 86.1 940 18 Scalpel: Customizing dnn pruning to the underlying hardware as Scalpel [3]. [1]. Table 3:Penn Tree Bank Dataset Language model using recur- parallelism. rent networks to prune LSTM nodes. (1 LSTM layer) SIGARCH Comput. Archit. News, 45(2):548–560, June 2017.