Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion actions/explanation/global_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@ def global_top_k(conversation, parse_text, i, **kwargs):
if first_argument in list(inverse_class_names.keys()):
class_idx = inverse_class_names[first_argument]

reverse = True

if "least" in parse_text:
reverse = False

print(reverse)

return topk(conversation, "ig_explainer", k,
data_path=f"./cache/{dataset_name}/ig_explainer_{dataset_name}_explanation.json",
res_path=f"./cache/{dataset_name}/ig_explainer_{dataset_name}_attribution.json",
print_with_pattern=True, class_idx=class_idx), 1
print_with_pattern=True, class_idx=class_idx, reverse=reverse), 1
37 changes: 26 additions & 11 deletions actions/explanation/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_results(explainer, data_path):
return results, model


def results_with_pattern(results):
def results_with_pattern(results, reverse):
"""
Output the results with certain pattern

Expand All @@ -65,18 +65,24 @@ def results_with_pattern(results):
"""
# example: dumb, fucking, and ugly are the most attributed for the hate speech label
if len(results) == 1:
return results[0][0] + " is the most attributed"
if reverse:
return results[0][0] + " is the most attributed"
else:
return results[0][0] + " is the least attributed"
else:
string = ""
for i in range(len(results) - 1):
string += results[i][0] + ", "
string += "and "
string += results[len(results) - 1][0]
return string + " are the most attributed."
if not reverse:
return string + " are the least attributed."
else:
return string + " are the most attributed."


def topk(conversation, explainer, k, threshold=-1, data_path="../../cache/boolq/ig_explainer_boolq_explanation.json",
res_path="../../cache/boolq/ig_explainer_boolq_attribution.json", print_with_pattern=True, class_idx=None):
res_path="../../cache/boolq/ig_explainer_boolq_attribution.json", print_with_pattern=True, class_idx=None, reverse=True):
"""
The operation to get most k important tokens

Expand All @@ -99,15 +105,24 @@ def topk(conversation, explainer, k, threshold=-1, data_path="../../cache/boolq/

if len(result_list) >= k:
if print_with_pattern:
return results_with_pattern(result_list[:k])
if reverse:
return results_with_pattern(result_list[:k], reverse)
else:
return results_with_pattern(result_list[::-1][:k], reverse)
else:
return result_list[:k]
if not reverse:
return result_list[::-1][:k]
else:
return result_list[:k]
else:
print("[Info] The length of score is smaller than k")
if print_with_pattern:
return results_with_pattern(result_list)
return results_with_pattern(result_list, reverse)
else:
return result_list
if not reverse:
return result_list[::-1]
else:
return result_list

if "boolq" in data_path:
results, model = get_results(explainer=explainer, data_path=data_path)
Expand Down Expand Up @@ -171,7 +186,7 @@ def topk(conversation, explainer, k, threshold=-1, data_path="../../cache/boolq/
if word_counter[word] >= threshold:
scores[word] = (word_attributions[word] / word_counter[word])

sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=reverse)

if not os.path.exists(res_path):
jsonString = json.dumps(sorted_scores)
Expand All @@ -181,12 +196,12 @@ def topk(conversation, explainer, k, threshold=-1, data_path="../../cache/boolq/

if len(sorted_scores) >= k:
if print_with_pattern:
return results_with_pattern(sorted_scores[:k])
return results_with_pattern(sorted_scores[:k], reverse)
else:
return sorted_scores[:k]
else:
print("[Info] The length of score is smaller than k")
if print_with_pattern:
return results_with_pattern(sorted_scores)
return results_with_pattern(sorted_scores, reverse)
else:
return sorted_scores
2 changes: 1 addition & 1 deletion logic/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
featureattribution: featureattributionword (allfeaturenames | allfeaturesword | topk | attrsentence)
featureattributionword: " nlpattribute"
allfeaturesword: " all"
topk: topkword ( {topkvalues} )
topk: topkword ( {topkvalues} ) ( reverse )
topkword: " topk"
attrsentence: " sentence"

Expand Down
6 changes: 6 additions & 0 deletions logic/sample_prompts_by_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
"cfe": "cfe.txt",
"augment": ["augmentation_chatgpt.txt", "augmentation.txt"],
"adversarial": ["adversarial_chatgpt.txt", "adversarial.txt"],
"include_prediction": "includes_+_predict.txt",
"include_countdata": "includes_+_countdata.txt",
"include_mistake": "includes_+_mistake.txt",
"include_score": "includes_+_score.txt",
"include_show": "includes_+_show.txt",
"include_label": "includes_+_label.txt"
}


Expand Down
5 changes: 4 additions & 1 deletion prompts/explanation/global_feature_importance.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,7 @@ User: thirty most important features
Parsed: important topk 30 [E]

User: display the 42 most important features
Parsed: important topk 42 [E]
Parsed: important topk 42 [E]

User: 3 least important features
Parsed: important topk 3 least [E]
110 changes: 104 additions & 6 deletions templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,71 @@ <h1 style="font: small-caps bold 20px/1 sans-serif; margin-top: 10px">🗃️ {{
</span>
</div>
</div>
<!--include buttons-->
<div class="rowc" id="include_btns" style="display: none">
<div class="control" style="display: table-cell; text-align: right; padding: 10px;"><b>Include</b></div>
<div class="tooltip-wrapper">
<button id="i_pred" class="idea-button" style="background: rgba(253, 186, 54, 0.96);"
onclick="sampleIncludePrediction()">
Dataset prediction
</button>
<span class="tooltip-text">
Lets the model predict the <b>filtered dataset</b> or any subset with <b>multiple instances</b>.
</span>
</div>
<div class="tooltip-wrapper">
<button id="i_label" class="idea-button" style="background: rgba(253, 186, 54, 0.96);"
onclick="sampleIncludeLabel()">
True label
</button>
<span class="tooltip-text">
Shows the <b>distribution</b> of the <b>true labels</b> in the <b>filtered</b> dataset
</span>
</div>
<div class="tooltip-wrapper">
<button id="i_mistake" class="idea-button"
style="background: rgba(253, 186, 54, 0.96);"
onclick="sampleIncludeMistake()">
Count mistake
</button>
<span class="tooltip-text">
Shows <b>how many</b> examples the model predicts a <b>wrong label</b> for.
</span>
</div>
<div class="tooltip-wrapper">
<button id="i_countdata" class="idea-button"
style="background: rgba(253, 186, 54, 0.96);"
onclick="sampleIncludeCountdata()">
Count data
</button>
<span class="tooltip-text">
Shows <b>how many</b> examples there are in the <b>filtered</b> data and what the range of <b>IDs</b> is.
</span>
</div>

<div class="tooltip-wrapper">
<button id="i_score" class="idea-button"
style="background: rgba(253, 186, 54, 0.96);"
onclick="sampleIncludeScore()">
Performance
</button>
<span class="tooltip-text">
Gets an overall <b>score</b> for the model performance on the <b>filtered </b> dataset (Accuracy, Precision, Recall, F1).
</span>
</div>

<div class="tooltip-wrapper">
<button id="i_show" class="idea-button"
style="background: rgba(253, 186, 54, 0.96);"
onclick="sampleIncludeShow()">
Show example
</button>
<span class="tooltip-text">
Displays a <b>single example</b> from the <b>filtered</b> dataset based on the <b>ID</b> you specify.
</span>
</div>
</div>

<div class="row row1" data-sort="abt_" style="display: table-row;">
<div class="control" style="display: table-cell; text-align: right; padding: 10px;"><b>About</b>
</div>
Expand Down Expand Up @@ -456,11 +521,19 @@ <h1 style="font: small-caps bold 20px/1 sans-serif; margin-top: 10px">🗃️ {{
</script>

<script>
function disableBtn(flag) {
function disableBtn(flag, type) {
if (!flag) {
document.getElementById("custom_input_btns").style.display = "table-row";
if(type === "custom_input"){
document.getElementById("custom_input_btns").style.display = "table-row";
} else{
document.getElementById("include_btns").style.display = "table-row";
}
} else {
document.getElementById("custom_input_btns").style.display = "none";
if(type === "custom_input"){
document.getElementById("custom_input_btns").style.display = "none";
} else{
document.getElementById("include_btns").style.display = "none";
}
}
}

Expand Down Expand Up @@ -498,6 +571,29 @@ <h1 style="font: small-caps bold 20px/1 sans-serif; margin-top: 10px">🗃️ {{
}
});
}
function sampleIncludeCountdata(){
doSample("include_countdata")
}

function sampleIncludeMistake(){
doSample("include_mistake")
}

function sampleIncludeLabel(){
doSample("include_label")
}

function sampleIncludeScore(){
doSample("include_score")
}

function sampleIncludeShow(){
doSample("include_show")
}

function sampleIncludePrediction(){
doSample("include_prediction")
}

function sampleKeyword() {
doSample("keyword");
Expand Down Expand Up @@ -683,21 +779,23 @@ <h1 style="font: small-caps bold 20px/1 sans-serif; margin-top: 10px">🗃️ {{

if (msgText == "quit") {
changeBtnColor(false);
disableBtn(true);
disableBtn(true, "custom_input");
}
} else if (document.getElementsByTagName("option")[idx].value === '1') {
selectBox.selectedIndex = "0";
botResponse(msgText, custom_input = '1');
changeBtnColor(true);
disableBtn(false);
disableBtn(false, "custom_input");
} else {
selectBox.selectedIndex = "0";
botResponse(msgText, custom_input = '2');
changeBtnColor(true);
disableBtn(false, "include");
}
});

function changeBtnColor(flag) {
let ids = ["c_feat-sent", "c_feat", "c_pred"]
let ids = ["c_feat-sent", "c_feat", "c_pred", "i_pred", "i_label", "i_mistake", "i_countdata", "i_score", "i_show"]

if (flag) {
ids.forEach(id => {
Expand Down