AVX+DGEMMのコードを読む
目次
- 概要
- コードを読む前に
- コードを読む
- つまづいたポイント
- まとめ
- 実際に使用したコード
概要
パタヘネの「3.8 高速化:半語並列性および行列の演算」で示されていたコードでつまづきました
備忘録として、この記事を書きます
- 作者:デイビッド・A・パターソン
- 発売日: 2016/10/26
- メディア: Kindle版
コード
実際に本に示されていた、行列の乗算を行うプログラム DGEMM (Double precision GEneral Matrix Multiply:倍精度汎用行列乗算) は以下の通りです
// AVXを利用して高速化したver void dgemm_use_avx(int n, double *A, double *B, double *C) { for(int i = 0; i < n; i += 4) { for(int j = 0; j < n; ++ j) { __m256d c0 = _mm256_load_pd(C+i+j*n); for(int k = 0; k < n; ++ k) { __m256d a0 = _mm256_load_pd(A+i+k*n); __m256d b0 = _mm256_broadcast_sd(B+k+j*n); c0 = _mm256_add_pd(c0, _mm256_mul_pd(a0, b0)); } _mm256_store_pd(C+i+j*n, c0); } } }
コードを読む前に
コードを読む部分で重要になってくる 半語並列性、AVX について書きます
半語並列性
1つの命令で複数のデータを演算できる特性のこと
SIMD(Single Instruction Multi Data)
128bitなど、広いレジスタ内をいくつかのパーティションに分けることで (16bit * 8, 32bit * 4 ...)、複数の演算を同時に行うことができたりする
AVX
2011年に登場したintelの拡張命令セット
レジスタ幅がそれまでの128bitから倍増され256bitとなり、単一の操作で32bitの浮動小数点演算を8つ、64bitの浮動小数点演算を4つ行うことができるようになった
コードを読む
1. 関数定義
void dgemm_use_avx(int n, double *A, double *B, double *C) {
dgemm_use_avx
関数を定義しています
4つの引数をもち、それぞれ以下の意味を持ちます
- n : 行列の次元
- A, B, C : 行列を展開して1次元配列に変換したもの
2. ループ
for(int i = 0; i < n; i += 4) { for(int j = 0; j < n; ++ j) {
行列の全ての要素を対象にして演算を行うために2重ループを使用しています
ただし、AVXを利用することで double(64bit) を対象とした演算は 同時に 4 つ 行うことができるため、変数iの増加量は 4 になっています
これは、同時に4列分の演算を行うため、一次元配列内において通常の4倍の増加させてあげる必要があります (行列を1次元配列で表現しているところがミソです)
3. 演算結果格納用行列Cへのアクセス
__m256d c0 = _mm256_load_pd(C+i+j*n);
演算結果を格納するための行列Cへのアクセスを行います
指定アドレスから256bit分読み込みを行うため、double 配列 4 要素を変数c0で 1 度に扱うことができます(下に説明があります)
4. 演算
for(int k = 0; k < n; ++ k) { __m256d a0 = _mm256_load_pd(A+i+k*n); __m256d b0 = _mm256_broadcast_sd(B+k+j*n) c0 = _mm256_add_pd(c0, _mm256_mul_pd(a0, b0)); }
行列乗算は対応する成分同士で掛け算を行い、その和を足し合わせることで行います
演算中に成分を足し合わせるために、変数kを用いたループを行っています
コードで使用している関数は以下のような処理を行います
_mm256_load_pd(double *p)
AVX命令を使用して4つの倍精度小数点(64bit * 4)を、指定したアドレスから連続した256bitの領域から読み込む
_mm256_broadcast_sd(double *p)
指定したアドレスからスカラ型の倍精度小数点を読み込んで、同一のコピーを4つ生成する
_mm256_add_pd(m256d a, m256d b)
足し算、要素ごとに演算が行われる
_mm256_mul_pd(m256d a, m256d b)
掛け算、要素ごとに演算が行われる
大きく分けて3演算を行っているのですが、それぞれ次のようなことをやっています
_mm256_load_pd
を用いて Aの行列の4成分を読み込む_mm256_broadcast_sd
を用いてBの行列の1成分を読み込む- それらを掛け合わせた後に足す
実際に行われている演算の様子を図示します
256bitという広い幅のレジスタを用いて、同時に4データを保持しつつ演算していることに注目します
また、この時行われている行列乗算は B*A であることに注意します (A*Bではない!!!!)(ここで時間かけた)
5. ストア
c0変数に格納していた演算結果を行列Cに書き出します
自分がつまづいたポイント
B*A
単純に自分が気づけなかっただけですが、演算は A*B ではなく B*A が行われます
degmmが元からこのような仕様なのか、それとも本のコードが間違っているかは謎ですが、戸惑いました
(一部コードを変更するだけでA*B を行うことができます)
単純に半語並列性の理解不足
複数データを同時に処理する場合のアドレス指定について、単純に理解不足でした
コードでは行列を1次元配列で表現しているのですが、その場合のアドレス指定の仕方で混乱していました
まとめ
AVXを用いた行列乗算について、コードを読んでいきました
並列処理への理解が少し(少し...)だけ深まりました
自分の思い込みによる寄り道が多すぎたので、もっと慎重になるべきだな…といった気持ちです
稚拙な文章ですが、最後まで読んでくださりありがとうございました
もし書いてある内容に誤り、不備があればご指摘をよろしくお願いします
実際に使用したコード
デバッグ用であったり、main関数を加えたコードの完全版を示します
// 実行 gcc -O0 -mavx -o degmm degmm.c
#include <stdio.h> #include <stdlib.h> #include <x86intrin.h> void print(int n, double *A); void print_m256d(__m256d m); void dgemm(int n, double *A, double *B, double *C) { for(int i = 0; i < n; ++ i) { for(int j = 0; j < n; ++ j) { double c0 = C[i+j*n]; for(int k = 0; k < n; ++ k) { c0 += A[i+k*n] * B[k+j*n]; } C[i+j*n] = c0; /* printf("%d %d\n", i, j); */ /* print(4, C); */ /* printf("\n"); */ } } } void dgemm_use_avx(int n, double *A, double *B, double *C) { for(int i = 0; i < n; i += 4) { for(int j = 0; j < n; ++ j) { __m256d c0 = _mm256_load_pd(C+i+j*n); for(int k = 0; k < n; ++ k) { __m256d a0 = _mm256_load_pd(A+i+k*n); __m256d b0 = _mm256_broadcast_sd(B+k+j*n); c0 = _mm256_add_pd(c0, _mm256_mul_pd(a0, b0)); /* print_m256d(a0); */ /* print_m256d(b0); */ /* print_m256d(c0); */ /* printf("\n"); */ } _mm256_store_pd(C+i+j*n, c0); /* printf("%d %d\n", i, j); */ /* print(4, C); */ /* printf("\n"); */ } } } void print(int n, double *A) { for(int i = 0; i < n; ++ i) { for(int j = 0; j < n; ++ j) { printf("%f ", A[i*n+j]); } printf("\n"); } } void print_m256d(__m256d m) { double M[4] = {0}; _mm256_store_pd(M, m); for(int i = 0; i < 4; ++ i) printf("%f ", m[i]); printf("\n"); } int main() { double A[16] = { 14, 12, 2, 6, 1, 10, 9, 16, 7, 8, 4, 11, 13, 5, 3, 15 }; double B[16] = { 11, 6, 4, 10, 13, 8, 3, 5, 12, 7, 14, 16, 9, 1, 2, 15 }; double C1[16] = {0}, C2[16] = {0}; dgemm(4, A, B, C1); dgemm_use_avx(4, A, B, C2); print(4, C1); print(4, C2); return 0; }