Weighted root-mean-square layer normalization
Source code in vllm/ir/ops/layernorm.py
| @register_op(has_reduction=True)
def rms_norm(
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
) -> Tensor:
"""Weighted root-mean-square layer normalization"""
orig_dtype = x.dtype
x = x.to(torch.float32)
x_var = x if variance_size is None else x[..., :variance_size]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + epsilon)
if weight is not None:
x = x.to(weight.dtype) * weight
return x.to(orig_dtype)
|