AndrewL
AndrewL
RRunPod
Created by AndrewL on 3/20/2024 in #⛅|pods
Cannot Install JAX
Hello, I am currently unable to properly install JAX on both the A100 SXM 80GB and the H100 80GB SXM5 in the Secure Cloud. When I run the command pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html I get the following error (partially shown) that there are dependency conflicts with torch: ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. torch 2.1.1 requires nvidia-cublas-cu12==12.1.3.1; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cublas-cu12 12.4.2.65 which is incompatible. torch 2.1.1 requires nvidia-cuda-cupti-cu12==12.1.105; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cuda-cupti-cu12 12.4.99 which is incompatible. I do not need believe I need torch for my program. Is there a way to either remove or edit these torch requirements? Also, I have been able to successfully install JAX over the last few weeks up until this morning. Any help would be greatly appreciated!
3 replies
RRunPod
Created by AndrewL on 1/25/2024 in #⛅|pods
No longer able to Use Jax on H100 machines
Hello, today I am no longer able to use Jax on newly launched H100 instances (yesterday was fine). I am following the usual install instructions: pip install --upgrade pip pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html In the interpreter, if I enter import jax and run a simple computation (e.g., take the square root of 2): import jax jax.numpy.sqrt(2) I get the error (below is part of error message) 2024-01-25 21:49:14.017380: W external/xla/xla/stream_executor/gpu/asm_compiler.cc:235] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 9.0 2024-01-25 21:49:14.017433: W external/xla/xla/stream_executor/gpu/asm_compiler.cc:238] Used ptxas at ptxas 2024-01-25 21:49:15.106726: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:1352] failed to load PTX text as a module: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid 2024-01-25 21:49:15.106816: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:1357] error log buffer (63 bytes): error : Binary format for key='0', ident='' is not recognize 2024-01-25 21:49:15.106873: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error: INTERNAL: Failed to load PTX text as a module: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid ... The above exception was the direct cause of the following exception: Traceback (most recent call last): File "<stdin>", line 1, in <module> jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to load PTX text as a module: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid; current tracing scope: fusion; current profiling annotation: XlaModule:#prefix=jit(<lambda>)/jit(main),hlomodule=jit__lambda,program_id=0#. This was not happening when I launched new instances yesterday. Any help would be appreciated! Thanks!
4 replies