Meta "Lightweight" KernelLLM subverts GPU kernel generation, 8B parameters crush GPT-4o

avatar
36kr
05-27
This article is machine translated
Show original

[Introduction] Meta launches KernelLLM, an 8B model fine-tuned on Llama 3.1 that can automatically convert PyTorch code into efficient Triton GPU kernels. Actual test data shows that its single inference performance surpasses GPT-4o and DeepSeek V3, with scores soaring during multiple generations.

In the AI field, parameter scale was once viewed as the "performance ceiling".

Meta's latest KernelLLM, with its 8B parameter "small frame", rubs 200B GPT-4o on the ground in GPU kernel generation tasks.

This is an 8B parameter model fine-tuned based on Llama 3.1 Instruct, aimed at automatically converting PyTorch modules into efficient Triton GPU kernels.

KernelLLM is simply a GPU kernel development artifact that achieves stronger performance with fewer parameters and is simple to use.

With only 8B parameters, it outperforms GPT-4o and DeepSeek V3 in single inference on KernelBench-Triton Level 1.

Through multiple inferences, KernelLLM performs better than DeepSeek R1.

All this comes from a model with a parameter scale two orders of magnitude smaller than its competitors.

@Denis Kanonik quips, "Is this trained on the test set again?"

KernelLLM Makes Kernel Development More Accessible

KernelLLM is an 8B model based on Llama 3.1 Instruct, specifically trained for tasks of writing GPU kernels with Triton.

It can make GPU programming simpler and automate high-performance GPU kernel generation.

KernelLLM meets the growing demand for high-performance GPU kernels by automatically generating efficient Triton implementations.

With increasing workloads and diverse accelerator architectures, the demand for customized kernel solutions has significantly increased.

Many current tools can only optimize during testing or focus on tuning KernelBench issues, making it difficult to address broader scenarios.

KernelLLM is the first LLM fine-tuned on external (PyTorch, Triton) code data.

Triton Kernel Generation Workflow

Input PyTorch code, and KernelLLM will generate Triton kernel candidate code.

Then use unit tests to verify these codes, run them with random inputs to check if the output is correct. If multiple candidate codes are generated, they can be compared to pick the best one.

KernelLLM's Triton Kernel Generation Process: Use KernelLLM to translate PyTorch code into Triton kernel candidate code. The generated code is verified through unit tests, testing the kernel with random input data of known shapes. This process supports generating multiple candidate codes (evaluated by pass@k), increasing the number of candidates to improve quality, and finally selecting the best Triton kernel implementation as output (green part)

To train this model, the team put in significant effort, using over 25,000 (PyTorch, Triton) code examples, along with synthetic samples.

This data partly comes from filtered code in TheStack, and partly generated through torch.compile() and prompt techniques.

Dataset KernelBook, reference link: https://huggingface.co/datasets/GPUMODE/KernelBook.

Training used the Llama3.1-8B-Instruct model, with supervised fine-tuning (SFT) on a custom dataset, testing its ability to generate correct Triton kernels and calling code on KernelBench-Triton.

KernelBench-Triton is a variant developed based on KernelBench[Ouyang et al. 2025], focusing on Triton kernel generation.

During training and evaluation, PyTorch code is configured with a prompt template containing format examples as instructions.

The model was trained for 10 epochs, with a batch size of 32, using standard SFT method, and hyperparameters were chosen based on validation set perplexity.

Training used 16 GPUs, taking a total of 12 hours (192 GPU hours), reporting validation results of the best checkpoint.

Performance Evaluation

Despite its small model size, its performance is comparable to state-of-the-art LLMs.

In the KernelBench-Triton test, the 8B parameter KernelLLM scored 20.2 in single inference, higher than DeepSeek V3 (16 points) with 671B parameters and GPT-4o (15 points) with 200B parameters.

If multiple candidate codes are generated, the score can rise, reaching 51.8 points when generating 10 and 57.1 points when generating 20.

KernelLLM inference was run with temperature=1.0 and top_p=0.97.

The model was tested on KernelBench, an open-source benchmark for evaluating LLM-written efficient GPU kernels.

It contains 250 carefully selected PyTorch modules, adjusted by workload, ranging from simple single operations (like Conv2D or Swish, Level 1) to complete model architectures (Level 3).

It performs consistently across different difficulty tasks, whether simple single operators or complex model architectures.

Tests simultaneously lower code correctness (by comparing with reference PyTorch output) and performance (by comparing acceleration ratio with benchmark implementation).

The team developed a new KernelBench-Triton variant specifically for evaluating LLM's ability to generate Triton kernels, very suitable for testing KernelLLM.

All tests were completed on NVIDIA H100 GPUs.

KernelLLM shows nearly logarithmic linear expansion behavior in pass@k

How to Use KernelLLM?

First, install some dependency packages:

  • pip install transformers accelerate torch triton

When using, import the library, call the generate_triton function, and you can generate optimized Triton code.

KernelLLM provides a simple interface for generating Triton cores from PyTorch code.

If you don't want to write a script, you can directly run python kernelllm.py to use the built-in REPL interface and see results in real-time.

kernelllm.py provides multiple ways to interact with the model.

  • python kernelllm.py

KernelLLM provides several methods to customize the generation process:

Sometimes it makes small mistakes, such as incorrect API references, syntax errors, and sometimes it cannot generate ideal kernels as instructed.

The generated code structure somewhat resembles compiler-automatically generated code, and it can easily have issues with variable naming, tensor shapes, type handling, and numerical precision details.

References:

https://x.com/reach_vb/status/1924478755898085552

https://huggingface.co/facebook/KernelLLM

This article is from the WeChat public account "New Intelligence", editor: Ying Zhi, published by 36Kr with authorization.

Source
Disclaimer: The content above is only the author's opinion which does not represent any position of Followin, and is not intended as, and shall not be understood or construed as, investment advice from Followin.
Like
Add to Favorites
Comments