mlx: Issue encountered in solving 2D Heat Equation with needing mx.eval to avoid segmentation fault

I have implemented a simple solution of the 2D Heat Conduction Equation with 2 Neumann and 2 Dirichlet BCs. I have the code implemented both using PyTorch and the MLX framework and I am testing the relative performance on an M2 Ultra with 128GB memory.

The MLX code is included below. So far, performance in various tests (on the same machine) show the MLX version to be somewhere between X2 and X10 faster depending on the problem size.

However, I have an issue that I need to understand. Depending on the problem size, I need to include the line

if step % 15000== 0: mx.eval(T)

to avoid segmentation fault. I imagine this has to do with the lazy evaluation and arrays being in buffer? My issue is that currently I figure each time how often I need to mx.eval empirically. Is there some programmatic and more elegant way to automatically issue the mx.eval at the right frequency based on the problem size?

Here is the complete code below. Thank you for all your help @awni !

# Solving the 2D Heat Conduction Equation with 2 Neumann and 2 Dirichlet PCs
import numpy as np
import matplotlib.pyplot as plt
import time
import mlx.core as mx

# Convergence tolerance to stop early (currently disabled)
#convergence_tolerance = 1e-8

# Grid size and material properties setup
nx, ny = 5000, 5000  # Set grid dimensions
k = 1.0              # Thermal conductivity

# Time-stepping parameters
desired_dt = 0.01  # Desired time step
max_steps = 10000 # Maximum number of time steps

# Creating a linearly spaced grid
x = mx.array(np.linspace(0,1,nx))
y = mx.array(np.linspace(0,1,ny))
dx = x[1] - x[0]   # Grid spacing in x direction
dy = y[1] - y[0]   # Grid spacing in y direction

# Function to calculate the maximum stable time step for the explicit Euler method
def calculate_max_stable_dt(alpha, dx, dy):
    return (1 / (2 * alpha)) * (1 / (1/dx**2 + 1/dy**2))

# Material properties for stability calculation
rho = 1.0  # Density
cp = 1.0   # Specific heat capacity
alpha = k / (rho * cp)  # Thermal diffusivity

# Compute maximum stable time step
dt_max = calculate_max_stable_dt(alpha, dx, dy)
dt = min(dt_max, desired_dt)  # Use the smaller of the desired or maximum stable time step

# Initializing the temperature field on the GPU
T = mx.zeros([nx, ny])
T_old = mx.zeros_like(T)

# Applying Dirichlet boundary conditions
T[:, 0] = 0.0   # Set left boundary temperature
T[:, -1] = 1.0  # Set right boundary temperature

# Time-stepping loop for the heat equation

start_time = time.time()  # Capture start time
for step in range(max_steps):
    T_old = mx.broadcast_to(T,shape=T.shape)

    # Update interior points using finite difference method
    # Pad the interior points for broadcasting
    T =  mx.pad(mx.pad( (T_old[1:-1,1:-1] + dt * k * (
        (T_old[2:, 1:-1] - 2 * T_old[1:-1, 1:-1] + T_old[:-2, 1:-1]) / dx**2 +
        (T_old[1:-1, 2:] - 2 * T_old[1:-1, 1:-1] + T_old[1:-1, :-2]) / dy**2
    )), ((0,0),(0,1)),1),((0,0),(1,0)), 0)

    # Update Neumann boundaries (zero-flux) at top and bottom
    T = mx.concatenate([mx.expand_dims(T[0, :], (0)), T, mx.expand_dims(T[-1, :], (-0))], axis=0)

    if step % 15000== 0:
        mx.eval(T)

end_time = time.time()  # Capture end time
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

#  Visualizing the temperature field using matplotlib
plt.imshow(T, cmap='hot', interpolation='nearest')
plt.colorbar()  # Add a color bar to indicate temperature scales
plt.show()

About this issue

  • Original URL
  • State: open
  • Created 7 months ago
  • Comments: 23 (9 by maintainers)

Most upvoted comments

it might be useful to have a section of examples on engineering/physics applications

That’s a great idea! We could potentially put a physics/ directory in the MLX examples repo, or make something standalone. Since I don’t know much about those use cases, if you’re are interested in this, maybe a good starting point is to make your own repo with some of these examples + comments / explanations, and we can go from there.

To add to the conversation about conv2D - it’s a process of creating and optimizing many specialization Even currently, the shapes that our Winograd conv supports will run faster than alternatives It’s work in progress to get a wider range of shapes to good speeds and it is no small priority for us!

Does this mean that MLX is faster in creating the comp graphs but MLX conv2D is slower when it is time to do the actual computation?

Exactly right.

  1. Great it would be awesome to have a fast conv2d, but I do realize how full your hands are right now, it is actually amazing how fast things are progressing @jagrit06

  2. I forked the repo and will try to create a library of engineering/physics examples of increasing complexity and once I have sufficient material in place you could review and decide whether to merge or keep it separate.

It looks like we have some work to for those shapes for our conv2D. I think your implementation has much more potential to be fast. If we can make the conv comparable to torch it should really sing! In the meantime, bear with us…we’re a small team with a lot to do. But things will only get much better from here!

I wrote a little benchmark and for the sizes in your problem we are > 10x slower than Torch’s conv2d. CC @jagrit06 who will hopefully have some time to work on this in the near future.

import numpy as np
import torch.nn.functional
import mlx.core as mx
import time

### Time Torch
T = np.random.randn(1000, 1000).astype(np.float32)
W = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], np.float32)

T = T.reshape(1, 1, 1000, 1000)
W = W.reshape(1, 1, 3, 3)

device = torch.device("mps")
T = torch.tensor(T).to(device)
W = torch.tensor(W).to(device)

for _ in range(5):
    T = torch.nn.functional.conv2d(T, W, padding="same")
torch.mps.synchronize()

tic = time.time()
for _ in range(100):
    T = torch.nn.functional.conv2d(T, W, padding="same")
torch.mps.synchronize()
toc = time.time()

print(f"Torch: {toc - tic}")

### Time MLX
T = np.random.randn(1000, 1000).astype(np.float32).reshape(1, 1000, 1000, 1)
W = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], np.float32).reshape(1, 3, 3, 1)
T = mx.array(T)
W = mx.array(W)

for _ in range(5):
    T = mx.conv2d(T, W, padding=(1, 1))
mx.eval(T)

tic = time.time()
for _ in range(100):
    T = mx.conv2d(T, W, padding=(1, 1))
mx.eval(T)
toc = time.time()
print(f"MLX: {toc - tic}")

I am able to reproduce the segfault on my machine.

Current (and likely) hypothesis is that the segfault is a result of a stack overflow during the eval when we recurse on the inputs to build the compute graph. It definitely makes sense that it depends on the size of the graph…

For now, I would recommend just doing a fixed amount of compute per eval, so adding something like:

if step % 1000 == 0:
    mx.eval(T_old)

I am not sure if we should fix this… we could try using an iterative graph construction to compute the graph, but at the same time I don’t recommend letting graphs get so big. It’s usually a sign that eval should be used more frequently and/or the number of operations should be reduced.

Another comment:

You should see if it’s possible to replace the whole inner computation with something like a convolution or matmul with the appropriate kernel. It would dramatically reduce the number of operations which in this case would speed things up substantially. I think the whole update from T_old to T is linear so it should be very doable with the right operation.

One more observation that might provide a clue. These two slightly different implementations (with the problem size nx, ny kept the same between them , require different frequencies of mx.eval(T_old) to avoid segfault. The difference between the two is that in the first case mx.pad is used to add the top and bottom rows and then assignment is made for the Neumann BCs, while in the second, the rows are added and assigned values via mx.concatenate

**Implementation 1: eval every 21000 steps **

    if step % 21000 == 0:
        mx.eval(T_old)

    T_mid = T_old[1:-1, 1:-1]

    T = mx.pad(mx.pad((T_mid + dt * k * (
            (T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1]) / dxsq +
            (T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2]) / dysq
    )), ((1, 1), (0, 1)), 1), ((0, 0), (1, 0)), 0)

    # Update Neumann boundaries (zero-flux) at top and bottom
    T[0, :] = T[1, :]
    T[-1, :] = T[-2, :]

**Implementation 2: eval every 16000 steps **

    if step % 16000 == 0:
        mx.eval(T_old)

    T_mid = T_old[1:-1, 1:-1]

    T = mx.pad(mx.pad((T_mid + dt * k * (
            (T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1]) / dxsq +
            (T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2]) / dysq
    )), ((0, 0), (0, 1)), 1), ((0, 0), (1, 0)), 0)

    # Update Neumann boundaries (zero-flux) at top and bottom
    T = mx.concatenate([mx.expand_dims(T[0, :], (0)), T, mx.expand_dims(T[-1, :], (-0))], axis=0)

Following your remark, I tested again T_old = mx.broadcast_to(T,shape=T.shape) and T_old=T and they are about the same (within load noise). I must have been misled by some other change I made in parallel before. So, this does not be to checked. Sorry for the wrong claim 🫣

It runs orders of magnitude faster (438 sec vs 2.37 sec for 100K steps) !!! And whatever the bug is, it’s not surfacing any more. I’m thrilled with the speedup 👍

Ok knowing that, I will try to do some debugging and see if I can provide some additional feedback.