| | import copy |
| |
|
| |
|
| | class ChatTemplate: |
| | cache = {} |
| | roles = set() |
| |
|
| | def __init__(self, model, im_start=r'<|im_start|>', im_end=r'<|im_end|>', nl='\n'): |
| | self.model = model |
| | self.nl = nl |
| | self.im_start = im_start |
| | self.im_start_token = model.tokenize(self.im_start.encode('utf-8'), add_bos=False, special=True) |
| | self.im_end = im_end |
| | self.im_end_nl = model.tokenize((self.im_end + self.nl).encode('utf-8'), add_bos=False, special=True) |
| | self.eos = [model._token_eos, self.im_end_nl[0]] |
| | self.onenl = [self.im_end_nl[-1]] |
| | tmp = model.tokenize(('\r' + self.nl).encode('utf-8'), add_bos=False, special=True) |
| | if len(tmp) == 1: |
| | self.onenl.append(tmp[0]) |
| | self.onerl = model.tokenize(b'\r', add_bos=False, special=True) |
| | self.nlnl = None |
| | tmp = model.tokenize((self.nl + self.nl).encode('utf-8'), add_bos=False, special=True) |
| | if len(tmp) == 1: |
| | self.nlnl = tmp[0] |
| | print('ChatTemplate', self.eos, self.im_end_nl, self.onerl, self.onenl, self.nlnl) |
| |
|
| | def _get(self, key: str): |
| | if key in self.cache: |
| | return copy.deepcopy(self.cache[key]) |
| | else: |
| | value = self.model.tokenize((self.im_start + key + self.nl).encode('utf-8'), add_bos=False, special=True) |
| | self.cache[key] = copy.deepcopy(value) |
| | return value |
| |
|
| | def _add_role(self, _role): |
| | if _role: |
| | self.roles.add('\n' + _role) |
| |
|
| | def eos_in_role(self, history: str, t_bot): |
| | if not (history.endswith('\n') or history.endswith('\r')): |
| | return 0 |
| | tmp = history.rstrip() |
| | for _role in self.roles: |
| | if tmp.endswith(_role): |
| | n = len(t_bot) |
| | for i in range(1, n): |
| | tmp = self.model.str_detokenize(t_bot[n - i:]) |
| | if tmp.rstrip().endswith(_role): |
| | print('eos_in_role', t_bot[n - i:], repr(tmp)) |
| | return i |
| | print('eos_in_role missing') |
| | break |
| | return 0 |
| |
|
| | def eos_in_nlnl(self, history: str, t_bot): |
| | if not (history.endswith('\n\n') or history.endswith('\n\r\n')): |
| | return 0 |
| | n = len(t_bot) |
| | for i in range(1, n): |
| | tmp = self.model.str_detokenize(t_bot[n - i:]) |
| | if tmp.endswith('\n\n') or tmp.endswith('\n\r\n'): |
| | if tmp.startswith(']'): |
| | return 0 |
| | print('eos_in_nlnl', t_bot[n - i:], repr(tmp)) |
| | return i |
| | print('eos_in_nlnl missing') |
| | return 0 |
| |
|
| | def __call__(self, _role, prompt=None): |
| | self._add_role(_role) |
| | if prompt is None: |
| | return self._get(_role) |
| | |
| | prompt = self.im_start + _role + self.nl + prompt |
| | prompt = self.model.tokenize(prompt.encode('utf-8'), add_bos=False, special=True) + self.im_end_nl |
| | |
| | return prompt |
| |
|