1. 磐创AI-开放猫官方网站首页
  2. 系列教程
  3. Transformers

Transformers 加载预训练模型 | 七

本文是全系列中第9 / 13篇:Transformers

作者|huggingface
编译|VK
来源|Github

加载Google AI或OpenAI预训练权重或PyTorch转储

from_pretrained()方法

要加载Google AI、OpenAI的预训练模型或PyTorch保存的模型(用torch.save()保存的BertForPreTraining实例),PyTorch模型类和tokenizer可以被from_pretrained()实例化:

model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None, from_tf=False, state_dict=None, *input, **kwargs)

其中

  • BERT_CLASS要么是用于加载词汇表的tokenizer(BertTokenizerOpenAIGPTTokenizer类),要么是加载八个BERT或三个OpenAI GPT PyTorch模型类之一(用于加载预训练权重):BertModelBertForMaskedLMBertForNextSentencePredictionBertForPreTrainingBertForSequenceClassificationBertForTokenClassificationBertForMultipleChoiceBertForQuestionAnsweringOpenAIGPTModelOpenAIGPTLMHeadModelOpenAIGPTDoubleHeadsModel

  • PRE_TRAINED_MODEL_NAME_OR_PATH为:

    • Google AI或OpenAI的预定义的快捷名称列表,其中的模型都是已经训练好的模型:
    • bert-base-uncased:12个层,768个隐藏节点,12个heads,110M参数量。
    • bert-large-uncased:24个层,1024个隐藏节点,16个heads,340M参数量。
    • bert-base-cased:12个层,768个隐藏节点,12个heads,110M参数量。
    • bert-large-cased:24个层,1024个隐藏节点,16个heads,340M参数量。
    • bert-base-multilingual-uncased:(原始,不推荐)12个层,768个隐藏节点,12个heads,110M参数量。
    • bert-base-multilingual-cased:(新的,推荐)12个层,768个隐藏节点,12个heads,110M参数量。
    • bert-base-chinese:简体中文和繁体中文,12个层,768个隐藏节点,12个heads,110M参数量。
    • bert-base-german-cased:仅针对德语数据训练,12个层,768个隐藏节点,12个heads,110M参数量。性能评估(https://deepset.ai/german-bert)
    • bert-large-uncased-whole-word-masking:24个层,1024个隐藏节点,16个heads,340M参数量。经过Whole Word Masking模式训练(该单词对应的标记全部掩码处理)
    • bert-large-cased-whole-word-masking:24个层,1024个隐藏节点,16个heads,340M参数量。经过Whole Word Masking模式训练(该单词对应的标记全部掩码处理)
    • bert-large-uncased-whole-word-masking-finetuned-squad:在SQuAD上微调的bert-large-uncased-whole-word-masking模型(使用run_bert_squad.py)。结果:exact_match:86.91579943235573,f1:93.1532499015869
    • bert-base-german-dbmdz-cased:仅针对德语数据训练,12个层,768个隐藏节点,12个heads,110M参数量。性能评估(https://deepset.ai/german-bert)
    • bert-base-german-dbmdz-uncased:仅针对德语数据(无大小写),12个层,768个隐藏节点,12个heads,110M参数量。性能评估(https://github.com/dbmdz/german-bert)
    • openai-gpt:OpenAI GPT英文模型,12个层,768个隐藏节点,12个heads,110M参数量。
    • gpt2:OpenAI GPT-2英语模型,12个层,768个隐藏节点,12个heads,117M参数量。
      • gpt2-medium:OpenAI GPT-2英语模型,24个层,1024个隐藏节点、16个heads,345M参数量。
    • transfo-xl-wt103:使用Transformer-XL英语模型在wikitext的-103上训练的模型,24个层,1024个隐藏节点、16个heads,257M参数量。

    • 一个路径或URL包含一个预训练模型:

    • bert_config.jsonopenai_gpt_config.json是用于模型的配置文件
    • pytorch_model.binBertForPreTraining保存的OpenAIGPTModelTransfoXLModelGPT2LMHeadModel的预训练实例的PyTorch转储。(使用常用的torch.save()保存)

    如果PRE_TRAINED_MODEL_NAME_OR_PATH是快捷名称,则将从AWS S3下载预训练权重。可以参见链接(https://github.com/huggingface/transformers/blob/master/transformers/modeling_bert.py)并存储在缓存文件夹中以避免以后需要下载(可以在~/.pytorch_pretrained_bert/中找到该缓存文件夹)。

    • cache_dir可以是特定目录的可选路径,以下载和缓存预先训练的模型权重。该选项在使用分布式训练时特别有用:为避免同时访问相同的权重,你可以设置例如cache_dir='./pretrained_model_{}'.format(args.local_rank)。)。

    • from_tf :我们应该从本地保存的TensorFlow checkpoint加载权重

    • state_dict :可选状态字典(collections.OrderedDict对象),而不是使用Google的预训练模式
    • *inputs** kwargs:特定Bert类的附加输入(例如:BertForSequenceClassification的num_labels)

Uncased表示在WordPiece标记化之前,文本已小写,例如,John Smith变为john smith。Uncased模型还会删除任何重音标记。Cased表示保留了真实的大小写和重音标记。通常,除非你知道案例信息对于你的任务很重要(例如,命名实体识别或词性标记),否则Uncased模型会更好。有关多语言和中文模型的信息,请参见(https://github.com/google-research/bert/blob/master/multilingual.md)或原始的TensorFlow存储库。

当使用Uncased的模型时,请确保将–do_lower_case传递给示例训练脚本(如果使用自己的脚本,则将do_lower_case=True传递给FullTokenizer))。

示例:

# BERT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, do_basic_tokenize=True)
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

# OpenAI GPT
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
model = OpenAIGPTModel.from_pretrained('openai-gpt')

# Transformer-XL
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')

# OpenAI GPT-2
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')

缓存目录

pytorch_pretrained_bert将预训练权重保存在缓存目录中(位于此优先级):

  • cache_dirfrom_pretrained()方法的可选参数(见上文),
  • shell环境变量PYTORCH_PRETRAINED_BERT_CACHE
  • PyTorch缓存目录+/pytorch_pretrained_bert/ ,其中PyTorch缓存目录由(按此顺序定义):
    • 外壳环境变量ENV_TORCH_HOME
    • shell环境变量ENV_XDG_CACHE_HOME +/torch/)
    • 默认值:~/.cache/torch/

通常,如果你未设置任何特定的环境变量pytorch_pretrained_bert缓存将位于~/.cache/torch/pytorch_pretrained_bert/中。

你可以始终安全地删除pytorch_pretrained_bert缓存,但是必须从我们的S3重新下载预训练模型权重和词汇文件。

原文链接:https://huggingface.co/transformers/serialization.html

原创文章,作者:磐石,如若转载,请注明出处:https://panchuang.net/2020/04/09/transformers-%e5%8a%a0%e8%bd%bd%e9%a2%84%e8%ae%ad%e7%bb%83%e6%a8%a1%e5%9e%8b-%e4%b8%83/

发表评论

登录后才能评论

联系我们

400-800-8888

在线咨询:点击这里给我发消息

邮件:admin@example.com

工作时间:周一至周五,9:30-18:30,节假日休息