Representation Range Needs for 16-Bit Neural Network Training
Total Page:16
File Type:pdf, Size:1020Kb
Representation range needs for 16-bit neural network training 1st Valentina Popescu 1st Abhinav Venigalla 2nd Di Wu 3rd Robert Schreiber Cerebras Systems University of Wisconsin–Madison Cerebras Systems [email protected] [email protected] [email protected] [email protected] *work done while at Cerebras Systems *work done while at Cerebras Systems Abstract—Deep learning has grown rapidly thanks to its state- representation range of representable values, even when de- of-the-art performance across a wide range of real-world appli- normalized floating point numbers (or denormals — see a cations. While neural networks have been trained using IEEE- formal definition in Section III), which add three orders of 754 binary32 arithmetic, the rapid growth of computational demands in deep learning has boosted interest in faster, low magnitude to the representation range, are used. Therefore, precision training. Mixed-precision training that combines IEEE- different allocations of the 15 bits beyond the sign bit, to allow 754 binary16 with IEEE-754 binary32 has been tried, and greater representation range, have been tried. Most notably, other 16-bit formats, for example Google’s bfloat16, have Google introduced the bfloat16 format, with 8 exponent bits, become popular. which until recently was only available on TPU clouds, but In floating-point arithmetic there is a tradeoff between pre- cision and representation range as the number of exponent bits is now also included in Nvidia’s latest architecture, Ampere. changes; denormal numbers extend the representation range. Experiments have shown that both binary16 and bfloat16 can This raises questions of how much exponent range is needed, work, either out of the box or in conjunction with techniques of whether there is a format between binary16 (5 exponent like loss scaling that help control the range of values in bits) and bfloat16 (8 exponent bits) that works better than software. either of them, and whether or not denormals are necessary. In the current paper we study the need for denormal numbers Our understanding of these issues needs to be deepened. for mixed-precision training, and we propose a 1/6/9 format, i.e., To our knowledge, no study on the width of the range of 6-bit exponent and 9-bit explicit mantissa, that offers a better values encountered during training has been performed and 1 6 9 range-precision tradeoff. We show that / / mixed-precision it has not been determined whether denormal numbers are training is able to speed up training on hardware that incurs a performance slowdown on denormal operations or eliminates necessary. To fill this gap, we study here the distribution the need for denormal numbers altogether. And, for a number of of tensor elements during training on public models like fully connected and convolutional neural networks in computer ResNet and BERT. The data show that the representation range vision and natural language processing, 1/6/9 achieves numerical offered by binary16 is fully utilized, with a high percentage parity to standard mixed-precision. of the values falling in the denormal range. Unfortunately, Index Terms—floating-point, deep learning, denormal num- on many machines performance suffers with every operation bers, neural network training involving denormals, due to special encoding and handling of the range [1], [2]. This raises the question of whether they are I. INTRODUCTION needed, or are encountered much less often, when more than 5 exponent bits are utilized. Indeed, bfloat16, with its large Larger and more capable neural networks demand ever representation range, does not support and does not appear to arXiv:2103.15940v2 [cs.AI] 6 Apr 2021 more computation to be trained. Industry has responded with require denormal numbers; but it loses an order of magnitude specialized processors that can speed up machine arithmetic. in numerical accuracy vis-a-vis binary16. Thus, we consider Most deep learning training is done using GPUs and/or CPUs, an alternative 1/6/9 format, i.e., one having 6 exponent bits which support IEEE-754 binary32 (single-precision) and (for and 9 explicit mantissa bits, both with the denormal range GPUs) binary16 (half -precision). and without it. We demonstrate that this format dramatically Arithmetic and load/store instructions with 16-bit numbers decreases the frequency of denormals, which will improve can be significantly faster than arithmetic with longer 32- performance on some machines. bit numbers. Since 32-bit is expensive for large workloads, industry and the research community have tried to use mixed- In Section II we discuss 16-bit formats that have been precision methods that perform some of the work in 16-bit previously used in neural network training. In Section III we and the rest in 32-bit. Training neural networks using 16-bit recall basic notions of floating-point arithmetic and define the formats is challenging because the magnitudes and the range notations we use. Section IV presents our testing methodology of magnitudes of the network tensors vary greatly from layer and details the training settings for each tested network (Sec- to layer, and can sometimes shift as training progresses. tion IV-B). The results we gathered are detailed in Sections V In this effort, binary16 has not proved to be an unqualified and, finally, we conclude in Section VI. success, mainly because its 5 exponent bits offer a narrow II. RELATED WORK by single-precision operations whose results are then rounded to values that are also representable in bfloat16. The results There have been several 16-bit floating-point (FP) formats suggest that bfloat16 requires no denormal support and no loss proposed for mixed-precision training of neural networks, scaling, and therefore, decreases complexity in the context including half -precision [3]–[5], Google’s bfloat16 [6], [7], of mixed-precision training. Recently, Nvidia also made it IBM’s DLFloat [8] and Intel’s Flexpoint [9], each with dedi- available in its latest architecture, Ampere [14]. cated configurations for the exponent and mantissa bits. IBM’s DLFloat [8] is a preliminary study of the 1/6/9 Mixed-precision training using half -precision was first ex- format that we also examine. DLFloat was tested on convo- plored by Baidu and Nvidia in [3]. Using half -precision for lutional neural networks, LSTMs [15] and transformers [16] all training data reduces final accuracy, because the formats of small sizes, implying the potential for minimal or even no fail to cover a wide enough representation range, i.e., small accuracy loss. The studied format did not support a denormal values in half -precision might fall into the denormal range range. The authors briefly discuss the factors contributing to or even become zero.To mitigate this issue, the authors of [3] the hardware saving, without providing statistics on how they proposed a mixed-precision training scheme. Three techniques actually influence the training results. were introduced to maintain SOTA accuracy (comparable to Flexpoint [9] is a blocked floating-point format proposed by single-precision): Intel. It is different from the previously mentioned formats by • ”master weights”, i.e., a single-precision copy of the that it uses 16 bits for mantissa and shares a 5-bit exponent weights, is maintained in memory and updated with the across each layer. Flexpoint was proposed together with an products of the unscaled half -precision weight gradients exponent management algorithm, called Autoflex, aiming to and the learning rate. This ensures that small gradients predict tensor exponents based on gathered statistics from for each minibatch are not discarded as rounding error previous iterations. The format allowed for un-normalized during weight updates; values to be stored, which invalidated the possibility of a • during the forward and backward pass, matrix multiplica- denormal range. For this reason we are not going to study tions take half -precision inputs and compute reductions it further in this work. using single-precision addition (using FMACS instruc- For all these studies, single-precision training accuracy was tions as described in Section III). For that activations and achieved. But there is no mention of whether denormals, which gradients are kept in half -precision, halving the required significantly impact hardware area and compute speed, are storage and bandwidth, while weights are downcasted on encountered; neither is it tested whether or not they must be the fly; supported to achieve that accuracy. • ”loss scaling”, which empirically rescales the loss before backpropagation in order to prevent small activation gra- III. FLOATING-POINT ARITHMETIC dients that would otherwise underflow in half-precision. Weight gradients must then be unscaled by the same Formally defined in Def.3.1, a floating-point (FP) number amount before performing weight updates. is most often used in computing as an approximation of a Loss scale factors were initially chosen on a model to real number. Given a fixed bit width for representation, it model basis, empirically. Automated methods like dynamic offers a trade-off between representation range and numerical loss scaling (DLS) [4], [5], and adaptive loss scale [10] accuracy. were later developed. These methods allow for the scale to Definition 3.1: A real number X is approximated in a be computed according to the gradient distribution. Using machine by a nearby floating-point number: the above techniques has become the standard workflow for E−p+1 mixed-precision training. It is reported to be 2 to 11× faster x = M · 2 , (1) than single-precision training [3]–[5], while also being able to where, achieve SOTA accuracy on widely used models. • E is the exponent, a signed integer Emin ≤ E ≤ Emax; Another mixed-precision format available in specialized − • M · 2 p+1 is the significand (sometimes improperly hardware is Google’s bfloat16 [6]. It was first reported as a referred to as the mantissa), where M is also a signed technique to accelerate large-batch training on TPUs (Tensor integer represented on p bits.