|
1 | 1 | Custom Dataset Example |
2 | 2 | ====================== |
3 | 3 |
|
4 | | -This example mirrors the dataset-generation style used in the test suite: mix |
5 | | -continuous features, ordered discrete features, binary flags, and unordered |
6 | | -categorical features, parse them with OCEAN, train a tree ensemble, and |
7 | | -explain one instance. |
| 4 | +This page turns the synthetic custom-dataset example into a notebook-style |
| 5 | +walkthrough. It starts from a raw pandas dataframe, parses the feature types |
| 6 | +with OCEAN, trains a random forest, and explains one class-``0`` prediction |
| 7 | +with the constraint-programming backend. |
8 | 8 |
|
9 | 9 | Why this example matters |
10 | 10 | ------------------------ |
11 | 11 |
|
12 | | -The packaged dataset loaders are convenient, but most real integrations start |
13 | | -from a custom pandas dataframe. This example shows the full path from raw data |
14 | | -to a readable counterfactual without depending on any external dataset. |
| 12 | +The packaged dataset loaders are useful when you want a quick start, but most |
| 13 | +real integrations begin with a dataframe that you already own. This example |
| 14 | +shows the full workflow on mixed feature types: |
15 | 15 |
|
16 | | -Running the example |
17 | | -------------------- |
| 16 | +- ordered discrete values through ``credit_lines``, |
| 17 | +- binary flags through ``owns_home`` and ``has_guarantor``, |
| 18 | +- continuous ratios through ``income_ratio``, ``debt_ratio``, and |
| 19 | + ``savings_ratio``, |
| 20 | +- unordered nominal features through ``job_type`` and ``region``. |
18 | 21 |
|
19 | | -.. code-block:: bash |
| 22 | +Cell 1: Build a mixed-type dataframe |
| 23 | +------------------------------------ |
20 | 24 |
|
21 | | - python examples/custom_dataset.py |
| 25 | +.. code-block:: python |
22 | 26 |
|
23 | | -What it does |
24 | | ------------- |
| 27 | + import numpy as np |
| 28 | + import pandas as pd |
25 | 29 |
|
26 | | -1. Generates a synthetic credit-style dataset with multiple feature types. |
27 | | -2. Uses :func:`ocean.feature.parse_features` to build the processed matrix and |
28 | | - mapper. |
29 | | -3. Trains a random forest on that processed matrix. |
30 | | -4. Selects a query that the model predicts as class ``0``. |
31 | | -5. Uses ``ocean.ConstraintProgrammingExplainer`` to search for the closest |
32 | | - class-``1`` counterfactual. |
33 | | -6. Prints both the original raw instance and the decoded explanation. |
| 30 | + rng = np.random.default_rng(42) |
| 31 | + raw = pd.DataFrame({ |
| 32 | + "credit_lines": rng.choice([0, 1, 2, 4], size=300), |
| 33 | + "owns_home": rng.integers(0, 2, size=300), |
| 34 | + "has_guarantor": rng.integers(0, 2, size=300), |
| 35 | + "income_ratio": rng.uniform(-0.4, 0.8, size=300), |
| 36 | + "debt_ratio": rng.uniform(0.0, 1.0, size=300), |
| 37 | + "savings_ratio": rng.uniform(-0.5, 0.6, size=300), |
| 38 | + "job_type": rng.choice( |
| 39 | + ["office", "manual", "service", "student"], |
| 40 | + size=300, |
| 41 | + ), |
| 42 | + "region": rng.choice( |
| 43 | + ["north", "south", "east", "west"], |
| 44 | + size=300, |
| 45 | + ), |
| 46 | + }) |
34 | 47 |
|
35 | | -In this example, ``credit_lines`` is treated as an ordered discrete feature, |
36 | | -while ``job_type`` and ``region`` are treated as unordered categories and are |
37 | | -therefore one-hot encoded. |
| 48 | + score = ( |
| 49 | + (raw["credit_lines"] >= 2).astype(int) |
| 50 | + + raw["owns_home"].astype(int) |
| 51 | + + raw["has_guarantor"].astype(int) |
| 52 | + + (raw["income_ratio"] > 0.1).astype(int) |
| 53 | + + (raw["savings_ratio"] > 0.0).astype(int) |
| 54 | + + raw["job_type"].isin(["office", "service"]).astype(int) |
| 55 | + + raw["region"].isin(["north", "east"]).astype(int) |
| 56 | + - (raw["debt_ratio"] > 0.55).astype(int) |
| 57 | + ) |
| 58 | + target = (score >= 4).astype(int).rename("approved") |
38 | 59 |
|
39 | | -Source |
40 | | ------- |
| 60 | +Cell 2: Parse the features with OCEAN |
| 61 | +------------------------------------- |
| 62 | + |
| 63 | +.. code-block:: python |
| 64 | +
|
| 65 | + from ocean.feature import parse_features |
| 66 | +
|
| 67 | + data, mapper = parse_features(raw, discretes=("credit_lines",)) |
| 68 | + print(data.columns) |
| 69 | +
|
| 70 | +.. code-block:: text |
| 71 | +
|
| 72 | + MultiIndex([( 'credit_lines', ''), |
| 73 | + ( 'owns_home', ''), |
| 74 | + ( 'has_guarantor', ''), |
| 75 | + ( 'income_ratio', ''), |
| 76 | + ( 'debt_ratio', ''), |
| 77 | + ( 'savings_ratio', ''), |
| 78 | + ( 'job_type', 'manual'), |
| 79 | + ( 'job_type', 'office'), |
| 80 | + ( 'job_type', 'service'), |
| 81 | + ( 'job_type', 'student'), |
| 82 | + ( 'region', 'east'), |
| 83 | + ( 'region', 'north'), |
| 84 | + ( 'region', 'south'), |
| 85 | + ( 'region', 'west')], |
| 86 | + ) |
| 87 | +
|
| 88 | +The important part is that ``credit_lines`` stays ordered and numeric, while |
| 89 | +``job_type`` and ``region`` expand into one-hot blocks. |
| 90 | + |
| 91 | +Cell 3: Fit a classifier and choose a query |
| 92 | +------------------------------------------- |
| 93 | + |
| 94 | +.. code-block:: python |
| 95 | +
|
| 96 | + import pandas as pd |
| 97 | + from sklearn.ensemble import RandomForestClassifier |
| 98 | +
|
| 99 | + model = RandomForestClassifier( |
| 100 | + n_estimators=40, |
| 101 | + max_depth=4, |
| 102 | + random_state=42, |
| 103 | + ) |
| 104 | + model.fit(data, target) |
| 105 | +
|
| 106 | + predictions = pd.Series(model.predict(data), index=data.index) |
| 107 | + query_index = predictions[predictions == 0].index[0] |
| 108 | + query = data.loc[query_index].to_numpy(dtype=float).flatten() |
| 109 | + query_frame = data.loc[[query_index]] |
| 110 | + raw_query = raw.loc[query_index] |
| 111 | +
|
| 112 | + print(raw_query) |
| 113 | + print() |
| 114 | + print("Model prediction:", int(model.predict(query_frame).item())) |
| 115 | +
|
| 116 | +.. code-block:: text |
| 117 | +
|
| 118 | + credit_lines 0 |
| 119 | + owns_home 0 |
| 120 | + has_guarantor 1 |
| 121 | + income_ratio -0.349179 |
| 122 | + debt_ratio 0.260349 |
| 123 | + savings_ratio 0.234634 |
| 124 | + job_type student |
| 125 | + region west |
| 126 | + Name: 0, dtype: object |
| 127 | +
|
| 128 | + Model prediction: 0 |
| 129 | +
|
| 130 | +Cell 4: Explain the query |
| 131 | +------------------------- |
| 132 | + |
| 133 | +.. code-block:: python |
| 134 | +
|
| 135 | + from ocean import ConstraintProgrammingExplainer |
| 136 | +
|
| 137 | + explainer = ConstraintProgrammingExplainer(model, mapper=mapper) |
| 138 | + explanation = explainer.explain( |
| 139 | + query, |
| 140 | + y=1, |
| 141 | + norm=1, |
| 142 | + max_time=10, |
| 143 | + num_workers=1, |
| 144 | + random_seed=42, |
| 145 | + ) |
| 146 | + if explanation is None: |
| 147 | + raise RuntimeError("No counterfactual was found for the synthetic example.") |
| 148 | +
|
| 149 | + counterfactual_frame = pd.DataFrame( |
| 150 | + [explanation.to_numpy()], |
| 151 | + columns=data.columns, |
| 152 | + ) |
| 153 | +
|
| 154 | + print("Target class:", 1) |
| 155 | + print("Counterfactual prediction:", int(model.predict(counterfactual_frame).item())) |
| 156 | +
|
| 157 | +.. code-block:: text |
| 158 | +
|
| 159 | + Target class: 1 |
| 160 | + Counterfactual prediction: 1 |
| 161 | +
|
| 162 | +Cell 5: Inspect the decoded explanation |
| 163 | +--------------------------------------- |
| 164 | + |
| 165 | +.. code-block:: python |
| 166 | +
|
| 167 | + print(explanation) |
| 168 | +
|
| 169 | +.. code-block:: text |
| 170 | +
|
| 171 | + Explanation: |
| 172 | + credit_lines : 0.0 |
| 173 | + owns_home : 0 |
| 174 | + has_guarantor : 1 |
| 175 | + income_ratio : -0.2833683341741562 |
| 176 | + debt_ratio : -0.1887158378958702 |
| 177 | + savings_ratio : 0.29842646420001984 |
| 178 | + job_type : student |
| 179 | + region : north |
| 180 | +
|
| 181 | +This decoded view is usually the most readable one: categorical one-hot blocks |
| 182 | +are mapped back to labels, and the keys match the original dataframe columns. |
| 183 | + |
| 184 | +Cell 6: Inspect the processed vector and the final distance |
| 185 | +----------------------------------------------------------- |
| 186 | + |
| 187 | +.. code-block:: python |
| 188 | +
|
| 189 | + print(explanation.to_series()) |
| 190 | + print() |
| 191 | + print("Distance:", explainer.get_distance()) |
| 192 | +
|
| 193 | +.. code-block:: text |
| 194 | +
|
| 195 | + credit_lines 0.000000 |
| 196 | + owns_home 0.000000 |
| 197 | + has_guarantor 1.000000 |
| 198 | + income_ratio -0.283368 |
| 199 | + debt_ratio -0.188716 |
| 200 | + savings_ratio 0.298426 |
| 201 | + job_type manual 0.000000 |
| 202 | + office 0.000000 |
| 203 | + service 0.000000 |
| 204 | + student 1.000000 |
| 205 | + region east 0.000000 |
| 206 | + north 1.000000 |
| 207 | + south 0.000000 |
| 208 | + west 0.000000 |
| 209 | + dtype: float64 |
| 210 | +
|
| 211 | + Distance: 1.3566914800296987 |
| 212 | +
|
| 213 | +``get_distance()`` is the user-facing metric to report here: it reconstructs |
| 214 | +the post-processed :math:`L_1` distance between the original query and the |
| 215 | +decoded counterfactual, including the half-weight treatment for one-hot blocks. |
| 216 | + |
| 217 | +Full script |
| 218 | +----------- |
| 219 | + |
| 220 | +If you want the exact runnable version behind this page: |
41 | 221 |
|
42 | 222 | .. literalinclude:: ../examples/custom_dataset.py |
43 | 223 | :language: python |
|
0 commit comments