All posts
Case StudyGPU OptimizationIndustry

Automating Algorithm Discovery: A Case Study in Kernel Generation with Datadog BitsEvolve

Jai Menon, Rohan Kulkarni, Sesh Nalla, and the ADRS Team
Automating Algorithm Discovery: A Case Study in Kernel Generation with Datadog BitsEvolve

This post is part of the AI-Driven Research for Systems (ADRS) blog series, where we explore how AI can be applied to systems research. We feature exciting work from Datadog this week!

In this blog post, we examine the problem of generating production-ready, optimized GPU code from an evolutionary search perspective. Specifically, we share results from BitsEvolve, an ADRS framework built at Datadog. BitsEvolve targets various modalities ranging from optimizing hotspots in CPU-bound code to policy/configuration tuning for applications like inference serving frameworks (e.g. vLLM), Load Balancers, Garbage Collectors and more. Through profile guidance and robust evaluation mechanisms, we show how BitsEvolve-generated code can outperform compiled models, achieving speedups of up to 1.6x with reasonable search costs.

The Problem: The Optimization Bottleneck

AI workloads are devouring compute. Across both datacenters and the edge, current trends point to a rising share for GPUs and accelerators. This momentum has also led to rapid maturity and increasing complexity in the GPU software stack. We have evolved from programming GPUs using raw CUDA to a landscape of DSLs and compiler-driven code generation, all with varying levels of efficacy. While well-known examples exist of hand-optimized, performant kernels (relative to peak SOL throughput), such efforts are restricted to a handful of core primitives. Similar to high-performance CPU optimization, the specialized skill set required is rare, and applying it to every niche problem is often not ROI-positive.

Furthermore, given the pace of innovation, the set of possible optimization targets is ever-growing. We must constantly adapt to new GPU architectures and compute capability levels, cost-efficient SKUs, evolving numeric formats (e.g., quantized types, microscaling formats), and shifting model configurations.

So, in the spirit of ADRS, we ask the question: Can we use LLM-based coding agents as our "GPU kernel engineers" to continuously optimize AI/ML workloads?

To explore answers to that question, we built a GPU code optimization and kernel generation flow in BitsEvolve, an agentic optimization system.

Related Work

Automated kernel generation has seen significant interest recently.

  • KernelBench provides reference computation graphs across different complexity levels to evaluate correctness and performance on open-weight and proprietary models.
  • KernelFalcon takes an agentic approach, combining techniques to achieve 100% correctness on the KernelBench suite.
  • KernelLLM explores fine-tuning existing models specifically for GPU kernel generation.

BitsEvolve builds on these ideas but targets a more holistic, production-first approach.

BitsEvolve for GPU Code Optimization

image.png
image.png

BitsEvolve is an ADRS framework that takes a base ML model (e.g., PyTorch model code) as input, generates an evaluation harness with the model code built in, and executes an LLM-guided evolutionary search as described in our previous Datadog blog post. The result is an optimized model that functions as a drop-in replacement for the original base model.

We build BitsEvolve on top of ShinkaEvolve, adding customizations that we are currently upstreaming (including support for languages like Rust and LLM query streaming). In the core evolutionary loop, we use frontier models: specifically GPT-5, GPT-5.1 (varying reasoning efforts), and Gemini 2.5 Pro (dynamic thinking).

In comparison to the related work mentioned previously, BitsEvolve aims to take a more holistic, layered approach to GPU code optimization. More concretely:

  • Modular Evolutionary Search: We retain the core LLM-guided evolutionary search but use modular layers to specialize our approach. For example:
    • Flexible Targets: The code generation agent can target raw CUDA, Triton, CuTe, or other vendor-specific DSLs.
    • Model Flexibility: The underlying model itself can be a proprietary external model, or a fine-tuned open-weights model.
  • Holistic Optimization: Telemetry feedback and whole-model optimization are important anchors for BitsEvolve. We don't restrict ourselves to generating optimized kernels alone but also optimize CPU/GPU synchronization and orchestration overheads, redundant data movement, etc. (see later examples).
  • "Merge-Ready" Code: The end goal of a BitsEvolve workflow is deployable production code. This means we need strong validation harnesses guaranteeing correctness, performance, and compatibility.

Real-World Application

We apply this new BitsEvolve workflow to models used internally at Datadog. These models use PyTorch as the model framework and serve production use cases. Our immediate goal was to optimize the forward pass to improve inference serving performance and cost-efficiency.

Shown here is an example fragment from BitsEvolve output. Some notable details:

  • The input to the MLP is decomposed into independent branches with partial precomputation
  • Dispatching to architecture-specific optimized GEMMs where appropriate
  • Fused kernels for remaining operations
  • Preallocated output and careful use of views

Optimized Code Example: GliNER Scoring

Reference PyTorch Code

self.out_mlp = nn.Sequential(
            nn.Linear(hidden_size * 3, hidden_size * 4),
            nn.Dropout(dropout),
            nn.ReLU(),
            nn.Linear(hidden_size * 4, 3)  # start, end, score
        )

Optimized Code

tile_s = min(self.tile_s, S)
for b in range(B):
		l1_b = l1[b].contiguous()          # [C, H]
		Pl_b = Pl[b].contiguous()          # [C, 4H]
		Pt_b = Pt[b].contiguous()          # [S, 4H]
    for s0 in range(0, S, tile_s):
        T = min(tile_s, S - s0)

        # 5) Build elementwise-product branch A_tile = l1_b * t1[b, s0:s0+T]
        scales = t1[b, s0:s0+T].contiguous()  # [T, H]
        A_tile = self._kernels.colscale_tile_bf16(l1_b, scales)  # [T, C, H]

        # 6) GEMM: (T*C, H) @ (H, 4H) = (T*C, 4H)
        A2D = A_tile.view(T * C, H).contiguous()
        Ms2D = torch.matmul(A2D, Wmul)  # [T*C, 4H]

        # 7) Fuse (+Pt_tile broadcast +Pl_b broadcast +bias) and ReLU in-place
        Pt_tile = Pt_b[s0:s0+T].contiguous()  # [T, 4H]
        Ms2D = self._kernels.fused_bias_relu_broadcast_inplace(Ms2D, Pt_tile, Pl_b, b1)

        # 8) Final GEMM to logits with bias: (T*C, 4H) @ (4H, 3) + b2
        S2D = torch.addmm(b2, Ms2D, W2T)  # [T*C, 3]

        # 9) Scatter to [B, S, C, 3]
        scores[b, s0:s0+T] = S2D.view(T, C, 3)
// Kernel: out[t, c, h] = l1[c, h] * scales[t, h]
// Shapes: l1 [C, H], scales [T, H], out [T, C, H]
__global__ void bf16_colscale_tile_kernel(
    const __nv_bfloat16* __restrict__ l1,
    const __nv_bfloat16* __restrict__ scales,
    __nv_bfloat16* __restrict__ out,
    int T, int C, int H
){
    int c = blockIdx.x;
    int t = blockIdx.y;
    int tid = threadIdx.x;

    if (c >= C || t >= T) return;

    const __nv_bfloat16* l1_row = l1 + c * H;
    const __nv_bfloat16* scale_row = scales + t * H;
    __nv_bfloat16* out_row = out + (t * C + c) * H;

    for (int h = tid; h < H; h += blockDim.x) {
        float a = __bfloat162float(l1_row[h]);
        float s = __bfloat162float(scale_row[h]);
        out_row[h] = __float2bfloat16_rn(a * s);
    }
}

torch::Tensor colscale_tile_bf16(torch::Tensor l1, torch::Tensor scales) {
    ...
    dim3 block(256);
    dim3 grid(C, T);
    bf16_colscale_tile_kernel<<<grid, block>>>(
        reinterpret_cast<const __nv_bfloat16*>(l1.data_ptr<at::BFloat16>()),
        reinterpret_cast<const __nv_bfloat16*>(scales.data_ptr<at::BFloat16>()),
        reinterpret_cast<__nv_bfloat16*>(out.data_ptr<at::BFloat16>()),
        T, C, H
    );

    return out;
}

// Kernel: Ms2D_inout[(t*C + c), h4] = relu( Ms + Pt[t, h4] + Pl[c, h4] + bias[h4] )
// Shapes: Ms2D [T*C, H4], Pt [T, H4], Pl [C, H4], bias [H4]
__global__ void bf16_fused_bias_relu_broadcast_tile_kernel(
    __nv_bfloat16* __restrict__ Ms2D,           // [T*C, H4]
    const __nv_bfloat16* __restrict__ Pt,       // [T, H4]
    const __nv_bfloat16* __restrict__ Pl,       // [C, H4]
    const __nv_bfloat16* __restrict__ bias,     // [H4]
    int T, int C, int H4
){
    int c = blockIdx.x;
    int t = blockIdx.y;
    int tid = threadIdx.x;

    if (c >= C || t >= T) return;

    __nv_bfloat16* Ms_row = Ms2D + (t * C + c) * H4;
    const __nv_bfloat16* Pt_row = Pt + t * H4;
    const __nv_bfloat16* Pl_row = Pl + c * H4;

    for (int h = tid; h < H4; h += blockDim.x) {
        float m = __bfloat162float(Ms_row[h]);
        float pt = __bfloat162float(Pt_row[h]);
        float pl = __bfloat162float(Pl_row[h]);
        float b = __bfloat162float(bias[h]);
        float v = m + pt + pl + b;
        v = v > 0.f ? v : 0.f;
        Ms_row[h] = __float2bfloat16_rn(v);
    }
}

torch::Tensor fused_bias_relu_broadcast_inplace(
    torch::Tensor Ms2D, torch::Tensor Pt, torch::Tensor Pl, torch::Tensor bias
) {
    ...

    dim3 block(256);
    dim3 grid(C, T);
    bf16_fused_bias_relu_broadcast_tile_kernel<<<grid, block>>>(
        reinterpret_cast<__nv_bfloat16*>(Ms2D.data_ptr<at::BFloat16>()),
        reinterpret_cast<const __nv_bfloat16*>(Pt.data_ptr<at::BFloat16>()),
        reinterpret_cast<const __nv_bfloat16*>(Pl.data_ptr<at::BFloat16>()),
        reinterpret_cast<const __nv_bfloat16*>(bias.data_ptr<at::BFloat16>()),
        T, C, H4
    );

    return Ms2D;
}

BitsEvolve Framework

The core of this new GPU code optimization pipeline remains an LLM-guided evolutionary search, similar to AlphaEvolve. However, we have added several critical components described below.

Harness Generation and Deployability

Benchmark suites like KernelBench ship with a framework that enables experimentation on a suite of example problems. Generalizing such a framework requires us to handle arbitrary real-world models where the complexity extends beyond just the size of the operator graph.

These models involve complex codebases (modular, multi-source file repos), external dependencies (e.g., a model built on top of the transformers library), lots of configuration knobs, and an established performance baseline (e.g., torch.compile() with max-autotune).

Our overarching goal is to enable BitsEvolve to produce optimized models that are a drop-in replacement for the existing model, in a way that passes evals and meets existing correctness and performance criteria.

To enable this goal, we build a harness generation system that produces a single Python module combining both the base model code as well as a test harness capable of running correctness and performance tests, loading weights/checkpoints, and profiling.

The figure below shows the harness logic when evaluating a candidate:

%%{init: {'theme':'base', 'themeVariables': { 'fontSize':'14px'}}}%%
graph TB
    subgraph Setup["Setup"]
    A[CLI Args] --> B[Random Seed]
    B --> C[Device Config]
    C --> D{Model<br/>Checkpoint?}
    D -->|exists| E[Load State]
    D -->|new| F[Save State]
    E --> G{torch.compile?}
    F --> G
    G -->|yes| H[Compiled Model]
    G -->|no| I[Eager Model]
    end

    subgraph Benchmark["Performance Benchmarking"]
    H --> J[Generate Test Data]
    I --> J
    J --> K[Warmup: 3 runs]
    K --> L[GPU Sync if CUDA]
    L --> M[Timed: 10 runs]
    M --> N[Stats: μ σ min max]
    end

    subgraph Validate["Validation & Correctness"]
    N --> O{Validation<br/>Mode?}
    O -->|validate-only| P[Check NaN/Inf/Range]
    O -->|save-reference| Q[Save Output → Disk]
    O -->|compare-reference| R[Load Reference]
    R --> S[Compute Δ abs/rel]
    S --> T{Within<br/>rtol/atol?}
    T -->|yes| U[✓ PASS]
    T -->|no| V[✗ FAIL]
    end

    subgraph Profile["Profiling Optional"]
    P --> W{Profile<br/>Enabled?}
    U --> W
    V --> W
    W -->|yes| X[torch.profiler<br/>CPU+CUDA+Memory]
    X --> Y[Export Chrome Trace]
    Y --> Z[Export Stack Traces]
    Z --> AA[Summary Tables]
    W -->|no| AB[Complete]
    AA --> AB
    end

    style H fill:#51cf66,color:#fff
    style M fill:#51cf66,color:#fff
    style S fill:#51cf66,color:#fff
    style X fill:#51cf66,color:#fff

Instructions and Prompting

The task prompt is customized for GPU code generation. We use few-shot prompting with handwritten/optimized examples and also include some high-level guidance around GPU architectural paradigms and optimization strategies. Crucially, we include fine-grained instructions around ensuring backward compatibility with model state (so old checkpoints still work). We also explicitly steer the model away from anti-patterns, such as forcing a lower GEMM precision than the base model intended or performing wasteful precomputation.

Profiling Feedback

We provide the LLM with two forms of feedback after every candidate evaluation:

  • Compilation & Correctness: If the code doesn't compile or produces incorrect results, all related logs (e.g. Python and NVCC logs) and harness output are provided as textual feedback
  • Performance Feedback: If the code is correct, then we run additional iterations with profiling. Currently, we support:
    • Level 1: Low overhead profiles at the operator level. We capture both CPU and GPU profiles including operator durations, call counts, memory usage etc.
    • Level 2: Higher overhead but more detailed profiles at the kernel level. NVTX annotations and Nsight/CUPTI are used to emit detailed kernel-level profiles. We further summarize these using a separate meta LLM step to get a compact set of profile-guided recommendations (textual feedback) for the evolution LLM.

Common Setup and Configuration

For the case studies below, we targeted forward pass performance improvement.

  • Measurement: 3 warmup iterations, 10 measured iterations.
  • Hardware: L4 and A10G GPUs (matching our production configs).
  • Input Configuration: For all models, we used input shapes resembling production use cases (batch sizes, sequence lengths, etc.) and hyperparameters that were tuned outside of BitsEvolve.
  • Data Types: We optimize against the team's chosen types (e.g., BFloat16). While BitsEvolve can propose quantization (e.g., TorchAO), we leave those decisions to the model teams for these experiments.
  • Budget: We used a soft budget cap of $70 for the evolutionary search. This tracks external model API costs (not GPU run costs).

Case Study 1: Sensitive Data Scanning (SDS)

Datadog's SDS product scans logs, traces, RUM, and events for sensitive data in real time and at scale. It is powered by a custom model that must meet strict SLOs and cost-efficiency targets. We optimize the two main parts of this model for BFloat16 inference on an L4 GPU.

1. The Encoder (DeBERTa v2)

The encoder is based on DeBERTa v2 (a disentangled self-attention BERT variant). The base code is available here.

BitsEvolve produces an optimized implementation that achieves a 1.53x speedup. The primary improvements were:

  • Flash Attention: Switched to Scaled Dot Product Attention kernels.
  • Custom kernels for various fused operations
    • Fused Disentangled Bias Kernel: Includes micro-optimizations to avoid shared memory bank conflicts.
    • Fused Add + LayerNorm Kernel: Reduces memory overhead.
    • Grouped GEMM: Collapsed 3 separate GEMMs for the QKV Projection into a single GEMM.
    • Prescaled QKV Transform and a Fused Context Transform kernel.
%%{init: {'theme':'base', 'themeVariables': { 'fontSize':'16px'}}}%%
graph LR
    subgraph Initial[" Initial Model (15 operations) "]
    direction TB
    I0[ ]
    I1[3 Linear projections]
    I2[3 Transpose ops]
    I3[Scale Q]
    I4[Q·K^T]
    I5[Scale]
    I6[c2p gather]
    I7[p2c gather]
    I8[Mask fill]
    I9[Softmax]
    I10[Dropout]
    I11[Multiply V]
    I12[Transpose]
    I13[Linear projection]
    I14[Add residual]
    I15[LayerNorm]
    end

    subgraph Optimized[" Optimized Model (8 operations) "]
    direction TB
    O0[ ]
    O1[Fused QKV projection]
    O2[Split QKV into Heads<br/>CUDA]
    O3[Dense matmuls<br/>cached pos emb]
    O4[Fused Bias + Mask<br/>CUDA]
    O5[Flash SDPA<br/>fused attention]
    O6[Merge Heads<br/>CUDA]
    O7[Linear projection]
    O8[Fused Add + LayerNorm<br/>CUDA]
    end

    I1 -.-> O1
    I2 -.-> O2
    I3 -.-> O2
    I6 -.-> O4
    I7 -.-> O4
    I8 -.-> O4
    I4 -.-> O5
    I5 -.-> O5
    I9 -.-> O5
    I10 -.-> O5
    I11 -.-> O5
    I12 -.-> O6
    I13 -.-> O7
    I14 -.-> O8
    I15 -.-> O8

    style I0 fill:none,stroke:none
    style O0 fill:none,stroke:none
    style O1 fill:#51cf66,color:#fff
    style O2 fill:#51cf66,color:#fff
    style O4 fill:#51cf66,color:#fff
    style O5 fill:#51cf66,color:#fff
    style O6 fill:#51cf66,color:#fff
    style O8 fill:#51cf66,color:#fff

2. The Scorer (Feed-Forward Network)

The SDS model's scoring layer is a Feed-Forward Network that consumes the encoder output. The base code can be found here. Note that our production model performs additional post-processing, which we also targeted.

BitsEvolve achieves a ~2x speedup on the Scorer.

  • Fused Kernels: Most of the speedup comes from tile-wise processing on top of the fused GEMM and a custom fused kernel for "expand + element-wise multiply."
  • Decomposition: The pattern of a large cat() -> MLP is optimized to decompose the weights out (independent of pair interaction).
  • Precomputation: These weights are precomputed with smaller GEMMs to hoist out redundant computation.
  • Micro-optimizations: We also see additional tensor access optimizations and other micro-optimizations.
%%{init: {'theme':'base', 'themeVariables': { 'fontSize':'14px'}}}%%
graph LR
    subgraph Initial[" Initial Model (10+ operations) "]
    direction TB
    I0[ ]
    I1[Project tokens]
    I2[Project labels]
    I3[Reshape to add dims]
    I4[Expand to B×S×C]
    I5[Element-wise multiply]
    I6[Concatenate 3 branches]
    I7[Linear 3H→4H on pairs]
    I8[Add bias]
    I9[Dropout skip]
    I10[ReLU activation]
    I11[Linear 4H→3]
    I12[Output]
    end

    subgraph Optimized[" Optimized Model (7 operations) "]
    direction TB
    O0[ ]
    O1[Project tokens & labels]
    O2[Decompose weights]
    O3[Precompute branches<br/>Pt, Pl]
    O4[Tilewise loop]
    O5[Fused colscale_tile<br/>implicit broadcast]
    O6[GEMM + Fused bias+ReLU<br/>4-in-1 kernel]
    O7[GEMM final projection]
    O8[Output]
    end

    I1 -.-> O1
    I2 -.-> O1
    I3 -.-> O4
    I4 -.-> O5
    I5 -.-> O5
    I6 -.-> O3
    I7 -.-> O3
    I7 -.-> O6
    I8 -.-> O6
    I10 -.-> O6
    I11 -.-> O7

    style I0 fill:none,stroke:none
    style O0 fill:none,stroke:none
    style O5 fill:#51cf66,color:#fff
    style O6 fill:#51cf66,color:#fff

Case Study 2: Toto Time Series Forecasting

Toto is a SOTA time series forecasting model, optimized for observability data. A high-level view of the model architecture is shown below.

Toto Architecture
Toto Architecture

Code for the Toto v1 model can be found here. The v1 release is open weights and is used along with our production Toto configuration. We target NVIDIA A10 inference performance on Float32 data.

BitsEvolve produced an optimized implementation that achieves a 1.57x speedup. The score progression graph is visualized below over 140 generations. Individual candidates are plotted in red and the cumulative mean/max scores are plotted in green and blue respectively. "Best Path" traces the lineage of the best performing candidate. The step function increases in score correlate with one or more significant optimizations, for example a fused kernel that compiles and produces correct results (possibly iterating over multiple generations), or a crossover that incorporates multiple composable optimizations.

Performance Score Progression
Performance Score Progression

The cost progression below tracks how the workflow performs against the budget cap ($70). The spikes correlate with full rewrites and crossovers, where token usage is typically higher.

Cost Progression
Cost Progression

Toto Backbone

Since TotoBackbone is transformer-based, we see optimizations similar to SDS, specifically operator fusion and data layout changes for SDPA inputs.

Interestingly, the LLM attempts to fuse the last residual add into the previous linear layer. However, profiling data shows that the layout changes required to fuse the add to the linear's GEMM made the fusion suboptimal, so the agent correctly rejected it.

%%{init: {'theme':'base', 'themeVariables': { 'fontSize':'16px'}}}%%
graph LR
    subgraph Initial[" Initial Model (15 operations) "]
    direction TB
    I0[ ]
    I1[RMSNorm: compute x²]
    I2[RMSNorm: mean + rsqrt]
    I3[RMSNorm: scale]
    I4[Compute cos/sin]
    I5[Rotate Q separately]
    I6[Rotate K separately]
    I7[Q·K^T attention<br/>SDPA]
    I8[Residual add]
    I9[RMSNorm: compute x²]
    I10[RMSNorm: mean + rsqrt]
    I11[RMSNorm: scale]
    I12[Linear expand to 2H]
    I13[Chunk into gate/val]
    I14[SiLU + multiply]
    I15[Linear project to D]
    I16[Residual add]
    end

    subgraph Optimized[" Optimized Model (6 operations) "]
    direction TB
    O0[ ]
    O1[Fused RMSNorm<br/>shared memory]
    O2[Fused QK Rotary<br/>time-aware scaling]
    O3[Q·K^T attention<br/>SDPA]
    O4[Fused Add + RMSNorm<br/>dual output]
    O5[Fused MLP<br/>cuBLAS + SwiGLU]
    O6[Residual add]
    end

    I1 -.-> O1
    I2 -.-> O1
    I3 -.-> O1
    I4 -.-> O2
    I5 -.-> O2
    I6 -.-> O2
    I7 -.-> O3
    I8 -.-> O4
    I9 -.-> O4
    I10 -.-> O4
    I11 -.-> O4
    I12 -.-> O5
    I13 -.-> O5
    I14 -.-> O5
    I15 -.-> O5
    I16 -.-> O6

    style I0 fill:none,stroke:none
    style O0 fill:none,stroke:none
    style O1 fill:#51cf66,color:#fff
    style O2 fill:#51cf66,color:#fff
    style O4 fill:#51cf66,color:#fff
    style O5 fill:#51cf66,color:#fff

Toto Forecaster

On the forecaster side, BitsEvolve generates

  • A fused kernel for the affine transform + KV cache write.
  • Basic optimizations around tensor accesses and indexing (improving speed without custom CUDA code).
  • A zero-copy setup phase, which contributes to the overall speedup.
graph TB
    subgraph "TotoForecaster.forecast()"
    direction TB
    A[Input Time Series] --> B[Setup Phase<br/>Zero-copy expand]
    B --> C[Autoregressive Loop<br/>N steps]

    subgraph "Each Step"
    C1[Model Forward Pass] --> C2[Sample Distribution]
    C2 --> C3[Fused Affine+Write]
    end

    C --> C1
    C3 --> C4{More steps?}
    C4 -->|Yes| C1
    C4 -->|No| D[Return Samples]
    end

    subgraph "Model Forward Pass Detail"
    M1[Fused RMSNorm] --> M2[Fused QK Rotary]
    M2 --> M3[SDPA Attention]
    M3 --> M4[Fused Add+RMSNorm]
    M4 --> M5[Fused MLP]
    M5 --> M6[Output Distribution]
    end

    C1 -.-> M1
    M6 -.-> C2

    style B fill:#51cf66,color:#fff
    style C3 fill:#51cf66,color:#fff
    style M1 fill:#51cf66,color:#fff
    style M2 fill:#51cf66,color:#fff
    style M4 fill:#51cf66,color:#fff
    style M5 fill:#51cf66,color:#fff

Optimized Code Example: Toto Rotary Embedding

Reference PyTorch Code

def apply_rotary_emb(
    freqs,
    t,
    start_index=0,
    scale=1.,
    seq_dim=-2,
    freqs_seq_dim=None
):
    """
    Apply rotary embeddings (from rotary_embedding_torch library).

    Args:
        freqs: Frequency tensor
        t: Input tensor to apply rotation to
        start_index: Starting index for rotation
        scale: Scaling factor (for xpos)
        seq_dim: Sequence dimension
        freqs_seq_dim: Frequency sequence dimension
    """
    dtype = t.dtype

    if not exists(freqs_seq_dim):
        if freqs.ndim == 2 or t.ndim == 3:
            freqs_seq_dim = 0

    if t.ndim == 3 or exists(freqs_seq_dim):
        seq_len = t.shape[seq_dim]
        freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim)

    rot_dim = freqs.shape[-1]
    end_index = start_index + rot_dim

    assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'

    # Split t into three parts: left, middle (to be transformed), and right
    t_left = t[..., :start_index]
    t_middle = t[..., start_index:end_index]
    t_right = t[..., end_index:]

    # Apply rotary embeddings
    t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)

    out = torch.cat((t_left, t_transformed, t_right), dim=-1)

    return out.type(dtype)

def rotate_queries_and_keys(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        seq_dim: int = None,
        seq_pos_offset: int = 0,
    ):
        """
        Apply rotary embeddings with xpos scaling to queries and keys.
        Uses exact logic from rotary_embedding_torch library.
        """
        if seq_dim is None:
            seq_dim = self.default_seq_dim

        device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]

        # Get sequence positions
        seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
        seq = seq + seq_pos_offset

        # Get frequencies and scale
        freqs = self.forward(seq)
        scale = self.get_scale(seq).to(dtype)

        # Apply rotary embeddings using library function
        rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
        rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)

        return rotated_q.type(dtype), rotated_k.type(dtype)

Optimized Code

def apply_fused_qk_rotary_emb(q, k, cos_base, sin_base, scale_base, seq_pos_offset, seq_dim=-2):
    """Apply rotary embeddings to Q and K with an offset-aware fused kernel (prefers in-place)."""
    seq_len = q.shape[seq_dim]
    end_idx = seq_pos_offset + seq_len

    # OPT: Slice precomputed tables to current sequence window
    cos_freq = cos_base[seq_pos_offset:end_idx].contiguous()
    sin_freq = sin_base[seq_pos_offset:end_idx].contiguous()
    scale = scale_base[seq_pos_offset:end_idx].contiguous()

    if (
        _fused_rotary_inplace_ext is not None
        and q.is_cuda
        and k.is_cuda
        and q.dtype in (torch.float16, torch.float32)
        and q.shape == k.shape
    ):
        _fused_rotary_inplace_ext.fused_qk_rotary_emb_inplace(q.contiguous(), k.contiguous(), cos_freq, sin_freq, scale)
        return q, k

    # OPT: Single fused kernel
    if (
        _fused_rotary_ext is not None
        and q.is_cuda
        and k.is_cuda
        and q.dtype in (torch.float16, torch.float32)
        and q.shape == k.shape
    ):
        return _fused_rotary_ext.fused_qk_rotary_emb_forward(
            q.contiguous(), k.contiguous(), cos_base, sin_base, scale_base, seq_pos_offset
        )

    # Fallback pure PyTorch path
    dtype = q.dtype
    rotated_q = (q * cos_freq + rotate_half(q) * sin_freq) * scale
    rotated_k = (k * cos_freq + rotate_half(k) * sin_freq) * (scale**-1)
    return rotated_q.type(dtype), rotated_k.type(dtype)
template<typename T>
__global__ void fused_qk_rotary_emb_kernel(
    const T* __restrict__ q_in,
    const T* __restrict__ k_in,
    T* __restrict__ q_out,
    T* __restrict__ k_out,
    const T* __restrict__ cos_base,
    const T* __restrict__ sin_base,
    const T* __restrict__ scale_base,
    const int total_pairs,
    const int seq_len,
    const int dim,
    const int seq_pos_offset
) {
    for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_pairs; idx += gridDim.x * blockDim.x) {
        const int pair_idx_in_token = idx % (dim / 2);
        const int token_idx = idx / (dim / 2);
        const int seq_idx = token_idx % seq_len;
        const int effective_seq_idx = seq_idx + seq_pos_offset;

        const int base_in_idx = token_idx * dim;
        const int freq_idx = effective_seq_idx * dim + 2 * pair_idx_in_token;

        const float c = static_cast<float>(cos_base[freq_idx]);
        const float s = static_cast<float>(sin_base[freq_idx]);
        const float sc_q = static_cast<float>(scale_base[freq_idx]);
        const float sc_k = 1.0f / sc_q;

        // Q
        const float q1 = static_cast<float>(q_in[base_in_idx + 2 * pair_idx_in_token]);
        const float q2 = static_cast<float>(q_in[base_in_idx + 2 * pair_idx_in_token + 1]);
        q_out[base_in_idx + 2 * pair_idx_in_token]     = static_cast<T>((q1 * c - q2 * s) * sc_q);
        q_out[base_in_idx + 2 * pair_idx_in_token + 1] = static_cast<T>((q2 * c + q1 * s) * sc_q);

        // K
        const float k1 = static_cast<float>(k_in[base_in_idx + 2 * pair_idx_in_token]);
        const float k2 = static_cast<float>(k_in[base_in_idx + 2 * pair_idx_in_token + 1]);
        k_out[base_in_idx + 2 * pair_idx_in_token]     = static_cast<T>((k1 * c - k2 * s) * sc_k);
        k_out[base_in_idx + 2 * pair_idx_in_token + 1] = static_cast<T>((k2 * c + k1 * s) * sc_k);
    }
}

std::vector<torch::Tensor> fused_qk_rotary_emb_forward(
    torch::Tensor q,
    torch::Tensor k,
    torch::Tensor cos_base,
    torch::Tensor sin_base,
    torch::Tensor scale_base,
    int seq_pos_offset
) {
    ...
    const int total_pairs = (q.numel() / dim) * (dim / 2);
    auto q_out = torch::empty_like(q);
    auto k_out = torch::empty_like(k);

    if (total_pairs == 0) {
        return {q_out, k_out};
    }

    const int block_size = 256;
    const int num_blocks = (total_pairs + block_size - 1) / block_size;
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(q.scalar_type(), "fused_qk_rotary_emb_kernel", [&] {
        fused_qk_rotary_emb_kernel<scalar_t><<<num_blocks, block_size, 0, stream>>>(
            q.data_ptr<scalar_t>(),
            k.data_ptr<scalar_t>(),
            q_out.data_ptr<scalar_t>(),
            k_out.data_ptr<scalar_t>(),
            cos_base.data_ptr<scalar_t>(),
            sin_base.data_ptr<scalar_t>(),
            scale_base.data_ptr<scalar_t>(),
            total_pairs,
            seq_len,
            dim,
            seq_pos_offset
        );
    });
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return {q_out, k_out};
}

What's Next

For all the models mentioned in this blog post, the optimizations have been tested in our production environment and integrated with the serving codebase.

  • SDS: We are working with the model team on end-to-end evaluations to enable the optimized model by default.
  • Toto: The optimized model is currently serving forecasting requests in production. We plan to expand optimization to L4 targets.

Going forward, we plan to further improve our approach. We want to converge towards high-quality optimizations quickly by leveraging more detailed and fine-grained telemetry and persisting learnings across runs.

The AI-Driven Research Systems (ADRS) initiative is an open, collaborative effort to explore how AI can accelerate scientific discovery itself, from evolving algorithms to optimizing real-world systems.

If you've built, optimized, or experimented with AI-driven research tools, we'd love to hear from you. Share your experiences, insights, or case studies with us in the ADRS Blog Series.

👉 Reach out to us via email: ucbskyadrs@gmail.com

💬 Join us: join.slack.com/t and Discord

If you have any questions, reach out to jai.menon@datadoghq.com!