Transformer Encoder-based Forecasting Model¶
A ConvoKit Forecaster-adherent implementation of conversational forecasting model based on Transformer Encoder Model (e.g. BERT, RoBERTa, SpanBERT, DeBERTa). This class is first used in the paper “Conversations Gone Awry, But Then? Evaluating Conversational Forecasting Models”(Tran et al., 2025).
IMPORTANT NOTE: This implementation can, in fact, support any model compatible with HuggingFace’s AutoModelForSequenceClassification, including decoder-based models such as Gemma and LLaMA. However, we suggest using parameter-efficient fine-tuning (e.g., LoRA) techniques for large language models. To facilitate this, we provide a separate class specifically designed for decoder-based architectures.
-
class
convokit.forecaster.TransformerEncoderModel.
TransformerEncoderModel
(model_name_or_path, config=TransformerForecasterConfig(output_dir='TransformerEncoderModel', per_device_batch_size=4, gradient_accumulation_steps=1, num_train_epochs=1, learning_rate=6.7e-06, random_seed=1, device='cuda', context_mode='normal'))¶ A ConvoKit Forecaster-adherent implementation of conversational forecasting model based on Transformer Encoder Model (e.g. BERT, RoBERTa, SpanBERT, DeBERTa). This class is first used in the paper “Conversations Gone Awry, But Then? Evaluating Conversational Forecasting Models” (Tran et al., 2025).
- Parameters
model_name_or_path – The name or local path of the pretrained transformer model to load.
config – (Optional) TransformerForecasterConfig object containing parameters for training and evaluation.
-
fit
(contexts, val_contexts)¶ Fine-tune the TransformerEncoder model, and save the best model according to validation performance.
This method transforms the input contexts into model-compatible format, configures training parameters, and trains the model using HuggingFace’s Trainer API. It also tunes a decision threshold using a separate held-out validation set.
- Parameters
contexts – an iterator over context tuples, provided by the Forecaster framework
val_contexts – an iterator over context tuples to be used only for validation.
-
transform
(contexts, forecast_attribute_name, forecast_prob_attribute_name)¶ Generate forecasts using the fine-tuned TransformerEncoder model on the provided contexts, and save the predictions to the output directory specified in the configuration.
- Parameters
contexts – context tuples from the Forecaster framework
forecast_attribute_name – Forecaster will use this to look up the table column containing your model’s discretized predictions (see output specification below)
forecast_prob_attribute_name – Forecaster will use this to look up the table column containing your model’s raw forecast probabilities (see output specification below)
- Returns
a Pandas DataFrame, with one row for each context, indexed by the ID of that context’s current utterance. Contains two columns, one with raw probabilities named according to forecast_prob_attribute_name, and one with discretized (binary) forecasts named according to forecast_attribute_name