모델 준비
모델은 huggingface에서 미리 준비해야 한다. 로컬에서 GPU 부담 없이 돌리기 위해 google/gemma-3-1b-it를 이용해 실험했다.
import os
from dotenv import load_dotenv
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer
load_dotenv()
login(token=os.getenv("HF_KEY"))
model_id = "google/gemma-3-1b-it"
save_directory = str(os.getenv("MODEL_PATH"))
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer.save_pretrained(save_directory)
model.save_pretrained(save_directory)
단순히 huggingface에서 모델을 받아 로컬에 저장하는 코드다.
vLLM 기본 사용법
API 서버를 만들기 전 vLLM의 기본 서빙 방법을 살펴보자.
import os
from dotenv import load_dotenv
from vllm import LLM, SamplingParams
load_dotenv()
model_path = str(os.getenv("MODEL_PATH"))
llm = LLM(model=model_path, gpu_memory_utilization=0.6, tensor_parallel_size=1)
sampling_params = SamplingParams(
temperature=0.5, top_p=0.7, repetition_penalty=1.1, max_tokens=512
)
query = "재밌는 농담 하나 해봐."
response = llm.generate(query, sampling_params)
print(response[0].outputs[0].text)
LLM과 SamplingParams를 이용해 기본적인 설정을 마친 후, generate를 통해 답변을 생성한다.
FastAPI 연결
API 서버로 서빙하기 위해서는 동시 사용자가 접근할 상황을 가정해 비동기 처리가 필요하다. 이는 AsyncEngineArgs, AsyncLLMEngine, SamplingParams를 통해 지원하며, FastAPI를 연결해 RESTful API로 서빙할 수 있다.
import os
import uuid
from dotenv import load_dotenv
import asyncio
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
load_dotenv()
model_path = str(os.getenv("MODEL_PATH"))
engine_args = AsyncEngineArgs(
model=model_path, gpu_memory_utilization=0.5, tensor_parallel_size=1
)
llm = AsyncLLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(
temperature=0.2, top_p=0.7, repetition_penalty=1.1, max_tokens=64
)
app = FastAPI()
class QueryRequest(BaseModel):
query: str
@app.post("/generate")
async def generate_post(request: QueryRequest):
request_id = str(uuid.uuid4())
llm.add_request(request_id, request.query, sampling_params)
sent_text_idx = 0
async def stream_response():
nonlocal sent_text_idx
while True:
request_outputs = llm.step()
for output in request_outputs:
if output.request_id == request_id:
text = output.outputs[0].text
new_text = text[sent_text_idx:]
sent_text_idx = len(text)
# 띄어쓰기 단위로 새로운 텍스트를 yield
for word in new_text.split(" "):
if word:
yield word + " "
# 요청이 완료되면 종료
if output.finished:
return
await asyncio.sleep(0.1)
return StreamingResponse(stream_response(), media_type="text/plain")
중간에 반복문을 돌면서 sent_text_idx를 업데이트하는 이유는 step의 답변 생성 방식 때문이다. 위 코드에서 text는 반복을 돌 때마다 아래처럼 문자열이 누적된다.
- "안녕"
- "안녕하세요"
- "안녕하세요. "
- "안녕하세요. 저는"
따라서 이전에 출력했던 청크를 반복 출력하지 않기 위해 이전 출력 인덱스를 저장했다가 새로 생성된 부분만 잘라 출력한다.