Skip to content

vllm.model_executor.layers.fused_moe.router.base_router

BaseRouter

Bases: FusedMoERouter

Base router class that provides common functionality for all router implementations.

This class implements the template method pattern where select_experts() handles common pre-processing and post-processing, delegating the actual routing logic to the abstract _compute_routing() method.

Source code in vllm/model_executor/layers/fused_moe/router/base_router.py
class BaseRouter(FusedMoERouter):
    """
    Base router class that provides common functionality for all router implementations.

    This class implements the template method pattern where select_experts() handles
    common pre-processing and post-processing, delegating the actual routing logic
    to the abstract _compute_routing() method.
    """

    def __init__(
        self,
        top_k: int,
        global_num_experts: int,
        eplb_state: EplbLayerState,
        enable_eplb: bool = False,
        # TODO(bnell): Once the MK is constructed at layer init time, we
        # can make this a plain value instead of a callback.
        indices_type_getter: Callable[[], torch.dtype | None] | None = None,
    ):
        """
        Note: the indices dtype might not be available at router construction
        time, so we need to supply a callback to get it at runtime.  This is
        because the indices type is supplied by modular kernels which are
        created after MoE layer/router construction.
        """
        super().__init__()
        self.top_k = top_k
        self.global_num_experts = global_num_experts
        self.eplb_state = eplb_state
        self.enable_eplb = enable_eplb
        self.indices_type_getter = indices_type_getter

    def _validate_eplb_state(self) -> None:
        """Validate that EPLB state is properly initialized if EPLB is enabled."""
        if self.enable_eplb:
            if self.eplb_state.expert_load_view is None:
                raise ValueError("enable_eplb=True requires expert_load_view != None")
            if self.eplb_state.logical_to_physical_map is None:
                raise ValueError(
                    "enable_eplb=True requires logical_to_physical_map != None"
                )
            if self.eplb_state.logical_replica_count is None:
                raise ValueError(
                    "enable_eplb=True requires logical_replica_count != None"
                )

    def _get_indices_type(self) -> torch.dtype | None:
        """Get the desired indices dtype from the getter function."""
        return (
            self.indices_type_getter() if self.indices_type_getter is not None else None
        )

    def _apply_eplb_mapping(self, topk_ids: torch.Tensor) -> torch.Tensor:
        """Apply EPLB mapping to convert logical expert IDs to physical expert IDs."""
        if self.enable_eplb:
            assert self.eplb_state.expert_load_view is not None
            assert self.eplb_state.logical_to_physical_map is not None
            assert self.eplb_state.logical_replica_count is not None
            return eplb_map_to_physical_and_record(
                topk_ids=topk_ids,
                expert_load_view=self.eplb_state.expert_load_view,
                logical_to_physical_map=self.eplb_state.logical_to_physical_map,
                logical_replica_count=self.eplb_state.logical_replica_count,
            )
        return topk_ids

    def _convert_indices_dtype(
        self, topk_ids: torch.Tensor, indices_type: torch.dtype | None
    ) -> torch.Tensor:
        """Convert topk_ids to the desired dtype if needed."""
        if (indices_type is not None) and topk_ids.dtype != indices_type:
            topk_ids = topk_ids.to(dtype=indices_type)

        assert topk_ids.dtype == indices_type or indices_type is None
        return topk_ids

    @abstractmethod
    def _compute_routing(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
        indices_type: torch.dtype | None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Compute the actual routing logic.

        This method must be implemented by subclasses to provide the specific
        routing algorithm (e.g., grouped_topk, fused_topk, custom routing, etc.).

        Args:
            hidden_states: Input hidden states
            router_logits: Router logits for expert selection
            indices_type: Desired dtype for expert indices (may be None)

        Returns:
            tuple of (topk_weights, topk_ids)
        """
        raise NotImplementedError

    def select_experts(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Route the input hidden states to the top-k experts based on the
        router logits.

        This method implements the template method pattern:
        1. Validates EPLB state
        2. Gets indices type
        3. Calls _compute_routing() to get topk_weights and topk_ids
        4. Applies EPLB mapping if enabled
        5. Converts indices dtype if needed

        Returns:
            (topk_weights, topk_ids)
            (tuple[torch.Tensor, torch.Tensor]):
            The weights and expert ids computation result.

            **Compatibility**: When EPLB is not enabled, the returned ids are
            equivalent to global logical ids, so should be compatible with
            plain MoE implementations without redundant experts.
        """
        # Step 1: Validate EPLB state
        self._validate_eplb_state()

        # Step 2: Get indices type.
        indices_type = self._get_indices_type()

        # Step 3: Compute routing (delegated to subclass)
        topk_weights, topk_ids = self._compute_routing(
            hidden_states, router_logits, indices_type
        )

        # Step 4: Apply EPLB mapping
        topk_ids = self._apply_eplb_mapping(topk_ids)

        # Step 5: Convert indices dtype
        topk_ids = self._convert_indices_dtype(topk_ids, indices_type)

        return topk_weights, topk_ids

enable_eplb instance-attribute

enable_eplb = enable_eplb

eplb_state instance-attribute

eplb_state = eplb_state

global_num_experts instance-attribute

global_num_experts = global_num_experts

indices_type_getter instance-attribute

indices_type_getter = indices_type_getter

top_k instance-attribute

top_k = top_k

__init__

__init__(
    top_k: int,
    global_num_experts: int,
    eplb_state: EplbLayerState,
    enable_eplb: bool = False,
    indices_type_getter: Callable[[], dtype | None]
    | None = None,
)

Note: the indices dtype might not be available at router construction time, so we need to supply a callback to get it at runtime. This is because the indices type is supplied by modular kernels which are created after MoE layer/router construction.

Source code in vllm/model_executor/layers/fused_moe/router/base_router.py
def __init__(
    self,
    top_k: int,
    global_num_experts: int,
    eplb_state: EplbLayerState,
    enable_eplb: bool = False,
    # TODO(bnell): Once the MK is constructed at layer init time, we
    # can make this a plain value instead of a callback.
    indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
    """
    Note: the indices dtype might not be available at router construction
    time, so we need to supply a callback to get it at runtime.  This is
    because the indices type is supplied by modular kernels which are
    created after MoE layer/router construction.
    """
    super().__init__()
    self.top_k = top_k
    self.global_num_experts = global_num_experts
    self.eplb_state = eplb_state
    self.enable_eplb = enable_eplb
    self.indices_type_getter = indices_type_getter

_apply_eplb_mapping

_apply_eplb_mapping(topk_ids: Tensor) -> Tensor

Apply EPLB mapping to convert logical expert IDs to physical expert IDs.

Source code in vllm/model_executor/layers/fused_moe/router/base_router.py
def _apply_eplb_mapping(self, topk_ids: torch.Tensor) -> torch.Tensor:
    """Apply EPLB mapping to convert logical expert IDs to physical expert IDs."""
    if self.enable_eplb:
        assert self.eplb_state.expert_load_view is not None
        assert self.eplb_state.logical_to_physical_map is not None
        assert self.eplb_state.logical_replica_count is not None
        return eplb_map_to_physical_and_record(
            topk_ids=topk_ids,
            expert_load_view=self.eplb_state.expert_load_view,
            logical_to_physical_map=self.eplb_state.logical_to_physical_map,
            logical_replica_count=self.eplb_state.logical_replica_count,
        )
    return topk_ids

_compute_routing abstractmethod

_compute_routing(
    hidden_states: Tensor,
    router_logits: Tensor,
    indices_type: dtype | None,
) -> tuple[Tensor, Tensor]

Compute the actual routing logic.

This method must be implemented by subclasses to provide the specific routing algorithm (e.g., grouped_topk, fused_topk, custom routing, etc.).

Parameters:

Name Type Description Default
hidden_states Tensor

Input hidden states

required
router_logits Tensor

Router logits for expert selection

required
indices_type dtype | None

Desired dtype for expert indices (may be None)

required

Returns:

Type Description
tuple[Tensor, Tensor]

tuple of (topk_weights, topk_ids)

Source code in vllm/model_executor/layers/fused_moe/router/base_router.py
@abstractmethod
def _compute_routing(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
    indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute the actual routing logic.

    This method must be implemented by subclasses to provide the specific
    routing algorithm (e.g., grouped_topk, fused_topk, custom routing, etc.).

    Args:
        hidden_states: Input hidden states
        router_logits: Router logits for expert selection
        indices_type: Desired dtype for expert indices (may be None)

    Returns:
        tuple of (topk_weights, topk_ids)
    """
    raise NotImplementedError

_convert_indices_dtype

_convert_indices_dtype(
    topk_ids: Tensor, indices_type: dtype | None
) -> Tensor

Convert topk_ids to the desired dtype if needed.

Source code in vllm/model_executor/layers/fused_moe/router/base_router.py
def _convert_indices_dtype(
    self, topk_ids: torch.Tensor, indices_type: torch.dtype | None
) -> torch.Tensor:
    """Convert topk_ids to the desired dtype if needed."""
    if (indices_type is not None) and topk_ids.dtype != indices_type:
        topk_ids = topk_ids.to(dtype=indices_type)

    assert topk_ids.dtype == indices_type or indices_type is None
    return topk_ids

_get_indices_type

_get_indices_type() -> dtype | None

Get the desired indices dtype from the getter function.

Source code in vllm/model_executor/layers/fused_moe/router/base_router.py
def _get_indices_type(self) -> torch.dtype | None:
    """Get the desired indices dtype from the getter function."""
    return (
        self.indices_type_getter() if self.indices_type_getter is not None else None
    )

_validate_eplb_state

_validate_eplb_state() -> None

Validate that EPLB state is properly initialized if EPLB is enabled.

Source code in vllm/model_executor/layers/fused_moe/router/base_router.py
def _validate_eplb_state(self) -> None:
    """Validate that EPLB state is properly initialized if EPLB is enabled."""
    if self.enable_eplb:
        if self.eplb_state.expert_load_view is None:
            raise ValueError("enable_eplb=True requires expert_load_view != None")
        if self.eplb_state.logical_to_physical_map is None:
            raise ValueError(
                "enable_eplb=True requires logical_to_physical_map != None"
            )
        if self.eplb_state.logical_replica_count is None:
            raise ValueError(
                "enable_eplb=True requires logical_replica_count != None"
            )

select_experts

select_experts(
    hidden_states: Tensor, router_logits: Tensor
) -> tuple[Tensor, Tensor]

Route the input hidden states to the top-k experts based on the router logits.

This method implements the template method pattern: 1. Validates EPLB state 2. Gets indices type 3. Calls _compute_routing() to get topk_weights and topk_ids 4. Applies EPLB mapping if enabled 5. Converts indices dtype if needed

Returns:

Type Description
Tensor

(topk_weights, topk_ids)

tuple[Tensor, Tensor]
tuple[Tensor, Tensor]

The weights and expert ids computation result.

tuple[Tensor, Tensor]

Compatibility: When EPLB is not enabled, the returned ids are

tuple[Tensor, Tensor]

equivalent to global logical ids, so should be compatible with

tuple[Tensor, Tensor]

plain MoE implementations without redundant experts.

Source code in vllm/model_executor/layers/fused_moe/router/base_router.py
def select_experts(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Route the input hidden states to the top-k experts based on the
    router logits.

    This method implements the template method pattern:
    1. Validates EPLB state
    2. Gets indices type
    3. Calls _compute_routing() to get topk_weights and topk_ids
    4. Applies EPLB mapping if enabled
    5. Converts indices dtype if needed

    Returns:
        (topk_weights, topk_ids)
        (tuple[torch.Tensor, torch.Tensor]):
        The weights and expert ids computation result.

        **Compatibility**: When EPLB is not enabled, the returned ids are
        equivalent to global logical ids, so should be compatible with
        plain MoE implementations without redundant experts.
    """
    # Step 1: Validate EPLB state
    self._validate_eplb_state()

    # Step 2: Get indices type.
    indices_type = self._get_indices_type()

    # Step 3: Compute routing (delegated to subclass)
    topk_weights, topk_ids = self._compute_routing(
        hidden_states, router_logits, indices_type
    )

    # Step 4: Apply EPLB mapping
    topk_ids = self._apply_eplb_mapping(topk_ids)

    # Step 5: Convert indices dtype
    topk_ids = self._convert_indices_dtype(topk_ids, indices_type)

    return topk_weights, topk_ids

eplb_map_to_physical_and_record

eplb_map_to_physical_and_record(
    topk_ids: Tensor,
    expert_load_view: Tensor,
    logical_to_physical_map: Tensor,
    logical_replica_count: Tensor,
) -> Tensor
Source code in vllm/model_executor/layers/fused_moe/router/base_router.py
def eplb_map_to_physical_and_record(
    topk_ids: torch.Tensor,
    expert_load_view: torch.Tensor,
    logical_to_physical_map: torch.Tensor,
    logical_replica_count: torch.Tensor,
) -> torch.Tensor:
    # CPU fallback: no EPLB so just return as is
    return topk_ids