Streaming is not quite working
Hi I was trying to deploy a simple huggingface model where my inference code looks like this
1 Reply
Here is how my inference code looks like
And this is how the handler code looks like:
Even tried out with
what is happening is, in my server side, I can see that it is doing streaming but on client side, I am getting responses when streaming finishes
class HFEngine:
def __init__(self) -> None:
load_dotenv()
self.model, self.tokenizer, self.streamer = self._initialize_llm(
model_name_or_path=os.environ.get("HF_MODEL_NAME"),
tokenizer_name_or_path=os.environ.get("HF_TOKENIZER_NAME"),
device=os.environ.get("DEVICE") or "cpu"
)
self.device = os.environ.get("DEVICE")
async def stream(self, chat_input: Union[str, List[Dict[str, str]]], generation_parameters: Dict[str, Any]):
try:
async for output in self._stream(chat_input=chat_input, generation_parameters=generation_parameters):
yield output
except Exception as e:
yield {"error": str(e)}
async def _stream(self, chat_input: Union[str, List[Dict[str, str]]], generation_parameters: Dict[str, Any]):
if isinstance(chat_input, str):
chat_input = [{"user": chat_input}]
input_ids = self.tokenizer.apply_chat_template(
conversation=chat_input, tokenize=True, return_tensors="pt"
).to(self.device)
generation_kwargs = dict(
input_ids=input_ids,
streamer=self.streamer,
**generation_parameters
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
for next_token in self.streamer:
try:
if next_token is not None:
yield {"status": 200, "delta": next_token}
except Empty:
await asyncio.sleep(0.001)
class HFEngine:
def __init__(self) -> None:
load_dotenv()
self.model, self.tokenizer, self.streamer = self._initialize_llm(
model_name_or_path=os.environ.get("HF_MODEL_NAME"),
tokenizer_name_or_path=os.environ.get("HF_TOKENIZER_NAME"),
device=os.environ.get("DEVICE") or "cpu"
)
self.device = os.environ.get("DEVICE")
async def stream(self, chat_input: Union[str, List[Dict[str, str]]], generation_parameters: Dict[str, Any]):
try:
async for output in self._stream(chat_input=chat_input, generation_parameters=generation_parameters):
yield output
except Exception as e:
yield {"error": str(e)}
async def _stream(self, chat_input: Union[str, List[Dict[str, str]]], generation_parameters: Dict[str, Any]):
if isinstance(chat_input, str):
chat_input = [{"user": chat_input}]
input_ids = self.tokenizer.apply_chat_template(
conversation=chat_input, tokenize=True, return_tensors="pt"
).to(self.device)
generation_kwargs = dict(
input_ids=input_ids,
streamer=self.streamer,
**generation_parameters
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
for next_token in self.streamer:
try:
if next_token is not None:
yield {"status": 200, "delta": next_token}
except Empty:
await asyncio.sleep(0.001)
import runpod
from engine import HFEngine
from constants import DEFAULT_MAX_CONCURRENCY
class JobInput:
def __init__(self, job):
self.llm_input = job.get("messages")
self.stream = job.get("stream", False)
self.sampling_params = job.get(
"sampling_params", {
"temperature": 0.1,
"top_p": 0.7,
"max_new_tokens":512
}
)
async def handler(job):
engine = HFEngine()
job_input = JobInput(job["input"])
async for delta in engine.stream(
chat_input=job_input.llm_input,
generation_parameters=job_input.sampling_params
):
yield delta
runpod.serverless.start(
{
"handler": handler,
"concurrency_modifier": lambda x: DEFAULT_MAX_CONCURRENCY,
}
)
import runpod
from engine import HFEngine
from constants import DEFAULT_MAX_CONCURRENCY
class JobInput:
def __init__(self, job):
self.llm_input = job.get("messages")
self.stream = job.get("stream", False)
self.sampling_params = job.get(
"sampling_params", {
"temperature": 0.1,
"top_p": 0.7,
"max_new_tokens":512
}
)
async def handler(job):
engine = HFEngine()
job_input = JobInput(job["input"])
async for delta in engine.stream(
chat_input=job_input.llm_input,
generation_parameters=job_input.sampling_params
):
yield delta
runpod.serverless.start(
{
"handler": handler,
"concurrency_modifier": lambda x: DEFAULT_MAX_CONCURRENCY,
}
)
"return_aggregate_stream": True,
and here is how my client code looks like which is a simple POST request
url = "http://localhost:8000/runsync"
headers = {
"Content-Type": "application/json"
}
data = {
"input": {
"messages": "Your prompt",
"stream": True
}
}
response = requests.post(url, headers=headers, data=json.dumps(data), timeout=600, stream=True)
if response.status_code == 200:
for res in response.iter_lines():
print(res)
url = "http://localhost:8000/runsync"
headers = {
"Content-Type": "application/json"
}
data = {
"input": {
"messages": "Your prompt",
"stream": True
}
}
response = requests.post(url, headers=headers, data=json.dumps(data), timeout=600, stream=True)
if response.status_code == 200:
for res in response.iter_lines():
print(res)