learn CUDA

CUDA is hard to learn to be honest, and if you read https://docs.nvidia.com/cuda/cuda-c-programming-guide/, things can be complicated, which is true, CUDA is complicated.

That guide is good actually, but if you are totally new, nothing you can run there, for example from https://docs.nvidia.com/cuda/cuda-c-programming-guide/#kernels,

```cuda
// Kernel definition
__global__ void VecAdd(float* A, float* B, float* C)
{
    int i = threadIdx.x;
    C[i] = A[i] + B[i];
}

int main()
{
    ...
    // Kernel invocation with N threads
    VecAdd<<<1, N>>>(A, B, C);
    ...
}
```
    

So, let's try to complete it,

```cuda
#include <stdio.h>
#include <cuda_runtime.h>

// Kernel definition
__global__ void VecAdd(float* A, float* B, float* C)
{
    int i = threadIdx.x;
    C[i] = A[i] + B[i];
}

int main()
{
    int N = 1000000;
    size_t size = N * sizeof(float);

    // h == host == cpu
    float *h_a, *h_b, *h_c;
    
    // d == device == gpu
    float *d_a, *d_b, *d_c;
    
    h_a = (float*)malloc(size);
    h_b = (float*)malloc(size);
    h_c = (float*)malloc(size);
    
    for (int i = 0; i < N; i++) {
        h_a[i] = i;
        h_b[i] = i * 2;
    }
    
    cudaMalloc(&d_a, size);
    cudaMalloc(&d_b, size);
    cudaMalloc(&d_c, size);
    
    cudaMemcpy(d_a, h_a, size, cudaMemcpyHostToDevice);
    cudaMemcpy(d_b, h_b, size, cudaMemcpyHostToDevice);


    // Kernel invocation with N threads
    VecAdd<<<1, N>>>(d_a, d_b, d_c);

    cudaMemcpy(h_c, d_c, size, cudaMemcpyDeviceToHost);
    
    free(h_a); free(h_b); free(h_c);
    cudaFree(d_a); cudaFree(d_b); cudaFree(d_c);
}
```

Save it as `test.cu`, and then you can continue to compile,

```bash
nvcc test.cu -o test
./test
```

Others are just standard C++ operations, we are only going to focus for CUDA extensions only,

1. for `VecAdd<<<1, N>>>`, first parameter `1`, this means, 1 block allocated only.

2. for `VecAdd<<<1, N>>>`, second parameter `N`, this means, N thread allocated in 1 block.

3. A thread is to execute one operation, plus, minus, etc. 1000000 threads means 1000000 operations can be done simultaneously, logically, but physically not.

4. For those experienced in CUDA, 1000000 threads in a single block is no-brainer and the code is not going to work as intended, here is why,

```bash
git clone https://github.com/NVIDIA/cuda-samples
cd cuda-samples/Samples/1_Utilities/deviceQuery
make
./deviceQuery
```

below is the output,

```
Device 0: "NVIDIA GeForce RTX 3090 Ti"
CUDA Driver Version / Runtime Version          12.5 / 12.1
CUDA Capability Major/Minor version number:    8.6
Total amount of global memory:                 24149 MBytes (25322520576 bytes)
(084) Multiprocessors, (128) CUDA Cores/MP:    10752 CUDA Cores
GPU Max Clock rate:                            1935 MHz (1.93 GHz)
Memory Clock rate:                             10501 Mhz
Memory Bus Width:                              384-bit
L2 Cache Size:                                 6291456 bytes
Maximum Texture Dimension Size (x,y,z)         1D=(131072), 2D=(131072, 65536), 3D=(16384, 16384, 16384)
Maximum Layered 1D Texture Size, (num) layers  1D=(32768), 2048 layers
Maximum Layered 2D Texture Size, (num) layers  2D=(32768, 32768), 2048 layers
Total amount of constant memory:               65536 bytes
Total amount of shared memory per block:       49152 bytes
Total shared memory per multiprocessor:        102400 bytes
Total number of registers available per block: 65536
Warp size:                                     32
Maximum number of threads per multiprocessor:  1536
Maximum number of threads per block:           1024
Max dimension size of a thread block (x,y,z): (1024, 1024, 64)
Max dimension size of a grid size    (x,y,z): (2147483647, 65535, 65535)
Maximum memory pitch:                          2147483647 bytes
Texture alignment:                             512 bytes
Concurrent copy and kernel execution:          Yes with 2 copy engine(s)
Run time limit on kernels:                     No
Integrated GPU sharing Host Memory:            No
Support host page-locked memory mapping:       Yes
Alignment requirement for Surfaces:            Yes
Device has ECC support:                        Disabled
Device supports Unified Addressing (UVA):      Yes
Device supports Managed Memory:                Yes
Device supports Compute Preemption:            Yes
Supports Cooperative Kernel Launch:            Yes
Supports MultiDevice Co-op Kernel Launch:      Yes
Device PCI Domain ID / Bus ID / location ID:   0 / 1 / 0
Compute Mode:
    < Default (multiple host threads can use ::cudaSetDevice() with device simultaneously) >
```

Look at,

```
(084) Multiprocessors, (128) CUDA Cores/MP:    10752 CUDA Cores
Maximum number of threads per multiprocessor:  1536
Maximum number of threads per block:           1024
Max dimension size of a thread block (x,y,z): (1024, 1024, 64)
Max dimension size of a grid size    (x,y,z): (2147483647, 65535, 65535)
```

- Max blocks my GPU can initiate is 2147483647 blocks.

- Each blocks max 1024 threads.

- Each CUDA core max 1536 threads simultaneously only.

- 84 multiprocessors, each got 128 CUDA cores.

- 10752 * 1536 = 16515072 threads. Back to our 1000000 threads, in order to run these threads simultaneously in most efficient way, we must use blocks.

- Why need blocks? it is all about parallelism. If Nvidia designed 1 block N threads instead M blocks N threads,

Single Block 1024 Threads Multiple Blocks 4 x 256 Threads

Generated by Claude Sonnet 3.5

-- In Nvidia, there is a term called `Streaming Multiprocessors (SM)` to do parallel computation, physically is the count of Multiprocessors, based on `DeviceQuery`, I got `(084) Multiprocessors`, 84 SMs can run in parallel.

-- If we use 1 block only, nothing can be split among SMs, SMs distributed like below,

GPU SM 1 SM 2 SM 3 Block 1 Block 2 Block 3 Block 4 Block 5 Block 6 Block 7

Generated by Claude Sonnet 3.5

5. Physically limit 1024 threads for each block, if we initiated beyond the physically limit, like 1000000 threads, CUDA will not execute the kernel. You can use CUDA debugger,

```bash
nvcc -G -g -o test test.cu
cuda-gdb test
break VecAdd
run
```

```text
[New Thread 0x7fa4d4012000 (LWP 140199)]
[New Thread 0x7fa4d2d02000 (LWP 140200)]
[Detaching after fork from child process 140201]
[New Thread 0x7fa4cbfff000 (LWP 140208)]
[New Thread 0x7fa4cb7fe000 (LWP 140209)]
warning: Cuda API error detected: cudaLaunchKernel returned (0x9)
```

So to fix this, use multiple blocks!

```cuda
#include <stdio.h>
#include <cuda_runtime.h>

// Kernel definition
__global__ void VecAdd(float* A, float* B, float* C)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;
    C[i] = A[i] + B[i];
}

int main()
{
    int N = 10240;
    size_t size = N * sizeof(float);

    // h == host == cpu
    float *h_a, *h_b, *h_c;
    
    // d == device == gpu
    float *d_a, *d_b, *d_c;
    
    h_a = (float*)malloc(size);
    h_b = (float*)malloc(size);
    h_c = (float*)malloc(size);
    
    for (int i = 0; i < N; i++) {
        h_a[i] = i;
        h_b[i] = i * 2;
    }
    
    cudaMalloc(&d_a, size);
    cudaMalloc(&d_b, size);
    cudaMalloc(&d_c, size);
    
    cudaMemcpy(d_a, h_a, size, cudaMemcpyHostToDevice);
    cudaMemcpy(d_b, h_b, size, cudaMemcpyHostToDevice);


    // Kernel invocation with N threads
    int threadsPerBlock = 1024;
    int blocksPerGrid = (N + threadsPerBlock - 1) / threadsPerBlock;
    VecAdd<<<blocksPerGrid, threadsPerBlock>>>(d_a, d_b, d_c);
    cudaMemcpy(h_c, d_c, size, cudaMemcpyDeviceToHost);

    for (int i = 0; i < N; i++) {

        if (h_c[i] != h_a[i] + h_b[i]) {
            printf("Error: %f + %f != %f\n", h_a[i], h_b[i], h_c[i]);
            break;
        }
    }
    
    free(h_a); free(h_b); free(h_c);
    cudaFree(d_a); cudaFree(d_b); cudaFree(d_c);
    
    return 0;
}
```

Save it as `test-fix.cu`, and then you can continue to compile,

```bash
nvcc test-fix.cu -o test-fix
./test-fix
```

If the values are not consistent, it will hit the `printf` and early break. How about the debugger?

```bash
nvcc -G -g -o test-fix test-fix.cu
cuda-gdb test-fix
break VecAdd
run
```

```text
[New Thread 0x7fd4e4efe000 (LWP 140517)]
[New Thread 0x7fd4df4f5000 (LWP 140518)]
[Detaching after fork from child process 140519]
[New Thread 0x7fd4dccb2000 (LWP 140532)]
[New Thread 0x7fd4d1fff000 (LWP 140533)]
[Switching focus to CUDA kernel 0, grid 1, block (0,0,0), thread (0,0,0), device 0, sm 0, warp 0, lane 0]

Thread 1 "test-fix" hit Breakpoint 1, VecAdd<<<(10,1,1),(1024,1,1)>>> (A=0x7fd4b3a00000, B=0x7fd4b3a0a000, C=0x7fd4b3a14000) at test-fix.cu:7
7           int i = blockDim.x * blockIdx.x + threadIdx.x;
```

Safely executed, the kernel invocation must be `VecAdd<<<blocksPerGrid, threadsPerBlock>>>`, where `threadsPerBlock = 1024` and `blocksPerGrid = (N + threadsPerBlock - 1) / threadsPerBlock` to make sure `N` partitioned nicely, 1000000 // 1024 = 978 blocks.

Multiple Blocks 978 x 1024 Threads

Generated by Claude Sonnet 3.5

I did not visualized all the 978 blocks, but you got the gist. How about `blockDim.x * blockIdx.x + threadIdx.x` in the kernel?

Block 0 (blockIdx.x = 0) threadIdx.x=0 threadIdx.x=1 Block 1 (blockIdx.x = 1) threadIdx.x=0 threadIdx.x=1 Block 2 (blockIdx.x = 2) threadIdx.x=0 threadIdx.x=1 Example Calculation For Block 1, threadIdx.x = 2: Global Index = blockDim.x * blockIdx.x + threadIdx.x = 1024 * 1 + 2 = 1026

Generated by Claude Sonnet 3.5

6. To test it run in parallel or not, you can put `printf` in the kernel,

```cuda
__global__ void VecAdd(float* A, float* B, float* C)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;
    // dont do this in real application
    printf("%d\n", i);
    C[i] = A[i] + B[i];
}
```

```text
1527
1528
1529
1530
1531
1532
1533
1534
1535
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
```

You can see that `i` printed in the kernel are not in the order.

7. Actually you can use `VecAdd<<<1, N>>>(d_a, d_b, d_c)` 1 block as long N is less than physical thread size, which is if you follow `blockDim.x * blockIdx.x + threadIdx.x`, `blockDim.x * blockIdx.x` is 0 because `blockIdx.x` is 0.

8. If you understand `blockDim`, `blockIdx`, `threadIdx`, and pointers, you are good to go, CUDA already put a lot of abstractions for us to write CUDA programming.