文本分类插图

代码

解释

1. 导入必要的库

  • from rich import printrich 是一个 Python 库,主要用于在控制台打印漂亮的彩色输出。print 是它的一个函数,可以格式化和增强输出效果。
  • from rich.console import ConsoleConsolerich 库中的一个类,用于控制台输出的格式化。比如,它支持在输出中加入颜色、进度条、表格等。
  • from transformers import AutoTokenizer, AutoModelForSequenceClassificationtransformers 是 Hugging Face 提供的一个库,包含了许多预训练的自然语言处理模型和工具。AutoTokenizer 是自动加载合适的分词器,AutoModelForSequenceClassification 是加载一个用于文本分类的预训练模型。
  • import torchtorch 是 PyTorch 库,用于深度学习模型的构建与训练。
  • from torch.nn.functional import softmaxsoftmax 是一个函数,通常用在分类任务中,将模型输出的原始得分(logits)转换为概率。

2. 定义文本分类的类别和示例

class_examples = {
    '新闻报道': '今日,股市经历了一轮震荡,受到宏观经济数据和全球贸易紧张局势的影响。投资者密切关注美联储可能的政策调整,以适应市场的不确定性。',
    '财务报告': '本公司年度财务报告显示,去年公司实现了稳步增长的盈利,同时资产负债表呈现强劲的状况。经济环境的稳定和管理层的有效战略执行为公司的健康发展奠定了基础。',
    '公司公告': '本公司高兴地宣布成功完成最新一轮并购交易,收购了一家在人工智能领域领先的公司。这一战略举措将有助于扩大我们的业务领域,提高市场竞争力',
    '分析师报告': '最新的行业分析报告指出,科技公司的创新将成为未来增长的主要推动力。云计算、人工智能和数字化转型被认为是引领行业发展的关键因素,投资者应关注这些趋势'
}

  • 这里定义了一个字典 class_examples,其中每个键是一个文本分类的标签(如 '新闻报道''财务报告'),而每个值是对应类别的一个示例句子。

3. 初始化 prompt 函数

def init_prompts():
    class_list = list(class_examples.keys())
    pre_history = [
        (
            f'现在你是一个文本分类器,你需要按照要求将我给你的句子分类到:{class_list}类别中。',
            f'好的。'
        )
    ]
    for _type, exmpale in class_examples.items():
        pre_history.append((f'“{exmpale}”是 {class_list} 里的什么类别?', _type))
    return {'class_list': class_list, 'pre_history': pre_history}

  • class_list = list(class_examples.keys()):获取字典中的所有键,即分类的标签,存储在 class_list 中。
  • pre_history = [...]:初始化一个列表,里面包含了与模型交互的历史。第一个元素是一个提示,告诉模型它的任务是分类,并给出所有的类别。
  • for _type, exmpale in class_examples.items()::对 class_examples 字典进行遍历,取出每个类别及其示例。对于每个类别,构造一个问句和答案,提示模型该句子属于哪个类别。
  • return {'class_list': class_list, 'pre_history': pre_history}:最终返回一个字典,包含类别列表和提示信息。

4. 定义推理函数

def inference(sentences: list, custom_settings: dict, model, tokenizer, console):
    for sentence in sentences:
        with console.status("[bold bright_green]Model Inference..."):
            inputs = tokenizer(sentence, return_tensors="pt").to(device)
            outputs = model(**inputs)
            logits = outputs.logits
            probs = softmax(logits, dim=-1)  # 转换为概率
            predicted_class = probs.argmax(dim=-1).item()  # 获取预测类别
            confidence = probs.max().item()  # 置信度
            response = custom_settings['class_list'][predicted_class]

            # 输出结果
            print(f'>>> [bold bright_red]sentence: {sentence}')
            print(f'>>> [bold bright_green]Predicted Class: {response}')
            print(f'>>> [bold bright_blue]Confidence: {confidence:.2%}')

  • for sentence in sentences::遍历传入的多个句子,依次进行分类。
  • with console.status("[bold bright_green]Model Inference...")::使用 rich 库显示一个进度条,表示模型正在进行推理。
  • inputs = tokenizer(sentence, return_tensors="pt").to(device):使用分词器对句子进行处理,转换为模型能够理解的格式(PyTorch tensors),并将数据送到指定的设备(CPU或GPU)上。
  • outputs = model(**inputs):将处理过的输入数据送入模型,获取模型的输出。
  • logits = outputs.logits:获取模型的原始输出(logits)。logits 是模型的未归一化的得分,后面会转换成概率。
  • probs = softmax(logits, dim=-1):使用 softmax 函数将 logits 转换为概率分布,使得每个类别的得分变成一个概率值。
  • predicted_class = probs.argmax(dim=-1).item():通过 argmax 函数找出概率值最大的类别索引,即模型的预测结果。
  • confidence = probs.max().item():获取预测类别的概率值,也就是模型的置信度。
  • response = custom_settings['class_list'][predicted_class]:根据预测的类别索引,从 custom_settings 中获取类别名称。
  • 最后,使用 rich 库打印输出:
    • >>> sentence:显示原始句子。
    • >>> Predicted Class:显示模型预测的类别。
    • >>> Confidence:显示预测的置信度。

5. 主程序

if __name__ == '__main__':
    console = Console()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-2b", trust_remote_code=True)
    model = AutoModelForSequenceClassification.from_pretrained("THUDM/glm-2b", num_labels=len(class_examples), trust_remote_code=True, ignore_mismatched_sizes=True)
    model.to(device)
    sentences = [
        "今日,央行发布公告宣布降低利率,以刺激经济增长。这一降息举措将影响贷款利率,并在未来几个季度内对金融市场产生影响。",
        "ABC公司今日发布公告称,已成功完成对XYZ公司股权的收购交易。本次交易是ABC公司在扩大业务范围、加强市场竞争力方面的重要举措。据悉,此次收购将进一步巩固ABC公司在行业中的地位。",
        "公司资产负债表显示,公司偿债能力强劲,现金流充足,为未来投资和扩张提供了坚实的财务基础。",
        "最新的分析报告指出,可再生能源行业预计将在未来几年经历持续增长,投资者应该关注这一领域的投资机会"
    ]
    custom_settings = init_prompts()
    print(custom_settings)

    # 调用推理函数进行分类
    inference(sentences, custom_settings, model, tokenizer, console)
  • if __name__ == '__main__'::保证只有在直接运行该脚本时,下面的代码才会执行。
  • console = Console():创建一个 Console 对象,用于漂亮的控制台输出。
  • device = 'cuda' if torch.cuda.is_available() else 'cpu':检查是否有 GPU(CUDA)可用,如果有则使用 GPU,否则使用 CPU。
  • tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-2b", trust_remote_code=True):加载预训练的分词器 glm-2b,并允许从远程下载代码。
  • model = AutoModelForSequenceClassification.from_pretrained(...):加载一个预训练的分类模型,指定类别数为 class_examples 的长度。
  • model.to(device):将模型移动到之前检测到的设备(GPU 或 CPU)。
  • sentences = [...]:定义一个包含多个待分类句子的列表。
  • custom_settings = init_prompts():调用 init_prompts() 函数,获取用于分类的类别列表和与模型交互的历史记录。这会生成一个包含 class_list(类别列表)和 pre_history(分类示例对话历史)的字典,并将它赋值给 custom_settings 变量。
  • print(custom_settings):打印出 custom_settings 内容,便于查看类别和初始化的对话历史。此行代码有助于调试,确保 custom_settings 正确返回。

6.调用推理函数进行分类

inference(sentences, custom_settings, model, tokenizer, console)
  • inference(sentences, custom_settings, model, tokenizer, console):调用我们之前定义的 inference 函数,将待分类的句子列表 sentences 和其他必要的参数(如模型 model,分词器 tokenizer,控制台 console)传递给函数。这个函数会对每个句子进行推理,计算出它们的类别及置信度,并输出结果。

总结

  1. 模型和分词器加载:我们使用 transformers 库加载了一个预训练的文本分类模型和分词器。在这个例子中,模型是 THUDM/glm-2b,它被设计来处理自然语言的任务,比如文本分类。
  2. 初始化和设置:我们定义了一个 class_examples 字典,里面包含了不同类别的示例句子。然后,init_prompts 函数生成了一个包含类别信息和示例问答的字典,这些信息会帮助模型理解它的任务是什么。
  3. 推理过程
    • 对每个待分类的句子,使用分词器将其转换成模型可以理解的格式。
    • 将分词后的输入传入模型,获取模型输出(logits),并通过 softmax 转换为概率分布。
    • 选择概率最高的类别作为模型的预测结果,并计算预测的置信度(最大概率值)。
  4. 控制台输出:使用 rich 库输出美观的结果,包括:
    • 原始句子。
    • 模型预测的类别。
    • 模型的置信度(概率)。

仅供参考)