No longer able to Use Jax on H100 machines
Hello, today I am no longer able to use Jax on newly launched H100 instances (yesterday was fine). I am following the usual install instructions:
pip install --upgrade pip
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
In the interpreter, if I enter import jax and run a simple computation (e.g., take the square root of 2):
import jax
jax.numpy.sqrt(2)
I get the error (below is part of error message)
2024-01-25 21:49:14.017380: W external/xla/xla/stream_executor/gpu/asm_compiler.cc:235] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 9.0
2024-01-25 21:49:14.017433: W external/xla/xla/stream_executor/gpu/asm_compiler.cc:238] Used ptxas at ptxas
2024-01-25 21:49:15.106726: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:1352] failed to load PTX text as a module: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid
2024-01-25 21:49:15.106816: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:1357] error log buffer (63 bytes): error : Binary format for key='0', ident='' is not recognize
2024-01-25 21:49:15.106873: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Failed to load PTX text as a module: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid
...
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to load PTX text as a module: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid; current tracing scope: fusion; current profiling annotation: XlaModule:#prefix=jit(<lambda>)/jit(main),hlomodule=jit__lambda,program_id=0#.
This was not happening when I launched new instances yesterday. Any help would be appreciated! Thanks!
2 Replies
Which type of H100? Community Cloud or Secure Cloud?
H100 80GB SXM5 (26 vCPU 251 GB RAM). Secure Cloud. Thank you!
Hello, is there an update on this issue? I noticed Jax works well on the A100 SXM 80GB in the secure cloud that run CUDA 12.2, but still not the H100s that run CUDA 12.3.