Custom Huggingface Pipeline

How to create a custom pipeline for a model.
transformers
huggingface
pytorch
Author

Thomas Capelle

Published

February 11, 2025

Packing a model as a HF model

A practical example of repacking Vectara’s hallucination model

We are using the great hallucination model form vectara, but the packaging is not optimal. I launched myself into the mission of repacking it as a HF model, specifically to show how to create a custom pipeline for a model. I have already done this in the past for Celadon great toxicity model, but this time I wanted to show how to do it for a text pair classification model.

Most of this post is based on the HF documentation

Quick intro, a HF model is a folder containing all the necessary code and checkpoint to run the model.

  • The weights of the model

  • The tokenizer

  • All the configuration files

  • The python files needed to perform inference correctly (If you are doing something non standard, etc…)

Let’s get started by loading the model:

from transformers import AutoModelForSequenceClassification

vectara_model = AutoModelForSequenceClassification.from_pretrained(
    "vectara/hallucination_evaluation_model", 
    trust_remote_code=True
    )
You are using a model of type HHEMv2Config to instantiate a model of type HHEMv2. This is not supported for all configurations of models and can yield errors.

First thing, the model config and model type don’t match. This is probably because the model was trained using a config and the repackaged and the model_type got renamed. It’s probably not an issue per-se, but let’s fix that.

vectara_model.config.model_type
'HHEMv2Config'

This is because in the config.json file the model_type="HHEMv2Config"

vectara_model.config
HHEMv2Config {
  "_attn_implementation_autoset": true,
  "_name_or_path": "vectara/hallucination_evaluation_model",
  "architectures": [
    "HHEMv2ForSequenceClassification"
  ],
  "auto_map": {
    "AutoConfig": "vectara/hallucination_evaluation_model--configuration_hhem_v2.HHEMv2Config",
    "AutoModelForSequenceClassification": "vectara/hallucination_evaluation_model--modeling_hhem_v2.HHEMv2ForSequenceClassification"
  },
  "id2label": {
    "0": "hallucinated",
    "1": "consistent"
  },
  "label2id": null,
  "model_type": "HHEMv2",
  "torch_dtype": "float32",
  "transformers_version": "4.48.1"
}

If we save the model again, the config will be updated.

vectara_model.save_pretrained("models/vectara")

If we load the new model back, no more warnings

vectara_model = AutoModelForSequenceClassification.from_pretrained(
    "models/vectara", trust_remote_code=True)
print(vectara_model.config.model_type)
HHEMv2

Model details

Let’s fix some of the custom logic happening inside this model.

Let’s load the model again and check some of the details

from transformers import AutoModelForSequenceClassification

vectara_model = AutoModelForSequenceClassification.from_pretrained(
    "models/vectara", trust_remote_code=True)

print(vectara_model)
HHEMv2ForSequenceClassification(
  (t5): T5ForTokenClassification(
    (transformer): T5EncoderModel(
      (shared): Embedding(32128, 768)
      (encoder): T5Stack(
        (embed_tokens): Embedding(32128, 768)
        (block): ModuleList(
          (0): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): Linear(in_features=768, out_features=768, bias=False)
                  (k): Linear(in_features=768, out_features=768, bias=False)
                  (v): Linear(in_features=768, out_features=768, bias=False)
                  (o): Linear(in_features=768, out_features=768, bias=False)
                  (relative_attention_bias): Embedding(32, 12)
                )
                (layer_norm): T5LayerNorm()
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (1): T5LayerFF(
                (DenseReluDense): T5DenseGatedActDense(
                  (wi_0): Linear(in_features=768, out_features=2048, bias=False)
                  (wi_1): Linear(in_features=768, out_features=2048, bias=False)
                  (wo): Linear(in_features=2048, out_features=768, bias=False)
                  (dropout): Dropout(p=0.1, inplace=False)
                  (act): NewGELUActivation()
                )
                (layer_norm): T5LayerNorm()
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
          )
          (1-11): 11 x T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): Linear(in_features=768, out_features=768, bias=False)
                  (k): Linear(in_features=768, out_features=768, bias=False)
                  (v): Linear(in_features=768, out_features=768, bias=False)
                  (o): Linear(in_features=768, out_features=768, bias=False)
                )
                (layer_norm): T5LayerNorm()
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (1): T5LayerFF(
                (DenseReluDense): T5DenseGatedActDense(
                  (wi_0): Linear(in_features=768, out_features=2048, bias=False)
                  (wi_1): Linear(in_features=768, out_features=2048, bias=False)
                  (wo): Linear(in_features=2048, out_features=768, bias=False)
                  (dropout): Dropout(p=0.1, inplace=False)
                  (act): NewGELUActivation()
                )
                (layer_norm): T5LayerNorm()
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
          )
        )
        (final_layer_norm): T5LayerNorm()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (dropout): Dropout(p=0.0, inplace=False)
    (classifier): Linear(in_features=768, out_features=2, bias=True)
  )
)

We can see that the model is a T5ForTokenClassification model, we out_features is 2, which is the number of classes in the model.

Another thing to note is that the tokenizer is loaded on the model’s init! This is not recommended at all. > Note: The model has a typo: tokenzier instead of tokenizer

vectara_model.tokenzier
T5TokenizerFast(name_or_path='google/flan-t5-base', vocab_size=32100, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>', '<extra_id_42>', '<extra_id_43>', '<extra_id_44>', '<extra_id_45>', '<extra_id_46>', '<extra_id_47>', '<extra_id_48>', '<extra_id_49>', '<extra_id_50>', '<extra_id_51>', '<extra_id_52>', '<extra_id_53>', '<extra_id_54>', '<extra_id_55>', '<extra_id_56>', '<extra_id_57>', '<extra_id_58>', '<extra_id_59>', '<extra_id_60>', '<extra_id_61>', '<extra_id_62>', '<extra_id_63>', '<extra_id_64>', '<extra_id_65>', '<extra_id_66>', '<extra_id_67>', '<extra_id_68>', '<extra_id_69>', '<extra_id_70>', '<extra_id_71>', '<extra_id_72>', '<extra_id_73>', '<extra_id_74>', '<extra_id_75>', '<extra_id_76>', '<extra_id_77>', '<extra_id_78>', '<extra_id_79>', '<extra_id_80>', '<extra_id_81>', '<extra_id_82>', '<extra_id_83>', '<extra_id_84>', '<extra_id_85>', '<extra_id_86>', '<extra_id_87>', '<extra_id_88>', '<extra_id_89>', '<extra_id_90>', '<extra_id_91>', '<extra_id_92>', '<extra_id_93>', '<extra_id_94>', '<extra_id_95>', '<extra_id_96>', '<extra_id_97>', '<extra_id_98>', '<extra_id_99>']}, clean_up_tokenization_spaces=False, added_tokens_decoder={
    0: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    1: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    2: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32000: AddedToken("<extra_id_99>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32001: AddedToken("<extra_id_98>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32002: AddedToken("<extra_id_97>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32003: AddedToken("<extra_id_96>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32004: AddedToken("<extra_id_95>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32005: AddedToken("<extra_id_94>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32006: AddedToken("<extra_id_93>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32007: AddedToken("<extra_id_92>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32008: AddedToken("<extra_id_91>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32009: AddedToken("<extra_id_90>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32010: AddedToken("<extra_id_89>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32011: AddedToken("<extra_id_88>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32012: AddedToken("<extra_id_87>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32013: AddedToken("<extra_id_86>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32014: AddedToken("<extra_id_85>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32015: AddedToken("<extra_id_84>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32016: AddedToken("<extra_id_83>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32017: AddedToken("<extra_id_82>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32018: AddedToken("<extra_id_81>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32019: AddedToken("<extra_id_80>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32020: AddedToken("<extra_id_79>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32021: AddedToken("<extra_id_78>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32022: AddedToken("<extra_id_77>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32023: AddedToken("<extra_id_76>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32024: AddedToken("<extra_id_75>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32025: AddedToken("<extra_id_74>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32026: AddedToken("<extra_id_73>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32027: AddedToken("<extra_id_72>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32028: AddedToken("<extra_id_71>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32029: AddedToken("<extra_id_70>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32030: AddedToken("<extra_id_69>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32031: AddedToken("<extra_id_68>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32032: AddedToken("<extra_id_67>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32033: AddedToken("<extra_id_66>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32034: AddedToken("<extra_id_65>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32035: AddedToken("<extra_id_64>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32036: AddedToken("<extra_id_63>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32037: AddedToken("<extra_id_62>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32038: AddedToken("<extra_id_61>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32039: AddedToken("<extra_id_60>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32040: AddedToken("<extra_id_59>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32041: AddedToken("<extra_id_58>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32042: AddedToken("<extra_id_57>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32043: AddedToken("<extra_id_56>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32044: AddedToken("<extra_id_55>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32045: AddedToken("<extra_id_54>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32046: AddedToken("<extra_id_53>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32047: AddedToken("<extra_id_52>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32048: AddedToken("<extra_id_51>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32049: AddedToken("<extra_id_50>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32050: AddedToken("<extra_id_49>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32051: AddedToken("<extra_id_48>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32052: AddedToken("<extra_id_47>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32053: AddedToken("<extra_id_46>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32054: AddedToken("<extra_id_45>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32055: AddedToken("<extra_id_44>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32056: AddedToken("<extra_id_43>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32057: AddedToken("<extra_id_42>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32058: AddedToken("<extra_id_41>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32059: AddedToken("<extra_id_40>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32060: AddedToken("<extra_id_39>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32061: AddedToken("<extra_id_38>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32062: AddedToken("<extra_id_37>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32063: AddedToken("<extra_id_36>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32064: AddedToken("<extra_id_35>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32065: AddedToken("<extra_id_34>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32066: AddedToken("<extra_id_33>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32067: AddedToken("<extra_id_32>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32068: AddedToken("<extra_id_31>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32069: AddedToken("<extra_id_30>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32070: AddedToken("<extra_id_29>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32071: AddedToken("<extra_id_28>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32072: AddedToken("<extra_id_27>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32073: AddedToken("<extra_id_26>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32074: AddedToken("<extra_id_25>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32075: AddedToken("<extra_id_24>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32076: AddedToken("<extra_id_23>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32077: AddedToken("<extra_id_22>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32078: AddedToken("<extra_id_21>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32079: AddedToken("<extra_id_20>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32080: AddedToken("<extra_id_19>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32081: AddedToken("<extra_id_18>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32082: AddedToken("<extra_id_17>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32083: AddedToken("<extra_id_16>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32084: AddedToken("<extra_id_15>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32085: AddedToken("<extra_id_14>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32086: AddedToken("<extra_id_13>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32087: AddedToken("<extra_id_12>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32088: AddedToken("<extra_id_11>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32089: AddedToken("<extra_id_10>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32090: AddedToken("<extra_id_9>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32091: AddedToken("<extra_id_8>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32092: AddedToken("<extra_id_7>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32093: AddedToken("<extra_id_6>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32094: AddedToken("<extra_id_5>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32095: AddedToken("<extra_id_4>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32096: AddedToken("<extra_id_3>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32097: AddedToken("<extra_id_2>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32098: AddedToken("<extra_id_1>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    32099: AddedToken("<extra_id_0>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

It is also a standard T5 tokenizer. One thing we can do, is save this tokenizer with the model as well.

vectara_model.tokenzier.save_pretrained("models/vectara")
('weave_models/vectara/tokenizer_config.json',
 'weave_models/vectara/special_tokens_map.json',
 'weave_models/vectara/spiece.model',
 'weave_models/vectara/added_tokens.json',
 'weave_models/vectara/tokenizer.json')

If we check the Readme on the file, it is recommended to use the model using this custom predict function built-in with the model.

vectara_model.predict??
Signature: vectara_model.predict(text_pairs)
Docstring: <no docstring>
Source:   
    def predict(self, text_pairs):
        tokenizer = self.tokenzier
        pair_dict = [{'text1': pair[0], 'text2': pair[1]} for pair in text_pairs]
        inputs = tokenizer(
            [self.prompt.format(**pair) for pair in pair_dict], return_tensors='pt', padding=True).to(self.t5.device)
        self.t5.eval()
        with torch.no_grad():
            outputs = self.t5(**inputs)
        logits = outputs.logits    
        logits = logits[:, 0, :] # tok_cls
        transformed_probs = torch.softmax(logits, dim=-1)
        raw_scores = transformed_probs[:, 1] # the probability of class 1
        return raw_scores
File:      ~/.cache/huggingface/modules/transformers_modules/vectara/hallucination_evaluation_model/b3973afb9f9595a40bb8403b46c6dac9c26d16d5/modeling_hhem_v2.py
Type:      method

As you can see, the predict function is a custom text pair classifier that: 1. Takes two text inputs 2. Applies a prompt template to format them 3. Uses the CLS token’s logits for final classification

We are going to repack this as a transformers pipeline. But first, let’s actually extract the underlying T5 model and save it instead, this will override the HemmV2 model that is loaded on the AutoModelForSequenceClassification model.

vectara_model.t5.save_pretrained("models/vectara")

Now we can load the model as a T5ForTokenClassification model. No more custom model!

from transformers import T5ForTokenClassification, AutoTokenizer

t5_model = T5ForTokenClassification.from_pretrained("models/vectara")
tokenizer = AutoTokenizer.from_pretrained("models/vectara")

If we pack this model just like that, it will not work as intended. We need to write the logic related to the predict function. This way we can perform inference correctly.

Writing a custom pipeline

We’ll create a pipeline that: 1. Handles text pair classification 2. Uses the T5ForTokenClassification model 3. Maintains compatibility with the original model’s behavior

from transformers import AutoTokenizer, Pipeline
import torch

class PairTextClassificationPipeline(Pipeline):
    def __init__(self, model, tokenizer=None, **kwargs):
        # Initialize tokenizer first
        if tokenizer is None:
            tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
        # Make sure we store the tokenizer before calling super().__init__
        self.tokenizer = tokenizer
        super().__init__(model=model, tokenizer=tokenizer, **kwargs)
        self.prompt = """<pad> Determine if the hypothesis is true given the premise?
        
        Premise: {text1}
        
        Hypothesis: {text2}"""
        
    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        return preprocess_kwargs, {}, {}

    def preprocess(self, inputs):
        # Expect inputs to be list of (Premise, Hypothesis) tuples
        pair_dict = {'text1': inputs[0], 'text2': inputs[1]}
        formatted_prompt = self.prompt.format(**pair_dict)
        model_inputs = self.tokenizer(
            formatted_prompt,
            return_tensors='pt', 
            padding=True
        )
        return model_inputs

    def _forward(self, model_inputs):
        model_outputs = self.model(**model_inputs)
        return model_outputs

    def postprocess(self, model_outputs):
        logits = model_outputs.logits
        logits = logits[:, 0, :] # tok_cls
        transformed_probs = torch.softmax(logits, dim=-1)
        raw_scores = transformed_probs[:, 1] # probability of class 1
        return raw_scores.item()

We can manually test the pipeline:

pipe = PairTextClassificationPipeline(t5_model, tokenizer=tokenizer)
Device set to use mps:0
pairs = [ # Test data, List[Tuple[str, str]]
    ("The capital of France is Berlin.", "The capital of France is Paris."),
    ('I am in California', 'I am in United States.'),
    ('I am in United States', 'I am in California.'),
    ("A person on a horse jumps over a broken down airplane.", 
     "A person is outdoors, on a horse."),
    ("A boy is jumping on skateboard in the middle of a red bridge.", 
     "The boy skates down the sidewalk on a red bridge"),
    ("A man with blond-hair, and a brown shirt drinking out of a public water fountain.", 
     "A blond man wearing a brown shirt is reading a book."),
    ("Mark Wahlberg was a fan of Manny.", "Manny was a fan of Mark Wahlberg.")
]

Let’s compare the GT values

scores = pipe(pairs)

# the ground truth scores
gt = [0.011061512865126133, 0.6473632454872131, 0.1290171593427658, 
      0.8969419002532959, 0.18462494015693665, 0.005031010136008263, 
      0.05432349815964699]

assert all(abs(s - g) < 1e-5 for s, g in zip(scores, gt))

Ok, but this pipeline is not yet saved within the model. Let’s save the PairTextClassificationPipeline to a file next to this notebook custom_pipeline.py. If you open config.json you will see that this entry get’s added:

  "custom_pipelines": {
    "pair-classification": {
      "impl": "custom_pipeline.PairTextClassificationPipeline",
      "pt": [
        "AutoModelForTokenClassification"
      ],
      "tf": []
    }
  },
from custom_pipeline import PairTextClassificationPipeline
from transformers.pipelines import PIPELINE_REGISTRY
from transformers import AutoModelForTokenClassification


# this adds the pipeline to the model =)
PIPELINE_REGISTRY.register_pipeline(
    "pair-classification",
    pipeline_class=PairTextClassificationPipeline,
    pt_model=AutoModelForTokenClassification,
)

And now if you save, the custom_pipeline.py is also being saved!

pipe.save_pretrained("models/hallu_scorer")
pipe.push_to_hub("tcapelle/hallu_scorer") # optional
No files have been modified since last commit. Skipping to prevent empty commit.
CommitInfo(commit_url='https://huggingface.co/tcapelle/hallu_scorer/commit/7da223761c65862d2c5222c19b97b66a1a50824a', commit_message='Upload PairTextClassificationPipeline', commit_description='', oid='7da223761c65862d2c5222c19b97b66a1a50824a', pr_url=None, repo_url=RepoUrl('https://huggingface.co/tcapelle/hallu_scorer', endpoint='https://huggingface.co', repo_type='model', repo_id='tcapelle/hallu_scorer'), pr_revision=None, pr_num=None)

Loading back and using the model

from transformers import pipeline

pipe = pipeline(
    "pair-classification", 
    model="tcapelle/hallu_scorer", 
    trust_remote_code=True)

score = pipe(
    ("The capital of France is Berlin.", 
     "The capital of France is Paris.")
     )
print(score)
Device set to use mps:0
0.01106148399412632