Skip to content

vllm.distributed.device_communicators.xpu_communicator

logger module-attribute

logger = init_logger(__name__)

XpuCommunicator

Bases: DeviceCommunicatorBase

Source code in vllm/distributed/device_communicators/xpu_communicator.py
class XpuCommunicator(DeviceCommunicatorBase):
    def __init__(
        self,
        cpu_group: ProcessGroup,
        device: torch.device | None = None,
        device_group: ProcessGroup | None = None,
        unique_name: str = "",
    ):
        super().__init__(cpu_group, device, device_group, unique_name)
        if self.use_all2all:
            if self.all2all_backend == "naive":
                from .all2all import NaiveAll2AllManager

                self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
                logger.info("Using naive all2all manager.")

            elif self.all2all_backend == "allgather_reducescatter":
                from .all2all import AgRsAll2AllManager

                self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
                logger.info("Using AgRs manager on XPU device.")

            else:  # type: ignore[has-type]
                logger.warning(
                    "`%s` all2all manager is not supported on XPU. "
                    "Falling back to AgRs manager for XPU, "
                    "which is the Default backend",
                    self.all2all_backend,  # type: ignore[has-type]
                )
                from .all2all import AgRsAll2AllManager

                self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
                logger.info("Using AgRs manager on XPU device.")

    def all_reduce(self, input_) -> torch.Tensor:
        dist.all_reduce(input_, group=self.device_group)
        return input_

    def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
        world_size = self.world_size

        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()

        # Note: This will produce an incorrect answer if we don't make
        # the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
        input_tensor = input_.movedim(0, dim).contiguous()

        assert input_tensor.shape[0] % world_size == 0
        chunk_size = input_tensor.shape[0] // world_size
        output_shape = (chunk_size,) + input_tensor.shape[1:]

        output = torch.empty(
            output_shape, dtype=input_tensor.dtype, device=input_tensor.device
        )

        dist.reduce_scatter_tensor(output, input_tensor)

        # Reshape before returning
        return output.movedim(0, dim).contiguous()

    def reduce_scatterv(
        self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
    ):
        world_size = self.world_size

        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()

        # Note: This will produce an incorrect answer if we don't make
        # the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
        input_tensor = input_.movedim(0, dim).contiguous()

        if sizes is not None:
            assert len(sizes) == world_size
            assert input_tensor.shape[0] == sum(sizes)
            chunk_size = sizes[self.rank_in_group]
        else:
            assert input_tensor.shape[0] % world_size == 0
            chunk_size = input_tensor.shape[0] // world_size
        output_shape = (chunk_size,) + input_tensor.shape[1:]

        output = torch.empty(
            output_shape, dtype=input_tensor.dtype, device=input_tensor.device
        )
        if sizes is not None and sizes.count(sizes[0]) != len(sizes):
            # if inputs shape in different ranks is not the same using reduce_scatter
            input_splits = list(input_tensor.split(sizes, dim=0))
            dist.reduce_scatter(output, input_splits)
        else:
            dist.reduce_scatter_tensor(output, input_tensor)
        # Reshape before returning
        return output.movedim(0, dim).contiguous()

    def all_gatherv(
        self,
        input_: torch.Tensor | list[torch.Tensor],
        dim: int = 0,
        sizes: list[int] | None = None,
    ):
        if dim != 0:
            raise NotImplementedError("only dim 0 all-gatherv is supported")
        world_size = self.world_size

        # 'sizes' is not needed if all inputs in the same group have the same
        # shape
        if sizes is not None and all(s == sizes[0] for s in sizes):
            sizes = None

        def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None):
            input_size = input_.size()
            if sizes is not None:
                assert len(sizes) == world_size
                assert input_.shape[dim] == sizes[self.rank_in_group], (
                    f"{input_.shape[dim]} != {sizes[self.rank_in_group]}"
                )
                output_size = (sum(sizes),) + input_size[1:]
            else:
                output_size = (input_size[0] * world_size,) + input_size[1:]
            # Allocate output tensor.
            output_tensor = torch.empty(
                output_size, dtype=input_.dtype, device=input_.device
            )

            if sizes is not None:
                all_gather_list = []
                for size in sizes:
                    all_gather_list.append(
                        torch.empty(
                            (size,) + input_.shape[1:],
                            dtype=input_.dtype,
                            device=input_.device,
                        )
                    )
                dist.all_gather(all_gather_list, input_)
                output_tensor = torch.cat(all_gather_list, dim=0)
            else:
                dist.all_gather([output_tensor], input_)
            return output_tensor

        if isinstance(input_, torch.Tensor):
            return _all_gather_single(input_, sizes)

        output_list = []
        for inp in input_:
            output_list.append(_all_gather_single(inp, sizes=sizes))
        return output_list

    def gather(
        self, input_: torch.Tensor, dst: int = 0, dim: int = -1
    ) -> torch.Tensor | None:
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()
        # For xpu path, gather doesn't work properly together with ray
        # cluster so we use all_gather instead for now.
        input_size = input_.size()
        # Allocate output tensor.
        output_tensor = torch.empty(
            (self.world_size,) + input_size, dtype=input_.dtype, device=input_.device
        )
        # All-gather.
        dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
        if self.rank_in_group == dst:
            # Reshape
            output_tensor = output_tensor.movedim(0, dim)
            output_tensor = output_tensor.reshape(
                input_size[:dim]
                + (self.world_size * input_size[dim],)
                + input_size[dim + 1 :]
            )
        else:
            output_tensor = None
        return output_tensor

    def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
        dist.broadcast(input_, src=src, group=self.device_group)

    def dispatch(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
        is_sequence_parallel: bool = False,
        extra_tensors: list[torch.Tensor] | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        assert self.all2all_manager is not None
        return self.all2all_manager.dispatch(
            hidden_states,
            router_logits,
            is_sequence_parallel,
            extra_tensors,  # type: ignore[call-arg]
        )

    def combine(
        self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
        assert self.all2all_manager is not None
        hidden_states = self.all2all_manager.combine(
            hidden_states, is_sequence_parallel
        )
        return hidden_states

all2all_manager instance-attribute

all2all_manager = NaiveAll2AllManager(cpu_group)

__init__

__init__(
    cpu_group: ProcessGroup,
    device: device | None = None,
    device_group: ProcessGroup | None = None,
    unique_name: str = "",
)
Source code in vllm/distributed/device_communicators/xpu_communicator.py
def __init__(
    self,
    cpu_group: ProcessGroup,
    device: torch.device | None = None,
    device_group: ProcessGroup | None = None,
    unique_name: str = "",
):
    super().__init__(cpu_group, device, device_group, unique_name)
    if self.use_all2all:
        if self.all2all_backend == "naive":
            from .all2all import NaiveAll2AllManager

            self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
            logger.info("Using naive all2all manager.")

        elif self.all2all_backend == "allgather_reducescatter":
            from .all2all import AgRsAll2AllManager

            self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
            logger.info("Using AgRs manager on XPU device.")

        else:  # type: ignore[has-type]
            logger.warning(
                "`%s` all2all manager is not supported on XPU. "
                "Falling back to AgRs manager for XPU, "
                "which is the Default backend",
                self.all2all_backend,  # type: ignore[has-type]
            )
            from .all2all import AgRsAll2AllManager

            self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
            logger.info("Using AgRs manager on XPU device.")

all_gatherv

all_gatherv(
    input_: Tensor | list[Tensor],
    dim: int = 0,
    sizes: list[int] | None = None,
)
Source code in vllm/distributed/device_communicators/xpu_communicator.py
def all_gatherv(
    self,
    input_: torch.Tensor | list[torch.Tensor],
    dim: int = 0,
    sizes: list[int] | None = None,
):
    if dim != 0:
        raise NotImplementedError("only dim 0 all-gatherv is supported")
    world_size = self.world_size

    # 'sizes' is not needed if all inputs in the same group have the same
    # shape
    if sizes is not None and all(s == sizes[0] for s in sizes):
        sizes = None

    def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None):
        input_size = input_.size()
        if sizes is not None:
            assert len(sizes) == world_size
            assert input_.shape[dim] == sizes[self.rank_in_group], (
                f"{input_.shape[dim]} != {sizes[self.rank_in_group]}"
            )
            output_size = (sum(sizes),) + input_size[1:]
        else:
            output_size = (input_size[0] * world_size,) + input_size[1:]
        # Allocate output tensor.
        output_tensor = torch.empty(
            output_size, dtype=input_.dtype, device=input_.device
        )

        if sizes is not None:
            all_gather_list = []
            for size in sizes:
                all_gather_list.append(
                    torch.empty(
                        (size,) + input_.shape[1:],
                        dtype=input_.dtype,
                        device=input_.device,
                    )
                )
            dist.all_gather(all_gather_list, input_)
            output_tensor = torch.cat(all_gather_list, dim=0)
        else:
            dist.all_gather([output_tensor], input_)
        return output_tensor

    if isinstance(input_, torch.Tensor):
        return _all_gather_single(input_, sizes)

    output_list = []
    for inp in input_:
        output_list.append(_all_gather_single(inp, sizes=sizes))
    return output_list

all_reduce

all_reduce(input_) -> Tensor
Source code in vllm/distributed/device_communicators/xpu_communicator.py
def all_reduce(self, input_) -> torch.Tensor:
    dist.all_reduce(input_, group=self.device_group)
    return input_

broadcast

broadcast(input_: Tensor, src: int = 0) -> None
Source code in vllm/distributed/device_communicators/xpu_communicator.py
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
    dist.broadcast(input_, src=src, group=self.device_group)

combine

combine(
    hidden_states: Tensor,
    is_sequence_parallel: bool = False,
) -> Tensor
Source code in vllm/distributed/device_communicators/xpu_communicator.py
def combine(
    self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
    assert self.all2all_manager is not None
    hidden_states = self.all2all_manager.combine(
        hidden_states, is_sequence_parallel
    )
    return hidden_states

dispatch

dispatch(
    hidden_states: Tensor,
    router_logits: Tensor,
    is_sequence_parallel: bool = False,
    extra_tensors: list[Tensor] | None = None,
) -> tuple[Tensor, Tensor]
Source code in vllm/distributed/device_communicators/xpu_communicator.py
def dispatch(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
    is_sequence_parallel: bool = False,
    extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    assert self.all2all_manager is not None
    return self.all2all_manager.dispatch(
        hidden_states,
        router_logits,
        is_sequence_parallel,
        extra_tensors,  # type: ignore[call-arg]
    )

gather

gather(
    input_: Tensor, dst: int = 0, dim: int = -1
) -> Tensor | None
Source code in vllm/distributed/device_communicators/xpu_communicator.py
def gather(
    self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> torch.Tensor | None:
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
    )
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    # For xpu path, gather doesn't work properly together with ray
    # cluster so we use all_gather instead for now.
    input_size = input_.size()
    # Allocate output tensor.
    output_tensor = torch.empty(
        (self.world_size,) + input_size, dtype=input_.dtype, device=input_.device
    )
    # All-gather.
    dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
    if self.rank_in_group == dst:
        # Reshape
        output_tensor = output_tensor.movedim(0, dim)
        output_tensor = output_tensor.reshape(
            input_size[:dim]
            + (self.world_size * input_size[dim],)
            + input_size[dim + 1 :]
        )
    else:
        output_tensor = None
    return output_tensor

reduce_scatter

reduce_scatter(input_: Tensor, dim: int = -1)
Source code in vllm/distributed/device_communicators/xpu_communicator.py
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
    world_size = self.world_size

    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()

    # Note: This will produce an incorrect answer if we don't make
    # the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
    input_tensor = input_.movedim(0, dim).contiguous()

    assert input_tensor.shape[0] % world_size == 0
    chunk_size = input_tensor.shape[0] // world_size
    output_shape = (chunk_size,) + input_tensor.shape[1:]

    output = torch.empty(
        output_shape, dtype=input_tensor.dtype, device=input_tensor.device
    )

    dist.reduce_scatter_tensor(output, input_tensor)

    # Reshape before returning
    return output.movedim(0, dim).contiguous()

reduce_scatterv

reduce_scatterv(
    input_: Tensor,
    dim: int = -1,
    sizes: list[int] | None = None,
)
Source code in vllm/distributed/device_communicators/xpu_communicator.py
def reduce_scatterv(
    self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
):
    world_size = self.world_size

    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()

    # Note: This will produce an incorrect answer if we don't make
    # the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
    input_tensor = input_.movedim(0, dim).contiguous()

    if sizes is not None:
        assert len(sizes) == world_size
        assert input_tensor.shape[0] == sum(sizes)
        chunk_size = sizes[self.rank_in_group]
    else:
        assert input_tensor.shape[0] % world_size == 0
        chunk_size = input_tensor.shape[0] // world_size
    output_shape = (chunk_size,) + input_tensor.shape[1:]

    output = torch.empty(
        output_shape, dtype=input_tensor.dtype, device=input_tensor.device
    )
    if sizes is not None and sizes.count(sizes[0]) != len(sizes):
        # if inputs shape in different ranks is not the same using reduce_scatter
        input_splits = list(input_tensor.split(sizes, dim=0))
        dist.reduce_scatter(output, input_splits)
    else:
        dist.reduce_scatter_tensor(output, input_tensor)
    # Reshape before returning
    return output.movedim(0, dim).contiguous()