AVX+DGEMMのコードを読む

目次

  1. 概要
  2. コードを読む前に
  3. コードを読む
  4. つまづいたポイント
  5. まとめ
  6. 実際に使用したコード

概要

パタヘネの「3.8 高速化:半語並列性および行列の演算」で示されていたコードでつまづきました

備忘録として、この記事を書きます

コード

実際に本に示されていた、行列の乗算を行うプログラム DGEMM (Double precision GEneral Matrix Multiply:倍精度汎用行列乗算) は以下の通りです

// AVXを利用して高速化したver
void dgemm_use_avx(int n, double *A, double *B, double *C) {
    for(int i = 0; i < n; i += 4) {
        for(int j = 0; j < n; ++ j) {
            __m256d c0 = _mm256_load_pd(C+i+j*n);
            for(int k = 0; k < n; ++ k) {
                __m256d a0 = _mm256_load_pd(A+i+k*n);
                __m256d b0 = _mm256_broadcast_sd(B+k+j*n);
                c0 = _mm256_add_pd(c0, _mm256_mul_pd(a0, b0));
            }
            _mm256_store_pd(C+i+j*n, c0);
        }
    }
}

コードを読む前に

コードを読む部分で重要になってくる 半語並列性AVX について書きます

半語並列性

1つの命令で複数のデータを演算できる特性のこと

SIMD(Single Instruction Multi Data)

128bitなど、広いレジスタ内をいくつかのパーティションに分けることで (16bit * 8, 32bit * 4 ...)、複数の演算を同時に行うことができたりする

AVX

2011年に登場したintelの拡張命令セット

レジスタ幅がそれまでの128bitから倍増され256bitとなり、単一の操作で32bitの浮動小数点演算を8つ、64bitの浮動小数点演算を4つ行うことができるようになった

コードを読む

1. 関数定義

void dgemm_use_avx(int n, double *A, double *B, double *C) {

dgemm_use_avx 関数を定義しています

4つの引数をもち、それぞれ以下の意味を持ちます

  • n : 行列の次元
  • A, B, C : 行列を展開して1次元配列に変換したもの

2. ループ

    for(int i = 0; i < n; i += 4) {
        for(int j = 0; j < n; ++ j) {

行列の全ての要素を対象にして演算を行うために2重ループを使用しています

ただし、AVXを利用することで double(64bit) を対象とした演算は 同時に 4 つ 行うことができるため、変数iの増加量は 4 になっています

これは、同時に4列分の演算を行うため、一次元配列内において通常の4倍の増加させてあげる必要があります (行列を1次元配列で表現しているところがミソです)

3. 演算結果格納用行列Cへのアクセス

__m256d c0 = _mm256_load_pd(C+i+j*n);

演算結果を格納するための行列Cへのアクセスを行います

指定アドレスから256bit分読み込みを行うため、double 配列 4 要素を変数c0で 1 度に扱うことができます(下に説明があります)

4. 演算

for(int k = 0; k < n; ++ k) {
     __m256d a0 = _mm256_load_pd(A+i+k*n);
     __m256d b0 = _mm256_broadcast_sd(B+k+j*n)
    c0 = _mm256_add_pd(c0, _mm256_mul_pd(a0, b0));
}

行列乗算は対応する成分同士で掛け算を行い、その和を足し合わせることで行います

演算中に成分を足し合わせるために、変数kを用いたループを行っています

コードで使用している関数は以下のような処理を行います

_mm256_load_pd(double *p)

AVX命令を使用して4つの倍精度小数点(64bit * 4)を、指定したアドレスから連続した256bitの領域から読み込む

f:id:guguru0014:20200323001002p:plain

_mm256_broadcast_sd(double *p)

指定したアドレスからスカラ型の倍精度小数点を読み込んで、同一のコピーを4つ生成する

_mm256_add_pd(m256d a, m256d b)

足し算、要素ごとに演算が行われる

_mm256_mul_pd(m256d a, m256d b)

掛け算、要素ごとに演算が行われる


大きく分けて3演算を行っているのですが、それぞれ次のようなことをやっています

  1. _mm256_load_pd を用いて Aの行列の4成分を読み込む
  2. _mm256_broadcast_sd を用いてBの行列の1成分を読み込む
  3. それらを掛け合わせた後に足す

実際に行われている演算の様子を図示します

256bitという広い幅のレジスタを用いて、同時に4データを保持しつつ演算していることに注目します

また、この時行われている行列乗算は B*A であることに注意します (A*Bではない!!!!)(ここで時間かけた)

f:id:guguru0014:20200322235654g:plain

5. ストア

c0変数に格納していた演算結果を行列Cに書き出します

自分がつまづいたポイント

B*A

単純に自分が気づけなかっただけですが、演算は A*B ではなく B*A が行われます

degmmが元からこのような仕様なのか、それとも本のコードが間違っているかは謎ですが、戸惑いました

(一部コードを変更するだけでA*B を行うことができます)

単純に半語並列性の理解不足

複数データを同時に処理する場合のアドレス指定について、単純に理解不足でした

コードでは行列を1次元配列で表現しているのですが、その場合のアドレス指定の仕方で混乱していました

まとめ

AVXを用いた行列乗算について、コードを読んでいきました

並列処理への理解が少し(少し...)だけ深まりました

自分の思い込みによる寄り道が多すぎたので、もっと慎重になるべきだな…といった気持ちです

稚拙な文章ですが、最後まで読んでくださりありがとうございました

もし書いてある内容に誤り、不備があればご指摘をよろしくお願いします

実際に使用したコード

デバッグ用であったり、main関数を加えたコードの完全版を示します

// 実行
gcc -O0 -mavx -o degmm degmm.c
#include <stdio.h>
#include <stdlib.h>
#include <x86intrin.h>

void print(int n, double *A);
void print_m256d(__m256d m);

void dgemm(int n, double *A, double *B, double *C) {
    for(int i = 0; i < n; ++ i) {
        for(int j = 0; j < n; ++ j) {
            double c0 = C[i+j*n];
            for(int k = 0; k < n; ++ k) {
                c0 += A[i+k*n] * B[k+j*n];
            }
            C[i+j*n] = c0;
            /* printf("%d %d\n", i, j); */
            /* print(4, C); */
            /* printf("\n"); */
        }
    }
}

void dgemm_use_avx(int n, double *A, double *B, double *C) {
    for(int i = 0; i < n; i += 4) {
        for(int j = 0; j < n; ++ j) {
            __m256d c0 = _mm256_load_pd(C+i+j*n);
            for(int k = 0; k < n; ++ k) {
                __m256d a0 = _mm256_load_pd(A+i+k*n);
                __m256d b0 = _mm256_broadcast_sd(B+k+j*n);
                c0 = _mm256_add_pd(c0, _mm256_mul_pd(a0, b0));
                /* print_m256d(a0); */
                /* print_m256d(b0); */
                /* print_m256d(c0); */
                /* printf("\n"); */
            }
            _mm256_store_pd(C+i+j*n, c0);
            /* printf("%d %d\n", i, j); */
            /* print(4, C); */
            /* printf("\n"); */
        }
    }
}

void print(int n, double *A) {
    for(int i = 0; i < n; ++ i) {
        for(int j = 0; j < n; ++ j) {
            printf("%f ", A[i*n+j]);
        }
        printf("\n");
    }
}

void print_m256d(__m256d m) {
    double M[4] = {0};
    _mm256_store_pd(M, m);
    for(int i = 0; i < 4; ++ i)
        printf("%f ", m[i]);
    printf("\n");
}

int main() {
    double A[16] = {
        14, 12,  2,  6,
         1, 10,  9, 16,
         7,  8,  4, 11,
        13,  5,  3, 15
    };
    double B[16] = {
        11,  6,  4, 10,
        13,  8,  3,  5,
        12,  7, 14, 16,
         9,  1,  2, 15
    };
    double C1[16] = {0}, C2[16] = {0};
    dgemm(4, A, B, C1);
    dgemm_use_avx(4, A, B, C2);
    print(4, C1);
    print(4, C2);
    return 0;
}