diff --git a/athena/generators/graphnet_module_op_sample_generator.py b/athena/generators/graphnet_module_op_sample_generator.py index 0631cc9..4257371 100644 --- a/athena/generators/graphnet_module_op_sample_generator.py +++ b/athena/generators/graphnet_module_op_sample_generator.py @@ -16,6 +16,7 @@ "is_entry_block", "block_name", "input_arg_names", + "input_original_names", "input_tensor_descs", "stmts", "output_arg_names", @@ -122,6 +123,9 @@ def GetUnusedTensorName(stmt): block.owner_op, block.region_idx, block.block_idx ), input_arg_names=[tensor.name for tensor in input_local_tensors], + input_original_names=[ + tensor.arg_name_as_input for tensor in input_local_tensors + ], input_tensor_descs=[GetInputTensorDesc(t) for t in input_local_tensors], stmts=stmts, output_arg_names=[tensor.name for tensor in output_local_tensors], diff --git a/athena/generators/graphnet_sequence_sample_generator.py b/athena/generators/graphnet_sequence_sample_generator.py index bca33aa..1ce317c 100644 --- a/athena/generators/graphnet_sequence_sample_generator.py +++ b/athena/generators/graphnet_sequence_sample_generator.py @@ -55,6 +55,9 @@ def MakeSequenceFuncDesc(self, seq_stmts): op_id2seq_stmt ), tensor_name4tensor_id=self.MakeTensorName4TensorId(op_id2seq_stmt), + tensor_original_name4tensor_id=self.MakeTensorOriginalName4TensorId( + op_id2seq_stmt + ), tensor_name4operand_id=self.MakeTensorName4OperandId(op_id2seq_stmt), input_spec_shape_dtype4tensor_id=self.MakeInputSpecShapeAndDtype4TensorId( op_id2seq_stmt, @@ -91,11 +94,17 @@ def GetUnusedTensorName(stmt): def GetOutputTensorNames(self, seq_stmts): tensors_used_by_downstream = set(seq_stmts[-1].tensors_used_by_downstream) + tensor_names_to_remove = set() + for stmt in seq_stmts: + if stmt.op_name in ["pd_op.full_int_array", "pd_op.full"]: + tensor_names_to_remove.update(stmt.output_tensor_names) + return [ tensor_name for stmt in seq_stmts for tensor_name in stmt.output_tensor_names if tensor_name in tensors_used_by_downstream + if tensor_name not in tensor_names_to_remove ] def MakeImmediateValue4OperandId( @@ -337,6 +346,17 @@ def TensorName4TensorId(tensor_id): return TensorName4TensorId + def MakeTensorOriginalName4TensorId( + self, op_id2seq_stmt: OrderedDict[int, PyCodeStmt] + ): + def GetSourceNames(op_id): + return op_id2seq_stmt[op_id].input_tensor_original_names + + def TensorOriginalName4TensorId(tensor_id): + return tensor_id.get_source_name(GetSourceNames) + + return TensorOriginalName4TensorId + def MakeTensorListMemberIds4OperandId( self, op_id2seq_stmt: OrderedDict[int, PyCodeStmt] ): @@ -524,11 +544,6 @@ def GetCppOperandTypeName(self, op, input_idx): type_name = pos_arg_type_names[input_idx] return type_name - def _GetTemplate(self, template_name): - dir_path = os.path.dirname(os.path.realpath(__file__)) - with open(f"{dir_path}/{template_name}", "r") as f: - return jinja_env.get_template(f.read()) - def GetSha256sum(content): m = hashlib.sha256() diff --git a/athena/generators/op_example_input_meta_script_generator.py b/athena/generators/op_example_input_meta_script_generator.py index 975e551..0412bb7 100644 --- a/athena/generators/op_example_input_meta_script_generator.py +++ b/athena/generators/op_example_input_meta_script_generator.py @@ -44,7 +44,7 @@ def __init__(self, ir_program, example_inputs_meta_getter): self.block_name_gen = BlockNameGenerator() self.unittest_stmts_gen = PaddleBlockUnittestStmtsGenerator(self.block_name_gen) - def Generate(self): + def Generate(self, eval_mode): def GetShapeInstance(tensor): if tensor.arg_name_as_input is not None: tensor_meta = self.example_inputs_meta_getter.Get( @@ -97,7 +97,7 @@ def MakeBlockDescriptor(block): input_local_tensors, stmts, output_local_tensors, - ) = self.unittest_stmts_gen.Generate(block) + ) = self.unittest_stmts_gen.Generate(block, eval_mode) input_local_tensors = [ ConvertToPaddleTensor(t) for t in input_local_tensors ] @@ -144,12 +144,12 @@ def __init__(self, ir_programs, example_inputs_meta_getter): self.name = "_".join(type(ir_program).__name__ for ir_program in ir_programs) self.example_inputs_meta_getter = example_inputs_meta_getter - def Generate(self): + def Generate(self, eval_mode): def MakeProgramBlocksDescriptor(ir_program): generator = ProgramBlocksDescriptorGenerator( ir_program, self.example_inputs_meta_getter ) - return generator.Generate() + return generator.Generate(eval_mode) programs = [ MakeProgramBlocksDescriptor(ir_program) for ir_program in self.ir_programs diff --git a/athena/generators/paddle_func_body_generator.py b/athena/generators/paddle_func_body_generator.py index a2e1ea7..6c0308d 100644 --- a/athena/generators/paddle_func_body_generator.py +++ b/athena/generators/paddle_func_body_generator.py @@ -27,6 +27,7 @@ class PyCodeStmt: op_unique_local_name: str pycode: List[IndentedPyCode] input_tensor_names: List[str] + input_tensor_original_names: List[str] output_tensor_names: List[str] inputs_type_strs: List[str] outputs_type_strs: List[str] @@ -82,7 +83,7 @@ def get_local_name(tensor): return self.tensor_converter.ConvertToLocalTensor(tensor).name self.op_id2used_by_me_and_downstream = GetOpId2TensorNamesUsedByMeAndDownstream( - self.func, free_vars, args, get_local_name + self.func, free_vars, args, get_local_name, eval_mode ) self.op_id2op_func_in_out_names_signature = GetOpId2OpPipeInOutNamesSignature( self.op_id2used_by_me_and_downstream, @@ -90,6 +91,7 @@ def get_local_name(tensor): free_vars, args, get_local_name, + eval_mode, ) self.block_op_calls = BlockOpCallsExtractor().Extract( self.func, free_vars, args @@ -245,6 +247,9 @@ def CollectPyCodeStmt(self, GetStmtPyCode, op, *input_tensors, **kwargs): def GetTensorName(tensor): return tensor.name if tensor is not None else "None" + def GetTensorOriginalName(tensor): + return tensor.arg_name_as_input if tensor is not None else "None" + self.stmts.append( PyCodeStmt( op=op, @@ -253,6 +258,9 @@ def GetTensorName(tensor): prefix=f"op_{op.GetNameSuffix()}", ), input_tensor_names=[GetTensorName(t) for t in input_local_tensors], + input_tensor_original_names=[ + GetTensorOriginalName(t) for t in input_local_tensors + ], output_tensor_names=local_output_tensor_names, inputs_type_strs=inputs_type_strs, outputs_type_strs=outputs_type_strs, diff --git a/athena/generators/template_graphnet_module_op_sample.jinja b/athena/generators/template_graphnet_module_op_sample.jinja index c313bc3..4d69d4a 100644 --- a/athena/generators/template_graphnet_module_op_sample.jinja +++ b/athena/generators/template_graphnet_module_op_sample.jinja @@ -7,7 +7,7 @@ {%- set max = example_tensor_meta.max -%} {%- set mean = example_tensor_meta.mean -%} {%- set std = example_tensor_meta.std -%} -{%- if data != None -%} +{%- if data is not none -%} shape = {{shape}} dtype = '{{dtype}}' data = {{data}} @@ -54,9 +54,11 @@ {%- for arg_name in block.input_arg_names %} {%- if "parameter" not in arg_name %} {%- set input_idx = loop.index0 %} +{%- set original_name = block.input_original_names[input_idx] %} {%- set example_tensor_meta = block.input_tensor_descs[input_idx] %} class Program_weight_tensor_{{arg_name}}: name = "{{arg_name}}" + original_name = "{{original_name}}" {{get_input_tensor_instance(example_tensor_meta)}} {%- endif %} {%- endfor %} @@ -73,9 +75,11 @@ class Program_weight_tensor_{{arg_name}}: {%- for arg_name in block.input_arg_names %} {%- if "parameter" in arg_name %} {%- set input_idx = loop.index0 %} +{%- set original_name = block.input_original_names[input_idx] %} {%- set example_tensor_meta = block.input_tensor_descs[input_idx] %} class Program_weight_tensor_{{arg_name}}: name = "{{arg_name}}" + original_name = "{{original_name}}" {{get_input_tensor_instance(example_tensor_meta)}} {%- endif %} {%- endfor %} diff --git a/athena/generators/template_graphnet_sequence_sample.jinja b/athena/generators/template_graphnet_sequence_sample.jinja index 43d9f85..bed3f08 100644 --- a/athena/generators/template_graphnet_sequence_sample.jinja +++ b/athena/generators/template_graphnet_sequence_sample.jinja @@ -75,6 +75,7 @@ {%- set shape, dtype = sig.input_spec_shape_dtype4tensor_id(tensor_id) %} class Program_weight_tensor_{{sig.tensor_name4tensor_id(tensor_id)}}: name = "{{tensor_name_converter(sig.tensor_name4tensor_id(tensor_id))}}" + original_name = "{{tensor_name_converter(sig.tensor_original_name4tensor_id(tensor_id))}}" {{get_input_tensor_instance(example_tensor_meta)}} {{ '\n' }} {%- endif %} @@ -82,15 +83,23 @@ class Program_weight_tensor_{{sig.tensor_name4tensor_id(tensor_id)}}: # --- seperate line ---- +{%- set param_list = [] -%} {%- for tensor_id in sig.tensor_ids %} -{%- if 'parameter' in sig.tensor_name4tensor_id(tensor_id) %} + {%- set tensor_name = sig.tensor_name4tensor_id(tensor_id) %} + {%- if 'parameter' in tensor_name %} + {%- set num = tensor_name.split("_")[1] | int %} + {%- set _ = param_list.append({'tensor_id': tensor_id, 'num': num}) %} + {%- endif %} +{%- endfor %} +{%- set param_list = param_list | sort(attribute="num") -%} +{%- for item in param_list %} +{%- set tensor_id = item.tensor_id %} {%- set example_tensor_meta = sig.example_input_meta4tensor_id(tensor_id) %} {%- set shape, dtype = sig.input_spec_shape_dtype4tensor_id(tensor_id) %} class Program_weight_tensor_{{sig.tensor_name4tensor_id(tensor_id)}}: name = "{{tensor_name_converter(sig.tensor_name4tensor_id(tensor_id))}}" + original_name = "{{tensor_name_converter(sig.tensor_original_name4tensor_id(tensor_id))}}" {{get_input_tensor_instance(example_tensor_meta)}} -{{ '\n' }} -{%- endif %} {%- endfor %} # --- seperate line ---- diff --git a/athena/generators/template_op_example_input_meta_script.jinja b/athena/generators/template_op_example_input_meta_script.jinja index 8dd66c7..d255f62 100644 --- a/athena/generators/template_op_example_input_meta_script.jinja +++ b/athena/generators/template_op_example_input_meta_script.jinja @@ -6,6 +6,9 @@ from absl import app from absl import flags import traceback import datetime +from scipy.stats import truncnorm +from typing import Union, List +from dataclasses import dataclass flags.DEFINE_integer("max_try_cnt", 10, "max try cnt") flags.DEFINE_integer("random_seed", 123, "output file") @@ -14,6 +17,17 @@ flags.DEFINE_string("output_file", "", "output file") FLAGS = flags.FLAGS +@dataclass +class TensorMeta: + dtype: str + shape: List[int] + max: Union[int, float] = None + min: Union[int, float] = None + mean: Union[int, float] = None + std: Union[int, float] = None + data: Union[List[int], List[float]] = None + + def SetSeed(random_seed): paddle.seed(random_seed) random.seed(random_seed) @@ -28,10 +42,11 @@ def InitIntegerTensor(shape, dtype, min_val, max_val): def InitFloatTensor(shape, dtype, mean, std, min_val, max_val): - if mean is not None and std is not None: - # NumPy does not support truncated normal, we simulate it here. - array = np.random.normal(0, 1, shape) * std * 0.2 + mean - array = np.clip(array, min_val, max_val) + if mean is not None and std is not None and std >= 1e-5: + # truncated normal + a = (min_val - mean) / std + b = (max_val - mean) / std + array = truncnorm.rvs(a, b, loc=mean, scale=std, size=shape) else: array = np.random.uniform(low=min_val, high=max_val, size=shape) return paddle.to_tensor(array).cast(dtype) @@ -46,7 +61,7 @@ def InitTensorShape(tensor): return [InitTensorShape(t) for t in tensor] if not hasattr(tensor, 'shape'): raise NotImplementedError(f"type(tensor): {type(tensor)}") - return tensor.shape + return list(tensor.shape) def CalculateTensorMeta(tensor, meta_name): @@ -62,21 +77,27 @@ def CalculateTensorMeta(tensor, meta_name): raise NotImplementedError(f"meta_name: {meta_name}") -def InitTensorMeta(tensor, meta_name): +def InitTensorMeta(tensor, meta_name, tensor_meta): + if tensor_meta: + return getattr(tensor_meta, meta_name) if tensor is None: return None if isinstance(tensor, list) and len(tensor) > 0 and isinstance(tensor[0], int): return None if isinstance(tensor, list): - return [InitTensorMeta(t, meta_name) for t in tensor] + return [InitTensorMeta(t, meta_name, tensor_meta) for t in tensor] if not hasattr(tensor, meta_name): raise NotImplementedError(f"type(tensor): {type(tensor)}, meta_name: {meta_name}") + kLimit = 64 + if tensor.numel().item() < kLimit: + return None if tensor.dtype in [paddle.float16, paddle.bfloat16]: return float(CalculateTensorMeta(tensor.cast("float32"), meta_name)) elif tensor.dtype in [paddle.float32, paddle.float64]: return float(CalculateTensorMeta(tensor, meta_name)) else: - return CalculateTensorMeta(tensor, meta_name) if meta_name in ["max", "min"] else None + tensor_copy = tensor.cast("int32") if tensor.dtype != paddle.int64 else tensor + return CalculateTensorMeta(tensor_copy, meta_name) if meta_name in ["max", "min"] else None def InitTensorData(tensor): @@ -109,30 +130,51 @@ def MetaValueToString(value): return f"float('{value}')" if isinstance(value, float) else str(value) -def GetRecordClass(program_id, op_id, op_name, input_idx, shape, mean, std, max, min, data): +def GetRecordClass(program_id, op_id, op_name, input_idx, name, shape, mean, std, max, min, data): return f""" class PirProgram_op_input_tensor_meta_{random.randint(0, sys.maxsize)}: program_id = {program_id} op_id = {op_id} op_name = "{op_name}" input_idx = {input_idx} + name = "{name}" shape = {str(shape)} + min_val = {MetaValueToString(min)} + max_val = {MetaValueToString(max)} mean = {MetaValueToString(mean)} std = {MetaValueToString(std)} - max_val = {MetaValueToString(max)} - min_val = {MetaValueToString(min)} data = {str(data)} """ def AppendRecordClassToOutputFile(content): with open(FLAGS.output_file, 'a') as f: - f.write(content) + f.write(content) {% macro get_input_shape_instance(block, block_idx, input_idx) -%} {{block.input_tensor_descs[input_idx].shape}} {%- endmacro %} +{% macro get_input_tensor_meta(block, block_idx, input_idx) -%} +{%- set shape = get_input_shape_instance(block, block_idx, input_idx) -%} +{%- set dtype = block.input_tensor_descs[input_idx].dtype -%} +{%- set big_dtype = block.input_tensor_descs[input_idx].big_dtype -%} +{%- set data = block.input_tensor_descs[input_idx].data -%} +{%- set min = block.input_tensor_descs[input_idx].min -%} +{%- set max = block.input_tensor_descs[input_idx].max -%} +{%- set mean = block.input_tensor_descs[input_idx].mean -%} +{%- set std = block.input_tensor_descs[input_idx].std -%} +{%- if data is not none -%} + TensorMeta(dtype='{{dtype}}', shape={{shape}}, data={{data}}) +{%- elif big_dtype == "bool" -%} + TensorMeta(dtype='{{bool}}', shape={{shape}}, max=2, min=0) +{%- elif big_dtype == "int64" -%} + TensorMeta(dtype='{{dtype}}', shape={{shape}}, max={{max}}, min={{min}}) +{%- elif big_dtype == "float64" -%} + TensorMeta(dtype='{{dtype}}', shape={{shape}}, max={{max}}, min={{min}}, mean={{mean}}, std={{std}}) +{%- endif -%} +{%- endmacro -%} + {% macro get_input_tensor_instance(block, block_idx, input_idx) -%} {%- set shape = get_input_shape_instance(block, block_idx, input_idx) -%} {%- set dtype = block.input_tensor_descs[input_idx].dtype -%} @@ -142,7 +184,7 @@ def AppendRecordClassToOutputFile(content): {%- set max = block.input_tensor_descs[input_idx].max -%} {%- set mean = block.input_tensor_descs[input_idx].mean -%} {%- set std = block.input_tensor_descs[input_idx].std -%} -{%- if data != None -%} +{%- if data is not none -%} paddle.to_tensor({{data}}, dtype='{{dtype}}').reshape({{shape}}) {%- elif big_dtype == "bool" -%} paddle.cast(paddle.randint(low=0, high=2, shape={{shape}}, dtype='int32'), 'bool') @@ -153,17 +195,28 @@ def AppendRecordClassToOutputFile(content): {%- endif -%} {%- endmacro -%} - {% for program_id, blocks in programs %} def InferAndSaveOpInputDims_Program{{program_id}}(): + tensor_name2meta = {} + op_input2name = {} op_input2shape = {} op_input2mean = {} op_input2std = {} op_input2max = {} op_input2min = {} op_input2data = {} + {{ "\n" }} + {%- for block in blocks %} + {%- if block.is_entry_block %} + {%- set block_idx = loop.index0 %} + {%- for arg_name in block.input_arg_names %} + {%- set input_idx = loop.index0 %} + tensor_name2meta['{{arg_name}}'] = {{get_input_tensor_meta(block, block_idx, input_idx)}} + {%- endfor %} + {%- endif %} + {%- endfor %} - def GetInputMetaRecorder(op_name, op_id, *inputs): + def GetInputMetaRecorder(op_name, op_id, input_names, *inputs): def AllInitialized(): return all( (op_id, input_idx) in op_input2shape @@ -173,24 +226,29 @@ def InferAndSaveOpInputDims_Program{{program_id}}(): if AllInitialized(): return lambda: None - for input_idx in range(len(inputs)): - op_input2shape[(op_id, input_idx)] = InitTensorShape(inputs[input_idx]) - op_input2mean[(op_id, input_idx)] = InitTensorMeta(inputs[input_idx], "mean") - op_input2std[(op_id, input_idx)] = InitTensorMeta(inputs[input_idx], "std") - op_input2max[(op_id, input_idx)] = InitTensorMeta(inputs[input_idx], "max") - op_input2min[(op_id, input_idx)] = InitTensorMeta(inputs[input_idx], "min") - op_input2data[(op_id, input_idx)] = InitTensorData(inputs[input_idx]) + for input_idx, tensor in enumerate(inputs): + tensor_name = input_names[input_idx] + tensor_meta = tensor_name2meta.get(tensor_name, None) if tensor_name else None + op_input2name[(op_id, input_idx)] = tensor_name + op_input2shape[(op_id, input_idx)] = InitTensorShape(tensor) + op_input2mean[(op_id, input_idx)] = InitTensorMeta(tensor, "mean", tensor_meta) + op_input2std[(op_id, input_idx)] = InitTensorMeta(tensor, "std", tensor_meta) + op_input2max[(op_id, input_idx)] = InitTensorMeta(tensor, "max", tensor_meta) + op_input2min[(op_id, input_idx)] = InitTensorMeta(tensor, "min", tensor_meta) + op_input2data[(op_id, input_idx)] = InitTensorData(tensor) def Record(): for input_idx in range(len(inputs)): shape = op_input2shape.get((op_id, input_idx), None) if shape is None: continue + AppendRecordClassToOutputFile(GetRecordClass( program_id={{program_id}}, op_id=op_id, op_name=op_name, input_idx=input_idx, + name=op_input2name.get((op_id, input_idx), None), shape=shape, mean=op_input2mean.get((op_id, input_idx), None), std=op_input2std.get((op_id, input_idx), None), @@ -221,7 +279,7 @@ def InferAndSaveOpInputDims_Program{{program_id}}(): {%- for stmt in block.stmts %} {%- set stmt_idx = loop.index0 %} # ({{stmt.outputs_type_strs|join(", ")}}) <- ({{stmt.inputs_type_strs|join(", ")}}) - recorder = GetInputMetaRecorder("{{stmt.op_name}}", {{stmt.op_id}}, {{stmt.input_tensor_names | join(", ")}}) + recorder = GetInputMetaRecorder("{{stmt.op_name}}", {{stmt.op_id}}, {{stmt.input_tensor_names}}, {{stmt.input_tensor_names | join(", ")}}) {%- for pycode in stmt.pycode %} {%- if pycode.num_tabs == 0 %} {{pycode.pycode(tensor_name_converter)}} @@ -241,6 +299,7 @@ def InferAndSaveOpInputDims_Program{{program_id}}(): {%- endfor %} return {{block.output_arg_names | join(", ")}}, {%- endfor %} + extractor = OpInputShapesExtractor() for _ in range(FLAGS.max_try_cnt): {%- for block in blocks %} diff --git a/athena/graphnet_samples.py b/athena/graphnet_samples.py index 51812f6..a6dab6e 100644 --- a/athena/graphnet_samples.py +++ b/athena/graphnet_samples.py @@ -68,12 +68,229 @@ class GraphnetSample: unique_name: str subgraph_idx: int + program_id: int metadata: Dict[str, str] input_meta: str weight_meta: str model: str +def ConvertOutputStringToSample( + model_name, unique_name, subgraph_idx, program_id, sample_str +): + metadata = { + "framework": "paddle", + "model_name": model_name, + "num_devices_required": 1, + "num_nodes_required": 1, + } + + input_meta, weight_meta, model = sample_str.split("# --- seperate line ----\n") + sample = GraphnetSample( + unique_name=unique_name, + subgraph_idx=subgraph_idx, + program_id=program_id, + metadata=metadata, + input_meta=input_meta.strip("\n\n\n") + "\n", + weight_meta=weight_meta.rstrip("\n\n\n") + "\n", + model=model, + ) + # PrintToTerminal(unique_name, sample_str) + return sample + + +class GraphGenerator: + def __init__(self, model_name, programs_file, example_inputs_file, eval_mode=True): + self.model_name = model_name + self.programs_file = programs_file + self.example_inputs_file = example_inputs_file + self.eval_mode = eval_mode + + self.example_inputs_meta_getter = MakeExampleInputsMetaGetter( + GetClasses(example_inputs_file) + ) + self.ir_programs = GetValidIrPrograms(programs_file) + + def GetOutputSampleStrings(self): + def MakeModuleOpSampleGenerator(ir_program, example_inputs_meta_getter): + return GraphnetModuleOpSampleGenerator( + ir_program, + example_inputs_meta_getter, + eval_mode=self.eval_mode, + ) + + for subgraph_idx, ir_program in enumerate(self.ir_programs): + program_id = GetProgramId(ir_program) + op_names = GetOpNames(ir_program) + program_hash = GetOpNamesHash(op_names) + generator = MakeModuleOpSampleGenerator( + ir_program, self.example_inputs_meta_getter + ) + sample_str = generator.Generate() + yield (subgraph_idx, program_id, program_hash, sample_str) + + def __call__(self): + graphnet_sample_results = [] + seg_counter = defaultdict(lambda: itertools.count()) + for _, (subgraph_idx, program_id, uid, sample_str) in enumerate( + self.GetOutputSampleStrings() + ): + unique_name = f"{uid}_{next(seg_counter[uid])}" + sample = ConvertOutputStringToSample( + self.model_name, unique_name, subgraph_idx, program_id, sample_str + ) + graphnet_sample_results.append(sample) + print( + f"[GraphGenerator] Generate {len(graphnet_sample_results)} graphnet samples." + ) + return graphnet_sample_results + + +class SubgraphGenerator: + def __init__( + self, + model_name, + programs_file, + example_inputs_file, + op_example_inputs_file, + eval_mode, + tmp_dir, + ): + self.model_name = model_name + self.programs_file = programs_file + self.example_inputs_file = example_inputs_file + self.eval_mode = eval_mode + + if tmp_dir: + self.GenerateOpExampleInputFile(op_example_inputs_file, tmp_dir) + else: + with tempfile.TemporaryDirectory(prefix="athena_op_example_") as tmp_dir: + self.GenerateOpExampleInputFile(op_example_inputs_file, tmp_dir) + self.op_example_inputs_meta_getter = MakeOpExampleInputsMetaGetter( + GetClasses(op_example_inputs_file) + ) + + self.ir_programs = GetValidIrPrograms(self.programs_file) + self.program_seq_stmts_list = self.ConvertToSequenceStmts() + + def GenerateOpExampleInputFile(self, op_example_inputs_file, tmp_dir): + if os.path.isfile(op_example_inputs_file): + print(f"Remove the existing {op_example_inputs_file}") + os.remove(op_example_inputs_file) + + assert os.path.isdir(tmp_dir), f"Directory {tmp_dir=} does not exist." + + print(f"Generate {op_example_inputs_file} ...") + tmp_output_file_prefix = "tmp_op_example_input_" + for name, unittest in GetOpExampleInputMetaUnittests( + self.programs_file, + self.example_inputs_file, + bucket_size=128, + eval_mode=self.eval_mode, + ): + sha256sum = GetSha256sum(unittest) + tmp_output_filepath = os.path.join( + tmp_dir, f"{tmp_output_file_prefix}{sha256sum[0:32]}.py" + ) + WriteToFile(tmp_output_filepath, unittest) + + # Execute the generated tmp file + generate_op_example_inputs_cmd = f"ATHENA_WHILE_LOOP_LIMIT=8 {sys.executable} {tmp_output_filepath} --max_try_cnt=10 --output_file={op_example_inputs_file}" + System(generate_op_example_inputs_cmd) + + def ExtractSeqStmts(self, stmts, program_id, op_example_inputs_meta_getter): + def IsValidPrimitive(stmt): + return op_example_inputs_meta_getter.HasAllInputs( + program_id, stmt.op + ) and IsPrimitive(stmt) + + yield from ( + seq_stmts + for is_primitive, stmt_group in groupby(stmts, key=IsValidPrimitive) + if is_primitive + for seq_stmts in [list(stmt_group)] + ) + + def ConvertToSequenceStmts(self): + unittest_stmts_gen = PaddleBlockUnittestStmtsGenerator(BlockNameGenerator()) + program_seq_stmts_list = [ + (program_id, seq_stmts) + for ir_program in self.ir_programs + for program_id in [GetProgramId(ir_program)] + for block in BlocksGenerator(ir_program).Generate() + if AllInputOutputTypesSupported(block) + for _, stmts, _ in [unittest_stmts_gen.Generate(block, self.eval_mode)] + for seq_stmts in self.ExtractSeqStmts( + stmts, program_id, self.op_example_inputs_meta_getter + ) + if len(seq_stmts) > 1 + if self.op_example_inputs_meta_getter.HasAllInputs( + program_id, seq_stmts[0].op + ) + ] + return program_seq_stmts_list + + def ExtendHeadAndTail(self, seq_stmts, split_positions, group_head_and_tail): + split_positions_for_seq_stmts = ( + [0, *split_positions, len(seq_stmts)] + if group_head_and_tail + else split_positions + ) + split_positions_for_seq_stmts = [ + min(x, len(seq_stmts)) for x in split_positions_for_seq_stmts + ] + split_positions_for_seq_stmts = list(sorted(set(split_positions_for_seq_stmts))) + print(f"split_positions_for_seq_stmts: {split_positions_for_seq_stmts}") + return split_positions_for_seq_stmts + + def GetOutputSampleStrings(self, split_positions, group_head_and_tail=True): + def MakeSequenceSampleGenerator( + program_id, seq_stmts, op_example_inputs_meta_getter + ): + generator = GraphnetSequenceSampleGenerator( + program_id, op_example_inputs_meta_getter + ) + return generator.Generate(seq_stmts) + + print(f"origin split_positions: {split_positions}") + generated_sample_strs = set() + for subgraph_idx, (program_id, seq_stmts) in enumerate( + self.program_seq_stmts_list + ): + split_positions_for_seq_stmts = self.ExtendHeadAndTail( + seq_stmts, split_positions, group_head_and_tail + ) + for i in range(len(split_positions_for_seq_stmts) - 1): + seq_stmts_slice = seq_stmts[ + split_positions_for_seq_stmts[i] : split_positions_for_seq_stmts[ + i + 1 + ] + ] + sample_str = MakeSequenceSampleGenerator( + program_id, seq_stmts_slice, self.op_example_inputs_meta_getter + ) + if sample_str not in generated_sample_strs: + generated_sample_strs.add(sample_str) + stmt_hash = GetSeqStmtsHash(seq_stmts_slice) + yield (subgraph_idx, program_id, stmt_hash, sample_str) + + def __call__(self, split_positions, group_head_and_tail=True): + graphnet_sample_results = [] + seg_counter = defaultdict(lambda: itertools.count()) + for _, (subgraph_idx, program_id, uid, sample_str) in enumerate( + self.GetOutputSampleStrings(split_positions, group_head_and_tail) + ): + unique_name = f"{uid}_{next(seg_counter[uid])}" + sample = ConvertOutputStringToSample( + self.model_name, unique_name, subgraph_idx, program_id, sample_str + ) + graphnet_sample_results.append(sample) + print( + f"[SubgraphGenerator] Generate {len(graphnet_sample_results)} graphnet subgraph samples ({split_positions=}, {group_head_and_tail=})." + ) + return graphnet_sample_results + + def RunGeneration( model_name, ir_programs, @@ -84,39 +301,19 @@ def RunGeneration( eval_mode, tmp_dir=None, ): - metadata = { - "framework": "paddle", - "model_name": model_name, - "num_devices_required": 1, - "num_nodes_required": 1, - } - - graphnet_sample_results = [] - seg_counter = defaultdict(lambda: itertools.count()) - for module_id, (subgraph_idx, uid, sample_str) in enumerate( - GetOutputSampleStrings( + if not split_positions: + generator = GraphGenerator(model_name, ir_programs, example_inputs, eval_mode) + graphnet_sample_results = generator() + else: + generator = SubgraphGenerator( + model_name, ir_programs, example_inputs, op_example_inputs, - split_positions, - group_head_and_tail, eval_mode, tmp_dir, ) - ): - unique_name = f"{uid}_{next(seg_counter[uid])}" - input_meta, weight_meta, model = sample_str.split("# --- seperate line ----\n") - sample = GraphnetSample( - unique_name=unique_name, - subgraph_idx=subgraph_idx, - metadata=metadata, - input_meta=input_meta.strip("\n\n\n") + "\n", - weight_meta=weight_meta.rstrip("\n\n\n") + "\n", - model=model, - ) - graphnet_sample_results.append(sample) - # PrintToTerminal(unique_name, sample_str) - print(f"Generate {len(graphnet_sample_results)} graphnet samples.") + graphnet_sample_results = generator(split_positions, group_head_and_tail) return graphnet_sample_results @@ -131,7 +328,6 @@ def main(argv): split_positions=split_positions, group_head_and_tail=FLAGS.group_head_and_tail, eval_mode=FLAGS.eval_mode, - tmp_dir=FLAGS.tmp_dir, ) subgraph_idx2samples = {} @@ -140,9 +336,10 @@ def main(argv): subgraph_idx2samples[sample.subgraph_idx] = [] subgraph_idx2samples[sample.subgraph_idx].append(sample) + program_id2subgraph_path = {} num_samples = len(graphnet_sample_results) for subgraph_idx, samples in subgraph_idx2samples.items(): - for sample_idx in range(len(samples)): + for sample_idx, sample in enumerate(samples): if num_samples == 1 and len(samples) == 1: subgraph_path = FLAGS.output_dir elif len(samples) == 1: @@ -153,17 +350,19 @@ def main(argv): subgraph_path = os.path.join( FLAGS.output_dir, f"subgraph_{subgraph_idx}_{sample_idx}" ) + program_id2subgraph_path[sample.program_id] = subgraph_path if not os.path.exists(subgraph_path): os.makedirs(subgraph_path) - WriteToFile(f"{subgraph_path}/model.py", samples[sample_idx].model) - WriteToFile( - f"{subgraph_path}/weight_meta.py", samples[sample_idx].weight_meta - ) - WriteToFile( - f"{subgraph_path}/input_meta.py", samples[sample_idx].input_meta - ) + WriteToFile(f"{subgraph_path}/model.py", sample.model) + WriteToFile(f"{subgraph_path}/weight_meta.py", sample.weight_meta) + WriteToFile(f"{subgraph_path}/input_meta.py", sample.input_meta) with open(os.path.join(subgraph_path, "graph_net.json"), "w") as f: - json.dump(samples[sample_idx].metadata, f, indent=4) + json.dump(sample.metadata, f, indent=4) + program_ids_content = [f"{k}: {v}" for k, v in program_id2subgraph_path.items()] + WriteToFile( + os.path.join(FLAGS.output_dir, "program_ids.txt"), + "\n".join(program_ids_content), + ) def GetSha256sum(content): @@ -231,130 +430,6 @@ def GetValidIrPrograms(programs_file): return ir_programs -def GenerateOpExampleInputFile( - programs_file, example_inputs_file, op_example_inputs_file, tmp_dir -): - if os.path.isfile(op_example_inputs_file): - print(f"Remove the existing {op_example_inputs_file}") - os.remove(op_example_inputs_file) - - if tmp_dir is None: - tmp_dir = tempfile.gettempdir() - - print(f"Generate {op_example_inputs_file} ...") - tmp_output_file_prefix = "tmp_op_example_input_" - for name, unittest in GetOpExampleInputMetaUnittests( - programs_file, example_inputs_file, bucket_size=128 - ): - sha256sum = GetSha256sum(unittest) - tmp_output_filepath = os.path.join( - tmp_dir, f"{tmp_output_file_prefix}{sha256sum[0:32]}.py" - ) - WriteToFile(tmp_output_filepath, unittest) - - # Execute the generated tmp file - generate_op_example_inputs_cmd = f"ATHENA_WHILE_LOOP_LIMIT=8 {sys.executable} {tmp_output_filepath} --max_try_cnt=10 --output_file={op_example_inputs_file}" - System(generate_op_example_inputs_cmd) - - -def GetOutputSampleStrings( - programs_file, - example_inputs_file, - op_example_inputs_file, - split_positions, - group_head_and_tail=True, - eval_mode=True, - tmp_dir=None, -): - def MakeModuleOpSampleGenerator(ir_program, example_inputs_meta_getter): - return GraphnetModuleOpSampleGenerator( - ir_program, - example_inputs_meta_getter, - eval_mode=eval_mode, - ) - - def MakeSequenceSampleGenerator( - program_id, seq_stmts, op_example_inputs_meta_getter - ): - generator = GraphnetSequenceSampleGenerator( - program_id, op_example_inputs_meta_getter - ) - return generator.Generate(seq_stmts) - - ir_programs = GetValidIrPrograms(programs_file) - if not split_positions: - example_inputs_meta_getter = MakeExampleInputsMetaGetter( - GetClasses(example_inputs_file) - ) - subgraph_idx = 0 - for ir_program in ir_programs: - op_names = GetOpNames(ir_program) - program_hash = GetOpNamesHash(op_names) - generator = MakeModuleOpSampleGenerator( - ir_program, example_inputs_meta_getter - ) - sample_str = generator.Generate() - yield (subgraph_idx, program_hash, sample_str) - subgraph_idx += 1 - else: - print(f"origin split_positions: {split_positions}") - GenerateOpExampleInputFile( - programs_file, example_inputs_file, op_example_inputs_file, tmp_dir - ) - op_example_inputs_meta_getter = MakeOpExampleInputsMetaGetter( - GetClasses(op_example_inputs_file) - ) - unittest_stmts_gen = PaddleBlockUnittestStmtsGenerator(BlockNameGenerator()) - program_seq_stmts_list = [ - (program_id, seq_stmts) - for ir_program in ir_programs - for program_id in [GetProgramId(ir_program)] - for block in BlocksGenerator(ir_program).Generate() - if AllInputOutputTypesSupported(block) - for _, stmts, _ in [unittest_stmts_gen.Generate(block)] - for seq_stmts in ExtractSeqStmts( - stmts, program_id, op_example_inputs_meta_getter - ) - if len(seq_stmts) > 1 - if op_example_inputs_meta_getter.HasAllInputs(program_id, seq_stmts[0].op) - ] - - generated_sample_strs = set() - subgraph_idx = 0 - for program_id, seq_stmts in program_seq_stmts_list: - split_positions_for_seq_stmts = ExtendHeadAndTail( - seq_stmts, split_positions, group_head_and_tail - ) - for i in range(len(split_positions_for_seq_stmts) - 1): - seq_stmts_slice = seq_stmts[ - split_positions_for_seq_stmts[i] : split_positions_for_seq_stmts[ - i + 1 - ] - ] - sample_str = MakeSequenceSampleGenerator( - program_id, seq_stmts_slice, op_example_inputs_meta_getter - ) - if sample_str not in generated_sample_strs: - generated_sample_strs.add(sample_str) - stmt_hash = GetSeqStmtsHash(seq_stmts_slice) - yield (subgraph_idx, stmt_hash, sample_str) - subgraph_idx += 1 - - -def ExtendHeadAndTail(seq_stmts, split_positions, group_head_and_tail): - split_positions_for_seq_stmts = ( - [0, *split_positions, len(seq_stmts)] - if group_head_and_tail - else split_positions - ) - split_positions_for_seq_stmts = [ - x for x in split_positions_for_seq_stmts if x <= len(seq_stmts) - ] - split_positions_for_seq_stmts = list(dict.fromkeys(split_positions_for_seq_stmts)) - print(f"split_positions_for_seq_stmts: {split_positions_for_seq_stmts}") - return split_positions_for_seq_stmts - - def IsPrimitive(stmt): op = stmt.op return all( @@ -368,20 +443,6 @@ def IsPrimitive(stmt): ) -def ExtractSeqStmts(stmts, program_id, op_example_inputs_meta_getter): - def IsValidPrimitive(stmt): - return op_example_inputs_meta_getter.HasAllInputs( - program_id, stmt.op - ) and IsPrimitive(stmt) - - yield from ( - seq_stmts - for is_primitive, stmt_group in groupby(stmts, key=IsValidPrimitive) - if is_primitive - for seq_stmts in [list(stmt_group)] - ) - - def GetOpNames(ir_program): primitive_op_extractor = PrimitiveOpExtractor() return [op.name for op in primitive_op_extractor.Extract(ir_program)] diff --git a/athena/op_example_input_meta_script.py b/athena/op_example_input_meta_script.py index 9d6079d..d822cc8 100644 --- a/athena/op_example_input_meta_script.py +++ b/athena/op_example_input_meta_script.py @@ -98,15 +98,15 @@ def IncOpCount(op, *args, **kwargs): return op_count == 0 -def ExtractInputTensors(ir_program): +def ExtractInputTensors(ir_program, eval_mode): module_block_func = GetModuleBlockFunc(ir_program) extractor = InputOutputTensorsExtractor(module_block_func) - input_tensors, _ = extractor.Extract(free_vars=[], args=[]) + input_tensors, _ = extractor.Extract(free_vars=[], args=[], eval_mode=eval_mode) return input_tensors -def HasExampleInputs(ir_program, example_inputs_meta_getter): - input_tensors = ExtractInputTensors(ir_program) +def HasExampleInputs(ir_program, example_inputs_meta_getter, eval_mode): + input_tensors = ExtractInputTensors(ir_program, eval_mode) return example_inputs_meta_getter.HasAllInputExamples( program_id=int(type(ir_program).__name__[len("PirProgram_") :]), input_tensors=input_tensors, @@ -138,7 +138,9 @@ def AllValidOutputTypes(op): ) -def GetOutputUnittests(original_programs_file, example_inputs_file, bucket_size): +def GetOutputUnittests( + original_programs_file, example_inputs_file, bucket_size, eval_mode=False +): example_inputs_meta_getter = MakeExampleInputsMetaGetter( GetClasses(example_inputs_file) ) @@ -149,7 +151,7 @@ def GetOutputUnittests(original_programs_file, example_inputs_file, bucket_size) for ir_program in [cls()] if not IsBackwardProgram(ir_program) if not IsProgramEmpty(ir_program) - if HasExampleInputs(ir_program, example_inputs_meta_getter) + if HasExampleInputs(ir_program, example_inputs_meta_getter, eval_mode) if OnlyValidTypes(ir_program) ) @@ -168,7 +170,7 @@ def GetBucket(elem): ", ".join(type(x).__name__ for x in ir_program_group), file=sys.stderr, ) - name, unittest = generator.Generate() + name, unittest = generator.Generate(eval_mode) print( "OpExampleInputMetaScriptGenerator Generated pir_programs:", ", ".join(type(x).__name__ for x in ir_program_group), diff --git a/athena/util/ops_func_signature.py b/athena/util/ops_func_signature.py index 58d29a9..58b1834 100644 --- a/athena/util/ops_func_signature.py +++ b/athena/util/ops_func_signature.py @@ -92,6 +92,7 @@ class OpsFuncSignature: OperandId, t.Optional[t.List[TensorListMemberId]] ] tensor_name4tensor_id: t.Callable[TensorId, str] + tensor_original_name4tensor_id: t.Callable[TensorId, str] tensor_name4operand_id: t.Callable[OperandId, str] input_spec_shape_dtype4tensor_id: t.Callable[TensorId, InputSpecDesc] example_input_meta4tensor_id: t.Callable[TensorId, InputTensorDesc] diff --git a/athena/util/tensor_topo.py b/athena/util/tensor_topo.py index 9579953..d3e7d4b 100644 --- a/athena/util/tensor_topo.py +++ b/athena/util/tensor_topo.py @@ -1,7 +1,6 @@ from typing import Dict, List from collections import OrderedDict from dataclasses import dataclass -import sys from athena.util.input_output_tensors_extractor import InputOutputTensorsExtractor from athena.util.block_op_calls_extractor import BlockOpCallsExtractor import itertools @@ -27,12 +26,13 @@ def GetOpId2OpPipeInOutNamesSignature( free_vars, args, get_local_name, + eval_mode, ) -> Dict[int, OpPipeInOutNamesSignature]: op_id2used = op_id2used_by_me_and_downstream if len(op_id2used) == 0: return {} extractor = InputOutputTensorsExtractor(func) - input_tensors, output_tensors = extractor.Extract(free_vars, args) + input_tensors, output_tensors = extractor.Extract(free_vars, args, eval_mode) input_tensor_names = [get_local_name(t) for t in input_tensors] def get_in_names_list(): @@ -72,11 +72,12 @@ def GetOpId2TensorNamesUsedByMeAndDownstream( free_vars, args, get_local_name, + eval_mode, ) -> Dict[int, List[str]]: in_out_name_sig_extractor = OpInOutNameSignatureExtractor(get_local_name) in_out_names_sigs = in_out_name_sig_extractor.Extract(func, free_vars, args) input_tensors, output_tensors = InputOutputTensorsExtractor(func).Extract( - free_vars, args + free_vars, args, eval_mode ) input_tensor_names = [get_local_name(tensor) for tensor in input_tensors] output_tensor_names = [get_local_name(tensor) for tensor in output_tensors] @@ -138,7 +139,6 @@ def _GetTensorName2ProducerIdx(in_out_names_sigs, input_tensors, get_local_name) class OpInOutNameSignatureExtractor: - def __init__(self, get_local_name): self.in_out_names_sigs = [] self.get_local_name = get_local_name