
This is another post in the CUDA series of introductory tutorials; this time we will explore how to implement a naive version of PyTorch EmbeddingBag to map a set of Tensors from a batch of variable-length so called ID-list features.
Getting Started
To run this code in CUDA, you will need a Linux box with a GPU and appropriate Nvidia SDK; please see my previous post to learn how to instantiate an AWS EC2 instance to try out the code.
The TL;DR version is as follows:
# Clone the repository
$ git clone git@github.com:massenz/cuda-learn.git
$ cd cuda-learn
# Build the CLI
$ make cli
--- 🛠️ Building CLI tool
cd go-aws-cli && go build -o ../build/cuda-learn_0.3.0_arm64_cli ./cmd
# Optionally put it somewhere on your PATH
$ ln -s $(pwd)/build/cuda-learn_0.3.0_arm64_cli ${USR_LOCAL}/cl
# Create the EC2 instance
$ cl setup
Remember that you will need to authenticate to your AWS account, the easiest way is to setup the AWS_PROFILE value to a user that has sufficient access permissions to create a VPC and setup AIM roles.
cl help will give you a list of available commands and options:
$ cl help
A CLI tool to create and manage AWS infrastructure for CUDA-Learn project.
Usage:
cuda-learn [command]
Available Commands:
completion Generate the autocompletion script for the specified shell
help Help about any command
instance Setup EC2 instance
setup Setup AWS infrastructure
teardown Terminate EC2 instance
vpc Setup VPC infrastructure
Flags:
--gh-auth string Path to PEM key for GitHub authentication
--gh-auth-key-name string Name for GitHub authentication key in SecretsManager (default "gh-auth")
-h, --help help for cuda-learn
--instance-type string EC2 instance type (default "g4dn.xlarge")
--key-name string SSH key name (default "gpu-key")
--project string Project tag value (default "cuda-learn")
--region string AWS region (default "us-west-2")
--subnet-cidr string Subnet CIDR block (default "10.0.1.0/24")
-v, --verbose Enable verbose logging to console
--vpc-cidr string VPC CIDR block (default "10.0.0.0/16")
Finally, you will need the SSH key to authenticate to the remote server, as well as its public IP; both are provided by the CLI:
$ cl setup
...
Successfully created EC2 instance:
IAM Instance Profile: arn:aws:iam::...:instance-profile/cuda-learn-profile
Instance ID: i-067ac4bc3b39355a5
Public IP: 54.200.145.18
To SSH into the instance use:
ssh -i private/gpu-key.pem ubuntu@54.200.145.18
To copy and run the setup script:
scp -i private/gpu-key.pem scripts/setup-host.sh ubuntu@54.200.145.18:~/
ssh -i private/gpu-key.pem ubuntu@54.200.145.18 "chmod +x ~/setup-host.sh && ~/setup-host.sh"
Running the Code
With that out the way, and after you have successfully cloned the repository on the remote server, you should be able to build the demo binary:
$ ssh -i private/gpu-key.pem ubuntu@54.200.145.18
ubuntu@ip-10-0-1-15:~$ ./setup-host.sh && cd cuda-learn
ubuntu@ip-10-0-1-15:~/cuda-learn$ make build
this can be executed with:
ubuntu@ip-10-0-1-15:~/cuda-learn$ ./build/bin/demo id-list data/simple.txt
Successfully read 3 samples, containing 8 values in total.
Moving host data to CUDA Global Memory
Prepared buffers: 3 offsets, 8 values
Creating the Embedding Lookup Table
Matrix: 20 x 8
Grid dimensions: 5 x 11
Block dimensions: 2 x 2
Launching the CUDA kernel to pool embeddings
Launching kernel with 1 blocks of 256 threads each
Copying pooled features back to host (3 x 8 floats)
Read 3 samples in batch.
Pooled features:
0.3856 -1.0440 -0.7609 0.5064 -0.0022 0.2260 -1.5498 1.7106
0.7684 -0.6514 0.6314 0.7123 -0.2337 -1.8443 1.9074 -0.6575
-0.6210 -0.1179 2.6052 2.9875 -5.8569 1.0351 -2.7455 1.4992
Understanding the Code
Reading Data in Chunks
One of the challenges when using kernel functions is that the memory needs to be pre-allocated on the host (CPU) and then the data moved to the device (GPU): we cannot use anything like std::vector or std::list to handle variable inputs, and, in fact, we need to lay it out in a “linearized” fashion (see below for more on this).
On the other hand, input features’ data size cannot be anticipated and needs to be read efficiently from disk, avoiding making more copies than strictly necessary, and also bearing in mind that we will need to (efficiently) copy it to device memory for processing.
To add further complexity, the features for each sample are of variable length, in other words, each line can have any number of IDs, which we will then need to map to indices (entries) in the Embeddings table.
The solution is to read the data in “chunks” and transform the data into two arrays: offsets and values.
The offsets array will contain (one per sample) the ending offset (we take advantage of the fact that (a) the first sample always has a starting offset of 0 and (b) that the ending offset of a sample is also the starting offset of the next one) for the sample, and the values contains the actual values read.

See the Chunks class for more details, this is the (simplified) code that reads in the data from a file:
std::shared_ptr<Chunks> readInputFile(const std::string &filename) {
auto result = std::make_shared<Chunks>();
Chunk currentChunk;
uint16_t currentOffset = 0;
while (std::getline(file, line)) {
//...
// As we don't know how many values there are in the line,
// and we won't know if there is enough room in the current chunk,
// we will first read them into a temporary vector.
std::vector<int64_t> values;
while (lineStream >> value) {
values.push_back(value);
}
if (!currentChunk.hasRoom(values.size())) {
result->push_back(currentChunk);
// NOTE we don't reset the currentOffset here, as it is used to
// calculate the offsets for the next chunk.
// This will simplify moving the data to the GPU, as we can just
// copy the values in the GPU memory without having to modify the offsets.
currentChunk = Chunk();
currentChunk.offsetAdjust = currentOffset;
}
currentChunk.offsets[currentChunk.size] = currentOffset + values.size();
// Copy values to the current chunk
for (size_t i = 0; i < values.size(); ++i) {
currentChunk.values[currentOffset - currentChunk.offsetAdjust + i] = values[i];
}
currentChunk.size++;
currentOffset += values.size();
}
// Adding last chunk if it has any values
if (currentChunk.size > 0) {
result->push_back(currentChunk);
}
return result;
}
The advantage of this format is that it is then extremely simple to copy chunks’ data to Device buffers, which are simply the concatenation of the data in all the Chunks.
Note that we do not reset the offset when reading the data, as we would have to somehow add a transform (adding the ending offset of the previous chunk) when copying the data to the GPU, and the CUDA API does not provide that functionality in its memcpy API (e.g., via a lambda).
Linearizing Matrices (row-major layout)
Embedding tables map “sparse” IDs to Tensor weights; they are used (for example) in recommendation systems to process data from discrete events (e.g., liked posts, watched videos, ads clicked, etc.) that do not lend themselves to natural “dense” processing using Linear ML modules: see the DLRM paper for more details (and its open source implementation).
They are typically initialized with random values, which are then updated during the training phase using backpropagation, until the model converges to a desired state (typically measured using Normalized Binary Cross-Entropy (PDF)).
Embeddings can be extremely large tables (up to million rows) with deep tensors (of 128 or more length), so much so that when training large models, they may need to be split across several GPUs (on the same host) or even across hosts.
Here, in this toy example, we generate a relatively small matrix (of size numTensors by nDim) which is not persisted anywhere (as it would in a real DLRM model).
Again, we need to “linearize” the data, keeping it in “row-major” format (one row after the other): this means that to get the element (i, j) of the matrix M with N columns, we actually need to access
M(i, j) = mat[i * N + j]
The code that generates a (rows x cols) matrix is in the mat-gen.cu file and it essentially boils down to:
__global__ void fillMatrixKernel(
float* mat, void* strategy,
float mean, float stddev, unsigned long seed) {
SizeCheck checker{strategy, BoundaryType::Rectangular};
if (checker()) {
auto idx = checker.idx();
curandState state;
curand_init(seed, idx, 0, &state);
mat[idx] = curand_normal(&state) * stddev + mean;
}
}
// Used here:
RectangularCheckStrategy strategy(rows, cols);
// Launch kernel
fillMatrixKernel<<<gridDim, blockDim>>>(
d_mat, d_strategy, mean, stddev, time(nullptr));
See the post on “Boundary Checking” for an explanation of how the RectangularCheckStrategy abstracts the linearization pattern of a 2D matrix.
The Kernel Code
We are now ready to understand CUDA implementation code:
- Generate an Embedding lookup matrix of size (
numTensors x nDim) with random values; - For each sample (there are
numSamplesof those) read all the values betweenoffsets[i-1]andoffsets[i], compute their modulo (so that the index falls within the Embedding row index bound), and read up all the Tensor values (nDimof them); - Add them all up and store them in the corresponding row of the
d_pooledFeaturesmatrix (ofnumSamplesxnDimdimensions)
__global__ void poolEmbeddings(
const size_t *d_offsets,
const int64_t *d_values,
void *d_strategy,
float *d_embeddings,
float *d_pooledFeatures,
size_t numTensors, size_t nDim) {
SizeCheck checker{d_strategy, BoundaryType::Linear};
if (checker()) {
auto idx = checker.idx();
// Get the start and end offsets for the current sample.
size_t start = idx > 0 ? d_offsets[idx - 1] : 0;
size_t end = d_offsets[idx];
// Initialize the pooled features for this sample.
for (size_t i = 0; i < nDim; ++i) {
d_pooledFeatures[idx * nDim + i] = 0.0f;
}
// Iterate over the values for this sample.
for (size_t i = start; i < end; ++i) {
// Hash the value to get the index in the embedding table.
size_t tensorIdx = d_values[i] % numTensors;
// Add the embedding to the pooled features.
for (size_t j = 0; j < nDim; ++j) {
d_pooledFeatures[idx * nDim + j] +=
d_embeddings[tensorIdx * nDim + j];
}
}
}
}
A couple of points worth noting:
- we (implicitly) pass the
numSamplesvia theLinearCheckStrategy; it needs to be “linear” as the kernel threads iterate over theoffsetsarray (and, implicitly, thevaluesone too); - the “hashing” is as naive as it gets: we simply take the modulo
numTensors; in reality, the mapping from IDs to Embedding index would be rather more sophisticated (but, ultimately the concept is the same: it is a mapping strategy from a “sparse” feature to a “dense” representation).
The demo.cpp code shows an example of invoking the Pooling Embedding kernel from the host (CPU).
Optimizations
Something that we have not explored at all (but has a massive impact on the code efficiency) is the sizing of the “grid” of threads being started on the CUDA GPU:
uint threadsPerBlock = 256;
uint blocks = (numSamples + threadsPerBlock - 1) / threadsPerBlock;
poolEmbeddings<<<blocks, threadsPerBlock>>>(
d_offsets, d_values, d_strategy, d_embeddings,
d_pooledFeatures, numTensors, nDim);
For the typical matrix sizes in this simple example this makes virtually no difference whatsoever, but when we start dealing with thousands, or potentially millions of samples/features, then the impact of grid sizing can be significant: see NVIDIA’s Matrix Multiplication Background User’s Guide to get a sense of the considerations involved, and in a future post I will also add a few examples (including GPU traces) to show their impact.
Please let me know in the comments if you found this interesting, and what else would you like to see in this series.




Leave a comment