SolidWorks Videos

Tuesday, April 15, 2025

Muon: An optimizer for the hidden layers of neural networks (XPU version)

     With help from [1], the Muon optimizer now works on Intel XPUs. As of now, tested on Intel Iris Xe that comes bundled with Core i7 1185G7. The purpose of this post is to help hobbyists such as yours truly to be able to use Muon on whatever hardware resources they have available. The version posted on [2] has been hardcoded for CUDA, it seems.

     My interest in this optimizer is to use it to try to solve pressing problems in engineering mechanics such as vortex shedding and fracture of materials, for solid mechanics refer to here and for fluid dynamics, please read this post.

     Muon optimizer also fails to predict vortex shedding. As of writing, SOAP is the best optimizer for training physics informed neural networks. With SOAP, lid-driven cavity (Re = 3200), flow past airfoils, squares and circles (Re = 50), flow past backwards-facing step (Re = 400) have been validated in the field of fluid dynamics. As far as solid mechanics is concerned, the neural network has successfully learnt 3-point bending test, tensile test and compression test on specimens with tremendous success. In a sea of optimizers that claim to perform better that ADAM; at least for PINN training, SOAP works best. It is to be noted that, all my PINNs are without any kind of training data, this is important as finite difference method, read free code here, and finite element methods do not require data and work with boundary conditions only.

     The complex case of vortex shedding is still a mystery. May be a new kind of activation function is required. Constant / learnable sinusoidal and polynomial based activation functions have been used with SOAP by yours truly to no success in predicting vortex shedding, again without training data. With training data, even ADAM is enough.

Code

import torch

from torch import Tensor


def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:

    assert G.ndim >= 2


    a, b, c = (3.4445, -4.7750, 2.0315)


    # Ensure consistent type and device

    X = G.to(dtype=torch.float32, device=G.device)


    if G.size(-2) > G.size(-1):

        X = X.mT


    # Normalize

    norm = X.norm(dim=(-2, -1), keepdim=True) + 1e-7

    X = X / norm


    for _ in range(steps):

        A = X @ X.mT

        B = b * A + c * (A @ A)

        X = a * X + B @ X


    if G.size(-2) > G.size(-1):

        X = X.mT


    return X



class Muon(torch.optim.Optimizer):

    

    def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, nesterov=True, ns_steps=5, rank=None, world_size=None, device=None):

        if device is None:

            device = torch.device("cpu")  # or use "xpu" if always on Intel

            self.device = device

        if (rank is None) or (world_size is None):

            raise Exception("world_size and rank params required, if you want to use this optimizer on a single GPU, pass rank=0 and world_size=1.")

        self.rank = rank

        self.world_size = world_size

        defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)

        params: list[Tensor] = [*params]

        param_groups = []

        for size in {p.numel() for p in params}:

            b = torch.empty(world_size, size, dtype=torch.bfloat16, device=self.device)

            # b = torch.empty(world_size, size, dtype=torch.bfloat16, device="cuda")

            group = dict(params=[p for p in params if p.numel() == size],

                         update_buffer=b, update_buffer_views=[b[i] for i in range(world_size)])

            param_groups.append(group)

        super().__init__(param_groups, defaults)

        

    @torch.no_grad()

    def step(self):

        for group in self.param_groups:

            # Fetch optimizer settings

            params: list[Tensor] = group["params"]

            lr = group["lr"]

            weight_decay = group["weight_decay"]

            momentum = group["momentum"]

            nesterov = group["nesterov"]

            ns_steps = group["ns_steps"]


            for p in params:

                g = p.grad

                if g is None:

                    continue

                

                # Momentum buffer

                state = self.state[p]

                if "momentum_buffer" not in state:

                    state["momentum_buffer"] = torch.zeros_like(g)

                    

                    buf: Tensor = state["momentum_buffer"]

                    buf.lerp_(g, 1 - momentum)

                    

                    g = g.lerp_(buf, momentum) if nesterov else buf

                    

                    # Handle convolutional filters

                    if g.ndim == 4:

                        g = g.view(len(g), -1)


                if g.ndim < 2:

                    # Skip Muon orthogonalization for 0D/1D (handled by AdamW)

                    continue

                

                # Orthogonalize gradient

                g = zeropower_via_newtonschulz5(g, steps=ns_steps).view_as(p)

                

                # Weight decay + update

                p.mul_(1 - lr * weight_decay)

                p.add_(g, alpha=-lr * max(1, p.size(-2) / p.size(-1))**0.5)

Implementation

     At first, define parameters of the network to be optimized by Muon and the other optimizer of choice by using the following statements. Afterwards, select the optimizer parameters.

muon_params = [p for p in model.parameters() if p.requires_grad and p.ndim >= 2] # define parameters to be updated by muon
AdamW_params = [p for p in model.parameters() if p.requires_grad and p.ndim < 2] # define parameters to be updated by AdamW
muon_optimizer = Muon(muon_params, lr = 0.02, momentum = 0.95, rank = 0, world_size = 1)
AdamW_optimizer = optim.AdamW(AdamW_params, lr = 3e-4, betas = (0.90, 0.95), weight_decay = 0.01)
optimizer = [muon_optimizer, soap_optimizer]

     After defining the optimizers and selecting the parameters to be updated using muon and AdamW or any other optimizer of choice, this is what the training loop could like, basically.

for epoch in range(5001):
    for opt in optimizer:
        opt.zero_grad() # clear gradients of all optimized variables
    loss_value = loss_fn(train_points) # compute prediction by passing inputs to model
    loss_value.backward() # compute gradient of loss with respect to model parameters
    for opt in optimizer:
        opt.step() # perform parameter up date
    if epoch % 100 == 0:
        print(f"{epoch} {loss_value.item()}"

References

[1] OpenAI, ChatGPT, April 15, 2025. [Online]. Available: https://chat.openai.com/

[2] https://github.com/KellerJordan/Muon