|
| 1 | +--- |
| 2 | +on_github: huggingface/kernels-uvnotes |
| 3 | +--- |
| 4 | + |
| 5 | +# PyTorch Native - Rotary Position Embeddings |
| 6 | + |
| 7 | +## GPU Info |
| 8 | + |
| 9 | +```python id=nv |
| 10 | +import subprocess |
| 11 | +print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout) |
| 12 | +``` |
| 13 | + |
| 14 | +## Rotary Embeddings Benchmark (PyTorch Native) |
| 15 | + |
| 16 | +```python id=benchmark outputs=rotary.jsonl |
| 17 | +# /// script |
| 18 | +# requires-python = ">=3.10" |
| 19 | +# dependencies = [ |
| 20 | +# "numpy", |
| 21 | +# "torch==2.8.0", |
| 22 | +# "kernels-benchmark-tools", |
| 23 | +# ] |
| 24 | +# |
| 25 | +# [tool.uv.sources] |
| 26 | +# kernels-benchmark-tools = { path = "../../../../../tools", editable = true } |
| 27 | +# /// |
| 28 | +import torch |
| 29 | +import sys |
| 30 | +from kernels_benchmark_tools import KernelTypeEnum, run_benchmark |
| 31 | + |
| 32 | + |
| 33 | +def apply_rotary_torch(x1, x2, cos, sin, conj=False): |
| 34 | + """Reference rotary implementation.""" |
| 35 | + if not conj: |
| 36 | + out1 = x1 * cos - x2 * sin |
| 37 | + out2 = x1 * sin + x2 * cos |
| 38 | + else: |
| 39 | + out1 = x1 * cos + x2 * sin |
| 40 | + out2 = -x1 * sin + x2 * cos |
| 41 | + return out1, out2 |
| 42 | + |
| 43 | + |
| 44 | +def torch_rotary(query, key, cos, sin, conj=False): |
| 45 | + rotary_dim = cos.shape[-1] |
| 46 | + |
| 47 | + # Clone inputs to avoid modifying them |
| 48 | + q_out = query.clone() |
| 49 | + k_out = key.clone() |
| 50 | + |
| 51 | + # Apply rotation to query |
| 52 | + q1 = q_out[..., :rotary_dim] |
| 53 | + q2 = q_out[..., rotary_dim : 2 * rotary_dim] |
| 54 | + q_out_1, q_out_2 = apply_rotary_torch(q1, q2, cos, sin, conj) |
| 55 | + q_out[..., :rotary_dim] = q_out_1 |
| 56 | + q_out[..., rotary_dim : 2 * rotary_dim] = q_out_2 |
| 57 | + |
| 58 | + # Apply rotation to key |
| 59 | + k1 = k_out[..., :rotary_dim] |
| 60 | + k2 = k_out[..., rotary_dim : 2 * rotary_dim] |
| 61 | + k_out_1, k_out_2 = apply_rotary_torch(k1, k2, cos, sin, conj) |
| 62 | + k_out[..., :rotary_dim] = k_out_1 |
| 63 | + k_out[..., rotary_dim : 2 * rotary_dim] = k_out_2 |
| 64 | + |
| 65 | + return q_out, k_out |
| 66 | + |
| 67 | + |
| 68 | +run_benchmark( |
| 69 | + kernel_type=KernelTypeEnum.ROTARY, |
| 70 | + impl_name="torch_eager", |
| 71 | + impl_tags={"family": "pytorch", "backend": "eager"}, |
| 72 | + impl_func=torch_rotary, |
| 73 | +) |
| 74 | +``` |
0 commit comments