Get help from the marimo community

Updated 4 weeks ago

Torchlens vs Marimo Parallel Execution

At a glance

The community member is trying to run a code snippet that examines a PyTorch model, but is encountering a warning about parallel execution being used, which is not compatible with the torchlens library. The community member tried running the same code outside of the Marimo environment and it worked fine, suggesting the issue is related to Marimo's executor.

In the comments, another community member suggests a potential workaround by setting the current process name to "MainProcess" using multiprocessing.current_process().name = "MainProcess". Additionally, the community member mentions that they realized they could use torchview instead to achieve the desired functionality, and it works well from Marimo.

<answer>The suggested workaround is to set the current process name to "MainProcess" using <code>multiprocessing.current_process().name = "MainProcess"</code>.</answer>
I'm trying to run the following code snippet, whose goal is to examine a pytorch model. It should print a list of layers and tensor dimensions.

Plain Text
import torch
import torchvision
import torchlens as tl

alexnet = torchvision.models.alexnet()
x = torch.rand(1, 3, 224, 224)
model_history = tl.log_forward_pass(alexnet, x, layers_to_save='all', vis_opt='unrolled')
print(model_history)

I get the following stacktrace warning me about parallel execution being used (which is a no go for torchlens):

Plain Text
Traceback (most recent call last):
  File "/home/jacob/.venv/lib/python3.13/site-packages/marimo/_runtime/executor.py", line 141, in execute_cell
    exec(cell.body, glbls)
    ~~~~^^^^^^^^^^^^^^^^^^
  Cell marimo://lab/test.py#cell=cell-0, line 7, in <module>
    model_history = tl.log_forward_pass(alexnet, x, layers_to_save='all', vis_opt='unrolled')
  File "/home/jacob/.venv/lib/python3.13/site-packages/torchlens/user_funcs.py", line 138, in log_forward_pass
    warn_parallel()
    ~~~~~~~~~~~~~^^
  File "/home/jacob/.venv/lib/python3.13/site-packages/torchlens/helper_funcs.py", line 705, in warn_parallel
    raise RuntimeError(
    ...<3 lines>...
    )
RuntimeError: WARNING: It looks like you are using parallel execution; only run pytorch-xray in the main process, since certain operations depend on execution order.

I tried running the same snippet outside of marimo in an ipyhon REPL and there I get no such error and it works fine. It seems to be related to marimo's executor.

Is there perhaps some way to work around this?
Thanks
Marked as solution
torchlens uses a pretty brittle check for the process name. We could fix this upstream by forcing all execution threads to be called "MainProcess", but not sure potential consequences.

In the meantime, here's a hack. Try:

Plain Text
import multiprocessing
multiprocessing.current_process().name = "MainProcess"
View full solution
d
J
2 comments
torchlens uses a pretty brittle check for the process name. We could fix this upstream by forcing all execution threads to be called "MainProcess", but not sure potential consequences.

In the meantime, here's a hack. Try:

Plain Text
import multiprocessing
multiprocessing.current_process().name = "MainProcess"
Thanks for the prompt feedback! I'll keep this in mind.
I realized I could use torchview instead to get what I want and it works great from Marimo.
Add a reply
Sign up and join the conversation on Discord