最近在做图片CLIP搜索的时候,需要计算两个向量的余弦相似度[1],
它的本质,是需要计算两个浮点数数组的点乘。这种场景可以用neon的向量运算来加速。
在aarch64架构上,gcc提供了这个头文件,提供了相关的接口来直接使用neon的能力。
在sqlite-vec等向量搜索插件的源码中,均是使用这个接口来实现相关功能。
乘加运算[2],简称fma,在接口中看到fma的,就是这种运算,x * y + z。
math.h也提供了fma()函数。
double fma(double x, double y, double z);
float32x4_t,这个就是4个浮点数,128位。neon里面有很多128位浮点数寄存器,里面可以表示4个独立的32位浮点数
neon就是一个次可以进行4个浮点数的各种运算。这4个浮点数就是一个vector。
还有其他的数据类型,这里我们只关注float,其它是类似的。如uint8x16_t,float64x2_t。64位浮点数最多只有2个。因为最多是128位。
加载数组到向量
float values[5] = { 1.0, 2.0, 3.0, 4.0, 5.0 };
float32x4_t v = vld1q_f32(values);
名字:v vector,ld load,1q 注意这里是数字1,q就是quard表示向量是4个成员。 f32就是32位浮点数。
向量的每个成员都设置为同一个常量
float32x4_t v = vmovq_n_f32(1.5);
从浮点数变量复制到每一个成员
float val = 3.0;
float32x4_t v = vld1q_dup_f32(&val);
还有一种最直接的方法,直接初始化变量
float32x4_t a = { 2.0, 3.0, 4.0, 5.0 };
将向量存到数组
float t[4];
vst1q_f32(t, v);
将向量中的某一个保存到浮点数变量
float a
vst1q_lane_f32(&a, v, 0)
没有在文档里面,找到对应的具体含义,只能写代码确定了:
#include <stdio.h>
#include <arm_neon.h>
static void printf_v(const char *name, float32x4_t v)
{
float t[4];
vst1q_f32(t, v);
printf("%s %f %f %f %f\n", name, t[0], t[1], t[2], t[3]);
}
void test_vfmaq_f32()
{
float32x4_t a = { 2.0, 3.0, 4.0, 5.0 };
float32x4_t b = { 1.0, 2.0, 3.0, 4.0 };
float32x4_t c = { 0.0, 1.0, 2.0, 3.0 };
float32x4_t d;
d = vfmaq_f32(a, b, c);
printf_v("a", a);
printf_v("b", b);
printf_v("c", c);
printf_v("vfmaq_f32", d);
}
int main(int argc, char **argv)
{
test_vfmaq_f32()
return 0;
}输出结果:
a 2.000000 3.000000 4.000000 5.000000 b 1.000000 2.000000 3.000000 4.000000 c 0.000000 1.000000 2.000000 3.000000 vfmaq_f32 2.000000 5.000000 10.000000 17.000000
因此
d = vfmaq_f32(a, b, c);
等价于
d[i] = a[i] + b[i] * c[i];
i等于0-3。
d = vmulq_f32(a, b)
等价于
d[i] = a[i] * b[i]
d = vfmaq_laneq_f32(a, b, c, lane)
d[i] = a[i] + b[i] *c[lane]
float d = vaddvq_f32(a)
d = a[0] + a[1] + a[2] + a[3]
有了上述基础后,计算点乘非常简单。
float dot_product(float *a, float *b, int n)
{
float32x4_t C = { 0.0 };
int i;
for (i = 0; i < n; i += 4) {
float32x4_t A;
float32x4_t B;
A = vld1q_f32(a + i);
B = vld1q_f32(b + i);
C = vfmaq_f32(C, A, B);
}
return vaddvq_f32(C);
}
float cosine_similarity_neon(float *a, float *b, int n)
{
float A = dot_product(a, a, n);
float B = dot_product(b, b, n);
float C = dot_product(a, b, n);
return C/sqrt(A)/sqrt(B);
}参考:
[1] https://en.wikipedia.org/wiki/Cosine_similarity
[2] https://www.gnu.org/software/c-intro-and-ref/manual/html_node/Fused-Multiply_002dAdd.html
[3] https://github.com/thenifty/neon-guide/blob/master/README.md
[4] https://developer.arm.com/documentation/102467/0201/Example---matrix-multiplication