PyTorch comes with a rich set of built-in optimizers like SGD, Adam, and RMSProp — more than enough for most deep learning tasks. But what if you’re reading a cutting-edge paper with a brand-new optimizer? Or maybe you’ve come up with your own twist on an existing algorithm and want to see how it performs. That’s where writing your own optimizer comes in. In this hands-on tutorial, we’ll demystify the process by walking through how to create a custom optimizer in PyTorch. As a concrete example, we’ll reimplement the Adam optimizer — step by step — so you can understand exactly how everything fits together under the hood.

Before we dive in, I’ll assume you have a basic understanding of gradient-based optimization, basic PyTorch structure, and how Python classes work. While brief refreshers are included for context, this tutorial builds on those core concepts. Our focus will be on machine learning models, but the techniques apply to optimizing any differentiable function.

1st order optimization

First order optimization is a category of algorithms that is used to find some set of unknown parameters. Usually in most cases you are trying to find a set of parameters that minimize some sort of error metric function (loss function). These algorithms rely on the gradient of these functions at certain points to estimate a descent direction. This process done iteratively until a minima is hit, meaning we have found our parameters (more or less). The classical known algorithm that does this is Gradient Descent (GD).

GD is defined by the following update rule

θ here would be our unknown parameters, μ is our stepsize and ∇f(θ) is the gradient of our loss function with respect to the parameters. Let us now ground this equation in the PyTorch training loop.

Standard Pytorch training loop

The following snippet of code is practically the most important section of your optimization procedure, it is where all the magic happens. I have labeled all the main steps of the process from 0–5.

....
# Loss Function
criterion = nn.CrossEntropyLoss()

# --- Initialize Optimizer ---:
# 0. Init. the optimizer with model parameters and the learning rate as input
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
# Loop over the training dataset for one epoch
for i, (images, labels) in enumerate(train_loader):
images = images.reshape(-1, input_size).to(device)
labels = labels.to(device)

# --- Forward Pass ---
# 1. Calculate the loss
outputs = model(images)
loss = criterion(outputs, labels)

# --- Backward Pass ---
# 2. Zero out the gradients from the previous step
optimizer.zero_grad()
# 3. Perform backpropagation to compute gradients
loss.backward()
# 4. Update the model's weights
optimizer.step()
....

Following this we then have our standard training loop.

With all of this in mind, building a custom optimizer essentially boils down to editing just two key parts: the __init__() function and the .step() function. By overriding these two methods, you can introduce virtually any functionality you want.

Now, the fun part — and the main point of this post — is how exactly we go about editing these functions. PyTorch simplifies this task greatly: we simply inherit from the base class torch.optim.Optimizer and then override its __init__() and .step() methods. As promised I will take you through the steps of overriding both these functions to show you how we can implement our custom Adam optimizer.

Before diving into the implementation details, let’s quickly define the Adam update step.

The Adam optimizer

Here things are a bit more complicated. We need to calculate the first and second moment estimates of the gradients, mₜ₊₁ and vₜ₊₁, respectively. After which we need to compute the biased correct terms of our given estimates, m̂ₜ₊₁ and v̂ₜ₊₁. And finally peform our update step θₜ - (μ m̂ₜ₊₁/v̂ₜ₊₁ + ε).

Let’s begin by defining our custom optimizer class, CustomAdam. To do this, we'll inherit from PyTorch's standard torch.optim.Optimizer class and override its __init__() method.

__init__()

class CustomAdam(torch.optim.Optimizer):
"""
Custom Adam class
"""
# Define standard Hyperparameters
def __init__(self, params, lr=0.01, beta1 =0.9, beta2 = 0.999, eps=1e-8, weight_decay=0):

# invalid input errors
if lr<= 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if eps<= 0.0:
raise ValueError(f"eps needs to be a positive value")

# define a dictionary for hyperparameters
defaults = dict(lr=lr, beta1=beta1, beta2=beta2 , eps=eps, weight_decay=weight_decay)
super(CustomAdam, self).__init__(params, defaults)
...

Note that the __init__() method is what python calls internally once you initialize any class. But notice, in our case __init__() is not only used to initialize our optimizer class instance but also defines its core functionality. For instance, it allows us to specify the optimizer's hyperparameters (e.g., lr, betas, eps, weight_decay) and set default values if no parameters are explicitly provided. Additionally, we can handle ill-conditioned input values here, such as non-positive or zero-valued learning rates or epsilons.

Inside __init__(), we typically define a dictionary called defaults. This dictionary conveniently stores our hyperparameters, making them easy to access later in the .step() function or other internal methods we might create. PyTorch uses this defaults dictionary internally for managing param_groups (which we'll discuss shortly).

With this foundation laid, we can now define the .step() method, which will handle the main computational logic.

step()

Recall that the .step() method is invoked at the end of each training iteration, immediately following the .backward() call. This means that at this stage, all parameters have their corresponding gradients computed and available.

For gradient descent (GD), the .step() method would typically loop through each parameter, updating it by subtracting the product of the step size and the gradient from its current value. Let's briefly illustrate how this update looks in practice for standard GD.

...
@torch.no_grad()
def step(self):
"""
Performs a single optimization step for GD.
"""
# Iterate through each group
for group in self.param_groups:
lr = group['lr']
# Iterate through each individual parameter in the group.
for p in group['params']:
if p.grad is None:
continue
# GD update rule:
# new_weight = old_weight - learning_rate * gradient
p.add_(-lr*p.grad)

PyTorch organizes parameters into individual “groups”. For all intents and purposes, all you need to know is that a “group” is a Python dictionary that bundles together a set of model parameters with their specific hyperparameters (e.g., learning rate) using defaults. For example, your normal base model parameters and your bias parameters will each reside in their own respective “group”, each potentially using distinct hyperparameters. This grouping allows great flexibility, as you can specify different hyperparameter values for each group by setting them individually in the defaults dictionary (we will not do this here).

The collection of all these groups is self.param_groups which is a list of these dictionaries. Since, in this tutorial, we’re using identical hyperparameters across all groups, group['lr'] remains consistent.

To loop through and update each model parameter, we first need to iterate through its collection of parameter groups self.param_groups. After which we can access the individual parameters of that group, via group['params']. To perform the actual updates we loop through group['params'] and apply the updates directly using the .add_() function to each individual parameter. This function efficiently modifies parameters in-place, ensuring fast updates without additional memory overhead.

It’s worth noting that both parameter grouping and the defaults dictionary are PyTorch internals. While the grouping mechanism provides substantial flexibility, it’s often sufficient to use the same hyperparameters across all parameters, in which case the defaults dictionary might seem unnecessary. You technically could save hyperparameters directly as instance variables instead (e.g., self.lr = lr) and avoid using defaults. However, we’ll stick to PyTorch’s recommended structure in this tutorial for clarity and consistency.

Adam step()

With this we can now define the .step() method for Adam

    ...
@torch.no_grad()
def step(self):
"""Performs a single adam optimization step."""

for group in self.param_groups:
lr = group['lr']
beta1 = group['beta1']
beta2 = group['beta2']
eps = group['eps']
weight_decay = group['weight_decay']

for p in group['params']:
if p.grad is None:
continue

state = self.state[p]

# State Initialization to save parameter information over iterations
if len(state) == 0:
state['step'] = 0
# m_t: Biased first moment estimate (like momentum)
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# v_t: Biased second raw moment estimate (adaptive learning rate)
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

# Get state variables for the current parameter
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

# Increment step counter t
state['step'] += 1
t = state['step']


if weight_decay != 0:
p.mul_(1 - lr * weight_decay)

# --- Core Adam Logic ---

# 1. Update the biased estimates for m_t and v_t
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2)

# 2. Calculate the NUMERATOR: The bias-corrected first moment (m_hat_t)
bias_correction1 = 1 - beta1 ** t
numerator = exp_avg / bias_correction1

# 3. Calculate the DENOMINATOR: The bias-corrected second moment (sqrt(v_hat_t)) + eps
bias_correction2 = 1 - beta2 ** t
# First, get v_hat_t
v_hat_t = exp_avg_sq / bias_correction2
# Then, calculate the full denominator
denominator = v_hat_t.sqrt().add(eps)

# 4. Calculate the final update amount
update_step = numerator / denominator

# 5. Apply the final update to the parameter
# new_weight = old_weight - lr * (numerator / denominator)
p.add_(-lr*update_step)

Adam requires maintaining both the first moment (mean of gradients) and second moment (uncentered variance of gradients) estimates for each parameter in order to compute adaptive update steps. In PyTorch, this is managed through an internal state dictionary that is automatically created for each parameter the first time it is seen by the optimizer. This dictionary allows us to store persistent, per-parameter information across iterations.

Specifically, for Adam, we typically initialize and update three entries in the state dictionary for each parameter: 'step': the iteration counter, 'exp_avg': the exponential moving average of past gradients (first moment), 'exp_avg_sq': the exponential moving average of squared gradients (second moment). These values are continuously updated during training to compute the bias-corrected parameter updates.

With this setup, everything becomes pretty straightforward. All that’s left is to compute the Adam update step and apply it to our parameters. These computations are neatly handled in steps 1–5 within the # – core Adam logic section – .

One thing to quickly note is how the calculations are performed. We use functions like .add_(), .mul_(), and .addcmul_(), which may look unusual at first. These are in-place operations, meaning they modify the tensor directly instead of creating a new one. This approach is more memory- and compute-efficient compared to writing expressions like exp_avg = beta1 * exp_avg + (1 - beta1) * p.grad, which create new tensors at each step. But you can use either.

And that’s it! With this, we’ve successfully implemented our own custom Adam optimizer.

Putting everything together

class CustomAdam(torch.optim.Optimizer):
"""
Custom Adam class
"""
# Define standard Hyperparameters
def __init__(self, params, lr=0.01, beta1 =0.9, beta2 = 0.999, eps=1e-8, weight_decay=0):

# invalid input errors
if lr<= 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if eps<= 0.0:
raise ValueError(f"eps needs to be a positive value")

# define a dictionary for hyperparameters
defaults = dict(lr=lr, beta1=beta1, beta2=beta2 , eps=eps, weight_decay=weight_decay)
super(CustomAdam, self).__init__(params, defaults)

@torch.no_grad()
def step(self):
"""Performs a single adam optimization step."""

for group in self.param_groups:
lr = group['lr']
beta1 = group['beta1']
beta2 = group['beta2']
eps = group['eps']
weight_decay = group['weight_decay']

for p in group['params']:
if p.grad is None:
continue

state = self.state[p]

# State Initialization to save parameter information over iterations
if len(state) == 0:
state['step'] = 0
# m_t: Biased first moment estimate (like momentum)
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# v_t: Biased second raw moment estimate (adaptive learning rate)
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

# Get state variables for the current parameter
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

# Increment step counter t
state['step'] += 1
t = state['step']


if weight_decay != 0:
p.mul_(1 - lr * weight_decay)

# --- Core Adam Logic ---

# 1. Update the biased estimates for m_t and v_t
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2)

# 2. Calculate the NUMERATOR: The bias-corrected first moment (m_hat_t)
bias_correction1 = 1 - beta1 ** t
numerator = exp_avg / bias_correction1

# 3. Calculate the DENOMINATOR: The bias-corrected second moment (sqrt(v_hat_t)) + eps
bias_correction2 = 1 - beta2 ** t
# First, get v_hat_t
v_hat_t = exp_avg_sq / bias_correction2
# Then, calculate the full denominator
denominator = v_hat_t.sqrt().add(eps)

# 4. Calculate the final update amount
update_step = numerator / denominator

# 5. Apply the final update to the parameter
# new_weight = old_weight - lr * (numerator / denominator)
p.add_(-lr*update_step)

Before wrapping up, it’s worth noting that what I’ve shown here isn’t the only way to implement a custom optimizer, but it is PyTorch’s recommended structure. For instance, there are multiple ways to access model parameters; you don’t necessarily have to use group['params']. Similarly, you’re not required to use the defaults dictionary, you can store hyperparameters directly as instance variables if you prefer.

The state dictionary is also optional. If you'd rather not use it, you can manage the first and second moment estimates with your own custom lists or dictionaries, initialized in the __init__() method. Additionally, you're free to modularize the logic however you like. For example, instead of placing all the Adam update steps directly inside .step(), you could define a separate function like AdamUpdate() and call it from within .step().

The beauty of PyTorch is the flexibility it offers. But with that flexibility comes a lot of responsibility. Inefficient implementations can quickly become performance bottlenecks during training. I have inadvertently turned training runs that should have taken minutes into hours. So while it’s great to experiment, always keep an eye on what you’re actually doing.

Minimal working example

Let’s now put our customAdam optimizer to the test by using it to train a small model and comparing its performance with PyTorch’s built-in Adam optimizer. If we did everything correctly, then theoretically performance should be exactly equal.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import copy
import random
import numpy as np

class CustomAdam(torch.optim.Optimizer):
"""
Custom Adam class
"""
# Define standard Hyperparameters
def __init__(self, params, lr=0.01, beta1 =0.9, beta2 = 0.999, eps=1e-8, weight_decay=0):

# invalid input errors
if lr<= 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if eps<= 0.0:
raise ValueError(f"eps needs to be a positive value")

# define a dictionary for hyperparameters
defaults = dict(lr=lr, beta1=beta1, beta2=beta2 , eps=eps, weight_decay=weight_decay)
super(CustomAdam, self).__init__(params, defaults)

@torch.no_grad()
def step(self):
"""Performs a single adam optimization step."""

for group in self.param_groups:
lr = group['lr']
beta1 = group['beta1']
beta2 = group['beta2']
eps = group['eps']
weight_decay = group['weight_decay']

for p in group['params']:
if p.grad is None:
continue

state = self.state[p]

# State Initialization to save parameter information over iterations
if len(state) == 0:
state['step'] = 0
# m_t: Biased first moment estimate (like momentum)
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# v_t: Biased second raw moment estimate (adaptive learning rate)
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

# Get state variables for the current parameter
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

# Increment step counter t
state['step'] += 1
t = state['step']


if weight_decay != 0:
p.mul_(1 - lr * weight_decay)

# --- Core Adam Logic ---

# 1. Update the biased estimates for m_t and v_t
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2)

# 2. Calculate the NUMERATOR: The bias-corrected first moment (m_hat_t)
bias_correction1 = 1 - beta1 ** t
numerator = exp_avg / bias_correction1

# 3. Calculate the DENOMINATOR: The bias-corrected second moment (sqrt(v_hat_t)) + eps
bias_correction2 = 1 - beta2 ** t
# First, get v_hat_t
v_hat_t = exp_avg_sq / bias_correction2
# Then, calculate the full denominator
denominator = v_hat_t.sqrt().add(eps)

# 4. Calculate the final update amount
update_step = numerator / denominator

# 5. Apply the final update to the parameter
# new_weight = old_weight - lr * (numerator / denominator)
p.add_(-lr*update_step)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set a seed so both models start with the same random weights
# --- Ensure Full Reproducibility ---
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

# for GPU reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)

subset_indices = list(range(int(len(trainset))))
trainset_subset = torch.utils.data.Subset(trainset, subset_indices)
trainloader = torch.utils.data.DataLoader(trainset_subset, batch_size=64,
shuffle=False, num_workers=2)

# --- Model ---
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 5)
self.fc1 = nn.Linear(32 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 32 * 5 * 5)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x

# --- Create two identical models ---
model_custom = SimpleCNN().to(device)
model_pytorch = copy.deepcopy(model_custom).to(device)


LEARNING_RATE = 0.001
EPOCHS = 10
criterion = nn.CrossEntropyLoss()

# --- Instantiate Optimizers ---
optimizer_custom = CustomAdam(model_custom.parameters(), lr=LEARNING_RATE)
optimizer_pytorch = optim.Adam(model_pytorch.parameters(), lr=LEARNING_RATE)


def train_model(model, optimizer, model_name):
"""A helper function to train a given model."""
print(f"--- Training {model_name} ---")
start_time = time.time()
losses = []
model.train() # Set model to training mode
for epoch in range(EPOCHS):
running_loss = 0.0
for inputs, labels in trainloader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()

avg_loss = running_loss / len(trainloader)
losses.append(avg_loss)
print(f'Epoch [{epoch + 1}/{EPOCHS}], Loss: {avg_loss:.4f}')

end_time = time.time()
print(f'Finished training {model_name}. Total time: {end_time - start_time:.2f}s\n')
return losses


custom_adam_losses = train_model(model_custom, optimizer_custom, "Custom Adam")
pytorch_adam_losses = train_model(model_pytorch, optimizer_pytorch, "PyTorch Adam")


plt.figure(figsize=(10, 6))
plt.plot(range(1, EPOCHS + 1), custom_adam_losses, marker='o', linestyle='-', label='Custom Adam')
plt.plot(range(1, EPOCHS + 1), pytorch_adam_losses, marker='x', linestyle='--', label='PyTorch optim.Adam')
plt.title('Optimizer Comparison: Training Loss on CIFAR-10')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.legend()
plt.grid(True)
plt.xticks(range(1, EPOCHS + 1))

Comparing the training loss of our custom optimizer with PyTorch’s implementation shows that they are effectively identical. However, keep in mind that exact reproducibility isn’t always guaranteed, as deep learning pipelines, particularly those involving GPUs, often introduce variability.

Summary

In this tutorial, we walked through a step-by-step guide for building custom optimizers in PyTorch. PyTorch makes this process straightforward by allowing you to inherit from the torch.optim.Optimizer class and override its two key methods: __init__() and step(). By customizing these methods—and using features like parameter groups and the state dictionary—you can create a wide range of optimizers tailored to your specific needs.