| |
|
| | import os
|
| | import torch
|
| | import pickle
|
| | from model import GPTConfig, GPT
|
| | import tiktoken
|
| | from rich.traceback import install
|
| |
|
| | install()
|
| |
|
| |
|
| | ckpt_path = 'out/ckpt.pt'
|
| | meta_path = 'data/mydata/meta.pkl'
|
| | device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| | tokenizer_name = 'cl100k_base'
|
| | max_new_tokens = 1024
|
| | temperature = 0.8
|
| | top_k = 100
|
| | special_tokens = {"<|endoftext|>", "<|im_start|>", "<|im_stop|>"}
|
| |
|
| |
|
| | enc = tiktoken.get_encoding(tokenizer_name)
|
| | encode = enc.encode
|
| | decode = enc.decode
|
| |
|
| |
|
| | with open(meta_path, 'rb') as f:
|
| | meta = pickle.load(f)
|
| | vocab_size = meta['vocab_size']
|
| |
|
| |
|
| | checkpoint = torch.load(ckpt_path, map_location=device)
|
| | model_args = checkpoint['model_args']
|
| | model_args['vocab_size'] = vocab_size
|
| | block_size = model_args.get('block_size', 1024)
|
| |
|
| |
|
| | model = GPT(GPTConfig(**model_args))
|
| | model.load_state_dict(checkpoint['model'])
|
| | model.to(device)
|
| | model.eval()
|
| |
|
| | @torch.no_grad()
|
| | def generate_stream(model, input_ids, max_new_tokens, temperature=1.0, top_k=None):
|
| | model.eval()
|
| | special_token_id = encode("<|endoftext|>", allowed_special=special_tokens)[0]
|
| |
|
| | for _ in range(max_new_tokens):
|
| | if input_ids.size(1) > block_size:
|
| | input_ids = input_ids[:, -block_size:]
|
| |
|
| | logits, _ = model(input_ids)
|
| | logits = logits[:, -1, :] / temperature
|
| |
|
| | if top_k is not None:
|
| | v, _ = torch.topk(logits, top_k)
|
| | logits[logits < v[:, [-1]]] = -float('Inf')
|
| |
|
| | probs = torch.nn.functional.softmax(logits, dim=-1)
|
| | next_token = torch.multinomial(probs, num_samples=1)
|
| | next_token_id = next_token.item()
|
| |
|
| | input_ids = torch.cat((input_ids, next_token), dim=1)
|
| |
|
| | decoded_token = decode([next_token_id])
|
| | print(decoded_token, end='', flush=True) if decoded_token not in special_tokens else None
|
| |
|
| | if next_token_id == special_token_id:
|
| | break
|
| |
|
| | print()
|
| | return input_ids
|
| |
|
| | def main():
|
| | print("π€ AI Assistant is ready. Type 'exit' or press Ctrl+C to quit.\n")
|
| | try:
|
| | while True:
|
| | user_input = input("You: ")
|
| | if user_input.lower() in {"exit", "quit"}:
|
| | print("π Exiting assistant.")
|
| | break
|
| |
|
| | prompt = f"""
|
| | <|im_start|>user
|
| | {user_input}<|endoftext|>
|
| | <|im_stop|>
|
| |
|
| | <|im_start|>assistant
|
| |
|
| | """
|
| | input_ids = torch.tensor(encode(prompt, allowed_special=special_tokens), dtype=torch.long, device=device)[None, ...]
|
| |
|
| | print("π€ Assistant:", end=' ', flush=True)
|
| | generate_stream(model, input_ids, max_new_tokens, temperature, top_k)
|
| | print("-" * 50)
|
| |
|
| | except KeyboardInterrupt:
|
| | print("\nπ Exiting assistant.")
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|