Forecaster¶
The Forecaster class provides a generic interface to conversational forecasting models, a class of models designed to computationally capture the trajectory of conversations in order to predict future events. Though individual conversational forecasting models can get quite complex, the Forecaster API abstracts away the implementation details into a standard fit-transform interface.
For end users of Forecaster: see the demo notebook which uses Forecaster to fine-tune the CRAFT forecasting model on the CGA-CMV corpus
For developers of conversational forecasting models: Forecaster also represents a common framework for conversational forecasting that you can use, in conjunction with other ML/NLP ecosystems like PyTorch and Huggingface, to streamline the development of your models! You can create your conversational forecasting model as a subclass of ForecasterModel, which can then be directly “plugged in” to the Forecaster wrapper which will provide a standard fit-transform interface to your model. At runtime, Forecaster will feed a temporally-ordered stream of conversational data to your ForecasterModel in the form of “context tuples”. Context tuples are generated in chronological order, simulating the notion that the model is following the conversation as it develops in real time and generating a new prediction every time a new utterance appears (e.g., in a social media setting, every time a new comment is posted). Each context tuple, in turn, is defined as a NamedTuple with the following fields:
context: a chronological list of Utterances up to and including the most recent Utterance at the time this context was generated. Beyond the chronological ordering, no structure of any kind is imposed on the Utterances, so developers of conversational forecasting models are free to perform any structuring of their own that they desire (so yes, if you want, you can build conversational graphs on top of the provided context!)current_utterance: the most recent utterance at the time this context tuple was generated. In the vast majority of cases, this will be identical to the last utterance in the context, except in cases where that utterance might have gotten filtered out of the context by the preprocessor (in those cases, current_utterance still reflects the “missing” most recent utterance, in order to provide a reference point for where we currently are in the conversation)future_context: during training only (i.e., in the fit function), the context tuple also includes this additional field that lists all future Utterances; that is, all Utterances chronologically after the current utterance (or an empty list if this Utterance is the last one). This is meant only to help with data preprocessing and selection during training; for example, CRAFT trains only on the last context in each conversation, so we need to look at future_context to know whether we are at the end of the conversation. It should not be used as input to the model, as that would be “cheating” - in fact, to enforce this, future_context is not available during evaluation (i.e. in the transform function) so that any model that improperly made use of future_context would crash during evaluation!conversation_id: the Conversation that this context-reply pair came from. ForecasterModel also has access to Forecaster’s labeler function and can use that together with the conversation_id to look up the label
Illustrative example, a conversation containing utterances [a, b, c, d] (in temporal order) will produce the following four context tuples, in this exact order:
#. (context=[a], current_utterance=a, future_context=[b,c,d])
#. (context=[a,b], current_utterance=b, future_context=[c,d])
#. (context=[a,b,c], current_utterance=c, future_context=[d])
#. (context=[a,b,c,d], current_utterance=d, future_context=[])
Belief estimation and decision policies¶
ConvoKit splits conversational forecasting into two steps:
Belief estimation — the forecaster model assigns a continuous score to each conversational context (typically a probability that a target event will occur). This is implemented by
ForecasterModel.score().Decision policy — a separate component converts that score into a binary intervention decision (intervene now, or wait). This is implemented by
DecisionPolicy.decide().
Separating belief from action lets you change when to intervene—for example simple thresholding, look-ahead deferral, or simulation-based voting—without retraining or modifying the underlying forecaster. Each ForecasterModel owns one decision policy (default: ThresholdDecisionPolicy), exposed via the decision_policy property.
At inference time, the model scores the current context, then the policy decides whether to act. During training, fit() can train both components; you can also call fit_belief_estimator() and fit_decision_policy() separately.
For policy types, API details, and implementation notes, see Decision Policy.
-
class
convokit.forecaster.forecaster.ContextTuple(context, current_utterance, future_context, conversation_id)¶ -
context¶ Alias for field number 0
-
conversation_id¶ Alias for field number 3
-
current_utterance¶ Alias for field number 1
-
future_context¶ Alias for field number 2
-
-
class
convokit.forecaster.forecaster.Forecaster(forecaster_model: convokit.forecaster.forecasterModel.ForecasterModel, labeler: Union[Callable[[convokit.model.conversation.Conversation], int], str], context_preprocessor: Optional[Callable[[List[convokit.model.utterance.Utterance]], List[convokit.model.utterance.Utterance]]] = None, forecast_attribute_name: str = 'forecast', forecast_prob_attribute_name: str = 'forecast_prob')¶ A wrapper class that provides a consistent, Transformer-style interface to any conversational forecasting model. From a user perspective, this makes it easy to apply forecasting models to ConvoKit corpora and evaluate them without having to know a lot about the inner workings of conversational forecasting, and to swap between different kinds of models without having to change a lot of code. From a developer perspective, this provides a prebuilt foundation upon which new conversational forecasting models can be easily developed, as the Forecaster class handles to complicated work of iterating over conversational contexts in temporal fashion, allowing the developer to focus only on writing the code to handle each conversational context.
- Parameters
forecaster_model – An instance of a ForecasterModel subclass that implements the conversational forecasting model you want to use. ConvoKit provides CRAFT and BERT implementations.
labeler – A function that specifies where/how to find the label for any given conversation. Alternatively, a string can be provided, in which case it will be interpreted as the name of a Conversation metadata field containing the label.
context_preprocessor – An optional function that allows simple preprocessing of conversational contexts. Note that this should NOT be used to perform any restructuring or feature engineering on the data (that work is considered the exclusive purview of the underlying ForecasterModel); instead, it is intended to perform simple Corpus-specific data cleaning steps (i.e., removing utterances that lack key metadata required by the model)
forecast_attribute_name – metadata feature name to use in annotation for forecast result, default: “forecast”
forecast_prob_attribute_name – metadata feature name to use in annotation for forecast result probability, default: “forecast_prob”
-
fit(corpus: convokit.model.corpus.Corpus, context_selector: Callable[[convokit.forecaster.forecaster.ContextTuple], bool] = <function Forecaster.<lambda>>, val_context_selector: Optional[Callable[[convokit.forecaster.forecaster.ContextTuple], bool]] = None)¶ Wrapper method for training the underlying conversational forecasting model. Forecaster itself does not implement any actual training logic. Instead, it handles the job of selecting and iterating over context tuples. The resulting iterator is presented as a parameter to the fit method of the underlying model, which can process the tuples however it sees fit. Within each tuple, context is unstructured - it contains all utterances temporally preceding the most recent utterance, plus that most recent utterance itself, but does not impose any particular structure beyond that, allowing each conversational forecasting model to decide how it wants to define “context”.
- Parameters
corpus – The Corpus containing the data to train on
context_selector – A function that takes in a context tuple and returns a boolean indicator of whether it should be included in training data. This can be used to both select data based on splits (i.e. keep only those in the “train” split) and to specify special behavior of what contexts are looked at in training (i.e. in CRAFT where only the last context, directly preceding the toxic comment, is used in training).
val_context_selector – An optional function that mirrors context_selector but is used to create a separate held-out validation set
- Returns
fitted Forecaster Transformer
-
fit_transform(corpus: convokit.model.corpus.Corpus, context_selector: Callable[[convokit.forecaster.forecaster.ContextTuple], bool] = <function Forecaster.<lambda>>) → convokit.model.corpus.Corpus¶ Convenience method for running fit and transform on the same data
- Parameters
corpus – the Corpus containing the data to run on
context_selector – A function that takes in a context tuple and returns a boolean indicator of whether it should be included. Excluded contexts will simply not have a forecast.
- Returns
annotated Corpus
-
summarize(corpus: convokit.model.corpus.Corpus, selector: Callable[[convokit.model.conversation.Conversation], bool] = <function Forecaster.<lambda>>)¶ Compute and display conversation-level performance metrics over a Corpus that has already been annotated by transform
- Parameters
corpus – the Corpus containing the forecasts to evaluate
selector – A filtering function to limit the conversations the metrics are computed over. Note that unlike the context_selectors used in fit and transform, this selector operates on conversations (since evaluation is conversation-level).
-
transform(corpus: convokit.model.corpus.Corpus, context_selector: Callable[[convokit.forecaster.forecaster.ContextTuple], bool] = <function Forecaster.<lambda>>, **kwargs) → convokit.model.corpus.Corpus¶ Wrapper method for applying the underlying conversational forecasting model to make forecasts over the Conversations in a given Corpus. Like the fit method, this simply acts to create an iterator over context tuples to be transformed, and forwards the iterator to the underlying conversational forecasting model to do the actual forecasting.
- Parameters
corpus – the Corpus containing the data to run on
context_selector – A function that takes in a context tuple and returns a boolean indicator of whether it should be included. Excluded contexts will simply not have a forecast.
- Returns
annotated Corpus
-
class
convokit.forecaster.forecasterModel.ForecasterModel(decision_policy=None, **kwargs)¶ An abstract class defining an interface that Forecaster can call into to invoke a conversational forecasting algorithm. The “contract” between Forecaster and ForecasterModel means that ForecasterModel can expect to receive conversational data in a consistent format, defined above.
-
abstract
fit(contexts, val_contexts=None)¶ Train this conversational forecasting model on the given data by fitting both the belief estimator and the decision policy.
- Parameters
contexts – an iterator over context tuples
val_contexts – an optional second iterator over context tuples to be used as a separate held-out validation set. Concrete ForecasterModel implementations may choose to ignore this, or conversely even enforce its presence.
-
abstract
fit_belief_estimator(contexts, val_contexts=None)¶ Fit only the belief estimator component that produces continuous scores.
-
fit_decision_policy(contexts, val_contexts=None, score_fn: Callable = None)¶ Fit only the decision policy component.
-
abstract
score(context) → float¶ Produce the belief estimator score for a context.
-
abstract
transform(contexts, forecast_attribute_name, forecast_prob_attribute_name)¶ Apply this trained conversational forecasting model to the given data, and return its forecasts in the form of a DataFrame indexed by (current) utterance ID
- Parameters
contexts – an iterator over context tuples
- Returns
a Pandas DataFrame, with one row for each context, indexed by the ID of that context’s current utterance. Contains two columns, one with raw probabilities named according to forecast_prob_attribute_name, and one with discretized (binary) forecasts named according to forecast_attribute_name. Subclass implementations of ForecasterModel MUST adhere to this return value specification!
-
abstract
This two-stage design is introduced in Wait! There’s a Way Out.
Forecaster Model¶
These are subclasses of ForecasterModel, each implementing forecasting models using different model architectures or families.
The following table is the current leaderboard comparing the performance of different forecaster models following a uniform evaluation framework described in Tran et al., 2025. If you want to include the performance of another model in this leaderboard, make a pull request with the respective ForecasterModel class and with the version of this demo that generates the respective new leaderboard line.
The DeferralDecisionPolicy is based upon the forecasting approach described in Wait! There’s a Way Out. The SimulationAverageDecisionPolicy is based upon the forecasting approach described in Simulation-based Decision Making for Dialogue Intervention in the static forecasting task, and is adapted for the non-static forecasting task in Wait! There’s a Way Out.
Unless otherwise specified, the performance is reported using the ThresholdDecisionPolicy.
Model |
Decision Policy |
Acc ↑ |
P ↑ |
R ↑ |
F1 ↑ |
FPR ↓ |
Mean H ↑ |
Recovery ↑ |
|---|---|---|---|---|---|---|---|---|
Gemma2 9B |
ThresholdDecisionPolicy |
71.0 |
69.1 |
76.1 |
72.3 |
34.2 |
3.9 |
+1.8 (8.4 - 6.6) |
Gemma2 9B |
DeferralDecisionPolicy |
70.9 |
72.0 |
68.4 |
70.1 |
26.7 |
3.8 |
-0.1 (7.0 - 7.1) |
Gemma2 9B |
SimulationAverageDecisionPolicy |
70.2 |
68.1 |
76.6 |
72.0 |
36.1 |
4.0 |
-1.2 (9.3 - 10.5) |
Mistral 7B |
ThresholdDecisionPolicy |
70.7 |
68.8 |
76.0 |
72.1 |
34.6 |
4.0 |
+2.9 (8.1 - 5.2) |
Phi4 14B |
ThresholdDecisionPolicy |
70.5 |
67.7 |
78.4 |
72.6 |
37.5 |
4.0 |
+2.0 (7.7 - 5.7) |
LlaMa3.1 8B |
ThresholdDecisionPolicy |
70.0 |
68.8 |
73.2 |
70.9 |
33.2 |
4.0 |
+1.7 (7.3 - 5.6) |
DeBERTaV3-large |
ThresholdDecisionPolicy |
68.9 |
67.3 |
73.7 |
70.3 |
36.0 |
4.2 |
+1.1 (7.6 - 6.5) |
RoBERTa-large |
ThresholdDecisionPolicy |
68.6 |
67.1 |
73.4 |
70.0 |
36.1 |
4.2 |
+1.6 (7.5 - 5.9) |
RoBERTa-base |
ThresholdDecisionPolicy |
68.1 |
67.3 |
70.6 |
68.8 |
34.4 |
4.2 |
+0.7 (7.4 - 6.7) |
DeBERTaV3-base |
ThresholdDecisionPolicy |
67.9 |
66.7 |
71.4 |
69.0 |
35.7 |
4.2 |
+1.5 (7.2 - 5.7) |
SpanBERT-large |
ThresholdDecisionPolicy |
67.0 |
65.8 |
70.5 |
68.1 |
36.6 |
4.2 |
+1.3 (8.3 - 7.0) |
SpanBERT-base |
ThresholdDecisionPolicy |
66.4 |
64.7 |
72.0 |
68.2 |
39.3 |
4.4 |
+1.7 (9.6 - 8.0) |
BERT-large |
ThresholdDecisionPolicy |
65.7 |
66.0 |
65.4 |
65.5 |
34.1 |
4.2 |
+0.4 (7.8 - 7.3) |
BERT-base |
ThresholdDecisionPolicy |
65.3 |
64.1 |
70.1 |
66.9 |
39.5 |
4.4 |
+1.9 (9.7 - 7.8) |
CRAFT |
ThresholdDecisionPolicy |
62.8 |
59.4 |
81.1 |
68.5 |
55.5 |
4.7 |
+4.9 (12.0 - 7.1) |
Table 1: Forecasting derailment on CGA-CMV-large conversations. The performance is measured in accuracy (Acc), precision (P), recall (R), F1, false positive rate (FPR), mean horizon (Mean H), and Forecast Recovery (Recovery) along with the correct and incorrect recovery rates. Results are reported as averages over five runs with different random seeds.
Model |
Decision Policy |
Acc ↑ |
P ↑ |
R ↑ |
F1 ↑ |
FPR ↓ |
Mean H ↑ |
Recovery ↑ |
|---|---|---|---|---|---|---|---|---|
Gemma2 9B |
ThresholdDecisionPolicy |
69.2 |
67.5 |
75.3 |
70.9 |
36.9 |
3.6 |
+0.9 (4.1 - 3.2) |
Phi4 14B |
ThresholdDecisionPolicy |
68.8 |
69.5 |
67.1 |
68.2 |
29.6 |
3.3 |
+0.8 (3.7 - 2.9) |
LlaMa3.1 8B |
ThresholdDecisionPolicy |
68.5 |
66.3 |
75.6 |
70.5 |
38.7 |
3.6 |
+1.8 (5.5 - 3.7) |
RoBERTa-large |
ThresholdDecisionPolicy |
68.2 |
67.8 |
69.7 |
68.6 |
33.3 |
3.6 |
+0.3 (3.9 - 3.5) |
SpanBERT-large |
ThresholdDecisionPolicy |
67.9 |
66.5 |
72.6 |
69.3 |
36.7 |
3.6 |
+0.1 (4.9 - 4.8) |
Mistral 7B |
ThresholdDecisionPolicy |
67.8 |
65.9 |
74.4 |
69.8 |
38.8 |
3.8 |
+1.1 (5.1 - 4.0) |
DeBERTaV3-large |
ThresholdDecisionPolicy |
67.8 |
66.9 |
70.9 |
68.7 |
35.3 |
3.7 |
+0.8 (3.8 - 3.0) |
RoBERTa-base |
ThresholdDecisionPolicy |
67.6 |
65.7 |
73.9 |
69.5 |
38.6 |
3.6 |
+0.5 (3.4 - 2.8) |
DeBERTaV3-base |
ThresholdDecisionPolicy |
67.5 |
67.0 |
69.2 |
68.0 |
34.3 |
3.6 |
+0.5 (2.7 - 2.3) |
SpanBERT-base |
ThresholdDecisionPolicy |
66.7 |
66.1 |
68.7 |
67.3 |
35.2 |
3.3 |
-0.7 (4.5 - 5.2) |
BERT-base |
ThresholdDecisionPolicy |
66.5 |
66.5 |
66.3 |
66.4 |
33.4 |
3.6 |
-1.6 (5.6 - 7.2) |
BERT-large |
ThresholdDecisionPolicy |
65.7 |
65.6 |
67.0 |
66.0 |
35.6 |
3.6 |
+0.0 (5.6 - 5.6) |
CRAFT |
ThresholdDecisionPolicy |
64.8 |
63.4 |
70.1 |
66.5 |
40.5 |
3.5 |
+0.4 (3.7 - 2.9) |
Table 2: Forecasting derailment on CGA-Wikiconv conversations. The performance is measured in accuracy (Acc), precision (P), recall (R), F1, false positive rate (FPR), mean horizon (Mean H), and Forecast Recovery (Recovery) along with the correct and incorrect recovery rates. Results are reported as averages over five runs with different random seeds.
For more information on how to produce a leaderboard string here, see the Run Transformer Fine-tuned Models.ipynb notebook. Forecaster.evaluate() prints a Leaderboard String via format_leaderboard_row(); paste that row into table 1 above (model names up to 44 characters). If a name is longer, widen the first column separator in every row of table 1 to match. If you would like to include your model in the leaderboard, please make a pull request adding the respective ForecasterModel and the version of the demo generating the leaderboard line. Please contact us on Discord for assistance.