Skip to content

vllm.model_executor.layers.fused_moe.deep_gemm_moe

logger module-attribute

logger = init_logger(__name__)

DeepGemmExperts

Bases: FusedMoEPermuteExpertsUnpermute

Source code in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
    def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
        super().__init__(moe_config=moe_config, quant_config=quant_config)
        assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout()
        assert quant_config.quant_dtype == torch.float8_e4m3fn
        assert not quant_config.per_act_token_quant
        assert not quant_config.per_out_ch_quant

    @staticmethod
    def activation_format() -> mk.FusedMoEActivationFormat:
        return mk.FusedMoEActivationFormat.Standard

    @staticmethod
    def _supports_current_device() -> bool:
        return is_deep_gemm_supported()

    @staticmethod
    def _supports_no_act_and_mul() -> bool:
        return False

    @staticmethod
    def _supports_quant_scheme(
        weight_key: QuantKey | None,
        activation_key: QuantKey | None,
    ) -> bool:
        SUPPORTED_W_A = [
            (kFp8Static128BlockSym, kFp8Dynamic128Sym),
        ]
        return (weight_key, activation_key) in SUPPORTED_W_A

    @staticmethod
    def _supports_activation(activation: str) -> bool:
        return activation in ["silu"]

    @staticmethod
    def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
        return True

    def supports_chunking(self) -> bool:
        return True

    def supports_expert_map(self) -> bool:
        return True

    def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
        return TopKWeightAndReduceNoOP()

    def workspace_shapes(
        self,
        M: int,
        N: int,
        K: int,
        topk: int,
        global_num_experts: int,
        local_num_experts: int,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
        activation: str,
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
        assert self.block_shape is not None
        block_m = self.block_shape[0]
        M_sum = compute_aligned_M(
            M, topk, local_num_experts, block_m, expert_tokens_meta
        )
        assert M_sum % block_m == 0

        activation_out_dim = self.adjust_N_for_activation(N, activation)
        workspace1 = (M_sum, max(activation_out_dim, K))
        workspace2 = (M_sum, max(N, K))
        output = (M, K)
        return (workspace1, workspace2, output)

    def _act_mul_quant(
        self, input: torch.Tensor, output: torch.Tensor, activation: str
    ) -> tuple[torch.Tensor, torch.Tensor]:
        assert self.block_shape is not None
        block_k = self.block_shape[1]
        scale_fmt = DeepGemmQuantScaleFMT.from_oracle()

        M_sum, N = input.size()
        activation_out_dim = self.adjust_N_for_activation(N, activation)

        # 1. DeepGemm UE8M0: use packed per-token-group quant
        if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
            act_out = torch.empty(
                (M_sum, activation_out_dim), dtype=input.dtype, device=input.device
            )
            self.activation(activation, act_out, input)
            a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm(
                act_out,
                block_k,
                out_q=output,
            )
            return a2q, a2q_scale

        # 2. Hopper / non‑E8M0: prefer the fused SiLU+mul+quant kernel
        if activation == "silu":
            use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
            return silu_mul_per_token_group_quant_fp8_colmajor(
                input=input,
                output=output,
                use_ue8m0=use_ue8m0,
            )

        # 3. fallback path for non-SiLU activations in non‑UE8M0 cases.
        act_out = torch.empty(
            (M_sum, activation_out_dim), dtype=input.dtype, device=input.device
        )
        self.activation(activation, act_out, input)
        return per_token_group_quant_fp8(
            act_out, block_k, column_major_scales=True, out_q=output
        )

    def apply(
        self,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        activation: str,
        global_num_experts: int,
        expert_map: torch.Tensor | None,
        a1q_scale: torch.Tensor | None,
        a2_scale: torch.Tensor | None,
        workspace13: torch.Tensor,
        workspace2: torch.Tensor,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
        apply_router_weight_on_input: bool,
    ):
        assert a1q_scale is not None
        assert a2_scale is None
        assert self.block_shape is not None
        assert self.w1_scale is not None
        assert self.w2_scale is not None

        a1q = hidden_states
        _, N, K = w1.size()

        local_num_experts = w1.size(0)
        if global_num_experts == -1:
            global_num_experts = local_num_experts

        assert w2.size(1) == K

        M_sum = compute_aligned_M(
            M=topk_ids.size(0),
            num_topk=topk_ids.size(1),
            local_num_experts=local_num_experts,
            alignment=get_mk_alignment_for_contiguous_layout()[0],
            expert_tokens_meta=expert_tokens_meta,
        )

        a1q_perm = _resize_cache(
            workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K)
        )
        a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
            aq=a1q,
            aq_scale=a1q_scale,
            topk_ids=topk_ids,
            local_num_experts=local_num_experts,
            expert_map=expert_map,
            expert_tokens_meta=expert_tokens_meta,
            aq_out=a1q_perm,
        )
        assert a1q.size(0) == M_sum

        mm1_out = _resize_cache(workspace2, (M_sum, N))
        m_grouped_fp8_gemm_nt_contiguous(
            (a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids
        )

        activation_out_dim = self.adjust_N_for_activation(N, activation)
        quant_out = _resize_cache(
            workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim)
        )
        a2q, a2q_scale = self._act_mul_quant(
            input=mm1_out.view(-1, N), output=quant_out, activation=activation
        )

        mm2_out = _resize_cache(workspace2, (M_sum, K))
        m_grouped_fp8_gemm_nt_contiguous(
            (a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids
        )

        if apply_router_weight_on_input:
            topk_weights = torch.ones_like(topk_weights)

        deepgemm_unpermute_and_reduce(
            a=mm2_out,
            topk_ids=topk_ids,
            topk_weights=topk_weights,
            inv_perm=inv_perm,
            expert_map=expert_map,
            output=output,
        )

__init__

__init__(
    moe_config: FusedMoEConfig,
    quant_config: FusedMoEQuantConfig,
)
Source code in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
    super().__init__(moe_config=moe_config, quant_config=quant_config)
    assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout()
    assert quant_config.quant_dtype == torch.float8_e4m3fn
    assert not quant_config.per_act_token_quant
    assert not quant_config.per_out_ch_quant

_act_mul_quant

_act_mul_quant(
    input: Tensor, output: Tensor, activation: str
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
def _act_mul_quant(
    self, input: torch.Tensor, output: torch.Tensor, activation: str
) -> tuple[torch.Tensor, torch.Tensor]:
    assert self.block_shape is not None
    block_k = self.block_shape[1]
    scale_fmt = DeepGemmQuantScaleFMT.from_oracle()

    M_sum, N = input.size()
    activation_out_dim = self.adjust_N_for_activation(N, activation)

    # 1. DeepGemm UE8M0: use packed per-token-group quant
    if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
        act_out = torch.empty(
            (M_sum, activation_out_dim), dtype=input.dtype, device=input.device
        )
        self.activation(activation, act_out, input)
        a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm(
            act_out,
            block_k,
            out_q=output,
        )
        return a2q, a2q_scale

    # 2. Hopper / non‑E8M0: prefer the fused SiLU+mul+quant kernel
    if activation == "silu":
        use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
        return silu_mul_per_token_group_quant_fp8_colmajor(
            input=input,
            output=output,
            use_ue8m0=use_ue8m0,
        )

    # 3. fallback path for non-SiLU activations in non‑UE8M0 cases.
    act_out = torch.empty(
        (M_sum, activation_out_dim), dtype=input.dtype, device=input.device
    )
    self.activation(activation, act_out, input)
    return per_token_group_quant_fp8(
        act_out, block_k, column_major_scales=True, out_q=output
    )

_supports_activation staticmethod

_supports_activation(activation: str) -> bool
Source code in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
@staticmethod
def _supports_activation(activation: str) -> bool:
    return activation in ["silu"]

_supports_current_device staticmethod

_supports_current_device() -> bool
Source code in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
@staticmethod
def _supports_current_device() -> bool:
    return is_deep_gemm_supported()

_supports_no_act_and_mul staticmethod

_supports_no_act_and_mul() -> bool
Source code in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
@staticmethod
def _supports_no_act_and_mul() -> bool:
    return False

_supports_parallel_config staticmethod

_supports_parallel_config(
    moe_parallel_config: FusedMoEParallelConfig,
) -> bool
Source code in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
    return True

_supports_quant_scheme staticmethod

_supports_quant_scheme(
    weight_key: QuantKey | None,
    activation_key: QuantKey | None,
) -> bool
Source code in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
@staticmethod
def _supports_quant_scheme(
    weight_key: QuantKey | None,
    activation_key: QuantKey | None,
) -> bool:
    SUPPORTED_W_A = [
        (kFp8Static128BlockSym, kFp8Dynamic128Sym),
    ]
    return (weight_key, activation_key) in SUPPORTED_W_A

activation_format staticmethod

activation_format() -> FusedMoEActivationFormat
Source code in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
    return mk.FusedMoEActivationFormat.Standard

apply

apply(
    output: Tensor,
    hidden_states: Tensor,
    w1: Tensor,
    w2: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    activation: str,
    global_num_experts: int,
    expert_map: Tensor | None,
    a1q_scale: Tensor | None,
    a2_scale: Tensor | None,
    workspace13: Tensor,
    workspace2: Tensor,
    expert_tokens_meta: ExpertTokensMetadata | None,
    apply_router_weight_on_input: bool,
)
Source code in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
def apply(
    self,
    output: torch.Tensor,
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    activation: str,
    global_num_experts: int,
    expert_map: torch.Tensor | None,
    a1q_scale: torch.Tensor | None,
    a2_scale: torch.Tensor | None,
    workspace13: torch.Tensor,
    workspace2: torch.Tensor,
    expert_tokens_meta: mk.ExpertTokensMetadata | None,
    apply_router_weight_on_input: bool,
):
    assert a1q_scale is not None
    assert a2_scale is None
    assert self.block_shape is not None
    assert self.w1_scale is not None
    assert self.w2_scale is not None

    a1q = hidden_states
    _, N, K = w1.size()

    local_num_experts = w1.size(0)
    if global_num_experts == -1:
        global_num_experts = local_num_experts

    assert w2.size(1) == K

    M_sum = compute_aligned_M(
        M=topk_ids.size(0),
        num_topk=topk_ids.size(1),
        local_num_experts=local_num_experts,
        alignment=get_mk_alignment_for_contiguous_layout()[0],
        expert_tokens_meta=expert_tokens_meta,
    )

    a1q_perm = _resize_cache(
        workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K)
    )
    a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
        aq=a1q,
        aq_scale=a1q_scale,
        topk_ids=topk_ids,
        local_num_experts=local_num_experts,
        expert_map=expert_map,
        expert_tokens_meta=expert_tokens_meta,
        aq_out=a1q_perm,
    )
    assert a1q.size(0) == M_sum

    mm1_out = _resize_cache(workspace2, (M_sum, N))
    m_grouped_fp8_gemm_nt_contiguous(
        (a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids
    )

    activation_out_dim = self.adjust_N_for_activation(N, activation)
    quant_out = _resize_cache(
        workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim)
    )
    a2q, a2q_scale = self._act_mul_quant(
        input=mm1_out.view(-1, N), output=quant_out, activation=activation
    )

    mm2_out = _resize_cache(workspace2, (M_sum, K))
    m_grouped_fp8_gemm_nt_contiguous(
        (a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids
    )

    if apply_router_weight_on_input:
        topk_weights = torch.ones_like(topk_weights)

    deepgemm_unpermute_and_reduce(
        a=mm2_out,
        topk_ids=topk_ids,
        topk_weights=topk_weights,
        inv_perm=inv_perm,
        expert_map=expert_map,
        output=output,
    )

finalize_weight_and_reduce_impl

finalize_weight_and_reduce_impl() -> TopKWeightAndReduce
Source code in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
    return TopKWeightAndReduceNoOP()

supports_chunking

supports_chunking() -> bool
Source code in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
def supports_chunking(self) -> bool:
    return True

supports_expert_map

supports_expert_map() -> bool
Source code in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
def supports_expert_map(self) -> bool:
    return True

workspace_shapes

workspace_shapes(
    M: int,
    N: int,
    K: int,
    topk: int,
    global_num_experts: int,
    local_num_experts: int,
    expert_tokens_meta: ExpertTokensMetadata | None,
    activation: str,
) -> tuple[
    tuple[int, ...], tuple[int, ...], tuple[int, ...]
]
Source code in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
def workspace_shapes(
    self,
    M: int,
    N: int,
    K: int,
    topk: int,
    global_num_experts: int,
    local_num_experts: int,
    expert_tokens_meta: mk.ExpertTokensMetadata | None,
    activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
    assert self.block_shape is not None
    block_m = self.block_shape[0]
    M_sum = compute_aligned_M(
        M, topk, local_num_experts, block_m, expert_tokens_meta
    )
    assert M_sum % block_m == 0

    activation_out_dim = self.adjust_N_for_activation(N, activation)
    workspace1 = (M_sum, max(activation_out_dim, K))
    workspace2 = (M_sum, max(N, K))
    output = (M, K)
    return (workspace1, workspace2, output)

_valid_deep_gemm

_valid_deep_gemm(
    hidden_states: Tensor, w1: Tensor, w2: Tensor
) -> bool

Check if the given problem size is supported by the DeepGemm grouped gemm kernel. All of M, N, K and the quantization block_shape must be aligned by dg.get_m_alignment_for_contiguous_layout().

Source code in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
def _valid_deep_gemm(
    hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor
) -> bool:
    """
    Check if the given problem size is supported by the DeepGemm grouped
    gemm kernel.  All of M, N, K and the quantization block_shape must be
    aligned by `dg.get_m_alignment_for_contiguous_layout()`.
    """
    if not has_deep_gemm():
        logger.debug_once("DeepGemm disabled: deep_gemm not available.")
        return False

    M = hidden_states.size(0)
    _, K, N = w2.size()

    align = get_mk_alignment_for_contiguous_layout()[0]

    if not _valid_deep_gemm_shape(M, N, K):
        logger.debug_once(
            "DeepGemm disabled due to unaligned problem size. "
            "M: %s, N: %s, K: %s. M should >= %s "
            "and N and K must be multiples of %s. "
            "This is not an error and we will fall back to triton.",
            M,
            N,
            K,
            align,
            align,
        )
        return False
    elif N <= 512:
        logger.debug_once(
            "DeepGemm disabled for N <= 512. M: %s, N: %s, K: %s. "
            "This means we will fallback to triton "
            "for this specific shape for further speed up.",
            M,
            N,
            K,
        )
        return False

    if w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn:
        logger.debug_once(
            "DeepGemm disabled: invalid weight dtype(s). w1.dtype: %s, w2.dtype: %s",
            w1.dtype,
            w2.dtype,
        )
        return False

    if (
        not hidden_states.is_contiguous()
        or not w1.is_contiguous()
        or not w2.is_contiguous()
    ):
        logger.debug_once(
            "DeepGemm disabled: weights or activations not contiguous. "
            "hidden_states.is_contiguous(): %s, w1.is_contiguous(): %s, "
            "w2.is_contiguous(): %s",
            hidden_states.is_contiguous(),
            w1.is_contiguous(),
            w2.is_contiguous(),
        )
        return False

    return True

_valid_deep_gemm_shape

_valid_deep_gemm_shape(M: int, N: int, K: int) -> bool
Source code in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool:
    align = get_mk_alignment_for_contiguous_layout()[0]
    return align <= M and N % align == 0 and K % align == 0