效果:
用户:
类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领
ChatGLM-6B:
WHICH衣的衣的衣,衣,衣,衣的衣,衣,衣,衣,衣,衣的衣,衣的衣,衣的衣,衣的衣,衣,衣的 "&衣,衣,衣的 "\"的 "\"的衣,衣, "\"的衣
代码:
import os
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
MODEL_PATH = "/data/nfs/llm/model/chatglm-6b"
CHECKPOINT_PATH = "/home/guodong.li/output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-500"
# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(MODEL_PATH, config=config, trust_remote_code=True).cuda()
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
print(f"Quantized to 4 bit")
#model = model.quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()
print("用户:你好\n")
response, history = model.chat(tokenizer, "你好", history=[])
print("ChatGLM-6B:\n",response)
print("\n------------------------------------------------\n用户:")
line = input()
while line:
response, history = model.chat(tokenizer, line, history=history)
print("ChatGLM-6B:\n", response)
print("\n------------------------------------------------\n用户:")
line = input()
Environment
- OS:Centos 7
- Python: 3.10
- Transformers: 2.28.0
- PyTorch: 1.13.1
- CUDA Support (`python -c "import torch; print(torch.cuda.is_available())"`) :