跳转至

C++如何实现FP16(IEEE754-2008)类型?

计算机在进行数值计算时,首先需要对数据进行表示,比如F32、F16、F64、int32、int16、int8等等。特别在深度学习中,如果模型的参数是F32类型,那么计算耗时且保存的参数占用内存大。

为了准确率同时保证计算的速度,一般需要把模型转换成F16类型。F16即是通过16个bit(位)表示浮点数类型。

我们一起看一下:

  • IEEE754-2008如何定义Float32和Float16的?
  • float32和float16如何转换?在C++23版本以下标准库是没有f16类型的。
  • Float16不用转换成float32比较,如何直接比较大小?

F16是如何表示的?

F32类型和F16类型对比

float16是指采用2字节(16位)进行编码存储的一种数据类型;同理float32是指采用4字节(32位)存储的数据类型。FP16(float16)按照IEEE754-2008的标准表示。

使用float16代替float32有如下优点:内存占用减少,应用float16内存占用比float32更小,可以设置更大的batch_size。 同时,加速计算​。

为了实现更高的精度,尾数位只有23位,但是有个隐式的1(当Exponent不全部为0时),多了一位表示,所以有24位表示。在计算时就是1.fraction。即1 + fraction表示的尾数。

  • Significand precision: 24 bits (23 explicitly stored)。

当然,在F16中同样:

  • Significand precision: 11 bits (10 explicitly stored)。

还有一个是BF16类型

从表示上看, BF16类型就是把F32类型的小数部分(Fraction)从23bit位截取成了7bit位。

  • bfloat16 的指数位比 float16 多,这使得它在动态范围上与单精度浮点数(float32)更为接近。由于尾数位较少,bfloat16 的精度相对较低。

  • bfloat16 更适合训练过程,尤其是在需要较大的动态范围的情况下。float16 更适合推理过程,尤其是在需要较高精度的情况下。

  • 通过numpy得到的16位浮点数,是FP16,不是BF16。

C++如何实现F16类型?

F16和F32互相转换

F16是用16bit去表示了符号位、指数、尾数。F32是用的32bit,同样表示的符号位、指数、尾数

那么要实现F16到F32的转换,只要建立了对应的映射关系就可以了。比如把F16的符号位转换到F32,那么就是向左移动16位就可以了。

然后指数和尾数,还要多考虑一丢丢儿!

就是下面这两张表,F16什么条件下表示NaN,什么条件下表示无穷大。建立起和F32的映射关系。完成转换代码就可以了。

F16的格式和对应值的关系:

F32的格式和对应值的关系:

参考:

Fallback的实现

F16和F32转换存在高效的硬件指令实现,这里介绍Fallback的实现,当硬件不支持时,采用这种最通用的形式进行转换。在c++中,采用uint16_t类型去保存float16的二进制内容。 在c++中,采用uint16_t类型去保存float16的二进制内容。

F16转换成F32,跟着代码中的每一行解释结合上面的格式定义图示,是不是很清楚了。

// f16的格式定义:https://en.wikipedia.org/wiki/Half-precision_floating-point_format
// f32的格式定义:https://en.wikipedia.org/wiki/Single-precision_floating-point_format
// Fraction和man、Significand、mantissa 指的同一个东西,都是尾数位,F16的最后10位,F32的最后23位。
static float f16_to_f32_fallback(uint16_t i)
{
    // Check for signed zero, 根据F16的表示规则,如果Exponent(5位)和Significand(10位)都为0,那么表示浮点数zero, -0
    // 转换成f32只需要向右移16位构成32位表示就可以了,即符号位+31位0
    if ((i & 0x7FFF) == 0)
    {
        uint32_t result = (static_cast<uint32_t>(i) << 16);
        float f = 0.0;
        std::memset(&f, 0, sizeof(f));
        std::memcpy(&f, &result, sizeof(result));
        return f;
    }

    // 根据f16的IEEE754-2008标准,获取sign,Exponent,Significand的值
    uint32_t half_sign = (i & 0x8000);
    uint32_t half_exp = (i & 0x7C00);
    uint32_t half_man = (i & 0x03FF);

    // Check for an infinity or NaN when all exponent bits set
    // 如果Exponent(5位)对应的bit位全是1,即11111,那么可能是infinity or NaN 
    if (half_exp == 0x7C00)
    {
        // 如果Significand(10位)是0,就表示+-infinity(无穷)
        // Check for signed infinity if mantissa is zero
        if (half_man == 0)
        {
            // 转换位float32就是,符号位 + float32的+-infinity表示
            // 即符号位:(half_sign << 16)
            // float32的+-infinity表示: Exponent(占用8bit)全为1,fraction(23bit)全为0,即0x7F800000
            uint32_t result = (half_sign << 16) | 0x7F800000;
            float f = 0.0;
            std::memset(&f, 0, sizeof(f));
            std::memcpy(&f, &result, sizeof(result));
            return f;
        }
        else
        {
            // 如果Significand(10位)不是0, 就表示NaN
            // 转换为对应Float32的NaN,即Exponent(占用8bit)全为1;fraction(23bit)为half_man右移动13位(f32的fraction表示位数减去f16的fraction表示的位数,即23 - 10等于13)
            // NaN, keep current mantissa but also set most significiant mantissa bit
            // 为啥不是 0x7F800000??? 
            uint32_t result = (half_sign << 16) | 0x7FC00000 | (half_man << 13);
            float f = 0.0;
            std::memset(&f, 0, sizeof(f));
            std::memcpy(&f, &result, sizeof(result));
            return f;
        }
    }

    // Calculate single-precision components with adjusted exponent
    // 转换为f32的符号位,右移动16位
    uint32_t sign = half_sign << 16;
    // Unbias exponent
    // 因为F16的Exponent的表示的不是e为底,Exponent为指数的指数函数,而是指数为Exponent-15,偏移量为15
    // 对应F32的偏移量为127,所以换成F32的Exponent就要E - 15 + 127表示
    int32_t unbiased_exp = (static_cast<int32_t>(half_exp) >> 10) - 15;

    // 通过前面的条件过滤,这里表示Exponent全为0,Significand不全为0,表示subnormal number
    // Check for subnormals, which will be normalized by adjusting exponent
    if (half_exp == 0)
    {
        // Calculate how much to adjust the exponent by
        int e = countl_zero(half_man) - 6;

        // Rebias and adjust exponent
        uint32_t exp = (127 - 15 - e) << 23;
        uint32_t man = (half_man << (14 + e)) & 0x7FFFFF;
        uint32_t result = sign | exp | man;
        float f = 0.0;
        std::memset(&f, 0, sizeof(f));
        std::memcpy(&f, &result, sizeof(result));
        return f;
    }

    // Rebias exponent for a normalized normal
    // 这里的加127,对应上面说的F16转F32时Exponent要加上F32的Exponent偏移量127;向右移动23位到达表示Exponent对应的bit位置
    uint32_t exp = (static_cast<uint32_t>(unbiased_exp + 127)) << 23;
    // 向右移动13位,Significand值由F16的10位表示转换成23位表示
    uint32_t man = (half_man & 0x03FF) << 13;
    uint32_t result = sign | exp | man;
    float f = 0.0;
    std::memset(&f, 0, sizeof(f));
    std::memcpy(&f, &result, sizeof(result));
    return f;
}
F32如何转换成F16呢?实现如下

// Fraction和man、Significand、mantissa 指的同一个东西,都是尾数位,F16的最后10位,F32的最后23位。
// In the below functions, round to nearest, with ties to even.
// Let us call the most significant bit that will be shifted out the round_bit.
//
// Round up if either
//  a) Removed part > tie.
//     (mantissa & round_bit) != 0 && (mantissa & (round_bit - 1)) != 0
//  b) Removed part == tie, and retained part is odd. F32的Fraction右移动13位后,剩下部分是奇数,可以进位
//     (mantissa & round_bit) != 0 && (mantissa & (2 * round_bit)) != 0
// F32的Fraction右移动13位后,剩下部分是奇数,可以进位
// (If removed part == tie and retained part is even, do not round up.)
// These two conditions can be combined into one:
//     (mantissa & round_bit) != 0 && (mantissa & ((round_bit - 1) | (2 * round_bit))) != 0
// which can be simplified into
//     (mantissa & round_bit) != 0 && (mantissa & (3 * round_bit - 1)) != 0
static uint16_t f32_to_f16_fallback(float value)
{
    // Convert to raw bytes
    uint32_t x;
    std::memset(&x, 0, sizeof(uint32_t));
    std::memcpy(&x, &value, sizeof x);

    // Extract IEEE754 components
    uint32_t sign = x & 0x80000000u;
    uint32_t exp = x & 0x7F800000u;
    uint32_t man = x & 0x007FFFFFu;

    // Check for all exponent bits being set, which is Infinity or NaN
    // Exponent全为1,表示Infinity or NaN
    if (exp == 0x7F800000u)
    {
        uint32_t nan_bit = (man == 0) ? 0 : 0x0200u;
        // 0x7C00u 表示F16的Exponent全为1
        // nan_bit:如果man==0,表示+-infinity,所以直接把(man >> 13)就变成了F16的man
        // 如果man != 0, 把(man >> 13)可能变成了0,所以加上一个nan_bit,确保转换成的F16!=0
        return (sign >> 16) | 0x7C00u | nan_bit | (man >> 13);
    }

    // 右移16bit转换成F16的sign
    uint32_t half_sign = sign >> 16;
    // 127是F32的Exponent的偏移
    // 从F32的Exponent转换成F16,就需要exp的值 - 127 + 15, 15是F16的偏移
    int32_t unbiased_exp = ((exp >> 23) - 127);
    int32_t half_exp = unbiased_exp + 15;

    // 表示half_exp超过F16的Exponent最大表示,11111
    // Check for exponent overflow, return +infinity
    // 表示infinity
    if (half_exp >= 0x1F)
    {
        // 0x7C00u表示,F16表示的bit位,Fraction全部为0, Exponent全部为1
        return half_sign | 0x7C00u;
    }

    // Check for underflow
    if (half_exp <= 0)
    {
        // Check mantissa for what we can do
        if ((14 - half_exp) > 24)
        {
            // No rounding possibility, so this is a full underflow, return signed zero
            return half_sign;
        }
        man = man | 0x00800000u;
        uint32_t half_man = man >> (14 - half_exp);
        uint32_t round_bit = 1 << (13 - half_exp);
        if ((man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0)
        {
            half_man++;
        }
        return half_sign | half_man;
    }
    // Rebias the exponent, 左移10位到F16的Exponent表示位置
    half_exp = (half_exp << 10);
    // 右移动13位,到F16的Significand表示位置
    uint32_t half_man = man >> 13;
    // round_bit表示F32的Fraction中的从右往左的第13位,也就是转换成F16(10位)时要移除的最后一个位置。
    uint32_t round_bit = 0x00001000u;
    // Check for rounding (see comment above functions)
    if ((man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0)
    {
        // Round it
        return (half_sign | half_exp | half_man) + (uint32_t)1;
    }
    else
    {
        return half_sign | half_exp | half_man;
    }
}

F16高效的比较大小

如果将F16转换成F32,然后再比较F16的大小,那么就会多了转换的开销。为了减少程序耗时,提高性能。那么是否可以直接比较F16类型呢?

有朋友可能会说,F16不是用uint16_t类型表示的吗?直接比较uint16_t类型可以吗? 答案是不行的。

直接比较uint16就是bit位之间的相互比较,而且F16还存在正数和负数,显然不能直接比较。

在c++中实现,定义了F16 class,只需要重载用于比较的operate函数就可以实现比较大小了。

class F16{
    bool is_nan() const
    {
        // 通过上面的定义判断是否为NaN
        return (value & 0x7FFF) > 0x7C00;
    }

    bool operator==(const F16 &other) const
    {
        if (is_nan() || other.is_nan())
        {
            return false;
        }
        else
        {
            // 直接判断二进制内容是否相等,或者exponent和fraction全为0,此时表示正负0相等
            return (value == other.value) || ((value | other.value) & 0x7FFF) == 0;
        }
    }

     bool operator<(const F16 &other) const
    {
        if (is_nan() || other.is_nan())
        {
            return false;
        }
        else
        {
            // 判断正负符号位
            bool neg = (value & 0x8000) != 0;
            bool other_neg = (other.value & 0x8000) != 0;
            // 都为正数,直接比较value
            if (!neg && !other_neg)
            {
                return value < other.value;
            }
            // neg为正,other_neg负,返回false
            else if (!neg && other_neg)
            {
                return false;
            }
            // neg为负,other_neg正
            // 进一步判断this和other是否都为0,不都为0,负数肯定小于正数
            else if (neg && !other_neg)
            {
                return ((value | other.value) & 0x7FFF) != 0;
            }
            // 都为负数,直接比较value
            else
            {
                return value > other.value;
            }
        }
    }

    bool operator<=(const F16 &other) const
    {
        if (is_nan() || other.is_nan())
        {
            return false;
        }
        else
        {
            bool neg = (value & 0x8000) != 0;
            bool other_neg = (other.value & 0x8000) != 0;
            if (!neg && !other_neg)
            {
                return value <= other.value;
            }
            else if (!neg && other_neg)
            {
                return ((value | other.value) & 0x7FFF) == 0;
            }
            else if (neg && !other_neg)
            {
                return true;
            }
            else
            {
                return value >= other.value;
            }
        }
    }
  private:
  uint16_t value;
}

完整的代码实现和demo,在这里KenForever1/F16cpp

bye!感谢您的阅读!

评论