AndrewL
Cannot Install JAX
Hello, I am currently unable to properly install JAX on both the A100 SXM 80GB and the H100 80GB SXM5 in the Secure Cloud. When I run the command
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
I get the following error (partially shown) that there are dependency conflicts with torch:
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.1.1 requires nvidia-cublas-cu12==12.1.3.1; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cublas-cu12 12.4.2.65 which is incompatible.
torch 2.1.1 requires nvidia-cuda-cupti-cu12==12.1.105; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cuda-cupti-cu12 12.4.99 which is incompatible.
I do not need believe I need torch for my program. Is there a way to either remove or edit these torch requirements? Also, I have been able to successfully install JAX over the last few weeks up until this morning.
Any help would be greatly appreciated!
3 replies
No longer able to Use Jax on H100 machines
Hello, today I am no longer able to use Jax on newly launched H100 instances (yesterday was fine). I am following the usual install instructions:
pip install --upgrade pip
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
In the interpreter, if I enter import jax and run a simple computation (e.g., take the square root of 2):
import jax
jax.numpy.sqrt(2)
I get the error (below is part of error message)
2024-01-25 21:49:14.017380: W external/xla/xla/stream_executor/gpu/asm_compiler.cc:235] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 9.0
2024-01-25 21:49:14.017433: W external/xla/xla/stream_executor/gpu/asm_compiler.cc:238] Used ptxas at ptxas
2024-01-25 21:49:15.106726: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:1352] failed to load PTX text as a module: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid
2024-01-25 21:49:15.106816: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:1357] error log buffer (63 bytes): error : Binary format for key='0', ident='' is not recognize
2024-01-25 21:49:15.106873: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Failed to load PTX text as a module: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid
...
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to load PTX text as a module: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid; current tracing scope: fusion; current profiling annotation: XlaModule:#prefix=jit(<lambda>)/jit(main),hlomodule=jit__lambda,program_id=0#.
This was not happening when I launched new instances yesterday. Any help would be appreciated! Thanks!
4 replies