// 1. 初始化掩码和累加器
// mask = _mm256_set1_epi8(0x03); 用于提取2-bit数据。
// accu = _mm256_setzero_si256(); 用于累加结果。
// 2. 分组处理
// 外层循环:每次处理32组。
// 内层循环:遍历每组中的32个元素。
// 对每一组数据,分别取出2-bit并与8-bit数据做乘法累加(点积)。
// 3. 数据解码
// 2-bit数据拆成4个bit组(shift和mask操作)。
// 8-bit直接载入。
// 使用 _mm256_maddubs_epi16 指令做乘法累加。
// 4. 累加所有结果
// 使用 _mm256_add_epi16 和 _mm256_add_epi32 指令合并结果。
// 最后用 hsum_i32_8 对 SIMD 累加结果求和,得到最终点积。
*s = (float)sumi;,将结果赋值给输出指针。
#include <vector>
#include <type_traits>
// #include "ggml-bitnet.h"
// #include "ggml-quants.h"
#include <cmath>
#include <cstring>
#define QK_I2_S 128
#define QK_I2 128
#include <immintrin.h>
#include <cstdint>
#include <cstdio>
#include <iostream>
// horizontally add 8 int32_t
static inline int hsum_i32_8(const __m256i a) {
    const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
    const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
    const __m128i sum64 = _mm_add_epi32(hi64, sum128);
    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
}
void print_m256i(__m256i vec) {
    uint8_t bytes[32];
    _mm256_storeu_si256((__m256i*)bytes, vec);
    for (int i = 0; i < 32; ++i) {
        printf("%02x ", bytes[i]);
        if ((i + 1) % 16 == 0) printf("\n");
    }
}
// x 指向 2-bit 量化数据,y 指向 8-bit 数据。
// 数据被分成多个组,便于后面的 SIMD 并行处理。
// 主要变量:
// QK_I2_S 通常为 128,表示每组的元素数。
// nb = n / QK_I2_S,组数。
// group32_num = nb / 32,每 32 组为一大组。
// la_num = nb % 32,剩余组数。
void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
    const uint8_t *    x = (uint8_t *)vx;
    const int8_t  *    y = (int8_t *)vy;
    const int nb = n / QK_I2_S;
    const int group32_num = nb / 32;
    const int la_num = nb % 32;
    const int groupla_num = nb % 32 != 0 ? 1 : 0;
    // n : 128 nb : 1 group32_num : 0 la_num : 1 groupla_num : 1
    std::cout << "n : " << n
              << " nb : " << nb
              << " group32_num : " << group32_num
              << " la_num : " << la_num
              << " groupla_num : " << groupla_num
              << std::endl;
    __m256i mask = _mm256_set1_epi8(0x03);
    __m256i accu = _mm256_setzero_si256();
    for (int i=0; i < group32_num; i++){
        std::cout << "i : " << i << std::endl;
        __m256i accu32 = _mm256_setzero_si256();
        for (int j=0; j < 32; j++) {
        // 128 index
        __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + i * 32 * 32 + j * 32));
        __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2);
        __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4);
        __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6);
        // each 32 index
        xq8_3 = _mm256_and_si256(xq8_3, mask);
        xq8_2 = _mm256_and_si256(xq8_2, mask);
        xq8_1 = _mm256_and_si256(xq8_1, mask);
        xq8_0 = _mm256_and_si256(xq8_0, mask);
        // each 32 index
        __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 0));
        __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 32));
        __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 64));
        __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 96));
        // 128 index accumulation add
        // split into 32 accumulation block
        // each block each 128 index accumulated 4index
        // each index maximum 256
        // each block maximum 4 * 256
        // each block accumulation maximum 127 * 256
        // each 32 group index (128 index in one group) needs cast to int32
        xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0);
        xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1);
        xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2);
        xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3);
        accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_0, xq8_1));
        accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_2, xq8_3));
        }
        accu = _mm256_add_epi32(_mm256_madd_epi16(accu32, _mm256_set1_epi16(1)), accu);
    }
    for (int i = 0; i < groupla_num; i++){
        __m256i accula = _mm256_setzero_si256();
        for (int j = 0; j < la_num; j++) {
        // 128 index
        // 例子中的128个int2,分成4组,每组32个int2
        // 128个int2
        // xx1 xx2 xx3 xx4 xx5 xx6 xx7 xx8 (int 16)
        // 00  xx1 xx2 xx3 xx4 xx5 xx6 xx7 (shift right 2bit)
        // 00  00  xx1 xx2 xx3 xx4 xx5 xx6 (shift right 4bit)
        // 00  00  00  xx1 xx2 xx3 xx4 xx5 (shift right 6bit)
        __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + group32_num * 32 * 32 + j * 32));
        __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2);
        __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4);
        __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6);
        // each 32 index
        // mask: 00000011 00000011
        // --  --  --  xx4 -- -- -- xx8 (int 16)
        // 00  --  --  xx3 -- -- -- xx7 (shift right 2bit)
        // 00  00  --  xx2 -- -- -- xx6 (shift right 4bit)
        // 00  00  00  xx1 -- -- -- xx5 (shift right 6bit)
        // int2通过移动bit位置和mask,拆分成了int8表示,就可以直接和int8进行运算了
        xq8_3 = _mm256_and_si256(xq8_3, mask);
        xq8_2 = _mm256_and_si256(xq8_2, mask);
        xq8_1 = _mm256_and_si256(xq8_1, mask);
        xq8_0 = _mm256_and_si256(xq8_0, mask);
        print_m256i(xq8_3);
        printf("-----\n");
        print_m256i(xq8_2);
        printf("-----\n");
        print_m256i(xq8_1);
        printf("-----\n");
        print_m256i(xq8_0);
        printf("-----\n");
        // x中的int2每间隔4个数字为一组
        // 03 03 03 03 03 03 03 03 03 03 03 03 03 03 03 03 
        // 03 03 03 03 03 03 03 03 03 03 03 03 03 03 03 03 
        // -----
        // 02 02 02 02 02 02 02 02 02 02 02 02 02 02 02 02 
        // 02 02 02 02 02 02 02 02 02 02 02 02 02 02 02 02 
        // -----
        // 01 01 01 01 01 01 01 01 01 01 01 01 01 01 01 01 
        // 01 01 01 01 01 01 01 01 01 01 01 01 01 01 01 01 
        // -----
        // 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 
        // 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00
        // each 32 index
        // 例子中的128个int8数据,分成4组,每组32个int8数据
        // xx1 xx2
        __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 0));
        __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 32));
        __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 64));
        __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 96));
        print_m256i(yq8_0);
        printf("-----\n");
        print_m256i(yq8_1);
        printf("-----\n");
        print_m256i(yq8_2);
        printf("-----\n");
        print_m256i(yq8_3);
        // 按顺序存储,相邻的32个int8为一组
        // -----
        // 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 
        // 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 
        // -----
        // 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 
        // 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f 
        // -----
        // 40 41 42 43 44 45 46 47 48 49 4a 4b 4c 4d 4e 4f 
        // 50 51 52 53 54 55 56 57 58 59 5a 5b 5c 5d 5e 5f 
        // -----
        // 60 61 62 63 64 65 66 67 68 69 6a 6b 6c 6d 6e 6f 
        // 70 71 72 73 74 75 76 77 78 79 7a 7b 7c 7d 7e 7f
        // 128 index accumulation add
        // split into 32 accumulation block
        // each block each 128 index accumulated 4index
        // each index maximum 256
        // each block maximum 4 * 256
        // each block accumulation maximum 127 * 256
        // each 32 group index (128 index in one group) needs cast to int32
        // 乘加运算
        // xx1, xx5 (16bit)
        // yy1, yy2 (16bit)
        // res = xx1 * yy1  + xx5 * yy2
        xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0);
        xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1);
        xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2);
        xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3);
        // printf("-----\n");
        // print_m256i(xq8_3);
        // printf("-----\n");
        // print_m256i(xq8_2);
        // printf("-----\n");、
        // 打印第1组乘加计算结果
        print_m256i(xq8_1);
        printf("-----\n");
        // print_m256i(xq8_0);
        // printf("-----\n");
        // 01 01 01 01 01 01 01 01 01 01 01 01 01 01 01 01 
        // 01 01 01 01 01 01 01 01 01 01 01 01 01 01 01 01
        // 乘加计算
        // 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 
        // 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f
        // 结果
        // 41 00 45 00 49 00 4d 00 51 00 55 00 59 00 5d 00 
        // 61 00 65 00 69 00 6d 00 71 00 75 00 79 00 7d 00
        accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_0, xq8_1));
        accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_2, xq8_3));
        }
        accu = _mm256_add_epi32(accu, _mm256_madd_epi16(accula, _mm256_set1_epi16(1)));
    }
    int sumi = hsum_i32_8(accu);
    *s = (float)sumi;
}