tokenizer_config['bos_token'] = special_bos = special_cls
if not special_eos and special_sep and tokenizer_config:
tokenizer_config['eos_token'] = special_eos = special_sep
- post_processor = tokenizer.get('post_processor', {})
- for processor in post_processor.get('processors', [post_processor]):
- if processor.get('type') == 'RobertaProcessing':
- self.add_special_token['bos'] = True
- self.add_special_token['eos'] = True
- self.add_special_token['sep'] = True
- if not special_cls and tokenizer_config:
- special_cls = processor.get('cls', [special_bos])[0]
- tokenizer_config['cls_token'] = special_cls
- if not special_sep and tokenizer_config:
- special_sep = processor.get('sep', [special_eos])[0]
- tokenizer_config['sep_token'] = special_sep
- continue
- # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
- # Only works with simple templates, **will** get it wrong on unusual sequences
- if processor.get('type') == 'TemplateProcessing':
- tmpl_single = processor.get('single', [])
- tmpl_pair = processor.get('pair', [])
- special_first = None
- special_last = None
- if len(tmpl_single) > 1:
- if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
- if not tokenizer_config:
- special_bos = special_first
- self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
- if special_first not in (special_bos, special_cls):
- logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
- if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
- if not tokenizer_config:
- special_eos = special_last
- elif special_last != special_eos:
- if 'eot' not in self.special_token_types:
- self.special_token_types = tuple(self.special_token_types) + ('eot', )
- tokenizer_config['eot_token'] = special_eos
- elif 'eom' not in self.special_token_types:
- self.special_token_types = tuple(self.special_token_types) + ('eom', )
- tokenizer_config['eom_token'] = special_eos
- else:
- logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
- tokenizer_config['eos_token'] = special_eos = special_last
- self.add_special_token['eos'] = True if special_last == special_eos else False
- if special_last != special_eos:
- logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
- if tmpl_pair:
- seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
- seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
- if (special_first and seq_start == 0) or (special_last and seq_stop is None):
- logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
- if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
- tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
- tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
- if tmpl_a != 'A' or tmpl_b != 'B':
- logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
- # A [sep] [eos] B
- if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
- add_sep = False
- if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
- if special_entry in (special_sep, special_eos) and not special_last:
- add_sep = True
- if special_entry not in (special_sep, special_eos):
- logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
- else:
- logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
- if len(tmpl_pair) == 2:
- if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
- if special_entry in (special_sep, special_eos):
+ if post_processor := tokenizer.get('post_processor'):
+ for processor in post_processor.get('processors', [post_processor]):
+ if processor.get('type') == 'RobertaProcessing':
+ self.add_special_token['bos'] = True
+ self.add_special_token['eos'] = True
+ self.add_special_token['sep'] = True
+ if not special_cls and tokenizer_config:
+ special_cls = processor.get('cls', [special_bos])[0]
+ tokenizer_config['cls_token'] = special_cls
+ if not special_sep and tokenizer_config:
+ special_sep = processor.get('sep', [special_eos])[0]
+ tokenizer_config['sep_token'] = special_sep
+ continue
+ # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
+ # Only works with simple templates, **will** get it wrong on unusual sequences
+ if processor.get('type') == 'TemplateProcessing':
+ tmpl_single = processor.get('single', [])
+ tmpl_pair = processor.get('pair', [])
+ special_first = None
+ special_last = None
+ if len(tmpl_single) > 1:
+ if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
+ if not tokenizer_config:
+ special_bos = special_first
+ self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
+ if special_first not in (special_bos, special_cls):
+ logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
+ if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
+ if not tokenizer_config:
+ special_eos = special_last
+ elif special_last != special_eos:
+ if 'eot' not in self.special_token_types:
+ self.special_token_types = tuple(self.special_token_types) + ('eot', )
+ tokenizer_config['eot_token'] = special_eos
+ elif 'eom' not in self.special_token_types:
+ self.special_token_types = tuple(self.special_token_types) + ('eom', )
+ tokenizer_config['eom_token'] = special_eos
+ else:
+ logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
+ tokenizer_config['eos_token'] = special_eos = special_last
+ self.add_special_token['eos'] = True if special_last == special_eos else False
+ if special_last != special_eos:
+ logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
+ if tmpl_pair:
+ seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
+ seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
+ if (special_first and seq_start == 0) or (special_last and seq_stop is None):
+ logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
+ if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
+ tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
+ tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
+ if tmpl_a != 'A' or tmpl_b != 'B':
+ logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
+ # A [sep] [eos] B
+ if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
+ add_sep = False
+ if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
+ if special_entry in (special_sep, special_eos) and not special_last:
add_sep = True
if special_entry not in (special_sep, special_eos):
- logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
+ logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
else:
- logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
- self.add_special_token['sep'] = add_sep
- if add_sep and not special_sep and tokenizer_config:
- tokenizer_config['sep_token'] = special_eos
- continue
+ logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
+ if len(tmpl_pair) == 2:
+ if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
+ if special_entry in (special_sep, special_eos):
+ add_sep = True
+ if special_entry not in (special_sep, special_eos):
+ logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
+ else:
+ logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
+ self.add_special_token['sep'] = add_sep
+ if add_sep and not special_sep and tokenizer_config:
+ tokenizer_config['sep_token'] = special_eos
+ continue
if not tokenizer_config:
return True
chat_template_alt = None