测试环境:
anaconda3 + python3.10
pip list
Package Version ------------------ ------------ attrs 24.3.0 Automat 24.8.1 buildtools 1.0.6 causal-conv1d 1.1.1 certifi 2024.12.14 cffi 1.15.0 charset-normalizer 3.4.1 colorama 0.4.6 constantly 23.10.4 docopt 0.6.2 einops 0.8.0 filelock 3.16.1 fsspec 2024.12.0 furl 2.1.3 greenlet 3.1.1 huggingface-hub 0.27.0 hyperlink 21.0.0 idna 3.10 incremental 24.7.2 Jinja2 3.1.5 mamba_ssm 1.1.3 MarkupSafe 3.0.2 mpmath 1.3.0 networkx 3.4.2 ninja 1.11.1.3 numpy 1.24.1 orderedmultidict 1.0.1 packaging 24.2 pillow 11.0.0 pip 24.2 pycparser 2.22 python-dateutil 2.9.0.post0 PyYAML 6.0.2 redo 3.0.0 regex 2024.11.6 requests 2.32.3 safetensors 0.4.5 setuptools 68.2.2 simplejson 3.19.3 six 1.17.0 SQLAlchemy 2.0.36 sympy 1.13.3 tokenizers 0.21.0 tomli 2.2.1 torch 2.1.1+cu118 torchaudio 2.1.1+cu118 torchvision 0.16.1+cu118 tqdm 4.67.1 transformers 4.47.1 triton 2.1.0 Twisted 24.11.0 typing_extensions 4.12.2 urllib3 2.3.0 wheel 0.44.0 zope.interface 7.2
测试代码:
import torch from mamba_ssm import Mambabatch, length, dim = 2, 64, 16 x = torch.randn(batch, length, dim).to("cuda")
model = Mamba( d_model=dim, d_state=16, d_conv=4, expand=2 ).to("cuda")
y = model(x) assert y.shape == x.shape print('success')
运行结果:
立即学习“Python免费学习笔记(深入)”;
![[python]windows安装mamba后测试代码](https://img.php.cn/upload/article/001/503/042/175832761373541.jpg)











