Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

AVX512 assembly breaks when called concurrently from different goroutines

I have a custom piece of golang (1.23.0) assembly which performs AVX512 operations to speed up a very common code path. The function checks to see if a group of players are holding a poker hand by representing hands as int64 bitsets. The code looks like this (a CardSet is simply an int64):

// func SubsetAVX512(cs []CardSet, hs []CardSet) int
// Returns 1 if any card set in cards contains any hand in hands, 0 otherwise

#include "textflag.h"

#define cs_data 0(FP)
#define cs_len  8(FP)
#define cs_cap  16(FP)
#define hs_data 24(FP)
#define hs_len  32(FP)
#define hs_cap  40(FP)
#define ret_off 48(FP)

// Define the function
TEXT ·SubsetAVX512(SB), NOSPLIT, $0-56

// Start of the function
    // Load parameters into registers
    MOVQ cs+cs_data, R8         // R8 = cards_ptr
    MOVQ cs+cs_len, R9          // R9 = cards_len

    MOVQ hs+hs_data, R10        // R10 = hands_ptr
    MOVQ hs+hs_len, R11         // R11 = hands_len

    // Check if hands_len == 0
    TESTQ R11, R11
    JE return_false

    // Check if cards_len == 0
    TESTQ R9, R9
    JE return_false

    // Initialize loop counters
    XORQ R12, R12                 // R12 = i = 0 (hands index)

    // Main loop over hands
outer_loop:
    CMPQ R12, R11                 // Compare i (R12) with hands_len (R11)
    JGE return_false              // If i >= hands_len, no match found

    // Load 8 hands into Z0 (512-bit register)
    LEAQ (R10)(R12*8), R13        // R13 = &hands[i]
    VMOVDQU64 0(R13), Z0          // Load 8 int64s from [R13] into Z0

    // Inner loop over cards
    XORQ R14, R14                 // R14 = j = 0 (cards index)
inner_loop:
    CMPQ R14, R9                  // Compare j (R14) with cards_len (R9)
    JGE next_hands_block          // If j >= cards_len, move to next hands block

    // Load cs from cards[j]
    LEAQ (R8)(R14*8), R15         // R15 = &cards[j]
    MOVQ 0(R15), AX               // AX = cards[j]

    // Broadcast cs into Z1
    VPBROADCASTQ AX, Z1           // Broadcast RAX into all lanes of Z1

    // Compute cs_vec & h_vec
    VPANDQ Z0, Z1, Z2             // Z2 = Z0 & Z1

    // Compare (cs_vec & h_vec) == h_vec
    VPCMPEQQ Z0, Z2, K1           // Compare Z0 == Z2, store result in mask K1

    // Check if any comparison is true
    KORTESTW K1, K1               // Test if any bits in K1 are set
    JNZ found_match               // If so, a match is found

    // Increment card index
    INCQ R14                      // j++
    JMP inner_loop                // Repeat inner loop

next_hands_block:
    // Increment hands index by 8
    ADDQ $8, R12                  // i += 8
    JMP outer_loop                // Repeat outer loop

found_match:
    // Match found, return 1
    MOVQ $1, AX                   // Set return value to 1 (true)
    RET

return_false:
    // No match found, return 0
    XORQ AX, AX                   // Set return value to 0 (false)
    RET

This code works great as long as it’s not called concurrently, this works:

type CardSet int64
func SubsetAVX512(cs, hs []CardSet) bool
func TestSubsetAVX512(t *testing.T) {
    cs := []CardSet{3, 1}
    hs := []CardSet{3, 0}
    var count int64
    for i := 0; i < 5; i++ {
        if SubsetAVX512(cs, hs) {
            atomic.AddInt64(&count, 1)
        }
    }
    require.Equal(t, int64(5), count)
}

however, this fails:

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

type CardSet int64
func SubsetAVX512(cs, hs []CardSet) bool
func TestSubsetAVX512(t *testing.T) {
    cs := []CardSet{3, 1}
    hs := []CardSet{3, 0}
    var count int64
    wg := sync.WaitGroup{}
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            if SubsetAVX512(cs, hs) {
                atomic.AddInt64(&count, 1)
            }
        }()
    }
    wg.Wait()
    require.Equal(t, int64(5), count)
}

I believe that the issue has to do with some of the registers I’m using being overwritten by concurrent goroutines. My guess is it’s the mask register K1 but that’s just a slightly educated guess.

>Solution :

Your problem is that you try to return a result in AX when the Go calling convention requires you to return results on the stack. Change the return to use

MOVQ $1, ret+ret_off

to properly return a result and you’ll see your problems disappear.

Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading