Skip to content
4 changes: 4 additions & 0 deletions athena/generators/graphnet_module_op_sample_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"is_entry_block",
"block_name",
"input_arg_names",
"input_original_names",
"input_tensor_descs",
"stmts",
"output_arg_names",
Expand Down Expand Up @@ -122,6 +123,9 @@ def GetUnusedTensorName(stmt):
block.owner_op, block.region_idx, block.block_idx
),
input_arg_names=[tensor.name for tensor in input_local_tensors],
input_original_names=[
tensor.arg_name_as_input for tensor in input_local_tensors
],
input_tensor_descs=[GetInputTensorDesc(t) for t in input_local_tensors],
stmts=stmts,
output_arg_names=[tensor.name for tensor in output_local_tensors],
Expand Down
25 changes: 20 additions & 5 deletions athena/generators/graphnet_sequence_sample_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def MakeSequenceFuncDesc(self, seq_stmts):
op_id2seq_stmt
),
tensor_name4tensor_id=self.MakeTensorName4TensorId(op_id2seq_stmt),
tensor_original_name4tensor_id=self.MakeTensorOriginalName4TensorId(
op_id2seq_stmt
),
tensor_name4operand_id=self.MakeTensorName4OperandId(op_id2seq_stmt),
input_spec_shape_dtype4tensor_id=self.MakeInputSpecShapeAndDtype4TensorId(
op_id2seq_stmt,
Expand Down Expand Up @@ -91,11 +94,17 @@ def GetUnusedTensorName(stmt):

def GetOutputTensorNames(self, seq_stmts):
tensors_used_by_downstream = set(seq_stmts[-1].tensors_used_by_downstream)
tensor_names_to_remove = set()
for stmt in seq_stmts:
if stmt.op_name in ["pd_op.full_int_array", "pd_op.full"]:
tensor_names_to_remove.update(stmt.output_tensor_names)

return [
tensor_name
for stmt in seq_stmts
for tensor_name in stmt.output_tensor_names
if tensor_name in tensors_used_by_downstream
if tensor_name not in tensor_names_to_remove
]

def MakeImmediateValue4OperandId(
Expand Down Expand Up @@ -337,6 +346,17 @@ def TensorName4TensorId(tensor_id):

return TensorName4TensorId

def MakeTensorOriginalName4TensorId(
self, op_id2seq_stmt: OrderedDict[int, PyCodeStmt]
):
def GetSourceNames(op_id):
return op_id2seq_stmt[op_id].input_tensor_original_names

def TensorOriginalName4TensorId(tensor_id):
return tensor_id.get_source_name(GetSourceNames)

return TensorOriginalName4TensorId

def MakeTensorListMemberIds4OperandId(
self, op_id2seq_stmt: OrderedDict[int, PyCodeStmt]
):
Expand Down Expand Up @@ -524,11 +544,6 @@ def GetCppOperandTypeName(self, op, input_idx):
type_name = pos_arg_type_names[input_idx]
return type_name

def _GetTemplate(self, template_name):
dir_path = os.path.dirname(os.path.realpath(__file__))
with open(f"{dir_path}/{template_name}", "r") as f:
return jinja_env.get_template(f.read())


def GetSha256sum(content):
m = hashlib.sha256()
Expand Down
8 changes: 4 additions & 4 deletions athena/generators/op_example_input_meta_script_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, ir_program, example_inputs_meta_getter):
self.block_name_gen = BlockNameGenerator()
self.unittest_stmts_gen = PaddleBlockUnittestStmtsGenerator(self.block_name_gen)

def Generate(self):
def Generate(self, eval_mode):
def GetShapeInstance(tensor):
if tensor.arg_name_as_input is not None:
tensor_meta = self.example_inputs_meta_getter.Get(
Expand Down Expand Up @@ -97,7 +97,7 @@ def MakeBlockDescriptor(block):
input_local_tensors,
stmts,
output_local_tensors,
) = self.unittest_stmts_gen.Generate(block)
) = self.unittest_stmts_gen.Generate(block, eval_mode)
input_local_tensors = [
ConvertToPaddleTensor(t) for t in input_local_tensors
]
Expand Down Expand Up @@ -144,12 +144,12 @@ def __init__(self, ir_programs, example_inputs_meta_getter):
self.name = "_".join(type(ir_program).__name__ for ir_program in ir_programs)
self.example_inputs_meta_getter = example_inputs_meta_getter

def Generate(self):
def Generate(self, eval_mode):
def MakeProgramBlocksDescriptor(ir_program):
generator = ProgramBlocksDescriptorGenerator(
ir_program, self.example_inputs_meta_getter
)
return generator.Generate()
return generator.Generate(eval_mode)

programs = [
MakeProgramBlocksDescriptor(ir_program) for ir_program in self.ir_programs
Expand Down
10 changes: 9 additions & 1 deletion athena/generators/paddle_func_body_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class PyCodeStmt:
op_unique_local_name: str
pycode: List[IndentedPyCode]
input_tensor_names: List[str]
input_tensor_original_names: List[str]
output_tensor_names: List[str]
inputs_type_strs: List[str]
outputs_type_strs: List[str]
Expand Down Expand Up @@ -82,14 +83,15 @@ def get_local_name(tensor):
return self.tensor_converter.ConvertToLocalTensor(tensor).name

self.op_id2used_by_me_and_downstream = GetOpId2TensorNamesUsedByMeAndDownstream(
self.func, free_vars, args, get_local_name
self.func, free_vars, args, get_local_name, eval_mode
)
self.op_id2op_func_in_out_names_signature = GetOpId2OpPipeInOutNamesSignature(
self.op_id2used_by_me_and_downstream,
self.func,
free_vars,
args,
get_local_name,
eval_mode,
)
self.block_op_calls = BlockOpCallsExtractor().Extract(
self.func, free_vars, args
Expand Down Expand Up @@ -245,6 +247,9 @@ def CollectPyCodeStmt(self, GetStmtPyCode, op, *input_tensors, **kwargs):
def GetTensorName(tensor):
return tensor.name if tensor is not None else "None"

def GetTensorOriginalName(tensor):
return tensor.arg_name_as_input if tensor is not None else "None"

self.stmts.append(
PyCodeStmt(
op=op,
Expand All @@ -253,6 +258,9 @@ def GetTensorName(tensor):
prefix=f"op_{op.GetNameSuffix()}",
),
input_tensor_names=[GetTensorName(t) for t in input_local_tensors],
input_tensor_original_names=[
GetTensorOriginalName(t) for t in input_local_tensors
],
output_tensor_names=local_output_tensor_names,
inputs_type_strs=inputs_type_strs,
outputs_type_strs=outputs_type_strs,
Expand Down
6 changes: 5 additions & 1 deletion athena/generators/template_graphnet_module_op_sample.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
{%- set max = example_tensor_meta.max -%}
{%- set mean = example_tensor_meta.mean -%}
{%- set std = example_tensor_meta.std -%}
{%- if data != None -%}
{%- if data is not none -%}
shape = {{shape}}
dtype = '{{dtype}}'
data = {{data}}
Expand Down Expand Up @@ -54,9 +54,11 @@
{%- for arg_name in block.input_arg_names %}
{%- if "parameter" not in arg_name %}
{%- set input_idx = loop.index0 %}
{%- set original_name = block.input_original_names[input_idx] %}
{%- set example_tensor_meta = block.input_tensor_descs[input_idx] %}
class Program_weight_tensor_{{arg_name}}:
name = "{{arg_name}}"
original_name = "{{original_name}}"
{{get_input_tensor_instance(example_tensor_meta)}}
{%- endif %}
{%- endfor %}
Expand All @@ -73,9 +75,11 @@ class Program_weight_tensor_{{arg_name}}:
{%- for arg_name in block.input_arg_names %}
{%- if "parameter" in arg_name %}
{%- set input_idx = loop.index0 %}
{%- set original_name = block.input_original_names[input_idx] %}
{%- set example_tensor_meta = block.input_tensor_descs[input_idx] %}
class Program_weight_tensor_{{arg_name}}:
name = "{{arg_name}}"
original_name = "{{original_name}}"
{{get_input_tensor_instance(example_tensor_meta)}}
{%- endif %}
{%- endfor %}
Expand Down
15 changes: 12 additions & 3 deletions athena/generators/template_graphnet_sequence_sample.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,31 @@
{%- set shape, dtype = sig.input_spec_shape_dtype4tensor_id(tensor_id) %}
class Program_weight_tensor_{{sig.tensor_name4tensor_id(tensor_id)}}:
name = "{{tensor_name_converter(sig.tensor_name4tensor_id(tensor_id))}}"
original_name = "{{tensor_name_converter(sig.tensor_original_name4tensor_id(tensor_id))}}"
{{get_input_tensor_instance(example_tensor_meta)}}
{{ '\n' }}
{%- endif %}
{%- endfor %}

# --- seperate line ----

{%- set param_list = [] -%}
{%- for tensor_id in sig.tensor_ids %}
{%- if 'parameter' in sig.tensor_name4tensor_id(tensor_id) %}
{%- set tensor_name = sig.tensor_name4tensor_id(tensor_id) %}
{%- if 'parameter' in tensor_name %}
{%- set num = tensor_name.split("_")[1] | int %}
{%- set _ = param_list.append({'tensor_id': tensor_id, 'num': num}) %}
{%- endif %}
{%- endfor %}
{%- set param_list = param_list | sort(attribute="num") -%}
{%- for item in param_list %}
{%- set tensor_id = item.tensor_id %}
{%- set example_tensor_meta = sig.example_input_meta4tensor_id(tensor_id) %}
{%- set shape, dtype = sig.input_spec_shape_dtype4tensor_id(tensor_id) %}
class Program_weight_tensor_{{sig.tensor_name4tensor_id(tensor_id)}}:
name = "{{tensor_name_converter(sig.tensor_name4tensor_id(tensor_id))}}"
original_name = "{{tensor_name_converter(sig.tensor_original_name4tensor_id(tensor_id))}}"
{{get_input_tensor_instance(example_tensor_meta)}}
{{ '\n' }}
{%- endif %}
{%- endfor %}

# --- seperate line ----
Expand Down
Loading