Skip to content

vllm.model_executor.layers.quantization.fp8

ACTIVATION_SCHEMES module-attribute

ACTIVATION_SCHEMES = ['static', 'dynamic']

logger module-attribute

logger = init_logger(__name__)

CopyNumelCounter

Bases: TorchDispatchMode

Tracks total number of elements modified with copy_. Useful for keeping track of weight loading where underlying weights can be arbitrarily transformed (such as with narrow) before calling copy.

Source code in vllm/model_executor/layers/quantization/fp8.py
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out

copied_numel instance-attribute

copied_numel = 0

__init__

__init__()
Source code in vllm/model_executor/layers/quantization/fp8.py
def __init__(self):
    super().__init__()
    self.copied_numel = 0

__torch_dispatch__

__torch_dispatch__(func, types, args=(), kwargs=None)
Source code in vllm/model_executor/layers/quantization/fp8.py
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
    if kwargs is None:
        kwargs = {}
    out = func(*args, **kwargs)
    if func == torch.ops.aten.copy_.default:
        self.copied_numel += args[0].numel()
    return out

Fp8Config

Bases: QuantizationConfig

Config class for FP8.

Source code in vllm/model_executor/layers/quantization/fp8.py
class Fp8Config(QuantizationConfig):
    """Config class for FP8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = False,
        activation_scheme: str = "dynamic",
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
    ) -> None:
        super().__init__()

        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized

        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
        self.activation_scheme = activation_scheme
        self.ignored_layers = ignored_layers or []
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now."
                )
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
            if activation_scheme != "dynamic":
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
        self.weight_block_size = weight_block_size

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 75

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = "fp8" in quant_method
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
        if not ignored_layers:
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        from vllm.model_executor.layers.quantization.ipex_quant import (
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
            weight_block_size=self.weight_block_size,
        )

        if isinstance(layer, LinearBase):
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)

            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
        if isinstance(layer, LinearBase):
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedLinearMethod()
            if not self.is_checkpoint_fp8_serialized:
                online_method = Fp8OnlineLinearMethod(self)
                online_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return online_method
            else:
                offline_method = Fp8LinearMethod(self)
                offline_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return offline_method
        elif isinstance(layer, FusedMoE):
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
            return moe_quant_method
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

    def get_cache_scale(self, name: str) -> str | None:
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
        return None

activation_scheme instance-attribute

activation_scheme = activation_scheme

ignored_layers instance-attribute

ignored_layers = ignored_layers or []

is_checkpoint_fp8_serialized instance-attribute

is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized

weight_block_size instance-attribute

weight_block_size = weight_block_size

__init__

__init__(
    is_checkpoint_fp8_serialized: bool = False,
    activation_scheme: str = "dynamic",
    ignored_layers: list[str] | None = None,
    weight_block_size: list[int] | None = None,
) -> None
Source code in vllm/model_executor/layers/quantization/fp8.py
def __init__(
    self,
    is_checkpoint_fp8_serialized: bool = False,
    activation_scheme: str = "dynamic",
    ignored_layers: list[str] | None = None,
    weight_block_size: list[int] | None = None,
) -> None:
    super().__init__()

    self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized

    if activation_scheme not in ACTIVATION_SCHEMES:
        raise ValueError(f"Unsupported activation scheme {activation_scheme}")
    self.activation_scheme = activation_scheme
    self.ignored_layers = ignored_layers or []
    if weight_block_size is not None:
        if not is_checkpoint_fp8_serialized:
            raise ValueError(
                "The block-wise quantization only supports fp8-serialized "
                "checkpoint for now."
            )
        if len(weight_block_size) != 2:
            raise ValueError(
                "The quantization block size of weight must have 2 "
                f"dimensions, but got {len(weight_block_size)} dimensions"
            )
        if activation_scheme != "dynamic":
            raise ValueError(
                "The block-wise quantization only supports "
                "dynamic activation scheme for now, but got "
                f"{activation_scheme} activation scheme."
            )
    self.weight_block_size = weight_block_size

apply_vllm_mapper

apply_vllm_mapper(hf_to_vllm_mapper: WeightsMapper)
Source code in vllm/model_executor/layers/quantization/fp8.py
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
    if self.ignored_layers is not None:
        self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)

from_config classmethod

from_config(config: dict[str, Any]) -> Fp8Config
Source code in vllm/model_executor/layers/quantization/fp8.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
    quant_method = cls.get_from_keys(config, ["quant_method"])
    is_checkpoint_fp8_serialized = "fp8" in quant_method
    activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
    ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
    weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
    if not ignored_layers:
        ignored_layers = cls.get_from_keys_or(
            config, ["modules_to_not_convert"], None
        )
    return cls(
        is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
        activation_scheme=activation_scheme,
        ignored_layers=ignored_layers,
        weight_block_size=weight_block_size,
    )

get_cache_scale

get_cache_scale(name: str) -> str | None

Check whether the param name matches the format for k/v cache scales in compressed-tensors. If this is the case, return its equivalent param name expected by vLLM

:param name: param name :return: matching param name for KV cache scale in vLLM

Source code in vllm/model_executor/layers/quantization/fp8.py
def get_cache_scale(self, name: str) -> str | None:
    """
    Check whether the param name matches the format for k/v cache scales
    in compressed-tensors. If this is the case, return its equivalent
    param name expected by vLLM

    :param name: param name
    :return: matching param name for KV cache scale in vLLM
    """
    if name.endswith(".output_scale") and ".k_proj" in name:
        return name.replace(".k_proj.output_scale", ".attn.k_scale")
    if name.endswith(".output_scale") and ".v_proj" in name:
        return name.replace(".v_proj.output_scale", ".attn.v_scale")
    if name.endswith(".output_scale") and ".q_proj" in name:
        return name.replace(".q_proj.output_scale", ".attn.q_scale")
    if name.endswith("self_attn.prob_output_scale"):
        return name.replace(".prob_output_scale", ".attn.prob_scale")
    # If no matches, return None
    return None

get_config_filenames classmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/fp8.py
@classmethod
def get_config_filenames(cls) -> list[str]:
    return []

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/fp8.py
@classmethod
def get_min_capability(cls) -> int:
    return 75

get_name classmethod

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/fp8.py
@classmethod
def get_name(cls) -> QuantizationMethods:
    return "fp8"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/fp8.py
def get_quant_method(
    self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
    if current_platform.is_xpu():
        return self.get_xpu_quant_method(layer, prefix)
    if isinstance(layer, LinearBase):
        if is_layer_skipped(
            prefix=prefix,
            ignored_layers=self.ignored_layers,
            fused_mapping=self.packed_modules_mapping,
        ):
            return UnquantizedLinearMethod()
        if not self.is_checkpoint_fp8_serialized:
            online_method = Fp8OnlineLinearMethod(self)
            online_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return online_method
        else:
            offline_method = Fp8LinearMethod(self)
            offline_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return offline_method
    elif isinstance(layer, FusedMoE):
        if is_layer_skipped(
            prefix=prefix,
            ignored_layers=self.ignored_layers,
            fused_mapping=self.packed_modules_mapping,
        ):
            return UnquantizedFusedMoEMethod(layer.moe_config)
        if self.is_checkpoint_fp8_serialized:
            moe_quant_method = Fp8MoEMethod(self, layer)
        else:
            moe_quant_method = Fp8OnlineMoEMethod(self, layer)
        return moe_quant_method
    elif isinstance(layer, Attention):
        return Fp8KVCacheMethod(self)
    return None

get_supported_act_dtypes classmethod

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/fp8.py
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16, torch.half]

get_xpu_quant_method

get_xpu_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/fp8.py
def get_xpu_quant_method(
    self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
    from vllm.model_executor.layers.quantization.ipex_quant import (
        XPUFp8LinearMethod,
        XPUFp8MoEMethod,
    )

    fp8_config = Fp8Config(
        is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
        activation_scheme=self.activation_scheme,
        ignored_layers=self.ignored_layers,
        weight_block_size=self.weight_block_size,
    )

    if isinstance(layer, LinearBase):
        if is_layer_skipped(
            prefix=prefix,
            ignored_layers=self.ignored_layers,
            fused_mapping=self.packed_modules_mapping,
        ):
            return UnquantizedLinearMethod()
        return XPUFp8LinearMethod(fp8_config)
    elif isinstance(layer, FusedMoE):
        if is_layer_skipped(
            prefix=prefix,
            ignored_layers=self.ignored_layers,
            fused_mapping=self.packed_modules_mapping,
        ):
            return UnquantizedFusedMoEMethod(layer.moe_config)

        return XPUFp8MoEMethod(fp8_config, layer)
    elif isinstance(layer, Attention):
        return Fp8KVCacheMethod(self)
    return None

Fp8KVCacheMethod

Bases: BaseKVCacheMethod

Supports loading kv-cache scaling factors from FP8 checkpoints.

Source code in vllm/model_executor/layers/quantization/fp8.py
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: Fp8Config):
        super().__init__(quant_config)

__init__

__init__(quant_config: Fp8Config)
Source code in vllm/model_executor/layers/quantization/fp8.py
def __init__(self, quant_config: Fp8Config):
    super().__init__(quant_config)

Fp8LinearMethod

Bases: LinearMethodBase

Linear method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale.

Limitations: 1. Only support float8_e4m3fn data type due to the limitation of torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)

Parameters:

Name Type Description Default
quant_config Fp8Config

The quantization config.

required
Source code in vllm/model_executor/layers/quantization/fp8.py
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Limitations:
    1. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
        self.out_dtype = torch.get_default_dtype()

        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.marlin_input_dtype = None
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False
        if vllm_is_batch_invariant():
            self.use_marlin = False

        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
        self.use_deep_gemm = is_deep_gemm_supported()

        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
        self.act_q_static = self.quant_config.activation_scheme == "static"

        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            # Use per-token quantization for better perf if dynamic and cutlass
            if self.act_q_static:
                activation_quant_key = kFp8StaticTensorSym
            elif cutlass_fp8_supported():
                activation_quant_key = kFp8DynamicTokenSym
            else:
                activation_quant_key = kFp8DynamicTensorSym

            self.fp8_linear = init_fp8_linear_kernel(
                activation_quant_key=activation_quant_key,
                weight_quant_key=kFp8StaticTensorSym,
                out_dtype=torch.get_default_dtype(),
                module_name=self.__class__.__name__,
            )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        if self.block_quant:
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )

        weight = create_fp8_weight_parameter(
            output_size_per_partition, input_size_per_partition, weight_loader
        )
        layer.register_parameter("weight", weight)

        # WEIGHT SCALE
        if not self.block_quant:
            scale = create_fp8_scale_parameter(
                PerTensorScaleParameter,
                output_partition_sizes,
                input_size_per_partition,
                None,
                weight_loader,
            )
            set_weight_attrs(scale, {"scale_type": "weight_scale"})
            layer.register_parameter("weight_scale", scale)
        else:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            scale = create_fp8_scale_parameter(
                BlockQuantScaleParameter,
                output_partition_sizes,
                input_size_per_partition,
                self.weight_block_size,
                weight_loader,
            )
            set_weight_attrs(scale, {"scale_type": "weight_scale"})
            # The weight_scale_inv name is intentional for deepseekv3
            layer.register_parameter("weight_scale_inv", scale)

        # INPUT ACTIVATION SCALE
        if self.act_q_static:
            scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
            set_weight_attrs(scale, {"scale_type": "input_scale"})
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        size_k_first = True
        input_scale = None
        # TODO(rob): refactor block quant into separate class.
        if self.block_quant:
            assert not self.act_q_static
            size_k_first = False

            weight, weight_scale_inv = process_fp8_weight_block_strategy(
                layer.weight, layer.weight_scale_inv
            )

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)

        # If checkpoint not serialized fp8, quantize the weights.
        else:
            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
            weight = layer.weight
            weight_scale = layer.weight_scale

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            if not self.use_marlin:
                weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                    weight,
                    weight_scale,
                    layer.logical_widths,
                    getattr(layer, "input_scale", None),
                )
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()

            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
        else:
            layer.input_scale = None

        if self.use_marlin:
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
            # Activations not quantized for marlin.
            del layer.input_scale
            return

        if self.block_quant:
            maybe_post_process_fp8_weight_block(layer)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
        if vllm_is_batch_invariant():
            if self.block_quant:
                assert self.weight_block_size is not None
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
                    weight_scale=layer.weight_scale_inv,
                    input_scale=layer.input_scale,
                    bias=bias,
                )
            else:
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)

        if self.use_marlin:
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=weight_scale,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
                input_dtype=self.marlin_input_dtype,
                bias=bias,
            )

        if self.block_quant:
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
            )

        return self.fp8_linear.apply_weights(layer, x, bias)

act_q_static instance-attribute

act_q_static = activation_scheme == 'static'

block_quant instance-attribute

block_quant = weight_block_size is not None

cutlass_block_fp8_supported instance-attribute

cutlass_block_fp8_supported = cutlass_block_fp8_supported()

fp8_linear instance-attribute

fp8_linear = init_fp8_linear_kernel(
    activation_quant_key=activation_quant_key,
    weight_quant_key=kFp8StaticTensorSym,
    out_dtype=get_default_dtype(),
    module_name=__name__,
)

marlin_input_dtype instance-attribute

marlin_input_dtype = None

out_dtype instance-attribute

out_dtype = get_default_dtype()

quant_config instance-attribute

quant_config = quant_config

use_aiter_and_is_supported instance-attribute

use_aiter_and_is_supported = is_linear_fp8_enabled()

use_deep_gemm instance-attribute

use_deep_gemm = is_deep_gemm_supported()

use_marlin instance-attribute

use_marlin = (
    not has_device_capability(89)
    or VLLM_TEST_FORCE_FP8_MARLIN
)

w8a8_block_fp8_linear instance-attribute

w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
    weight_group_shape=GroupShape(*(weight_block_size)),
    act_quant_group_shape=GroupShape(
        1, weight_block_size[0]
    ),
    cutlass_block_fp8_supported=cutlass_block_fp8_supported,
    use_aiter_and_is_supported=use_aiter_and_is_supported,
)

weight_block_size instance-attribute

weight_block_size = weight_block_size

__init__

__init__(quant_config: Fp8Config)
Source code in vllm/model_executor/layers/quantization/fp8.py
def __init__(self, quant_config: Fp8Config):
    self.quant_config = quant_config
    self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
    self.out_dtype = torch.get_default_dtype()

    # For GPUs that lack FP8 hardware support, we can leverage the Marlin
    # kernel for fast weight-only FP8 quantization
    self.marlin_input_dtype = None
    self.use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
    # Disable marlin for rocm
    if current_platform.is_rocm():
        self.use_marlin = False
    if vllm_is_batch_invariant():
        self.use_marlin = False

    self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
    self.use_deep_gemm = is_deep_gemm_supported()

    self.weight_block_size = self.quant_config.weight_block_size
    self.block_quant = self.weight_block_size is not None
    self.act_q_static = self.quant_config.activation_scheme == "static"

    if self.block_quant:
        assert not self.act_q_static
        assert self.weight_block_size is not None
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(*self.weight_block_size),
            act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
            cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
            use_aiter_and_is_supported=self.use_aiter_and_is_supported,
        )
    else:
        # Use per-token quantization for better perf if dynamic and cutlass
        if self.act_q_static:
            activation_quant_key = kFp8StaticTensorSym
        elif cutlass_fp8_supported():
            activation_quant_key = kFp8DynamicTokenSym
        else:
            activation_quant_key = kFp8DynamicTensorSym

        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=activation_quant_key,
            weight_quant_key=kFp8StaticTensorSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
        )

apply

apply(
    layer: Module, x: Tensor, bias: Tensor | None = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/fp8.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: torch.Tensor | None = None,
) -> torch.Tensor:
    # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
    # we will use BF16 dequant when DeepGEMM is not supported.
    if vllm_is_batch_invariant():
        if self.block_quant:
            assert self.weight_block_size is not None
            return self.w8a8_block_fp8_linear.apply(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
            )
        else:
            # per-tensor/channel: dequant to BF16 and run GEMM
            weight_fp8 = layer.weight.to(torch.bfloat16)
            weight_scale = layer.weight_scale.to(torch.bfloat16)
            if weight_scale.numel() == 1:
                # Per-tensor: simple scalar multiplication
                weight_bf16 = weight_fp8 * weight_scale
            else:
                # Multiple scales (fused modules like QKV)
                # Try to infer correct broadcasting
                # weight is [K, N], scale could be [num_logical_weights]
                # Need to figure out how to broadcast - for now just try
                # direct multiplication
                if (
                    weight_scale.dim() == 1
                    and weight_scale.shape[0] == weight_fp8.shape[0]
                ):
                    # Per-row scaling
                    weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                else:
                    # Fallback
                    weight_bf16 = weight_fp8 * weight_scale
            return torch.nn.functional.linear(x, weight_bf16.t(), bias)

    if self.use_marlin:
        if self.block_quant:
            weight_scale = layer.weight_scale_inv
        else:
            weight_scale = layer.weight_scale

        return apply_fp8_marlin_linear(
            input=x,
            weight=layer.weight,
            weight_scale=weight_scale,
            workspace=layer.workspace,
            size_n=layer.output_size_per_partition,
            size_k=layer.input_size_per_partition,
            input_dtype=self.marlin_input_dtype,
            bias=bias,
        )

    if self.block_quant:
        assert self.weight_block_size is not None

        return self.w8a8_block_fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale_inv,
            input_scale=layer.input_scale,
            bias=bias,
        )

    return self.fp8_linear.apply_weights(layer, x, bias)

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/fp8.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    output_size_per_partition = sum(output_partition_sizes)
    weight_loader = extra_weight_attrs.get("weight_loader")
    layer.logical_widths = output_partition_sizes
    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition
    layer.orig_dtype = params_dtype
    layer.weight_block_size = None

    if self.block_quant:
        assert self.weight_block_size is not None
        layer.weight_block_size = self.weight_block_size
        validate_fp8_block_shape(
            layer,
            input_size,
            output_size,
            input_size_per_partition,
            output_partition_sizes,
            self.weight_block_size,
        )

    weight = create_fp8_weight_parameter(
        output_size_per_partition, input_size_per_partition, weight_loader
    )
    layer.register_parameter("weight", weight)

    # WEIGHT SCALE
    if not self.block_quant:
        scale = create_fp8_scale_parameter(
            PerTensorScaleParameter,
            output_partition_sizes,
            input_size_per_partition,
            None,
            weight_loader,
        )
        set_weight_attrs(scale, {"scale_type": "weight_scale"})
        layer.register_parameter("weight_scale", scale)
    else:
        assert not self.act_q_static
        assert self.weight_block_size is not None
        scale = create_fp8_scale_parameter(
            BlockQuantScaleParameter,
            output_partition_sizes,
            input_size_per_partition,
            self.weight_block_size,
            weight_loader,
        )
        set_weight_attrs(scale, {"scale_type": "weight_scale"})
        # The weight_scale_inv name is intentional for deepseekv3
        layer.register_parameter("weight_scale_inv", scale)

    # INPUT ACTIVATION SCALE
    if self.act_q_static:
        scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
        set_weight_attrs(scale, {"scale_type": "input_scale"})
        layer.register_parameter("input_scale", scale)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/fp8.py
def process_weights_after_loading(self, layer: Module) -> None:
    size_k_first = True
    input_scale = None
    # TODO(rob): refactor block quant into separate class.
    if self.block_quant:
        assert not self.act_q_static
        size_k_first = False

        weight, weight_scale_inv = process_fp8_weight_block_strategy(
            layer.weight, layer.weight_scale_inv
        )

        # Update layer with new values
        replace_parameter(layer, "weight", weight.data)
        replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)

    # If checkpoint not serialized fp8, quantize the weights.
    else:
        # If checkpoint is fp8 per-tensor, handle that there are N scales for N
        # shards in a fused module
        weight = layer.weight
        weight_scale = layer.weight_scale

        # If using w8a8, torch._scaled_mm needs per tensor, so
        # requantize the logical shards as a single weight.
        if not self.use_marlin:
            weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                weight,
                weight_scale,
                layer.logical_widths,
                getattr(layer, "input_scale", None),
            )
            if self.act_q_static:
                assert input_scale is not None
                input_scale = input_scale.max()
        weight = weight.t()

        # Update layer with new values.
        replace_parameter(layer, "weight", weight.data)
        replace_parameter(layer, "weight_scale", weight_scale.data)

    if input_scale is not None:
        replace_parameter(layer, "input_scale", input_scale)
    else:
        layer.input_scale = None

    if self.use_marlin:
        prepare_fp8_layer_for_marlin(
            layer, size_k_first, input_dtype=self.marlin_input_dtype
        )
        # Activations not quantized for marlin.
        del layer.input_scale
        return

    if self.block_quant:
        maybe_post_process_fp8_weight_block(layer)

Fp8MoEMethod

Bases: FusedMoEMethodBase

MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale.

Also supports loading quantized FP16/BF16 model checkpoints with dynamic activation scaling. The weight scaling factor will be initialized after the model weights are loaded.

Parameters:

Name Type Description Default
quant_config Fp8Config

The quantization config.

required
Source code in vllm/model_executor/layers/quantization/fp8.py
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.quant_config = quant_config
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant: bool = self.weight_block_size is not None
        self.weight_scale_name = (
            "weight_scale_inv" if self.block_quant else "weight_scale"
        )

        # Set weight key and activation key for kernel compatibility
        if self.block_quant:
            weight_key = kFp8Static128BlockSym
            activation_key = kFp8Dynamic128Sym
        else:
            weight_key = kFp8StaticTensorSym
            activation_key = (
                kFp8StaticTensorSym
                if self.quant_config.activation_scheme == "static"
                else kFp8DynamicTensorSym
            )

        # Select Fp8 MoE backend
        self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
            config=self.moe,
            weight_key=weight_key,
            activation_key=activation_key,
            allow_vllm_cutlass=False,
        )

        # Delay creation of the kernel until after process-weights.
        self.kernel: mk.FusedMoEModularKernel | None = None

    @property
    def topk_indices_dtype(self) -> torch.dtype | None:
        if self.kernel is not None:
            return self.kernel.prepare_finalize.topk_indices_dtype()
        return None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

        if self.block_quant:
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.weight_block_size[0],
                self.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
            if intermediate_size_per_partition % block_n != 0:
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
                # Required by row parallel
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}."
                )

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        if not self.block_quant:
            # For per-tensor quant, the scales are per expert and weight.
            w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
            w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
        else:
            # For block quant, the scales are per block (typically 128x128).
            w13_scale_data = torch.ones(
                num_experts,
                2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
                (hidden_size + block_k - 1) // block_k,
                dtype=torch.float32,
            )
            w2_scale_data = torch.ones(
                num_experts,
                (hidden_size + block_n - 1) // block_n,
                (intermediate_size_per_partition + block_k - 1) // block_k,
                dtype=torch.float32,
            )
        w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
        w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
        # Note: name is weight_scale for tensor, weight_scale_inv for block.
        layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
        layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)

        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            assert not self.block_quant
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            set_weight_attrs(w13_input_scale, extra_weight_attrs)

            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
            layer.register_parameter("w2_input_scale", w2_input_scale)
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

        else:
            layer.w13_input_scale = None
            layer.w2_input_scale = None

    def _setup_kernel(
        self,
        layer: Module,
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        w13_input_scale: torch.Tensor | None,
        w2_input_scale: torch.Tensor | None,
    ) -> None:
        # Shuffle weights to runtime format.
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )

        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
        replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)

        # Setup modular kernel for TP case and naive DP/EP case.
        # In non-naive DP/EP case, we will create a ModularKernelMethod.
        # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
        # in both cases.
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
        if self.moe_quant_config and (
            (not self.moe.moe_parallel_config.use_all2all_kernels)
            or self.moe.moe_parallel_config.use_naive_all2all_kernels
        ):
            assert self.experts_cls is not None
            self.kernel, self.use_inplace = make_fp8_moe_kernel(
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
                experts_cls=self.experts_cls,
            )

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Allow for accessing weights and scales in standard way.
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale

        # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
        if current_platform.is_fp8_fnuz():
            w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w13,
                w13_scale,
                w13_input_scale,
            )
            w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w2,
                w2_scale,
                w2_input_scale,
            )

        # Per tensor kernels require single activation scale. Use the max.
        if self.quant_config.activation_scheme == "static":
            assert not self.block_quant
            assert w13_input_scale is not None and w2_input_scale is not None
            w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
                w13_input_scale, w2_input_scale
            )
            replace_parameter(layer, "w13_input_scale", w13_input_scale)
            replace_parameter(layer, "w2_input_scale", w2_input_scale)

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        if not self.block_quant:
            shard_size = layer.intermediate_size_per_partition
            w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
                w13, w13_scale, shard_size, layer.local_num_experts
            )

        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
        )

    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            return None
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            # For no-EP case, don't use the MKM framework.
            if not self.moe.moe_parallel_config.use_all2all_kernels:
                return None

            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
            )
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        return super().maybe_make_prepare_finalize(routing_tables)

    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
        layer: torch.nn.Module,
    ) -> FusedMoEPermuteExpertsUnpermute:
        assert self.moe_quant_config is not None
        assert self.experts_cls is not None
        return make_fp8_moe_kernel_for_mkm(
            moe_config=self.moe,
            quant_config=self.moe_quant_config,
            experts_cls=self.experts_cls,
            prepare_finalize=prepare_finalize,
        )

    def get_fused_moe_quant_config(
        self, layer: torch.nn.Module
    ) -> FusedMoEQuantConfig | None:
        # TRTLLM does not use Modular Kernel.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            return None

        w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            block_shape=self.weight_block_size,
        )

    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

    @property
    def is_monolithic(self) -> bool:
        return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM

    def apply_monolithic(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert self.is_monolithic
        assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM

        # TODO(rob): convert this to MK.
        if layer.enable_eplb:
            raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
        assert layer.activation == "silu", (
            f"Expected 'silu' activation but got {layer.activation}"
        )

        if self.block_quant:
            import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401

            e_score_correction_bias = (
                layer.e_score_correction_bias.to(x.dtype)
                if layer.e_score_correction_bias is not None
                else None
            )
            routing_method_type = layer.routing_method_type
            return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
                routing_logits=router_logits.to(torch.float32)
                if routing_method_type == RoutingMethodType.DeepSeekV3
                else router_logits,
                routing_bias=e_score_correction_bias,
                x=x,
                w13_weight=layer.w13_weight,
                w13_weight_scale_inv=layer.w13_weight_scale_inv,
                w2_weight=layer.w2_weight,
                w2_weight_scale_inv=layer.w2_weight_scale_inv,
                global_num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                intermediate_size=layer.intermediate_size_per_partition,
                expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=layer.local_num_experts,
                block_shape=self.weight_block_size,
                routing_method_type=routing_method_type,
                routed_scaling=layer.routed_scaling_factor,
            )
        else:
            return apply_fi_trtllm_fp8_per_tensor_moe(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=layer.e_score_correction_bias,
                global_num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
            )

    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert self.kernel is not None
        assert not self.is_monolithic
        return self.kernel(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )

allow_inplace property

allow_inplace: bool

block_quant instance-attribute

block_quant: bool = weight_block_size is not None

is_monolithic property

is_monolithic: bool

kernel instance-attribute

kernel: FusedMoEModularKernel | None = None

quant_config instance-attribute

quant_config = quant_config

supports_eplb property

supports_eplb: bool

topk_indices_dtype property

topk_indices_dtype: dtype | None

weight_block_size instance-attribute

weight_block_size = weight_block_size

weight_scale_name instance-attribute

weight_scale_name = (
    "weight_scale_inv" if block_quant else "weight_scale"
)

__init__

__init__(quant_config: Fp8Config, layer: Module)
Source code in vllm/model_executor/layers/quantization/fp8.py
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
    super().__init__(layer.moe_config)
    self.quant_config = quant_config
    self.weight_block_size = self.quant_config.weight_block_size
    self.block_quant: bool = self.weight_block_size is not None
    self.weight_scale_name = (
        "weight_scale_inv" if self.block_quant else "weight_scale"
    )

    # Set weight key and activation key for kernel compatibility
    if self.block_quant:
        weight_key = kFp8Static128BlockSym
        activation_key = kFp8Dynamic128Sym
    else:
        weight_key = kFp8StaticTensorSym
        activation_key = (
            kFp8StaticTensorSym
            if self.quant_config.activation_scheme == "static"
            else kFp8DynamicTensorSym
        )

    # Select Fp8 MoE backend
    self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
        config=self.moe,
        weight_key=weight_key,
        activation_key=activation_key,
        allow_vllm_cutlass=False,
    )

    # Delay creation of the kernel until after process-weights.
    self.kernel: mk.FusedMoEModularKernel | None = None

_setup_kernel

_setup_kernel(
    layer: Module,
    w13: Tensor,
    w2: Tensor,
    w13_scale: Tensor,
    w2_scale: Tensor,
    w13_input_scale: Tensor | None,
    w2_input_scale: Tensor | None,
) -> None
Source code in vllm/model_executor/layers/quantization/fp8.py
def _setup_kernel(
    self,
    layer: Module,
    w13: torch.Tensor,
    w2: torch.Tensor,
    w13_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    w13_input_scale: torch.Tensor | None,
    w2_input_scale: torch.Tensor | None,
) -> None:
    # Shuffle weights to runtime format.
    w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
        fp8_backend=self.fp8_backend,
        layer=layer,
        w13=w13,
        w2=w2,
        w13_scale=w13_scale,
        w2_scale=w2_scale,
        w13_input_scale=w13_input_scale,
        w2_input_scale=w2_input_scale,
    )

    # Replace parameters with updated versions. Note that this helper
    # function ensures the replacement is compatible with RL weight reloads.
    replace_parameter(layer, "w13_weight", w13)
    replace_parameter(layer, "w2_weight", w2)
    replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
    replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)

    # Setup modular kernel for TP case and naive DP/EP case.
    # In non-naive DP/EP case, we will create a ModularKernelMethod.
    # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
    # in both cases.
    self.moe_quant_config = self.get_fused_moe_quant_config(layer)
    if self.moe_quant_config and (
        (not self.moe.moe_parallel_config.use_all2all_kernels)
        or self.moe.moe_parallel_config.use_naive_all2all_kernels
    ):
        assert self.experts_cls is not None
        self.kernel, self.use_inplace = make_fp8_moe_kernel(
            moe_quant_config=self.moe_quant_config,
            moe_config=self.moe,
            fp8_backend=self.fp8_backend,
            experts_cls=self.experts_cls,
        )

apply

apply(
    layer: FusedMoE,
    x: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
) -> Tensor | tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/quantization/fp8.py
def apply(
    self,
    layer: FusedMoE,
    x: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    assert self.kernel is not None
    assert not self.is_monolithic
    return self.kernel(
        x,
        layer.w13_weight,
        layer.w2_weight,
        topk_weights,
        topk_ids,
        inplace=self.use_inplace,
        activation=layer.activation,
        global_num_experts=layer.global_num_experts,
        expert_map=layer.expert_map,
        apply_router_weight_on_input=layer.apply_router_weight_on_input,
    )

apply_monolithic

apply_monolithic(
    layer: FusedMoE, x: Tensor, router_logits: Tensor
) -> Tensor | tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/quantization/fp8.py
def apply_monolithic(
    self,
    layer: FusedMoE,
    x: torch.Tensor,
    router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    assert self.is_monolithic
    assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM

    # TODO(rob): convert this to MK.
    if layer.enable_eplb:
        raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
    assert layer.activation == "silu", (
        f"Expected 'silu' activation but got {layer.activation}"
    )

    if self.block_quant:
        import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401

        e_score_correction_bias = (
            layer.e_score_correction_bias.to(x.dtype)
            if layer.e_score_correction_bias is not None
            else None
        )
        routing_method_type = layer.routing_method_type
        return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
            routing_logits=router_logits.to(torch.float32)
            if routing_method_type == RoutingMethodType.DeepSeekV3
            else router_logits,
            routing_bias=e_score_correction_bias,
            x=x,
            w13_weight=layer.w13_weight,
            w13_weight_scale_inv=layer.w13_weight_scale_inv,
            w2_weight=layer.w2_weight,
            w2_weight_scale_inv=layer.w2_weight_scale_inv,
            global_num_experts=layer.global_num_experts,
            top_k=layer.top_k,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            intermediate_size=layer.intermediate_size_per_partition,
            expert_offset=layer.ep_rank * layer.local_num_experts,
            local_num_experts=layer.local_num_experts,
            block_shape=self.weight_block_size,
            routing_method_type=routing_method_type,
            routed_scaling=layer.routed_scaling_factor,
        )
    else:
        return apply_fi_trtllm_fp8_per_tensor_moe(
            layer=layer,
            hidden_states=x,
            router_logits=router_logits,
            routing_bias=layer.e_score_correction_bias,
            global_num_experts=layer.global_num_experts,
            top_k=layer.top_k,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/fp8.py
def create_weights(
    self,
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    layer.intermediate_size_per_partition = intermediate_size_per_partition
    layer.hidden_size = hidden_size
    layer.num_experts = num_experts
    layer.orig_dtype = params_dtype
    layer.weight_block_size = None

    assert self.quant_config.is_checkpoint_fp8_serialized
    params_dtype = torch.float8_e4m3fn

    if self.block_quant:
        assert self.weight_block_size is not None
        layer.weight_block_size = self.weight_block_size
        tp_size = get_tensor_model_parallel_world_size()
        block_n, block_k = (
            self.weight_block_size[0],
            self.weight_block_size[1],
        )
        # NOTE: To ensure proper alignment of the block-wise quantization
        # scales, the output_size of the weights for both the gate and up
        # layers must be divisible by block_n.
        # Required by column parallel or enabling merged weights
        if intermediate_size_per_partition % block_n != 0:
            raise ValueError(
                f"The output_size of gate's and up's weight = "
                f"{intermediate_size_per_partition} is not divisible by "
                f"weight quantization block_n = {block_n}."
            )
        if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
            # Required by row parallel
            raise ValueError(
                f"The input_size of down's weight = "
                f"{intermediate_size_per_partition} is not divisible by "
                f"weight quantization block_k = {block_k}."
            )

    # WEIGHTS
    w13_weight = torch.nn.Parameter(
        torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w13_weight", w13_weight)
    set_weight_attrs(w13_weight, extra_weight_attrs)

    w2_weight = torch.nn.Parameter(
        torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w2_weight", w2_weight)
    set_weight_attrs(w2_weight, extra_weight_attrs)

    # WEIGHT_SCALES
    if not self.block_quant:
        # For per-tensor quant, the scales are per expert and weight.
        w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
        w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
    else:
        # For block quant, the scales are per block (typically 128x128).
        w13_scale_data = torch.ones(
            num_experts,
            2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
            (hidden_size + block_k - 1) // block_k,
            dtype=torch.float32,
        )
        w2_scale_data = torch.ones(
            num_experts,
            (hidden_size + block_n - 1) // block_n,
            (intermediate_size_per_partition + block_k - 1) // block_k,
            dtype=torch.float32,
        )
    w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
    w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
    # Note: name is weight_scale for tensor, weight_scale_inv for block.
    layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
    layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)

    # Add the quantization method used (per tensor/grouped/channel)
    # to ensure the weight scales are loaded in properly
    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        if self.block_quant
        else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
    )
    set_weight_attrs(w13_weight_scale, extra_weight_attrs)
    set_weight_attrs(w2_weight_scale, extra_weight_attrs)

    # INPUT_SCALES
    if self.quant_config.activation_scheme == "static":
        assert not self.block_quant
        w13_input_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        set_weight_attrs(w13_input_scale, extra_weight_attrs)

        w2_input_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w2_input_scale", w2_input_scale)
        set_weight_attrs(w2_input_scale, extra_weight_attrs)

    else:
        layer.w13_input_scale = None
        layer.w2_input_scale = None

get_fused_moe_quant_config

get_fused_moe_quant_config(
    layer: Module,
) -> FusedMoEQuantConfig | None
Source code in vllm/model_executor/layers/quantization/fp8.py
def get_fused_moe_quant_config(
    self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
    # TRTLLM does not use Modular Kernel.
    if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
        return None

    w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
    w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
    a1_scale = layer.w13_input_scale
    a2_scale = layer.w2_input_scale

    return make_fp8_moe_quant_config(
        fp8_backend=self.fp8_backend,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
        block_shape=self.weight_block_size,
    )

maybe_make_prepare_finalize

maybe_make_prepare_finalize(
    routing_tables: tuple[Tensor, Tensor, Tensor]
    | None = None,
) -> FusedMoEPrepareAndFinalize | None
Source code in vllm/model_executor/layers/quantization/fp8.py
def maybe_make_prepare_finalize(
    self,
    routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
    if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
        return None
    elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
        # For no-EP case, don't use the MKM framework.
        if not self.moe.moe_parallel_config.use_all2all_kernels:
            return None

        prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
            self.moe,
            use_deepseek_fp8_block_scale=self.block_quant,
        )
        logger.debug_once("%s", prepare_finalize.__class__.__name__)
        return prepare_finalize
    return super().maybe_make_prepare_finalize(routing_tables)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/fp8.py
def process_weights_after_loading(self, layer: Module) -> None:
    if getattr(layer, "_already_called_process_weights_after_loading", False):
        return

    # Allow for accessing weights and scales in standard way.
    w13 = layer.w13_weight
    w2 = layer.w2_weight
    w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
    w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
    w13_input_scale = layer.w13_input_scale
    w2_input_scale = layer.w2_input_scale

    # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
    if current_platform.is_fp8_fnuz():
        w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
            w13,
            w13_scale,
            w13_input_scale,
        )
        w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
            w2,
            w2_scale,
            w2_input_scale,
        )

    # Per tensor kernels require single activation scale. Use the max.
    if self.quant_config.activation_scheme == "static":
        assert not self.block_quant
        assert w13_input_scale is not None and w2_input_scale is not None
        w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
            w13_input_scale, w2_input_scale
        )
        replace_parameter(layer, "w13_input_scale", w13_input_scale)
        replace_parameter(layer, "w2_input_scale", w2_input_scale)

    # Per tensor kernels require single weight scale for w13 per expert, but
    # on disk there is a scale for w1 and w3. Use the max to requantize.
    if not self.block_quant:
        shard_size = layer.intermediate_size_per_partition
        w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
            w13, w13_scale, shard_size, layer.local_num_experts
        )

    # Shuffle weights to runtime format and setup kernel.
    self._setup_kernel(
        layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
    )

select_gemm_impl

select_gemm_impl(
    prepare_finalize: FusedMoEPrepareAndFinalize,
    layer: Module,
) -> FusedMoEPermuteExpertsUnpermute
Source code in vllm/model_executor/layers/quantization/fp8.py
def select_gemm_impl(
    self,
    prepare_finalize: FusedMoEPrepareAndFinalize,
    layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
    assert self.moe_quant_config is not None
    assert self.experts_cls is not None
    return make_fp8_moe_kernel_for_mkm(
        moe_config=self.moe,
        quant_config=self.moe_quant_config,
        experts_cls=self.experts_cls,
        prepare_finalize=prepare_finalize,
    )

Fp8OnlineLinearMethod

Bases: Fp8LinearMethod

Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint and quantized the weights during loading.

Source code in vllm/model_executor/layers/quantization/fp8.py
class Fp8OnlineLinearMethod(Fp8LinearMethod):
    """Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
    and quantized the weights during loading."""

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # WEIGHT
        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call from doing
                # anything
                layer._already_called_process_weights_after_loading = True

            return res

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=patched_weight_loader,
        )
        layer.register_parameter("weight", weight)

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # TODO(future): support block_quant in online quant path
        assert not self.block_quant

        layer.input_scale = None
        qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
        weight = qweight.t()

        # Update layer with new values.
        replace_parameter(layer, "weight", weight.data)
        replace_parameter(layer, "weight_scale", weight_scale.data)

        if self.use_marlin:
            size_k_first = True
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/fp8.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    output_size_per_partition = sum(output_partition_sizes)
    weight_loader = extra_weight_attrs.get("weight_loader")
    layer.logical_widths = output_partition_sizes
    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition
    layer.orig_dtype = params_dtype
    layer.weight_block_size = None

    # WEIGHT
    def patched_weight_loader(param, loaded_weight, *args, **kwargs):
        # track how many elements we have updated
        if not hasattr(layer, "_loaded_numel"):
            layer._loaded_numel = 0

        # load the current weight chunk
        copy_numel_counter = CopyNumelCounter()
        with copy_numel_counter:
            res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
        layer._loaded_numel += copy_numel_counter.copied_numel

        # if we have loaded all of the elements, call
        # process_weights_after_loading
        target_loaded_numel = layer.weight.numel()
        if layer._loaded_numel == target_loaded_numel:
            self.process_weights_after_loading(layer)

            # Delete the bookkeeping
            del layer._loaded_numel
            # Prevent the usual `process_weights_after_loading` call from doing
            # anything
            layer._already_called_process_weights_after_loading = True

        return res

    weight = ModelWeightParameter(
        data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=params_dtype,
        ),
        input_dim=1,
        output_dim=0,
        weight_loader=patched_weight_loader,
    )
    layer.register_parameter("weight", weight)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/fp8.py
def process_weights_after_loading(self, layer: Module) -> None:
    if getattr(layer, "_already_called_process_weights_after_loading", False):
        return

    # TODO(future): support block_quant in online quant path
    assert not self.block_quant

    layer.input_scale = None
    qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
    weight = qweight.t()

    # Update layer with new values.
    replace_parameter(layer, "weight", weight.data)
    replace_parameter(layer, "weight_scale", weight_scale.data)

    if self.use_marlin:
        size_k_first = True
        prepare_fp8_layer_for_marlin(
            layer, size_k_first, input_dtype=self.marlin_input_dtype
        )

Fp8OnlineMoEMethod

Bases: Fp8MoEMethod

MoE method for online FP8 quantization. Supports loading quantized FP16/BF16 model checkpoints with dynamic activation scaling. The weight scaling factor will be initialized after the model weights are loaded.

Parameters:

Name Type Description Default
quant_config Fp8Config

The quantization config.

required
Source code in vllm/model_executor/layers/quantization/fp8.py
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

        layer.w13_input_scale = None
        layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
        w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
        w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale

        for expert in range(layer.local_num_experts):
            w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
                layer.w13_weight[expert, :, :]
            )
            w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
                layer.w2_weight[expert, :, :]
            )

        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer,
            w13,
            w2,
            w13_scale,
            w2_scale,
            layer.w13_input_scale,
            layer.w2_input_scale,
        )

__init__

__init__(quant_config: Fp8Config, layer: Module)
Source code in vllm/model_executor/layers/quantization/fp8.py
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
    super().__init__(quant_config, layer)
    assert not quant_config.is_checkpoint_fp8_serialized
    assert quant_config.activation_scheme == "dynamic"
    assert quant_config.weight_block_size is None

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/fp8.py
def create_weights(
    self,
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    layer.intermediate_size_per_partition = intermediate_size_per_partition
    layer.hidden_size = hidden_size
    layer.num_experts = num_experts
    layer.orig_dtype = params_dtype
    layer.weight_block_size = None

    # We are doing online quantization, patch the weight loaded
    # to call `process_weights_after_loading` in a streaming fashion
    # as soon as the last weight chunk is loaded.
    weight_loader = extra_weight_attrs["weight_loader"]
    # create a new holder to prevent modifying behavior of any other
    # objects which might depend on the old one
    new_extra_weight_attrs = extra_weight_attrs

    def patched_weight_loader(param, loaded_weight, *args, **kwargs):
        # add a counter to track how many elements we have updated
        if not hasattr(layer, "_loaded_numel"):
            layer._loaded_numel = 0

        # load the current weight chunk
        copy_numel_counter = CopyNumelCounter()
        with copy_numel_counter:
            res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
        layer._loaded_numel += copy_numel_counter.copied_numel

        # if we have loaded all of the elements, call
        # process_weights_after_loading
        target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
        if layer._loaded_numel == target_loaded_numel:
            self.process_weights_after_loading(layer)

            # Delete the bookkeeping
            del layer._loaded_numel
            # Prevent the usual `process_weights_after_loading` call
            # from doing anything
            layer._already_called_process_weights_after_loading = True

        return res

    new_extra_weight_attrs["weight_loader"] = patched_weight_loader
    extra_weight_attrs = new_extra_weight_attrs

    # WEIGHTS
    w13_weight = torch.nn.Parameter(
        torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w13_weight", w13_weight)
    set_weight_attrs(w13_weight, extra_weight_attrs)

    w2_weight = torch.nn.Parameter(
        torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w2_weight", w2_weight)
    set_weight_attrs(w2_weight, extra_weight_attrs)

    # WEIGHT_SCALES
    # Allocate 2 scales for w1 and w3 respectively.
    # They will be combined to a single scale after weight loading.
    w13_weight_scale = torch.nn.Parameter(
        torch.ones(num_experts, dtype=torch.float32), requires_grad=False
    )
    w2_weight_scale = torch.nn.Parameter(
        torch.ones(num_experts, dtype=torch.float32), requires_grad=False
    )
    layer.register_parameter("w13_weight_scale", w13_weight_scale)
    layer.register_parameter("w2_weight_scale", w2_weight_scale)
    set_weight_attrs(w13_weight_scale, extra_weight_attrs)
    set_weight_attrs(w2_weight_scale, extra_weight_attrs)

    layer.w13_input_scale = None
    layer.w2_input_scale = None

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/fp8.py
def process_weights_after_loading(self, layer: Module) -> None:
    if getattr(layer, "_already_called_process_weights_after_loading", False):
        return

    # If checkpoint is fp16, quantize in place.
    fp8_dtype = current_platform.fp8_dtype()
    w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
    w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
    w13_scale = layer.w13_weight_scale
    w2_scale = layer.w2_weight_scale

    for expert in range(layer.local_num_experts):
        w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
            layer.w13_weight[expert, :, :]
        )
        w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
            layer.w2_weight[expert, :, :]
        )

    # Shuffle weights to runtime format and setup kernel.
    self._setup_kernel(
        layer,
        w13,
        w2,
        w13_scale,
        w2_scale,
        layer.w13_input_scale,
        layer.w2_input_scale,
    )