mirror of
https://github.com/ml-explore/mlx.git
synced 2024-09-15 10:04:00 +02:00
Add docs for the distributed namespace (#1184)
This commit is contained in:
parent
578842954c
commit
0163a8e57a
|
@ -43,6 +43,7 @@ are the CPU and GPU.
|
|||
usage/function_transforms
|
||||
usage/compile
|
||||
usage/numpy
|
||||
usage/distributed
|
||||
usage/using_streams
|
||||
|
||||
.. toctree::
|
||||
|
@ -69,6 +70,7 @@ are the CPU and GPU.
|
|||
python/metal
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/distributed
|
||||
python/tree_utils
|
||||
|
||||
.. toctree::
|
||||
|
|
19
docs/src/python/distributed.rst
Normal file
19
docs/src/python/distributed.rst
Normal file
|
@ -0,0 +1,19 @@
|
|||
.. _distributed:
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
Distributed Communication
|
||||
==========================
|
||||
|
||||
MLX provides a distributed communication package using MPI. The MPI library is
|
||||
loaded at runtime; if MPI is available then distributed communication is also
|
||||
made available.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Group
|
||||
is_available
|
||||
init
|
||||
all_sum
|
||||
all_gather
|
166
docs/src/usage/distributed.rst
Normal file
166
docs/src/usage/distributed.rst
Normal file
|
@ -0,0 +1,166 @@
|
|||
.. _usage_distributed:
|
||||
|
||||
Distributed Communication
|
||||
=========================
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
|
||||
provide distributed communication operations that allow the computational cost
|
||||
of training or inference to be shared across many physical machines. You can
|
||||
see a list of the supported operations in the :ref:`API docs<distributed>`.
|
||||
|
||||
.. note::
|
||||
A lot of operations may not be supported or not as fast as they should be.
|
||||
We are adding more and tuning the ones we have as we are figuring out the
|
||||
best way to do distributed computing on Macs using MLX.
|
||||
|
||||
Getting Started
|
||||
---------------
|
||||
|
||||
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
||||
machine. The minimal distributed program in MLX is as simple as:
|
||||
|
||||
.. code:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
world = mx.distributed.init()
|
||||
x = mx.distributed.all_sum(mx.ones(10))
|
||||
print(world.rank(), x)
|
||||
|
||||
The program above sums the array ``mx.ones(10)`` across all
|
||||
distributed processes. If simply run with ``python``, however, only one
|
||||
process is launched and no distributed communication takes place.
|
||||
|
||||
To launch the program in distributed mode we need to use ``mpirun`` or
|
||||
``mpiexec`` depending on the MPI installation. The simplest possible way is the
|
||||
following:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 python test.py
|
||||
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
|
||||
The above launches two processes on the same (local) machine and we can see
|
||||
both standard output streams. The processes send the array of 1s to each other
|
||||
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
|
||||
print 4 etc.
|
||||
|
||||
Installing MPI
|
||||
---------------
|
||||
|
||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||
with the Anaconda package manager as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ conda install openmpi
|
||||
|
||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||
|
||||
Setting up Remote Hosts
|
||||
-----------------------
|
||||
|
||||
MPI can automatically connect to remote hosts and set up the communication over
|
||||
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||
debug connectivity issues is the following:
|
||||
|
||||
* ``ssh hostname`` works from all machines to all machines without asking for
|
||||
password or host confirmation
|
||||
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
|
||||
full path to force all machines to use a specific path.
|
||||
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
||||
in the ``.ssh/config`` files on all machines.
|
||||
|
||||
.. note::
|
||||
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
|
||||
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
|
||||
|
||||
An easy way to pass the host names to MPI is using a host file. A host file
|
||||
looks like the following, where ``host1`` and ``host2`` should be the fully
|
||||
qualified domain names or IPs for these hosts.
|
||||
|
||||
.. code::
|
||||
|
||||
host1 slots=1
|
||||
host2 slots=1
|
||||
|
||||
When using MLX, it is very likely that you want to use 1 slot per host, ie one
|
||||
process per host. The hostfile also needs to contain the current
|
||||
host if you want to run on the local host. Passing the host file to
|
||||
``mpirun`` is simply done using the ``--hostfile`` command line argument.
|
||||
|
||||
Training Example
|
||||
----------------
|
||||
|
||||
In this section we will adapt an MLX training loop to support data parallel
|
||||
distributed training. Namely, we will average the gradients across a set of
|
||||
hosts before applying them to the model.
|
||||
|
||||
Our training loop looks like the following code snippet if we omit the model,
|
||||
dataset and optimizer initialization.
|
||||
|
||||
.. code:: python
|
||||
|
||||
model = ...
|
||||
optimizer = ...
|
||||
dataset = ...
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
for x, y in dataset:
|
||||
loss = step(model, x, y)
|
||||
mx.eval(loss, model.parameters())
|
||||
|
||||
All we have to do to average the gradients across machines is perform an
|
||||
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
|
||||
have to :func:`mlx.utils.tree_map` the gradients with following function.
|
||||
|
||||
.. code:: python
|
||||
|
||||
def all_avg(x):
|
||||
return mx.distributed.all_sum(x) / mx.distributed.init().size()
|
||||
|
||||
Putting everything together our training loop step looks as follows with
|
||||
everything else remaining the same.
|
||||
|
||||
.. code:: python
|
||||
|
||||
from mlx.utils import tree_map
|
||||
|
||||
def all_reduce_grads(grads):
|
||||
N = mx.distributed.init()
|
||||
if N == 1:
|
||||
return grads
|
||||
return tree_map(
|
||||
lambda x: mx.distributed.all_sum(x) / N,
|
||||
grads)
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
grads = all_reduce_grads(grads) # <--- This line was added
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
Tuning All Reduce
|
||||
-----------------
|
||||
|
||||
We are working on improving the performance of all reduce on MLX but for now
|
||||
the two main things one can do to extract the most out of distributed training with MLX are:
|
||||
|
||||
1. Perform a few large reductions instead of many small ones to improve
|
||||
bandwidth and latency
|
||||
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
|
||||
connections between each host to improve bandwidth
|
|
@ -16,7 +16,7 @@ int main() {
|
|||
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||
|
||||
array x = ones({10});
|
||||
array out = distributed::all_reduce_sum(x, global_group);
|
||||
array out = distributed::all_sum(x, global_group);
|
||||
|
||||
std::cout << out << std::endl;
|
||||
}
|
||||
|
|
|
@ -56,7 +56,7 @@ namespace detail {
|
|||
Stream communication_stream();
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_reduce_sum(Group group, const array& input, array& output);
|
||||
void all_sum(Group group, const array& input, array& output);
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_gather(Group group, const array& input, array& output);
|
||||
|
|
|
@ -260,7 +260,7 @@ Stream communication_stream() {
|
|||
return comm_stream;
|
||||
}
|
||||
|
||||
void all_reduce_sum(Group group, const array& input_, array& output) {
|
||||
void all_sum(Group group, const array& input_, array& output) {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().all_reduce(
|
||||
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
|
||||
|
|
|
@ -31,7 +31,7 @@ Stream communication_stream() {
|
|||
return comm_stream;
|
||||
}
|
||||
|
||||
void all_reduce_sum(Group group, const array& input, array& output) {}
|
||||
void all_sum(Group group, const array& input, array& output) {}
|
||||
void all_gather(Group group, const array& input, array& output) {}
|
||||
|
||||
} // namespace detail
|
||||
|
|
|
@ -17,7 +17,7 @@ Group to_group(std::optional<Group> group) {
|
|||
|
||||
} // namespace
|
||||
|
||||
array all_reduce_sum(const array& x, std::optional<Group> group_) {
|
||||
array all_sum(const array& x, std::optional<Group> group_) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt);
|
||||
array all_sum(const array& x, std::optional<Group> group = std::nullopt);
|
||||
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
|
|
@ -24,7 +24,7 @@ void AllReduce::eval_cpu(
|
|||
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]);
|
||||
distributed::detail::all_sum(group(), inputs[0], outputs[0]);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
|
@ -36,7 +36,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
|||
const std::vector<int>& axes) {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return {{all_reduce_sum(inputs[0], group())}, axes};
|
||||
return {{all_sum(inputs[0], group())}, axes};
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ std::vector<array> AllReduce::jvp(
|
|||
const std::vector<int>& argnums) {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return {all_reduce_sum(tangents[0], group())};
|
||||
return {all_sum(tangents[0], group())};
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
|
|
|
@ -69,13 +69,13 @@ void init_distributed(nb::module_& parent_module) {
|
|||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"all_reduce_sum",
|
||||
&distributed::all_reduce_sum,
|
||||
"all_sum",
|
||||
&distributed::all_sum,
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"),
|
||||
"def all_sum(x: array, *, group: Optional[Group] = None) -> array"),
|
||||
R"pbdoc(
|
||||
All reduce sum.
|
||||
|
||||
|
|
|
@ -37,13 +37,13 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
|||
]
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_reduce_sum(x)
|
||||
y = mx.distributed.all_sum(x)
|
||||
self.assertTrue(mx.all(y == world.size()))
|
||||
|
||||
sub = world.split(world.rank() % 2)
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_reduce_sum(x, group=sub)
|
||||
y = mx.distributed.all_sum(x, group=sub)
|
||||
self.assertTrue(mx.all(y == sub.size()))
|
||||
|
||||
def test_all_gather(self):
|
||||
|
@ -87,7 +87,7 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
|||
sub_2 = world.split(world.rank() % 2)
|
||||
|
||||
x = mx.ones((1, 8)) * world.rank()
|
||||
y = mx.distributed.all_reduce_sum(x, group=sub_1)
|
||||
y = mx.distributed.all_sum(x, group=sub_1)
|
||||
z = mx.distributed.all_gather(y, group=sub_2)
|
||||
z_target = mx.arange(8).reshape(4, 2).sum(-1, keepdims=True)
|
||||
|
||||
|
|
Loading…
Reference in a new issue