index

Apple MLX: A Complete Guide to Machine Learning on Apple Silicon

Author: Aadit Agrawal

Why MLX Matters for Apple Silicon Developers

Apple released MLX in December 2023 as an open-source array framework built specifically for machine learning on Apple Silicon. Unlike PyTorch with MPS backend or TensorFlow with Metal plugin, MLX was designed from the ground up to exploit the unified memory architecture of M-series chips. The framework has matured through 2024 and 2025, reaching version 0.30+ with support for the M5 chip’s Neural Accelerators announced at WWDC 2025.

This guide covers everything you need to know to start using MLX for machine learning research and inference on your Mac.


What is MLX

MLX is an array framework for numerical computing and machine learning. It provides a NumPy-like API with support for automatic differentiation, JIT compilation, and GPU acceleration through Metal. The framework was developed by Apple’s machine learning research team and is actively maintained on GitHub at ml-explore/mlx.

Core Design Principles

MLX follows several design principles that distinguish it from other frameworks:

Familiar API: The Python interface mirrors NumPy closely. If you know NumPy, you can start using MLX immediately. Higher-level neural network APIs in mlx.nn follow PyTorch conventions.

Lazy Evaluation: Computations are not executed immediately. Instead, MLX builds a computation graph that is only evaluated when results are needed. This allows the framework to optimize and fuse operations before execution.

Unified Memory: Arrays live in shared memory accessible by both CPU and GPU. There is no need to copy data between devices.

Composable Function Transformations: Operations like grad(), vmap(), and compile() can be composed together. You can take the gradient of a compiled function, or compile a vectorized gradient computation.

Multi-Language Support: MLX has bindings for Python, Swift, C++, and C. The same models can run on macOS, iOS, iPadOS, and visionOS.

Basic Array Operations

Here is a simple example showing MLX array operations:

import mlx.core as mx

# Create arrays
a = mx.array([1.0, 2.0, 3.0])
b = mx.array([4.0, 5.0, 6.0])

# Operations work like NumPy
c = a + b
d = mx.sin(a) * mx.cos(b)

# Arrays are not computed until needed
print(c)  # This triggers evaluation

The framework supports all common array operations including broadcasting, slicing, reshaping, and linear algebra routines.


The Unified Memory Model

The most significant architectural advantage of MLX comes from Apple Silicon’s unified memory design. On traditional systems with discrete GPUs, data must be copied between CPU RAM and GPU VRAM. This copying introduces latency and limits the effective memory available for large models.

Memory Architecture Comparison
Discrete GPU Copy Required
CPU Memory
1011
tensor
CPU
PCIe 4.0 ~5-10ms
GPU VRAM
----
copy
GPU
2x memory usage
~32GB/s bandwidth limit
vs
Unified Memory Zero Copy
Shared Memory Pool
1011
tensor
CPU
GPU
ANE
1x memory usage
~400GB/s bandwidth
No transfer overhead
Full RAM for GPU
Seamless processor handoff

How Unified Memory Works

On Apple Silicon, the CPU, GPU, and Neural Engine share the same physical memory pool. When you create an MLX array, it exists in this shared space. Any processor can access it directly without copying.

import mlx.core as mx

# Create an array - it lives in unified memory
x = mx.random.normal((1000, 1000))

# Run on GPU - no copy needed
y = mx.matmul(x, x.T, stream=mx.gpu)

# Run on CPU - still no copy needed
z = mx.sum(y, stream=mx.cpu)

In MLX, you specify the device when calling an operation, not when creating the array. The array itself has no device affiliation. This eliminates common bugs in PyTorch where tensors end up on the wrong device.

Memory Advantages for Large Models

The unified memory model provides practical benefits for running large language models. A Mac with 128GB of unified memory can load models that would require a workstation with a high-end GPU to run on CUDA systems.

However, there are limits. The GPU cannot typically use more than about 75% of total system memory. A 128GB Mac can allocate roughly 96GB for GPU tasks. This is still substantial compared to the 24GB available on consumer NVIDIA GPUs.


Lazy Evaluation Deep Dive

MLX uses lazy evaluation, meaning operations are recorded but not executed immediately. Understanding this model is essential for writing efficient MLX code.

How Lazy Evaluation Works

When you write c = a + b, MLX does not compute the sum. Instead, it creates a node in a computation graph with a and b as inputs. The actual computation only happens when you need the result.

import mlx.core as mx

a = mx.array([1.0, 2.0, 3.0])
b = mx.array([4.0, 5.0, 6.0])

# No computation happens here
c = a + b
d = c * 2
e = mx.sum(d)

# Computation happens when we need the value
print(e)  # Triggers evaluation of entire graph

You can explicitly trigger evaluation with mx.eval():

mx.eval(e)  # Force computation

Memory Benefits

Lazy evaluation enables memory optimizations. Consider initializing a large model:

import mlx.nn as nn

# Model weights are "created" but not allocated
model = LargeModel()  # Uses float32 by default

# Update to float16 before evaluation
model.load_weights("weights.safetensors")  # Loads as float16

# Only float16 memory is ever allocated
mx.eval(model.parameters())

If evaluation were eager, the model would first allocate float32 weights, then allocate float16 weights, doubling peak memory usage.

Graph Fusion and Optimization

The lazy evaluation model allows MLX to fuse operations. Multiple element-wise operations can be combined into a single GPU kernel, reducing memory bandwidth requirements and kernel launch overhead.


Automatic Differentiation

MLX implements automatic differentiation through function transformations rather than tape-based recording. This approach, inspired by JAX, provides more flexibility and composability.

The grad Function

The mx.grad() function takes a function and returns a new function that computes gradients:

import mlx.core as mx

def f(x):
    return mx.sum(x ** 2)

# Create gradient function
grad_f = mx.grad(f)

x = mx.array([1.0, 2.0, 3.0])
print(grad_f(x))  # [2.0, 4.0, 6.0]

Unlike PyTorch, there is no backward() method, no zero_grad(), no requires_grad property, and no detach(). Gradients are computed by transforming functions.

value_and_grad

Computing both the function value and gradient is common in optimization. Rather than calling the function twice, use value_and_grad():

def loss_fn(model, x, y):
    pred = model(x)
    return mx.mean((pred - y) ** 2)

# Get both loss value and gradients
loss_and_grad_fn = mx.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, x, y)

Higher-Order Gradients

Function transformations compose naturally. You can compute the gradient of a gradient:

def f(x):
    return x ** 3

grad_f = mx.grad(f)       # 3x^2
grad2_f = mx.grad(grad_f)  # 6x

x = mx.array(2.0)
print(grad2_f(x))  # 12.0

JIT Compilation with mx.compile

MLX provides a compile() transformation that optimizes computation graphs. Compilation fuses operations, eliminates redundant computations, and generates optimized Metal kernels.

Basic Usage

import mlx.core as mx

def model_forward(x, w1, w2):
    h = mx.tanh(x @ w1)
    return h @ w2

# Compile the function
compiled_forward = mx.compile(model_forward)

# First call builds and compiles the graph (slow)
y = compiled_forward(x, w1, w2)

# Subsequent calls reuse compiled code (fast)
y = compiled_forward(x, w1, w2)

The first call to a compiled function is slow because MLX must trace the computation, optimize the graph, and compile Metal shaders. Subsequent calls with the same input shapes reuse the compiled code.

Composing compile with Other Transformations

Compilation works with other function transformations:

# Compile a gradient function
compiled_grad = mx.compile(mx.grad(loss_fn))

# Compile a vectorized function
compiled_vmap = mx.compile(mx.vmap(batch_fn))

Debugging Compiled Functions

Compiled functions cannot contain print statements or other side effects because they are traced with placeholder inputs. For debugging, disable compilation:

# Globally disable compilation
mx.disable_compile()

# Or use environment variable
# MLX_DISABLE_COMPILE=1 python script.py

Metal Integration

MLX uses Metal, Apple’s GPU programming framework, as its backend. Most users never interact with Metal directly, but understanding the integration helps with debugging and optimization.

Default Stream Behavior

MLX operations run on a default GPU stream:

import mlx.core as mx

# These operations run on the GPU by default
a = mx.random.normal((1000, 1000))
b = mx.matmul(a, a.T)

You can explicitly specify CPU execution:

# Force CPU execution
c = mx.add(a, b, stream=mx.cpu)

Custom Metal Kernels

For operations not covered by built-in primitives, MLX allows custom Metal kernels:

import mlx.core as mx

source = """
    uint elem = thread_position_in_grid.x;
    T tmp = inp[elem];
    out[elem] = tmp * tmp;
"""

kernel = mx.fast.metal_kernel(
    name="square",
    input_names=["inp"],
    output_names=["out"],
    source=source,
)

a = mx.array([1.0, 2.0, 3.0, 4.0])
result = kernel(inputs=[a], grid=(4,), output_shapes=[(4,)], output_dtypes=[mx.float32])

M5 Neural Accelerator Support

With macOS 26.2 and MLX’s latest versions, the framework can leverage the Neural Accelerators in the M5 chip. MLX uses the Tensor Operations (TensorOps) and Metal Performance Primitives frameworks introduced with Metal 4 to access dedicated matrix multiplication hardware.


MLX-LM: Running Large Language Models

MLX-LM is the official package for running and fine-tuning large language models on Apple Silicon. It provides a simple interface for loading models from Hugging Face, generating text, and fine-tuning with LoRA.

Installation

pip install mlx-lm

Loading and Generating Text

from mlx_lm import load, generate

# Load a model from Hugging Face
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")

# Generate text
prompt = "Explain quantum computing in simple terms:"
response = generate(model, tokenizer, prompt=prompt, max_tokens=200)
print(response)

Command-line usage:

mlx_lm.generate --model mlx-community/Qwen3-4B-Instruct-2507-4bit --prompt "hello"

Supported Model Architectures

MLX-LM supports most popular LLM architectures including:

  • LLaMA and LLaMA 2/3
  • Mistral and Mixtral (MoE)
  • Qwen 2, 2.5, 3 and Qwen3 MoE
  • Phi 2, 3, 4
  • Gemma 1, 2, 3
  • OLMo and OLMoE
  • MiniCPM
  • DeepSeek

Most models available on Hugging Face can be loaded directly if they follow standard architectures.

Quantization

Quantization reduces memory usage and increases generation speed. MLX-LM supports 4-bit, 6-bit, and 8-bit quantization:

# Convert and quantize a model
python -m mlx_lm.convert \
    --hf-path mistralai/Mistral-7B-Instruct-v0.3 \
    --q-bits 4 \
    --q-group-size 64

Using a quantized model in Python:

from mlx_lm import load, generate

# Load a pre-quantized model
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")

# Generate as normal
response = generate(model, tokenizer, prompt="Hello!", max_tokens=100)

Quantized models from the mlx-community organization on Hugging Face are ready to use without conversion.

Memory Requirements

Approximate memory usage for different quantization levels on a 7B parameter model:

PrecisionMemory Usage
float32~28 GB
float16~14 GB
8-bit~7 GB
4-bit~3.5 GB

A 4-bit quantized Llama 3B can generate around 50 tokens/second on recent Apple Silicon.


LoRA Fine-Tuning with MLX-LM

Low-Rank Adaptation (LoRA) enables fine-tuning large models with limited memory by training small adapter matrices instead of full model weights. MLX-LM includes built-in support for LoRA and QLoRA (quantized LoRA).

Preparing Training Data

Training data should be in JSONL format with a “text” field:

{"text": "Question: What is the capital of France?\nAnswer: Paris"}
{"text": "Question: What is 2+2?\nAnswer: 4"}

Running LoRA Training

python -m mlx_lm.lora \
    --model mlx-community/gemma-3-4b-it-4bit \
    --train \
    --data ./data \
    --adapter-path ./lora_adapters \
    --lora-r 16 \
    --lora-alpha 32 \
    --iters 500 \
    --batch-size 1 \
    --learning-rate 2e-4

Key parameters:

  • --lora-r: Rank of the LoRA matrices (higher = more capacity, more memory)
  • --lora-alpha: Scaling factor for LoRA updates
  • --lora-layers: Number of layers to apply LoRA (default 16)
  • --batch-size: Training batch size (reduce for memory)
  • --grad-checkpoint: Enable gradient checkpointing for memory savings

QLoRA for Memory Efficiency

When using a 4-bit quantized base model, training automatically uses QLoRA:

# Using a 4-bit model enables QLoRA automatically
python -m mlx_lm.lora \
    --model mlx-community/Llama-3-8B-Instruct-4bit \
    --train \
    --data ./data \
    --adapter-path ./qlora_adapters

Memory comparison for Llama 7B:

MethodMemory Usage
Full fine-tuning~28 GB
LoRA (r=8)~14 GB
QLoRA (4-bit + LoRA)~7 GB

Using Trained Adapters

Generate text with your trained adapter:

from mlx_lm import load, generate

# Load base model with adapter
model, tokenizer = load(
    "mlx-community/gemma-3-4b-it-4bit",
    adapter_path="./lora_adapters"
)

response = generate(model, tokenizer, prompt="Your prompt here")

Merging Adapters

For deployment, merge adapters into the base model:

python -m mlx_lm.fuse \
    --model mlx-community/gemma-3-4b-it-4bit \
    --adapter-path ./lora_adapters \
    --save-path ./merged_model \
    --de-quantize  # Optional: convert back to float16

OpenAI-Compatible API Server

MLX-LM includes a built-in server that exposes an OpenAI-compatible API. This allows any tool designed for OpenAI’s API to work with local models.

Starting the Server

mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit

Using the API

from openai import OpenAI

client = OpenAI(base_url="http://localhost:8080/v1", api_key="not-needed")

response = client.chat.completions.create(
    model="mlx-community/Mistral-7B-Instruct-v0.3-4bit",
    messages=[
        {"role": "user", "content": "What is machine learning?"}
    ]
)

print(response.choices[0].message.content)

Alternative Servers

For production use cases, consider these community servers:

FastMLX: High-performance server with support for vision-language models and efficient resource management. Available at Blaizzy/fastmlx.

vllm-mlx: OpenAI-compatible server with continuous batching, MCP tool calling, and multimodal support. Achieves 400+ tokens/second. Available at GitHub.

mlx-openai-server: FastAPI-based server supporting LLMs, VLMs (via mlx-vlm), image generation (via mflux), and Whisper transcription. Available on PyPI.


The MLX Ecosystem

Beyond MLX core and MLX-LM, a rich ecosystem of packages has developed for various machine learning tasks.

mlx-community on Hugging Face

The mlx-community organization on Hugging Face hosts over 1,000 models converted to MLX format. These include:

  • LLaMA 3.3 and 3.2 variants
  • Qwen 2.5, Qwen3, QwQ
  • Gemma 2 and 3
  • Mistral and Mixtral
  • Phi-3 and Phi-4
  • Whisper models for speech recognition

Models are available in various quantization levels (4-bit, 8-bit, fp16) and can be loaded directly with MLX-LM.

mlx-examples Repository

The ml-explore/mlx-examples repository contains standalone examples including:

Language Models: Transformer training, LLaMA/Mistral text generation, Mixtral 8x7B (MoE), LoRA and QLoRA fine-tuning.

Image Generation: Stable Diffusion and SDXL with text-to-image and image-to-image generation. Supports quantization for memory-constrained systems.

Audio: OpenAI Whisper for speech recognition, Meta’s EnCodec for audio compression, MusicGen for music generation.

Vision-Language: CLIP for joint text-image embeddings, LLaVA for text generation from images.

mlx-vlm

MLX-VLM provides inference and fine-tuning for vision-language models:

from mlx_vlm import load, generate

model, processor = load("mlx-community/llava-1.5-7b-4bit")

output = generate(model, processor, "path/to/image.jpg", "What is in this image?")
print(output)

Supported features:

  • Multi-image analysis
  • Video captioning and summarization
  • LoRA and QLoRA fine-tuning for VLMs
  • Audio and video support with Gemma 3n

whisper-mlx

MLX Whisper provides speech recognition using OpenAI’s Whisper models:

import mlx_whisper

result = mlx_whisper.transcribe("audio.mp3")
print(result["text"])

# With word-level timestamps
result = mlx_whisper.transcribe("audio.mp3", word_timestamps=True)

Pre-converted models are available from the Hugging Face MLX Community. For faster transcription, lightning-whisper-mlx provides optimized implementations.

Stable Diffusion

The mlx-examples repository includes Stable Diffusion implementations:

# Text to image
python txt2image.py "A photo of an astronaut riding a horse on Mars"

# With quantization for 8GB Macs
python txt2image.py --quantize "A sunset over mountains"

# Image to image
python image2image.py --image input.png --strength 0.7 "Make it look like a painting"

Both SD 1.5 and SDXL are supported. Quantization (4-bit text encoder, 8-bit UNet) enables generation on 8GB Macs without swapping.


Moving from CUDA to MLX: A Migration Guide

If you are coming from PyTorch with CUDA, this section covers the key differences and provides practical guidance for porting code.

Prerequisites

Hardware: Any Mac with Apple Silicon (M1, M2, M3, M4, M5 series).

Software:

  • macOS 14.0 or later (macOS 26.2 for M5 Neural Accelerator support)
  • Python 3.10 or later
  • Native ARM Python (not Rosetta x86 emulation)

Verify your Python installation:

python -c "import platform; print(platform.processor())"
# Should print "arm", not "i386"

Installation

# Install MLX core
pip install mlx

# Install MLX-LM for language models
pip install mlx-lm

# Install additional packages as needed
pip install mlx-whisper  # For speech recognition
pip install mlx-vlm      # For vision-language models

Key API Differences

No Device Management: In PyTorch, you explicitly move tensors between devices. In MLX, arrays live in unified memory and you specify the device when calling operations:

# PyTorch
x = torch.tensor([1, 2, 3]).cuda()
y = x + 1  # Runs on CUDA

# MLX
x = mx.array([1, 2, 3])
y = x + 1  # Runs on GPU by default
z = mx.add(x, y, stream=mx.cpu)  # Force CPU

Lazy Evaluation: Operations are not executed immediately:

# PyTorch - computed immediately
y = x + 1

# MLX - creates computation graph
y = x + 1  # Not computed yet
mx.eval(y)  # Now computed
print(y)   # Also triggers computation

Function-Based Gradients: No backward() method or gradient tape:

# PyTorch
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2
y.backward()
print(x.grad)

# MLX
def f(x):
    return x ** 2

grad_f = mx.grad(f)
x = mx.array([1.0])
print(grad_f(x))

Porting Neural Network Code

The mlx.nn module closely follows PyTorch’s API:

# PyTorch
import torch.nn as nn

class PyTorchMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# MLX
import mlx.nn as nn
import mlx.core as mx

class MLXMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)

    def __call__(self, x):
        x = mx.maximum(self.fc1(x), 0)  # ReLU
        return self.fc2(x)

Key differences:

  • Use __call__ instead of forward
  • Use mx.maximum(x, 0) instead of torch.relu(x) (or nn.relu(x))
  • No .cuda() calls needed

Converting Model Weights

When porting a trained PyTorch model, you need to convert the weights:

import torch
import mlx.core as mx
import numpy as np

# Load PyTorch weights
pytorch_weights = torch.load("model.pt", map_location="cpu")

# Convert to MLX format
mlx_weights = {}
for key, value in pytorch_weights.items():
    # Convert to numpy, then to MLX array
    np_array = value.numpy()

    # Handle weight format differences (NCHW -> NHWC for convolutions)
    if "conv" in key and "weight" in key:
        np_array = np.transpose(np_array, (0, 2, 3, 1))

    mlx_weights[key] = mx.array(np_array)

# Save in MLX format
mx.savez("model.npz", **mlx_weights)

For LLMs, the mlx_lm.convert script handles this automatically:

python -m mlx_lm.convert --hf-path meta-llama/Llama-3-8B

Training Loop Pattern

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

model = MLXMLP(784, 256, 10)
optimizer = optim.Adam(learning_rate=1e-3)

def loss_fn(model, x, y):
    logits = model(x)
    return mx.mean(nn.losses.cross_entropy(logits, y))

loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

for epoch in range(num_epochs):
    for x_batch, y_batch in dataloader:
        loss, grads = loss_and_grad_fn(model, x_batch, y_batch)
        optimizer.update(model, grads)
        mx.eval(model.parameters(), optimizer.state)

Common Patterns

Model Evaluation Mode:

# PyTorch
model.eval()

# MLX
model.eval()  # Same API

Saving and Loading:

# Save model weights
model.save_weights("model.safetensors")

# Load model weights
model.load_weights("model.safetensors")

Mixed Precision: MLX handles dtypes explicitly:

# Convert model to float16
model = model.apply(lambda x: x.astype(mx.float16) if x.dtype == mx.float32 else x)

Performance Optimization

Use Compilation

Wrap hot paths in mx.compile():

@mx.compile
def training_step(model, optimizer, x, y):
    loss, grads = loss_and_grad_fn(model, x, y)
    optimizer.update(model, grads)
    return loss

Batch Operations

MLX has lower overhead for small operations than CUDA, but batching still helps:

# Less efficient
for x in data:
    y = model(x)

# More efficient
y = model(mx.stack(data))

Quantization for Inference

Use 4-bit or 8-bit models for inference when possible:

from mlx_lm import load

# 4-bit model uses 75% less memory and runs faster
model, tokenizer = load("mlx-community/Llama-3-8B-Instruct-4bit")

Memory Management

Check memory usage:

info = mx.metal.get_memory_info()
print(f"Allocated: {info['allocated'] / 1e9:.2f} GB")
print(f"Peak: {info['peak'] / 1e9:.2f} GB")

Clear cached memory:

mx.metal.clear_cache()

Profiling

Build MLX with Metal debugging for GPU profiling:

mkdir build && cd build
cmake .. -DMLX_METAL_DEBUG=ON
make -j

Then use Xcode’s GPU profiler to analyze kernel execution.


Debugging Tips

Disable Compilation for Debugging

Compiled functions cannot contain print statements. Disable compilation to debug:

mx.disable_compile()
# or
# MLX_DISABLE_COMPILE=1 python script.py

Force Evaluation

When debugging lazy evaluation issues, force computation:

result = complex_operation(x)
mx.eval(result)  # Force computation
print(result)    # Now safe to inspect

Check Array Properties

x = mx.array([1.0, 2.0, 3.0])
print(f"Shape: {x.shape}")
print(f"Dtype: {x.dtype}")
print(f"Size: {x.size}")

Memory Issues

If you encounter memory errors during training:

  1. Reduce batch size
  2. Enable gradient checkpointing
  3. Use quantized models
  4. Reduce sequence length
  5. Reduce LoRA rank

Practical Example: End-to-End LLM Inference

Here is a complete example for running local LLM inference:

from mlx_lm import load, generate

def main():
    # Load a quantized model
    print("Loading model...")
    model, tokenizer = load("mlx-community/Qwen2.5-7B-Instruct-4bit")

    # System prompt
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Explain how transformers work in 3 sentences."}
    ]

    # Format with chat template
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # Generate
    print("Generating...")
    response = generate(
        model,
        tokenizer,
        prompt=prompt,
        max_tokens=200,
        temp=0.7
    )

    print(response)

if __name__ == "__main__":
    main()

Practical Example: Fine-Tuning a Model

Complete example for fine-tuning with LoRA:

import json
import subprocess

# Step 1: Prepare training data
train_data = [
    {"text": "### Instruction: Translate to French\n### Input: Hello\n### Response: Bonjour"},
    {"text": "### Instruction: Translate to French\n### Input: Goodbye\n### Response: Au revoir"},
    # Add more examples...
]

# Save training data
with open("train.jsonl", "w") as f:
    for item in train_data:
        f.write(json.dumps(item) + "\n")

# Step 2: Run training
cmd = [
    "python", "-m", "mlx_lm.lora",
    "--model", "mlx-community/Mistral-7B-Instruct-v0.3-4bit",
    "--train",
    "--data", ".",
    "--adapter-path", "./adapters",
    "--lora-r", "8",
    "--iters", "100",
    "--batch-size", "1",
    "--learning-rate", "1e-4"
]

subprocess.run(cmd)

# Step 3: Use the fine-tuned model
from mlx_lm import load, generate

model, tokenizer = load(
    "mlx-community/Mistral-7B-Instruct-v0.3-4bit",
    adapter_path="./adapters"
)

prompt = "### Instruction: Translate to French\n### Input: Thank you\n### Response:"
print(generate(model, tokenizer, prompt=prompt, max_tokens=50))

MLX vs PyTorch Performance Summary

Based on benchmarks from late 2025:

AspectMLXPyTorch + CUDA
Setup complexitySimple pip installCUDA toolkit, drivers
Memory efficiencyUnified, no copiesExplicit transfers
Small batch latencyLower overheadHigher kernel launch cost
Large batch throughputCompetitive on M2 Ultra+Faster on RTX 4090
LLM inferenceOptimized for local useMore optimized on server GPUs
Training speedSlower for large modelsFaster for production training
DebuggingStraightforwardMature tooling

MLX is well suited for:

  • Local LLM inference and experimentation
  • Development and prototyping on Mac
  • Memory-constrained scenarios (unified 128GB+)
  • Deploying models to Apple devices

PyTorch + CUDA remains better for:

  • Large-scale training runs
  • Maximum throughput requirements
  • Production server deployment

References