diff --git a/actions/explanation/global_topk.py b/actions/explanation/global_topk.py index 13342e77..c7b7485e 100644 --- a/actions/explanation/global_topk.py +++ b/actions/explanation/global_topk.py @@ -28,7 +28,14 @@ def global_top_k(conversation, parse_text, i, **kwargs): if first_argument in list(inverse_class_names.keys()): class_idx = inverse_class_names[first_argument] + reverse = True + + if "least" in parse_text: + reverse = False + + print(reverse) + return topk(conversation, "ig_explainer", k, data_path=f"./cache/{dataset_name}/ig_explainer_{dataset_name}_explanation.json", res_path=f"./cache/{dataset_name}/ig_explainer_{dataset_name}_attribution.json", - print_with_pattern=True, class_idx=class_idx), 1 + print_with_pattern=True, class_idx=class_idx, reverse=reverse), 1 diff --git a/actions/explanation/topk.py b/actions/explanation/topk.py index d8a94fd1..a7d738b8 100644 --- a/actions/explanation/topk.py +++ b/actions/explanation/topk.py @@ -54,7 +54,7 @@ def get_results(explainer, data_path): return results, model -def results_with_pattern(results): +def results_with_pattern(results, reverse): """ Output the results with certain pattern @@ -65,18 +65,24 @@ def results_with_pattern(results): """ # example: dumb, fucking, and ugly are the most attributed for the hate speech label if len(results) == 1: - return results[0][0] + " is the most attributed" + if reverse: + return results[0][0] + " is the most attributed" + else: + return results[0][0] + " is the least attributed" else: string = "" for i in range(len(results) - 1): string += results[i][0] + ", " string += "and " string += results[len(results) - 1][0] - return string + " are the most attributed." + if not reverse: + return string + " are the least attributed." + else: + return string + " are the most attributed." def topk(conversation, explainer, k, threshold=-1, data_path="../../cache/boolq/ig_explainer_boolq_explanation.json", - res_path="../../cache/boolq/ig_explainer_boolq_attribution.json", print_with_pattern=True, class_idx=None): + res_path="../../cache/boolq/ig_explainer_boolq_attribution.json", print_with_pattern=True, class_idx=None, reverse=True): """ The operation to get most k important tokens @@ -99,15 +105,24 @@ def topk(conversation, explainer, k, threshold=-1, data_path="../../cache/boolq/ if len(result_list) >= k: if print_with_pattern: - return results_with_pattern(result_list[:k]) + if reverse: + return results_with_pattern(result_list[:k], reverse) + else: + return results_with_pattern(result_list[::-1][:k], reverse) else: - return result_list[:k] + if not reverse: + return result_list[::-1][:k] + else: + return result_list[:k] else: print("[Info] The length of score is smaller than k") if print_with_pattern: - return results_with_pattern(result_list) + return results_with_pattern(result_list, reverse) else: - return result_list + if not reverse: + return result_list[::-1] + else: + return result_list if "boolq" in data_path: results, model = get_results(explainer=explainer, data_path=data_path) @@ -171,7 +186,7 @@ def topk(conversation, explainer, k, threshold=-1, data_path="../../cache/boolq/ if word_counter[word] >= threshold: scores[word] = (word_attributions[word] / word_counter[word]) - sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) + sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=reverse) if not os.path.exists(res_path): jsonString = json.dumps(sorted_scores) @@ -181,12 +196,12 @@ def topk(conversation, explainer, k, threshold=-1, data_path="../../cache/boolq/ if len(sorted_scores) >= k: if print_with_pattern: - return results_with_pattern(sorted_scores[:k]) + return results_with_pattern(sorted_scores[:k], reverse) else: return sorted_scores[:k] else: print("[Info] The length of score is smaller than k") if print_with_pattern: - return results_with_pattern(sorted_scores) + return results_with_pattern(sorted_scores, reverse) else: return sorted_scores diff --git a/logic/grammar.py b/logic/grammar.py index 393e6e64..054351ce 100644 --- a/logic/grammar.py +++ b/logic/grammar.py @@ -27,7 +27,7 @@ featureattribution: featureattributionword (allfeaturenames | allfeaturesword | topk | attrsentence) featureattributionword: " nlpattribute" allfeaturesword: " all" -topk: topkword ( {topkvalues} ) +topk: topkword ( {topkvalues} ) ( reverse ) topkword: " topk" attrsentence: " sentence" diff --git a/logic/sample_prompts_by_action.py b/logic/sample_prompts_by_action.py index b7eb707b..72258131 100644 --- a/logic/sample_prompts_by_action.py +++ b/logic/sample_prompts_by_action.py @@ -38,6 +38,12 @@ "cfe": "cfe.txt", "augment": ["augmentation_chatgpt.txt", "augmentation.txt"], "adversarial": ["adversarial_chatgpt.txt", "adversarial.txt"], + "include_prediction": "includes_+_predict.txt", + "include_countdata": "includes_+_countdata.txt", + "include_mistake": "includes_+_mistake.txt", + "include_score": "includes_+_score.txt", + "include_show": "includes_+_show.txt", + "include_label": "includes_+_label.txt" } diff --git a/prompts/explanation/global_feature_importance.txt b/prompts/explanation/global_feature_importance.txt index 200a48e9..0539a079 100644 --- a/prompts/explanation/global_feature_importance.txt +++ b/prompts/explanation/global_feature_importance.txt @@ -92,4 +92,7 @@ User: thirty most important features Parsed: important topk 30 [E] User: display the 42 most important features -Parsed: important topk 42 [E] \ No newline at end of file +Parsed: important topk 42 [E] + +User: 3 least important features +Parsed: important topk 3 least [E] \ No newline at end of file diff --git a/templates/index.html b/templates/index.html index 0dc621f4..3f846329 100644 --- a/templates/index.html +++ b/templates/index.html @@ -182,6 +182,71 @@

🗃️ {{ + + +
About
@@ -456,11 +521,19 @@

🗃️ {{