With help from [1], the Muon optimizer now works on Intel XPUs. As of now, tested on Intel Iris Xe that comes bundled with Core i7 1185G7. The purpose of this post is to help hobbyists such as yours truly to be able to use Muon on whatever hardware resources they have available. The version posted on [2] has been hardcoded for CUDA, it seems.
My interest in this optimizer is to use it to try to solve pressing problems in engineering mechanics such as vortex shedding and fracture of materials, for solid mechanics refer to here and for fluid dynamics, please read this post.
Muon optimizer also fails to predict vortex shedding. As of writing, SOAP is the best optimizer for training physics informed neural networks. With SOAP, lid-driven cavity (Re = 3200), flow past airfoils, squares and circles (Re = 50), flow past backwards-facing step (Re = 400) have been validated in the field of fluid dynamics. As far as solid mechanics is concerned, the neural network has successfully learnt 3-point bending test, tensile test and compression test on specimens with tremendous success. In a sea of optimizers that claim to perform better that ADAM; at least for PINN training, SOAP works best. It is to be noted that, all my PINNs are without any kind of training data, this is important as finite difference method, read free code here, and finite element methods do not require data and work with boundary conditions only.
The complex case of vortex shedding is still a mystery. May be a new kind of activation function is required. Constant / learnable sinusoidal and polynomial based activation functions have been used with SOAP by yours truly to no success in predicting vortex shedding, again without training data. With training data, even ADAM is enough.
Code
import torch
from torch import Tensor
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
# Ensure consistent type and device
X = G.to(dtype=torch.float32, device=G.device)
if G.size(-2) > G.size(-1):
X = X.mT
# Normalize
norm = X.norm(dim=(-2, -1), keepdim=True) + 1e-7
X = X / norm
for _ in range(steps):
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
class Muon(torch.optim.Optimizer):
def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, nesterov=True, ns_steps=5, rank=None, world_size=None, device=None):
if device is None:
device = torch.device("cpu") # or use "xpu" if always on Intel
self.device = device
if (rank is None) or (world_size is None):
raise Exception("world_size and rank params required, if you want to use this optimizer on a single GPU, pass rank=0 and world_size=1.")
self.rank = rank
self.world_size = world_size
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
params: list[Tensor] = [*params]
param_groups = []
for size in {p.numel() for p in params}:
b = torch.empty(world_size, size, dtype=torch.bfloat16, device=self.device)
# b = torch.empty(world_size, size, dtype=torch.bfloat16, device="cuda")
group = dict(params=[p for p in params if p.numel() == size],
update_buffer=b, update_buffer_views=[b[i] for i in range(world_size)])
param_groups.append(group)
super().__init__(param_groups, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
# Fetch optimizer settings
params: list[Tensor] = group["params"]
lr = group["lr"]
weight_decay = group["weight_decay"]
momentum = group["momentum"]
nesterov = group["nesterov"]
ns_steps = group["ns_steps"]
for p in params:
g = p.grad
if g is None:
continue
# Momentum buffer
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1 - momentum)
g = g.lerp_(buf, momentum) if nesterov else buf
# Handle convolutional filters
if g.ndim == 4:
g = g.view(len(g), -1)
if g.ndim < 2:
# Skip Muon orthogonalization for 0D/1D (handled by AdamW)
continue
# Orthogonalize gradient
g = zeropower_via_newtonschulz5(g, steps=ns_steps).view_as(p)
# Weight decay + update
p.mul_(1 - lr * weight_decay)
p.add_(g, alpha=-lr * max(1, p.size(-2) / p.size(-1))**0.5)
Implementation
References
[1] OpenAI, ChatGPT, April 15, 2025. [Online]. Available: https://chat.openai.com/
[2] https://github.com/KellerJordan/Muon