From 26528bc297a1eacbbe129a7451e6661420ba7ca8 Mon Sep 17 00:00:00 2001 From: unseenme Date: Tue, 16 Apr 2024 00:50:50 +0900 Subject: [PATCH 1/3] [added] gen and check script slow --- tools/cinn/monkey/check.py | 57 ++++++++++ tools/cinn/monkey/gen.py | 166 +++++++++++++++++++++++++++++ tools/cinn/monkey/gen_and_check.sh | 19 ++++ 3 files changed, 242 insertions(+) create mode 100644 tools/cinn/monkey/check.py create mode 100644 tools/cinn/monkey/gen.py create mode 100644 tools/cinn/monkey/gen_and_check.sh diff --git a/tools/cinn/monkey/check.py b/tools/cinn/monkey/check.py new file mode 100644 index 00000000000000..a175617af53ece --- /dev/null +++ b/tools/cinn/monkey/check.py @@ -0,0 +1,57 @@ +import os +import shutil +import subprocess +import time + +def process_py_files_until_file_exists(directory, flag_file): + """ + Continuously process Python files in the specified directory until the flag file exists. + + Args: + - directory (str): The directory to search for Python files. + - flag_file (str): The name of the flag file to check for existence. + + Returns: + - int: The number of times the test function was called. + """ + selection_dir = os.path.join(directory, "selection") + if not os.path.exists(selection_dir): + os.makedirs(selection_dir) + cnt = 0 + cnt_slc = 0 + while not check_file_existence(directory, flag_file): + py_files = [file for file in os.listdir(directory) if file.endswith(".py")] + + if py_files: + for py_file in py_files: + file_path = os.path.join(directory, py_file) + try: + subprocess.run(["python", file_path], check=True) + # print("rm") + os.remove(file_path) + except subprocess.CalledProcessError: + # print("move") + shutil.move(file_path, selection_dir) + cnt_slc += 1 + cnt += 1 + print(f"\rTested: {cnt}. Selection: {cnt_slc}.", end="") + # os.remove(file_path) + else: + time.sleep(0.05) # Sleep for 50 milliseconds (adjust as needed) + +def check_file_existence(directory, filename): + """ + Check if the specified file exists in the given directory. + + Args: + - directory (str): The directory to search for the file. + - filename (str): The name of the file to check for. + + Returns: + - bool: True if the file exists, False otherwise. + """ + file_path = os.path.join(directory, filename) + return os.path.isfile(file_path) + +if __name__ == "__main__": + call_count = process_py_files_until_file_exists("/dev/shm/test", "stop") diff --git a/tools/cinn/monkey/gen.py b/tools/cinn/monkey/gen.py new file mode 100644 index 00000000000000..c2aa43a1ec212a --- /dev/null +++ b/tools/cinn/monkey/gen.py @@ -0,0 +1,166 @@ +from dataclasses import dataclass +from collections import namedtuple +from typing import Generator, Union +from defensive_list import DList +from dag_generator import PickWeight +from script import Script +import dag_generator +import dim_eq1_generator as dim_eq1_generator +import dims_eq1_generator as dims_eq1_generator +import op_name_generator as op_name_generator +import tensor_name_generator as tensor_name_generator +from tensor_name_generator import TensorNameGenRequirement +import shape_signature_inferer as shape_signature_inferer +from shape_signature_inferer import StaticDim +import instruction_util as instruction_util +import op_call_code_gen +from paddle_eager_generator import PaddleEagerGenerator +from numpy_generator import NumpyGenerator +from unit_test_case_spec import ( + UnitTestCaseRequirement, + UnitTestCaseSpec, + GenerateRandomUnitTestCaseSpec, + GetAblatedUnitTestCaseSpec +) +import os +import sys +import time + +def GenerateUnitTestCaseSpec( + unit_test_case_requirement: UnitTestCaseRequirement +) -> UnitTestCaseSpec: + return GenerateRandomUnitTestCaseSpec( + requirement=unit_test_case_requirement + ) + +def get_content(): + unit_test_case_requirement=UnitTestCaseRequirement( + dag_gen_requirement=dag_generator.DAGGenRequirement( + min_num_sources=0, + max_num_sources=0, + pick_probability=dag_generator.DAGGenTypePickProbability() + ), + dims_eq1_gen_requirement=dims_eq1_generator.DimsEq1GenRequirement( + dims_eq1_probability=[0.1, 0.15, 0.2] + ), + op_name_gen_requirement=op_name_generator.OpNameGenRequirement(), + tensor_name_gen_requirement=TensorNameGenRequirement(), + dim_size_requirement=shape_signature_inferer.DimSizeRequirement( + dim_size=[StaticDim(128), StaticDim(64), StaticDim(32)] + ) + ) + unit_test_case_spec = GenerateUnitTestCaseSpec( + unit_test_case_requirement=unit_test_case_requirement + ) + + # numpy_gen = NumpyGenerator() + # script = numpy_gen.Generate(unit_test_case_spec.patched_instruction_code_gen_spec) + + paddle_eager_gen = PaddleEagerGenerator() + script = paddle_eager_gen.Generate(unit_test_case_spec.patched_instruction_code_gen_spec) + # print("import numpy") + # print(script.file_content) + return script.file_content + +counter = 0 + +def check_file_existence(directory, filename): + """ + Check if the specified file exists in the given directory. + + Args: + - directory (str): The directory to search for the file. + - filename (str): The name of the file to check for. + + Returns: + - bool: True if the file exists, False otherwise. + """ + file_path = os.path.join(directory, filename) + return os.path.isfile(file_path) + +def count_python_files(directory): + """ + Count the number of Python files in the specified directory. + + Args: + - directory (str): The directory to search for Python files. + + Returns: + - int: The number of Python files in the directory. + """ + python_files = [file for file in os.listdir(directory) if file.endswith(".py")] + return len(python_files) + +def generate_test_cases_until_file_exists(directory, flag_file, max_python_files): + """ + Continuously generate test cases until the specified file exists. + + Args: + - directory (str): The directory to search for Python files. + - flag_file (str): The name of the file to check for existence. + - max_python_files (int): The maximum number of Python files allowed before pausing. + + Returns: + - None + """ + while not check_file_existence(directory, flag_file): + python_file_count = count_python_files(directory) + + if python_file_count < max_python_files: + print("py", python_file_count) + gen(directory) # Call your existing gen() function to generate test cases + time.sleep(0.05) + else: + time.sleep(0.05) + +def modify_assert_line(text): + lines = text.split("\n") + text = "" + for line in lines: + if line.startswith("assert"): + parts = line.split(" ") + parts[1] = "tuple(" + parts[1] + ")" + line = " ".join(parts) + text += line + "\n" + return text + +def gen(path): + # Placeholder for the gen() function. Implement your test case generation logic here. + # print("Generating test case...") + global counter + while True: + file_name = os.path.join(path, f"test_case_b_{counter}.py") + counter += 1 + if not os.path.isfile(file_name): + break + print("gen", counter-1) + with open(file_name, "w") as f: + f.write("import paddle") + f.write(modify_assert_line(get_content())) + # f.write("import numpy") + # f.write(get_content()) + +def process_arguments(): + if len(sys.argv) < 4: + arg_dir = "/dev/shm/test" + arg_flag = "stop" + arg_max = 10 + else: + arg_dir = sys.argv[1] + arg_flag = sys.argv[2] + arg_max = int(sys.argv[3]) + + return arg_dir, arg_flag, arg_max + +def main(): + directory, flag_file, max_python_files = process_arguments() + print("dir:", directory) + print("flag:", flag_file) + print("max:", max_python_files) + + if not os.path.exists(directory): + os.makedirs(directory) + generate_test_cases_until_file_exists(directory, flag_file, max_python_files) + +if __name__ == "__main__": + main() diff --git a/tools/cinn/monkey/gen_and_check.sh b/tools/cinn/monkey/gen_and_check.sh new file mode 100644 index 00000000000000..5684fc99bfafe2 --- /dev/null +++ b/tools/cinn/monkey/gen_and_check.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# Copyright (c) 2021 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +nohup python3 gen.py & + +python3 check.py From 1b7a190684fae3276b0f129d422832625c5af0ce Mon Sep 17 00:00:00 2001 From: unseenme Date: Fri, 19 Apr 2024 09:23:46 +0900 Subject: [PATCH 2/3] [improved] generate cinn test style case by edit string --- tools/cinn/monkey/gen.py | 235 ++++++++++++++++++++++----------------- 1 file changed, 130 insertions(+), 105 deletions(-) diff --git a/tools/cinn/monkey/gen.py b/tools/cinn/monkey/gen.py index c2aa43a1ec212a..147c91d0b513ea 100644 --- a/tools/cinn/monkey/gen.py +++ b/tools/cinn/monkey/gen.py @@ -22,9 +22,61 @@ GenerateRandomUnitTestCaseSpec, GetAblatedUnitTestCaseSpec ) -import os -import sys -import time +import datetime + +content_head = """ +import unittest +import numpy as np +import paddle + + +class CINNCosSubGraphNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + tensor1 = x +""" + +content_body = """ +class TestCinnCos(unittest.TestCase): + def setUp(self): + paddle.seed(2024) + self.prepare_data() + + def prepare_data(self): +""" + +content_tail = """ + self.x.stop_gradient = True + + def apply_to_static(self, net, use_cinn, input_spec=None): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static( + net, + input_spec=input_spec, + build_strategy=build_strategy, + full_graph=True, + ) + + def train(self, use_cinn): + net = CINNCosSubGraphNet() + net.eval() + net = self.apply_to_static(net, use_cinn) + out = net(self.x) + return out + + def test_train(self): + cinn_out = self.train(use_cinn=True) + dy_out = self.train(use_cinn=False) + + np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-6) + + +if __name__ == '__main__': + unittest.main() +""" def GenerateUnitTestCaseSpec( unit_test_case_requirement: UnitTestCaseRequirement @@ -33,7 +85,18 @@ def GenerateUnitTestCaseSpec( requirement=unit_test_case_requirement ) -def get_content(): +def modify_assert_line(text): + lines = text.split("\n") + text = "" + for line in lines: + if line.startswith("assert"): + parts = line.split(" ") + parts[1] = "tuple(" + parts[1] + ")" + line = " ".join(parts) + text += line + "\n" + return text + +if __name__ == '__main__': unit_test_case_requirement=UnitTestCaseRequirement( dag_gen_requirement=dag_generator.DAGGenRequirement( min_num_sources=0, @@ -54,113 +117,75 @@ def get_content(): ) # numpy_gen = NumpyGenerator() - # script = numpy_gen.Generate(unit_test_case_spec.patched_instruction_code_gen_spec) paddle_eager_gen = PaddleEagerGenerator() - script = paddle_eager_gen.Generate(unit_test_case_spec.patched_instruction_code_gen_spec) + # script = numpy_gen.Generate(unit_test_case_spec.patched_instruction_code_gen_spec) # print("import numpy") # print(script.file_content) - return script.file_content - -counter = 0 - -def check_file_existence(directory, filename): - """ - Check if the specified file exists in the given directory. - - Args: - - directory (str): The directory to search for the file. - - filename (str): The name of the file to check for. - - Returns: - - bool: True if the file exists, False otherwise. - """ - file_path = os.path.join(directory, filename) - return os.path.isfile(file_path) - -def count_python_files(directory): - """ - Count the number of Python files in the specified directory. - - Args: - - directory (str): The directory to search for Python files. - - Returns: - - int: The number of Python files in the directory. - """ - python_files = [file for file in os.listdir(directory) if file.endswith(".py")] - return len(python_files) - -def generate_test_cases_until_file_exists(directory, flag_file, max_python_files): - """ - Continuously generate test cases until the specified file exists. - - Args: - - directory (str): The directory to search for Python files. - - flag_file (str): The name of the file to check for existence. - - max_python_files (int): The maximum number of Python files allowed before pausing. - - Returns: - - None - """ - while not check_file_existence(directory, flag_file): - python_file_count = count_python_files(directory) - - if python_file_count < max_python_files: - print("py", python_file_count) - gen(directory) # Call your existing gen() function to generate test cases - time.sleep(0.05) - else: - time.sleep(0.05) -def modify_assert_line(text): - lines = text.split("\n") - text = "" + script = paddle_eager_gen.Generate(unit_test_case_spec.patched_instruction_code_gen_spec) + # print("import paddle") + # print(script.file_content) + net_content = modify_assert_line(script.file_content) + + outpu_dir = "/home/aistudio/data" + current_datetime = datetime.datetime.now() + curr_time = current_datetime.strftime("%Y-%m-%d_%H-%M-%S") + with open(outpu_dir + "/" + curr_time + "_content.py", "w") as f: + f.write(net_content) + + lines = net_content.strip().split('\n') + input_shape = None for line in lines: - if line.startswith("assert"): - parts = line.split(" ") - parts[1] = "tuple(" + parts[1] + ")" - line = " ".join(parts) - text += line + "\n" - return text + if line.startswith('tensor1'): + start_index = line.find("((") + end_index = line.find("))") + if start_index != -1 and end_index != -1: + input_shape = line[start_index + 2:end_index] + break -def gen(path): - # Placeholder for the gen() function. Implement your test case generation logic here. - # print("Generating test case...") - global counter - while True: - file_name = os.path.join(path, f"test_case_b_{counter}.py") - counter += 1 - if not os.path.isfile(file_name): + out_tensor_name = None + for line in reversed(lines): + if not line.startswith('assert'): + out_tensor_name = line.split()[0] break - print("gen", counter-1) + + # current_datetime = datetime.datetime.now() + file_name = outpu_dir + "/" + curr_time + "_testcase.py" with open(file_name, "w") as f: - f.write("import paddle") - f.write(modify_assert_line(get_content())) - # f.write("import numpy") - # f.write(get_content()) - -def process_arguments(): - if len(sys.argv) < 4: - arg_dir = "/dev/shm/test" - arg_flag = "stop" - arg_max = 10 - else: - arg_dir = sys.argv[1] - arg_flag = sys.argv[2] - arg_max = int(sys.argv[3]) - - return arg_dir, arg_flag, arg_max - -def main(): - directory, flag_file, max_python_files = process_arguments() - print("dir:", directory) - print("flag:", flag_file) - print("max:", max_python_files) - - if not os.path.exists(directory): - os.makedirs(directory) - generate_test_cases_until_file_exists(directory, flag_file, max_python_files) - -if __name__ == "__main__": - main() + f.write(content_head) + for line in lines: + if not line.startswith('tensor1 '): + f.write(" " + line + "\n") + f.write(" return " + out_tensor_name + "\n\n") + f.write(content_body) + f.write(' self.x = paddle.uniform([' + input_shape + '], dtype="float32", min=-0.5, max=0.5)') + f.write(content_tail) + + +# ablated_unit_test_case_spec = GetAblatedUnitTestCaseSpec( +# instructions=unit_test_case_spec.instructions, +# requirement=unit_test_case_requirement, +# bottom_up_ablation_size=-1, +# component_ablation_size=-1, +# ) +# print("#", "*"*80) +# print("# full ablated") +# print("#", "*"*80) +# script = numpy_gen.Generate(ablated_unit_test_case_spec.patched_instruction_code_gen_spec) +# print("import paddle") +# print(script.file_content) +# +# +# ablated_unit_test_case_spec = GetAblatedUnitTestCaseSpec( +# instructions=unit_test_case_spec.instructions, +# requirement=unit_test_case_requirement, +# bottom_up_ablation_size=len(unit_test_case_spec.instructions)/2, +# component_ablation_size=-1, +# ) +# print("#", "*"*80) +# print("# half ablated") +# print("#", "*"*80) +# script = numpy_gen.Generate(ablated_unit_test_case_spec.patched_instruction_code_gen_spec) +# print("import paddle") +# print(script.file_content) From 8ad279f215d8dc2fa8fa6adf9a3e2a9299a89f7c Mon Sep 17 00:00:00 2001 From: unseenme Date: Wed, 24 Apr 2024 08:46:03 +0900 Subject: [PATCH 3/3] [improved] input tensor; added env flags; no output when check --- tools/cinn/monkey/check.py | 2 +- tools/cinn/monkey/gen.py | 217 +++++++++++++++++++++++++------------ 2 files changed, 151 insertions(+), 68 deletions(-) diff --git a/tools/cinn/monkey/check.py b/tools/cinn/monkey/check.py index a175617af53ece..66fb8bf85ea4f7 100644 --- a/tools/cinn/monkey/check.py +++ b/tools/cinn/monkey/check.py @@ -26,7 +26,7 @@ def process_py_files_until_file_exists(directory, flag_file): for py_file in py_files: file_path = os.path.join(directory, py_file) try: - subprocess.run(["python", file_path], check=True) + subprocess.run(["python", file_path], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # print("rm") os.remove(file_path) except subprocess.CalledProcessError: diff --git a/tools/cinn/monkey/gen.py b/tools/cinn/monkey/gen.py index 147c91d0b513ea..44d73d24d49a1f 100644 --- a/tools/cinn/monkey/gen.py +++ b/tools/cinn/monkey/gen.py @@ -23,23 +23,33 @@ GetAblatedUnitTestCaseSpec ) import datetime +import os +import sys +import time content_head = """ +import os +os.environ['FLAGS_cinn_new_group_scheduler'] = '1' +os.environ['FLAGS_group_schedule_tiling_first'] = '1' +os.environ['FLAGS_prim_all'] = 'true' +os.environ['FLAGS_prim_enable_dynamic'] = '1' +os.environ['FLAGS_print_ir'] = '1' +os.environ['FLAGS_enable_pir_api'] = '1' +os.environ['FLAGS_cinn_bucket_compile'] = '1' + import unittest import numpy as np import paddle -class CINNCosSubGraphNet(paddle.nn.Layer): +class CinnMonkeyNet(paddle.nn.Layer): def __init__(self): super().__init__() - def forward(self, x): - tensor1 = x """ content_body = """ -class TestCinnCos(unittest.TestCase): +class TestCinnMonkey(unittest.TestCase): def setUp(self): paddle.seed(2024) self.prepare_data() @@ -48,8 +58,6 @@ def prepare_data(self): """ content_tail = """ - self.x.stop_gradient = True - def apply_to_static(self, net, use_cinn, input_spec=None): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = use_cinn @@ -61,10 +69,12 @@ def apply_to_static(self, net, use_cinn, input_spec=None): ) def train(self, use_cinn): - net = CINNCosSubGraphNet() + net = CinnMonkeyNet() net.eval() net = self.apply_to_static(net, use_cinn) - out = net(self.x) +""" + +content_main = """ return out def test_train(self): @@ -85,18 +95,34 @@ def GenerateUnitTestCaseSpec( requirement=unit_test_case_requirement ) -def modify_assert_line(text): - lines = text.split("\n") - text = "" - for line in lines: - if line.startswith("assert"): - parts = line.split(" ") - parts[1] = "tuple(" + parts[1] + ")" - line = " ".join(parts) - text += line + "\n" - return text +def generate_content_neck(tensors_list): + num_tensors = len(tensors_list) + forward_function = " def forward(self" + for i in range(num_tensors): + forward_function += f", x{i+1}" + forward_function += "):\n" -if __name__ == '__main__': + for i, tensor_name in enumerate(tensors_list): + forward_function += f" {tensor_name} = x{i+1}\n" + forward_function += "\n" + + return forward_function + +def generate_content_prepare(shape_list): + content = '' + for i, shape in enumerate(shape_list): + content += f' self.x{i+1} = paddle.uniform([{shape}], dtype="float32", min=-0.5, max=0.5)\n' + content += f' self.x{i+1}.stop_gradient = True\n' + return content + +def generate_content_train(num): + content = ' out = net(' + for i in range(num): + content += f'self.x{i+1}, ' + content = content[:-2] + ')' + return content + +def get_content_and_write(dir): unit_test_case_requirement=UnitTestCaseRequirement( dag_gen_requirement=dag_generator.DAGGenRequirement( min_num_sources=0, @@ -126,66 +152,123 @@ def modify_assert_line(text): script = paddle_eager_gen.Generate(unit_test_case_spec.patched_instruction_code_gen_spec) # print("import paddle") # print(script.file_content) - net_content = modify_assert_line(script.file_content) + lines = script.file_content.strip().split('\n') - outpu_dir = "/home/aistudio/data" - current_datetime = datetime.datetime.now() - curr_time = current_datetime.strftime("%Y-%m-%d_%H-%M-%S") - with open(outpu_dir + "/" + curr_time + "_content.py", "w") as f: - f.write(net_content) - - lines = net_content.strip().split('\n') - input_shape = None + input_names = [] + input_shapes = [] + net_lines = [] + is_input_line = True for line in lines: - if line.startswith('tensor1'): - start_index = line.find("((") - end_index = line.find("))") - if start_index != -1 and end_index != -1: - input_shape = line[start_index + 2:end_index] - break + if is_input_line: + if line.startswith('#'): + is_input_line = False + net_lines.append(line) + if line.startswith('tensor'): + input_names.append(line.split()[0]) + start_index = line.find("((") + end_index = line.find("))") + if start_index != -1 and end_index != -1: + input_shapes.append(line[start_index + 2:end_index]) + else: + net_lines.append(line) out_tensor_name = None - for line in reversed(lines): + for line in reversed(net_lines): if not line.startswith('assert'): out_tensor_name = line.split()[0] break - # current_datetime = datetime.datetime.now() - file_name = outpu_dir + "/" + curr_time + "_testcase.py" + # dir = "/home/aistudio/data" + # dir = "/dev/shm/test" + current_datetime = datetime.datetime.now() + curr_time = current_datetime.strftime("%Y-%m-%d_%H-%M-%S-%f") + file_name = dir + "/testcase_" + curr_time + ".py" with open(file_name, "w") as f: f.write(content_head) - for line in lines: - if not line.startswith('tensor1 '): - f.write(" " + line + "\n") + f.write(generate_content_neck(input_names)) + for line in net_lines: + if line.startswith("assert"): + parts = line.split(" ") + parts[1] = "tuple(" + parts[1] + ")" + line = " ".join(parts) + f.write(" " + line + "\n") f.write(" return " + out_tensor_name + "\n\n") f.write(content_body) - f.write(' self.x = paddle.uniform([' + input_shape + '], dtype="float32", min=-0.5, max=0.5)') + f.write(generate_content_prepare(input_shapes)) f.write(content_tail) + f.write(generate_content_train(len(input_names))) + f.write(content_main) + +def check_file_existence(directory, filename): + """ + Check if the specified file exists in the given directory. + + Args: + - directory (str): The directory to search for the file. + - filename (str): The name of the file to check for. + + Returns: + - bool: True if the file exists, False otherwise. + """ + file_path = os.path.join(directory, filename) + return os.path.isfile(file_path) + +def count_python_files(directory): + """ + Count the number of Python files in the specified directory. + + Args: + - directory (str): The directory to search for Python files. + + Returns: + - int: The number of Python files in the directory. + """ + python_files = [file for file in os.listdir(directory) if file.endswith(".py")] + return len(python_files) + +def generate_test_cases_until_file_exists(directory, flag_file, max_python_files): + """ + Continuously generate test cases until the specified file exists. + + Args: + - directory (str): The directory to search for Python files. + - flag_file (str): The name of the file to check for existence. + - max_python_files (int): The maximum number of Python files allowed before pausing. + + Returns: + - None + """ + while not check_file_existence(directory, flag_file): + python_file_count = count_python_files(directory) + + if python_file_count < max_python_files: + print("py", python_file_count) + get_content_and_write(directory) + # time.sleep(0.05) + else: + time.sleep(0.5) + +def process_arguments(): + if len(sys.argv) < 4: + arg_dir = "/dev/shm/test" + arg_flag = "stop" + arg_max = 10 + else: + arg_dir = sys.argv[1] + arg_flag = sys.argv[2] + arg_max = int(sys.argv[3]) + + return arg_dir, arg_flag, arg_max + +def main(): + directory, flag_file, max_python_files = process_arguments() + print("dir:", directory) + print("flag:", flag_file) + print("max:", max_python_files) + if not os.path.exists(directory): + os.makedirs(directory) + generate_test_cases_until_file_exists(directory, flag_file, max_python_files) -# ablated_unit_test_case_spec = GetAblatedUnitTestCaseSpec( -# instructions=unit_test_case_spec.instructions, -# requirement=unit_test_case_requirement, -# bottom_up_ablation_size=-1, -# component_ablation_size=-1, -# ) -# print("#", "*"*80) -# print("# full ablated") -# print("#", "*"*80) -# script = numpy_gen.Generate(ablated_unit_test_case_spec.patched_instruction_code_gen_spec) -# print("import paddle") -# print(script.file_content) -# -# -# ablated_unit_test_case_spec = GetAblatedUnitTestCaseSpec( -# instructions=unit_test_case_spec.instructions, -# requirement=unit_test_case_requirement, -# bottom_up_ablation_size=len(unit_test_case_spec.instructions)/2, -# component_ablation_size=-1, -# ) -# print("#", "*"*80) -# print("# half ablated") -# print("#", "*"*80) -# script = numpy_gen.Generate(ablated_unit_test_case_spec.patched_instruction_code_gen_spec) -# print("import paddle") -# print(script.file_content) +if __name__ == "__main__": + main()