Cannot Install JAX
Hello, I am currently unable to properly install JAX on both the A100 SXM 80GB and the H100 80GB SXM5 in the Secure Cloud. When I run the command
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
I get the following error (partially shown) that there are dependency conflicts with torch:
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.1.1 requires nvidia-cublas-cu12==12.1.3.1; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cublas-cu12 12.4.2.65 which is incompatible.
torch 2.1.1 requires nvidia-cuda-cupti-cu12==12.1.105; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cuda-cupti-cu12 12.4.99 which is incompatible.
I do not need believe I need torch for my program. Is there a way to either remove or edit these torch requirements? Also, I have been able to successfully install JAX over the last few weeks up until this morning.
Any help would be greatly appreciated!
1 Reply
Best to ask on the Github repo/project page.