Can't use GPU with Jax in serverless endpoint

Hi, I'm trying to run a serverless worker to perform point tracking on a video. It works ok, but I think that it is running on CPU. I read that the telemetry on the UI isn't reliable, but the Container Logs indicate that too. There is an image of what they logs say. It finds the Nvidia GPU, but there are problems with Jax I think. I use the function on the first image to check the device: And the outputs I get are on the second image: In my Dockerfile, I'm setting this as base image: FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 I'm running this command to install the jax version that is supposed to work with CUDA 11.8. RUN pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html Then I install requirements.txt (I don't install Jax again here) and do other stuff And finally I do this to set the library path for CUDA: ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH I still can't get to make it work on GPU, if someone could tell me where the problem could be, it would be extremely helpful, thank you.
No description
No description
38 Replies
nerdylive
nerdylive6mo ago
Hey before running the code try setting this env variable export CUDA_VISIBLE_DEVICES=0,1 Run that command in a cli Let me know if that works or not
Madiator2011 (Work)
try add this to your dockerfile
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
ENV NVIDIA_VISIBLE_DEVICES=all
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
ENV NVIDIA_VISIBLE_DEVICES=all
nerdylive
nerdylive6mo ago
yeah i gues that, but it has been included in the newest image tag
Madiator2011 (Work)
I'm also why to use CUDA 11.8 rather than 12.1
pip install -U "jax[cuda12]"
pip install -U "jax[cuda12]"
galakurpismo3
galakurpismo3OP6mo ago
Yeah I tried both cuda 12 and 11.8
Madiator2011 (Work)
@galakurpismo3 any use case I might try make Better JAX template though would need to understand how you test it
galakurpismo3
galakurpismo3OP6mo ago
do i have to run this command in a cmd inside the Worker Container? Or how is it?
Madiator2011 (Work)
you would probably need to add it in docker container
galakurpismo3
galakurpismo3OP6mo ago
but the container is running on the serverless endpoint right?
Madiator2011 (Work)
workers are basically pods
galakurpismo3
galakurpismo3OP6mo ago
ok I'll run that command from the python code in the beginning and add your suggestion too
Madiator2011 (Work)
tried to run: pip install --upgrade "jax[cuda12_local]"
galakurpismo3
galakurpismo3OP6mo ago
okay, in the dockerfile, right?
Madiator2011 (Work)
GitHub
GitHub - NVIDIA/JAX-Toolbox: JAX-Toolbox
JAX-Toolbox. Contribute to NVIDIA/JAX-Toolbox development by creating an account on GitHub.
galakurpismo3
galakurpismo3OP6mo ago
okay, I'll try yes, thank you Hi, I think that it worked but there is a new error now, related to cudnn I think, these are the logs: Starting Serverless Worker |  Version 1.6.0 --- {"requestId": "cbeb73b4-8679-43d1-aaa0-8c68101e76ac-e1", "message": "Started.", "level": "INFO"} Get inside input_fn xla_bridge.py       :889  Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' xla_bridge.py       :889  Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory inference.py        :172  Found device: cuda:0 inference.py        :176  JAX is not using the GPU. Check your JAX installation and environment configuration. inference.py        :177  JAX backend: gpu inference.py        :182  CUDA_VISIBLE_DEVICES: 0,1 inference.py        :183  LD_LIBRARY_PATH: /opt/venv/lib/python3.9/site-packages/cv2/../../lib64:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 inference.py        :187  libcudart.so loaded successfully. inference.py        :189  libcudnn.so loaded successfully. inference.py        :143  Read and resized video, number of frames: 107 E0716  cuda_dnn.cc:535 Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR E0716  cuda_dnn.cc:539 Memory usage: 84536328192 bytes free, 84986691584 bytes total. E0716  cuda_dnn.cc:535 Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR E0716  cuda_dnn.cc:539 Memory usage: 84536328192 bytes free, 84986691584 bytes total. inference.py        :162  Error during processing: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details. {"requestId": "cbeb73b4-8679-43d1-aaa0-8c68101e76ac-e1", "message": "Finished.", "level": "INFO"} I've tried with 24GB GPU and 80GB GPU. I'm using this base image: FROM nvidia/cuda:12.0.0-cudnn8-devel-ubuntu20.04
nerdylive
nerdylive6mo ago
GitHub
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN libr...
Description I have a python virtual environment with a clean installation of JAX # Installs the wheel compatible with CUDA 12 and cuDNN 8.8 or newer. # Note: wheels only available on linux. pip ins...
nerdylive
nerdylive6mo ago
this maybe related
galakurpismo3
galakurpismo3OP6mo ago
It looks like an issue with vscode there, I don't know if it would be related, I've tried with all gpus and I get the same error every time
nerdylive
nerdylive6mo ago
So what's your versions now? Jax, cudnn, cuda the error messages looks alike maybe try bigger vram gpu's or try another later version on the jax and cuda
nerdylive
nerdylive6mo ago
GitHub
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN libr...
Description I have a python virtual environment with a clean installation of JAX # Installs the wheel compatible with CUDA 12 and cuDNN 8.8 or newer. # Note: wheels only available on linux. pip ins...
galakurpismo3
galakurpismo3OP6mo ago
Yeah I tried with all gpus now
nerdylive
nerdylive6mo ago
if it didn't work try this first
galakurpismo3
galakurpismo3OP6mo ago
It's Cuda 12.0, with this base image, I think it installs CUDNN 8.8: https://hub.docker.com/layers/nvidia/cuda/12.0.0-cudnn8-runtime-ubuntu20.04/images/sha256-7d0f83420618c3b337d02cfa8243b8e4a7e002ee4b436dd5c70f71cee176f4a0?context=explore And for Jax I do this to install it: RUN pip install --upgrade "jax[cuda12_local]"
nerdylive
nerdylive6mo ago
what about pip install -U "jax[cuda12]" try using CUDA >=12.1 too filter the serverless
galakurpismo3
galakurpismo3OP6mo ago
what does this mean?
nerdylive
nerdylive6mo ago
To get this version, and update your docker base image to use this version too on the endpoints, edit endpoint, expand the bottom section then select cuda version ( checkboxes )
galakurpismo3
galakurpismo3OP6mo ago
aah okay, I'll try 11.8 too, thank you
nerdylive
nerdylive6mo ago
yeah, just now i checked i think jax for cuda 12 works for 12.1 or later
galakurpismo3
galakurpismo3OP6mo ago
I'll try this, I'll tell you if it works, thanks a lot for helping nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04
Madiator2011 (Work)
@galakurpismo3 is your worker open source?
galakurpismo3
galakurpismo3OP6mo ago
I can share it with you but it's not simple to test, I'll try to share a simplified version
galakurpismo3
galakurpismo3OP6mo ago
hi, here is a simple version of the worker: https://github.com/galakurpi/yekar_coaches_point_tracking_simple for testing it, send the video link i have in this code in that same format: import requests url = 'https://api.runpod.ai/v2/sd1ylpcd55dj12/run' data = { 'input': { 'video_url': 'https://drive.google.com/uc?export=download&id=1SER_MwYt0XyOHOX0UbN30iyMCmeWE-dd' } } headers = { 'Content-Type': 'application/json', 'Authorization': 'Bearer <RUNPOD API KEY MISSING>' # If authentication is needed } response = requests.post(url, json=data, headers=headers) print(response.json()) thank you
GitHub
GitHub - galakurpi/yekar_coaches_point_tracking_simple
Contribute to galakurpi/yekar_coaches_point_tracking_simple development by creating an account on GitHub.
galakurpismo3
galakurpismo3OP6mo ago
let me know if you test anything or need anything
Madiator2011
Madiator20116mo ago
btw did you make sure to filter cuda version on machines in serverless
nerdylive
nerdylive6mo ago
. Have you tried the versions?
galakurpismo3
galakurpismo3OP6mo ago
Actually, no, sorry, but the logs showed that CUDA 12.1 was running But I'll try again with that
nerdylive
nerdylive6mo ago
Try your code, template in pods if you want, it's faster You can iterate there, if something works you can build your docker img and code from your result on the pod that works
galakurpismo3
galakurpismo3OP6mo ago
I tried with the filtering of CUDA 12.1 and nothing changed
Want results from more Discord servers?
Add your server