I'm trying to run the following code snippet, whose goal is to examine a pytorch model. It should print a list of layers and tensor dimensions.
import torch
import torchvision
import torchlens as tl
alexnet = torchvision.models.alexnet()
x = torch.rand(1, 3, 224, 224)
model_history = tl.log_forward_pass(alexnet, x, layers_to_save='all', vis_opt='unrolled')
print(model_history)
I get the following stacktrace warning me about parallel execution being used (which is a no go for torchlens):
Traceback (most recent call last):
File "/home/jacob/.venv/lib/python3.13/site-packages/marimo/_runtime/executor.py", line 141, in execute_cell
exec(cell.body, glbls)
~~~~^^^^^^^^^^^^^^^^^^
Cell marimo://lab/test.py#cell=cell-0, line 7, in <module>
model_history = tl.log_forward_pass(alexnet, x, layers_to_save='all', vis_opt='unrolled')
File "/home/jacob/.venv/lib/python3.13/site-packages/torchlens/user_funcs.py", line 138, in log_forward_pass
warn_parallel()
~~~~~~~~~~~~~^^
File "/home/jacob/.venv/lib/python3.13/site-packages/torchlens/helper_funcs.py", line 705, in warn_parallel
raise RuntimeError(
...<3 lines>...
)
RuntimeError: WARNING: It looks like you are using parallel execution; only run pytorch-xray in the main process, since certain operations depend on execution order.
I tried running the same snippet outside of marimo in an ipyhon REPL and there I get no such error and it works fine. It seems to be related to marimo's executor.
Is there perhaps some way to work around this?
Thanks