Skip to content
Merged
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ It provides support for the following machine learning frameworks and packages:
* sklearn-crfsuite_. ELI5 allows to check weights of sklearn_crfsuite.CRF
models.

* OpenAI_ python client. ELI5 allows to explain LLM predictions with token probabilities.

ELI5 also implements several algorithms for inspecting black-box models
(see `Inspecting Black-Box Estimators`_):
Expand Down Expand Up @@ -81,6 +82,7 @@ and formatting on a client.
.. _Catboost: https://github.com/catboost/catboost
.. _Permutation importance: https://eli5.readthedocs.io/en/latest/blackbox/permutation_importance.html
.. _Inspecting Black-Box Estimators: https://eli5.readthedocs.io/en/latest/blackbox/index.html
.. _OpenAI: https://github.com/openai/openai-python

License is MIT.

Expand Down
553 changes: 553 additions & 0 deletions docs/source/_notebooks/explain_llm_logprobs.rst

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/libraries/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ Supported Libraries
catboost
lightning
sklearn_crfsuite
openai
keras

71 changes: 71 additions & 0 deletions docs/source/libraries/openai.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
.. _library-openai:

OpenAI
======

OpenAI_ provides a client library for calling Large Language Models (LLMs).

.. _OpenAI: https://github.com/openai/openai-python

eli5 supports :func:`eli5.explain_prediction` for
``ChatCompletion``, ``ChoiceLogprobs`` and ``openai.Client`` objects,
highlighting tokens proportionally to the log probability,
which can help to see where model is less confident in it's predictions.
More likely tokens are highlighted in green,
while unlikely tokens are highlighted in red:

.. image:: ../static/llm-explain-logprobs.png
:alt: LLM token probabilities visualized

Explaining with a client, invoking the model with ``logprobs`` enabled:
::

import eli5
import opeanai
client = openai.Client()
prompt = 'some string' # or [{"role": "user", "content": "some string"}]
explanation = eli5.explain_prediction(client, prompt, model='gpt-4o')
explanation

You may pass any extra keyword arguments to :func:`eli5.explain_prediction`,
they would be passed to the ``client.chat.completions.create``,
e.g. you may pass ``n=2`` to get multiple responses
and see explanations for each of them.

You'd normally want to run it in a Jupyter notebook to see the explanation
formatted as HTML.

You can access the ``Choice`` object on the ``explanation.targets[0].target``:
::

explanation.targets[0].target.message.content

If you have already obtained a chat completion with ``logprobs`` from OpenAI client,
you may call :func:`eli5.explain_prediction` with
``ChatCompletion`` or ``ChoiceLogprobs`` like this:
::

chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="gpt-4o",
logprobs=True,
)
eli5.explain_prediction(chat_completion) # or
eli5.explain_prediction(chat_completion.choices[0].logprobs)


See the :ref:`tutorial <explain-llm-logprobs-tutorial>` for a more detailed usage
example.

.. note::
While token probabilities reflect model uncertainty in many cases,
they are not always indicative,
e.g. in case of `Chain of Thought <https://arxiv.org/abs/2201.11903>`_
preceding the final response.

.. note::
Top-level :func:`eli5.explain_prediction` calls are dispatched
to :func:`eli5.llm.explain_prediction.explain_prediction_openai_client`
or :func:`eli5.llm.explain_prediction.explain_prediction_openai_completion`
or :func:`eli5.llm.explain_prediction.explain_prediction_openai_logprobs`
.
2 changes: 2 additions & 0 deletions docs/source/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ following machine learning frameworks and packages:
* :ref:`library-sklearn-crfsuite`. ELI5 allows to check weights of
sklearn_crfsuite.CRF models.

* :ref:`library-openai`. ELI5 allows to explain LLM predictions with token probabilities.

ELI5 also implements several algorithms for inspecting black-box models
(see :ref:`eli5-black-box`):

Expand Down
Binary file added docs/source/static/llm-explain-logprobs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions docs/source/tutorials/explain_llm_logprobs.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.. _explain-llm-logprobs-tutorial:

.. note::

This tutorial can be run as an IPython notebook_.

.. _notebook: https://github.com/eli5-org/eli5/blob/master/notebooks/explain_llm_logprobs.ipynb

.. include:: ../_notebooks/explain_llm_logprobs.rst
15 changes: 14 additions & 1 deletion docs/update-notebooks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,17 @@ rm -r source/_notebooks/keras-image-classifiers_files
mv ../notebooks/keras-image-classifiers_files/ \
source/_notebooks/
sed -i 's&.. image:: keras-image-classifiers_files/&.. image:: ../_notebooks/keras-image-classifiers_files/&g' \
source/_notebooks/keras-image-classifiers.rst
source/_notebooks/keras-image-classifiers.rst


# LLM logprobs explain prediction tutorial
jupyter nbconvert \
--to rst \
--stdout \
'../notebooks/explain_llm_logprobs.ipynb' \
> source/_notebooks/explain_llm_logprobs.rst

sed -i '' 's/``eli5.explain_prediction``/:func:`eli5.explain_prediction`/g' \
source/_notebooks/explain_llm_logprobs.rst
sed -i '' 's/\/docs\/source//g' \
source/_notebooks/explain_llm_logprobs.rst
9 changes: 9 additions & 0 deletions eli5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,12 @@
except ImportError:
# keras is not available
pass

try:
from .llm.explain_prediction import (
explain_prediction_openai_logprobs,
explain_prediction_openai_client
)
except ImportError:
# openai not available
pass
10 changes: 7 additions & 3 deletions eli5/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Optional
from typing import Union, Optional, Sequence

import numpy as np

Expand Down Expand Up @@ -135,7 +135,7 @@ def __init__(self,
WeightedSpan = tuple[
Feature,
list[tuple[int, int]], # list of spans (start, end) for this feature
float, # feature weight
float, # feature weight or probability
]


Expand All @@ -147,16 +147,20 @@ class DocWeightedSpans:
:document:). :preserve_density: determines how features are colored
when doing formatting - it is better set to True for char features
and to False for word features.
:with_probabilities: would interpret weights as probabilities from 0 to 1,
using a more suitable color scheme.
"""
def __init__(self,
document: str,
spans: list[WeightedSpan],
spans: Sequence[WeightedSpan],
preserve_density: Optional[bool] = None,
with_probabilities: Optional[bool] = None,
vec_name: Optional[str] = None,
):
self.document = document
self.spans = spans
self.preserve_density = preserve_density
self.with_probabilities = with_probabilities
self.vec_name = vec_name


Expand Down
3 changes: 1 addition & 2 deletions eli5/formatters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
Functions to convert explanations to human-digestible formats.

Expand All @@ -24,4 +23,4 @@
)
except ImportError:
# Pillow or matplotlib not available
pass
pass
40 changes: 38 additions & 2 deletions eli5/formatters/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from jinja2 import Environment, PackageLoader
from scipy.special import betainc

from eli5 import _graphviz
from eli5.base import (Explanation, TargetExplanation, FeatureWeights,
Expand Down Expand Up @@ -164,7 +165,8 @@ def render_weighted_spans(pws: PreparedWeightedSpans) -> str:
return ''.join(
_colorize(''.join(t for t, _ in tokens_weights),
weight,
pws.weight_range)
pws.weight_range,
pws.doc_weighted_spans.with_probabilities)
for weight, tokens_weights in groupby(
zip(pws.doc_weighted_spans.document, pws.char_weights),
key=lambda x: x[1]))
Expand All @@ -173,12 +175,24 @@ def render_weighted_spans(pws: PreparedWeightedSpans) -> str:
def _colorize(token: str,
weight: float,
weight_range: float,
with_probabilities: Optional[bool],
) -> str:
""" Return token wrapped in a span with some styles
(calculated from weight and weight_range) applied.
"""
token = html_escape(token)
if np.isclose(weight, 0.):
if with_probabilities:
return (
'<span '
'style="background-color: {color}" '
'title="{weight:.3f}"'
'>{token}</span>'.format(
color=format_hsl(
probability_color_hsl(weight, weight_range)),
weight=weight,
token=token)
)
elif np.isclose(weight, 0.):
return (
'<span '
'style="opacity: {opacity}"'
Expand Down Expand Up @@ -225,6 +239,28 @@ def weight_color_hsl(weight: float, weight_range: float, min_lightness=0.8) -> _
return hue, saturation, lightness


def probability_color_hsl(probability: float, probability_range: float) -> _HSL_COLOR:
""" Return HSL color components for given probability,
where the max absolute probability is given by probability_range
(should always be 1 at the moment).
"""
hue = transformed_probability(probability / probability_range) * 120
saturation = 1
lightness = 0.5
return hue, saturation, lightness


def transformed_probability(prob: float, alpha: float = 0.4) -> float:
"""
Transforms a probability in [0, 1] using the Beta(alpha, alpha) CDF.
This function is symmetric about (0.5, 0.5) and raises sharply near 0 and 1,
highlighting differences in low and high probabilities.
The parameter 'alpha' controls the steepness.
"""
prob = max(0.0, min(1.0, prob))
return betainc(alpha, alpha, prob)


def format_hsl(hsl_color: _HSL_COLOR) -> str:
""" Format hsl color as css color string.
"""
Expand Down
Empty file added eli5/llm/__init__.py
Empty file.
99 changes: 99 additions & 0 deletions eli5/llm/explain_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import math
from typing import Union

import openai
from openai.types.chat.chat_completion import ChoiceLogprobs, ChatCompletion

from eli5.base import Explanation, TargetExplanation, WeightedSpans, DocWeightedSpans
from eli5.explain import explain_prediction


LOGPROBS_ESTIMATOR = 'llm_logprobs'


@explain_prediction.register(ChoiceLogprobs)
def explain_prediction_openai_logprobs(logprobs: ChoiceLogprobs, doc=None):
""" Creates an explanation of the logprobs
(available as ``.choices[idx].logprobs`` on a ChatCompletion object),
highlighting them proportionally to the log probability.
More likely tokens are highlighted in green,
while unlikely tokens are highlighted in red.
``doc`` argument is ignored.
"""
if logprobs.content is None:
raise ValueError('Predictions must be obtained with logprobs enabled')
text = ''.join(x.token for x in logprobs.content)
spans = []
idx = 0
for lp in logprobs.content:
token_len = len(lp.token)
spans.append((
f'{idx}-{lp.token}', # each token is a unique feature with it's own weight
[(idx, idx + token_len)],
math.exp(lp.logprob)))
idx += token_len
weighted_spans = WeightedSpans([
DocWeightedSpans(
document=text,
spans=spans,
preserve_density=False,
with_probabilities=True,
)
])
target_explanation = TargetExplanation(target=text, weighted_spans=weighted_spans)
return Explanation(
estimator=LOGPROBS_ESTIMATOR,
targets=[target_explanation],
)


@explain_prediction.register(ChatCompletion)
def explain_prediction_openai_completion(
chat_completion: ChatCompletion, doc=None):
""" Creates an explanation of the ChatCompletion's logprobs
highlighting them proportionally to the log probability.
More likely tokens are highlighted in green,
while unlikely tokens are highlighted in red.
``doc`` argument is ignored.
"""
targets = []
for choice in chat_completion.choices:
if choice.logprobs is None:
raise ValueError('Predictions must be obtained with logprobs enabled')
target, = explain_prediction_openai_logprobs(choice.logprobs).targets
target.target = choice
targets.append(target)
explanation = Explanation(
estimator=LOGPROBS_ESTIMATOR,
targets=targets,
)
return explanation


@explain_prediction.register(openai.Client)
def explain_prediction_openai_client(
client: openai.Client,
doc: Union[str, list[dict]],
*,
model: str,
**kwargs,
):
"""
Calls OpenAI client, obtaining response for ``doc`` (a string, or a list of messages),
with logprobs enabled, and explains the prediction,
highlighting tokens proportionally to the log probability.
More likely tokens are highlighted in green,
while unlikely tokens are highlighted in red.
. Other keyword arguments are passed to OpenAI client, with
``model`` keyword argument required.
"""
if isinstance(doc, str):
messages = [{"role": "user", "content": doc}]
else:
messages = doc
kwargs['logprobs'] = True
chat_completion = client.chat.completions.create(
messages=messages, # type: ignore
model=model,
**kwargs)
return explain_prediction_openai_completion(chat_completion)
4 changes: 1 addition & 3 deletions eli5/templates/weighted_spans.html
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,5 @@
{% endif %}

{% if rendered_ws %}
<p style="margin-bottom: 2.5em; margin-top:-0.5em;">
{{ rendered_ws }}
</p>
<p style="margin-bottom: 2.5em; margin-top:0; white-space: pre-wrap;">{{ rendered_ws }}</p>
{% endif %}
Loading