Orion-14B 모델 성능이 LLaMA2 13B 보다 좋다길래,
inference test를 해보려고 했다.
https://huggingface.co/OrionStarAI/Orion-14B-Base
모델을 허깅페이스에 공개해 두었기 때문에, 아래 코드로 간단하게 모델을 다운받고 테스트 해볼 수 있다.
"OrionStarAI/Orion-14B" 에 해당하는 부분만, 다운로드 받고 싶은 모델 경로로 변경해주면 된다.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
tokenizer = AutoTokenizer.from_pretrained("OrionStarAI/Orion-14B", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("OrionStarAI/Orion-14B", device_map="auto",
torch_dtype=torch.bfloat16, trust_remote_code=True)
model.generation_config = GenerationConfig.from_pretrained("OrionStarAI/Orion-14B")
messages = [{"role": "user", "content": "Hello, what is your name? "}]
response = model.chat(tokenizer, messages, streaming=False)
print(response)
예를 들어, Base 모델을 테스트하려 한다면 "OrionStarAI/Orion-14B-Base" 로 바꿔서 실행하면 된다.
그런데 제공되어 있는 코드로 허깅페이스 모델 다운로드 받는 도중, 자꾸 아래와 같은 에러 메세지가 나오면서 다운이 안되는 것이다.
connectionerror httpsconnectionpool(host='cdn-lfs-us-1.huggingface.co' port=443)
지금껏 LFS(대용량 파일)을 많이 썼었는데 한번도 본 적 없는 에러였는데,
모델 파일이 커서 다운로드 받다가 중간에 커넥션이 끊기면 나는 에러인 듯하다.
커맨드에서도 동일한 에러가 발생하는건지 보려고
Orion github에 올라와있는 demo/cli_demo.py로 실행해도 동일한 에러가 발생하였다.
검색 결과, 아래 글을 보고 "resume_download" 옵션을 줘서 해결할 수 있었다.
https://huggingface.co/google/switch-large-128/discussions/5
모델을 불러오는 부분, 즉
model = AutoModelForCausalLM.from_pretrained("OrionStarAI/Orion-14B", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
이부분에 resume_download = True 를 추가하면 된다.
최종 코드는 다음과 같다.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
tokenizer = AutoTokenizer.from_pretrained("OrionStarAI/Orion-14B-Base", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("OrionStarAI/Orion-14B-Base", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, resume_download = True)
model.generation_config = GenerationConfig.from_pretrained("OrionStarAI/Orion-14B-Base")
messages = [{"role": "user", "content": "Hello, what is your name? "}]
response = model.chat(tokenizer, messages, streaming=False)
print(response)
cash 파일이 잘 저장되고 있는지 확인해보니, 잘 저장되고 있다!
에러 해결!
'Programming > Error' 카테고리의 다른 글
[Error] nohup 안됨 - MPIrun (0) | 2024.06.04 |
---|---|
[python] 폐쇄망 패키지 설치 에러 : transformers Installation Error - Failed building wheel for tokenizers (0) | 2023.11.22 |
댓글