Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 78 additions & 42 deletions gd_protobuf_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def field_type_name(self)->str:
def field_define(self, indent: str, define_name: str = None, flag: int = 0)->str:
if define_name is None:
define_name = self.field_name()
return f"{indent}var {define_name}: {self.field_type_name()} = {self.default_value}"
# Use an untyped variable declaration to avoid invalid or overly
# specific type hints (especially for message/enum types).
return f"{indent}var {define_name} = {self.default_value}"

def field_clear(self, indent: str) -> str:
return f"{indent}self.{self.field_name()} = {self.default_value}"
Expand Down Expand Up @@ -95,13 +97,16 @@ class GDMessageField(GDField):
def field_define(self, indent: str, define_name: str = None, flag: int = 0)->str:
if define_name is None:
define_name = self.field_name()
# For message fields, declare without a static type hint so that
# Godot's type system doesn't have to resolve package-qualified names.
if flag == 0:
return f"{indent}var {define_name}: {self.field_type_name()} = {self.default_value}"
return f"{indent}var {define_name} = {self.default_value}"
else:
return f"{indent}var {define_name}: {self.field_type_name()} = {self.field_type_name()}.new()"
return f"{indent}var {define_name} = {self.field_type_name()}.new()"

def field_clear(self, indent: str) -> str:
content = f"{indent}if self.{self.field_name()} != null:"
# Clear nested message fields with a properly formatted multiline if-block.
content = f"{indent}if self.{self.field_name()} != null:\n"
content += f"{indent}\tself.{self.field_name()}.clear()"
return content

Expand Down Expand Up @@ -179,10 +184,13 @@ def create(self, name: str, number: int, field_type: int, field: GDField):
self.sub_field = field
# Make field name private
# self.name = "_" + name
super().create(name, number, field_type, f"Array[{self.sub_field.field_type_name()}]", "[]", "")
# Use a generic Array container type; element type hints will be kept
# loose (Variant) to avoid invalid qualified type names.
super().create(name, number, field_type, "Array", "[]", "")

def field_clear(self, indent: str) -> str:
return f"{indent}self.{self.method_field_clear_name()}"
# Call the generated clear_* method to reset the repeated field.
return f"{indent}self.{self.method_field_clear_name()}()"


def field_merge(self, indent: str, other: str) -> str:
Expand Down Expand Up @@ -231,15 +239,18 @@ def field_define(self, indent: str, define_name: str = None, flag: int = 0) -> s
content += f"{indent}func {self.method_field_size_name()}() -> int:\n"
content += f"{indent}\treturn self.{self.field_size_name()}\n"
content += f"{indent}## Get {self.field_name()}\n"
content += f"{indent}func {self.method_field_get_array_name()}() -> {self.field_type_name()}:\n"
# Expose the underlying Array without an element type annotation.
content += f"{indent}func {self.method_field_get_array_name()}() -> Array:\n"
content += f"{indent}\treturn self.{self.field_name()}.slice(0, self.{self.field_size_name()})\n"
content += f"{indent}## Get {self.field_name()} item \n"
content += f"{indent}func {self.method_field_get_name()}(index: int) -> {self.sub_field.field_type_name()}: # index begin from 1\n"
# Individual items are treated as Variant to avoid invalid qualified
# type names in function signatures.
content += f"{indent}func {self.method_field_get_name()}(index: int) -> Variant: # index begin from 1\n"
content += f"{indent}\tif {self._index_check_content()}:\n"
content += f"{indent}\t\treturn self.{self.field_name()}[index - 1]\n"
content += f"{indent}\treturn {self.sub_field.default_value}\n"
content += f"{indent}## Add {self.field_name()}\n"
content += f"{indent}func {self.method_field_add_name()}(item: {self.sub_field.field_type_name()}) -> {self.sub_field.field_type_name()}:\n"
content += f"{indent}func {self.method_field_add_name()}(item: Variant) -> Variant:\n"
content += f"{indent}\tif self.{self.field_size_name()} >= 0 and self.{self.field_size_name()} < self.{self.field_name()}.size():\n"
content += f"{indent}\t\tself.{self.field_name()}[self.{self.field_size_name()}] = item\n"
content += f"{indent}\telse:\n"
Expand Down Expand Up @@ -375,10 +386,15 @@ def field_parse(self,
return content

class GDMessageType:
def __init__(self, descriptor: Descriptor, package_name: str = ""):
def __init__(self, descriptor: Descriptor, package_name: str = "", package_aliases: dict | None = None):
self.descriptor = descriptor
self.package_name = package_name
# self.field_dic = {}
# Mapping from foreign protobuf package names to the local preload
# aliases used in the generated GDScript (e.g. "ats2.types.messages"
# -> "messages"). This allows us to resolve fully-qualified
# protobuf type names to the correct alias-based GDScript references.
self.package_aliases = package_aliases or {}
# self.field_dic = {}
self.field_list = []

def add_field(self, field: GDField):
Expand All @@ -393,17 +409,37 @@ def get_field(self, number: int)->GDField:
# return self.field_list[number]

def real_type_name(self, type_full_name: str)->str:
if len(type_full_name) <= 0:
"""Resolve a protobuf type name to the GDScript reference to use.

This prefers local (same-package) names, but for foreign packages it
uses the preload aliases passed in from the generator so that we emit
expressions like `messages.MessageHeader` instead of a
package-qualified protobuf name such as
`ats2.types.messages.MessageHeader`, which GDScript cannot resolve.
"""
if not type_full_name:
return type_full_name

# Strip leading dot from fully-qualified type names.
if type_full_name[0] == '.':
type_full_name = type_full_name[1:]

if len(self.package_name) <= 0:
return type_full_name
# Same-package types: drop the package prefix entirely so we refer to
# the local class name (and nested names, if any).
if self.package_name and type_full_name.startswith(self.package_name + "."):
return type_full_name[len(self.package_name) + 1 :]

# Imported types: look for a matching foreign package and rewrite to
# use its preload alias plus the remaining type path.
for foreign_pkg, alias in self.package_aliases.items():
if type_full_name.startswith(foreign_pkg + "."):
remainder = type_full_name[len(foreign_pkg) + 1 :]
return f"{alias}.{remainder}"

# Remove package name
return type_full_name.replace(self.package_name + ".", "")
# Fallback: return the original (minus any leading dot). This keeps
# behaviour unchanged for types we cannot resolve via aliases, though
# it may still produce package-qualified names.
return type_full_name

def create_gd_field(gd_msg: GDMessageType, descriptor: Descriptor, field: FieldDescriptor) ->GDField:
# default_value_func = lambda default : field.default_value if hasattr(field, 'default_value') else default
Expand All @@ -412,40 +448,32 @@ def create_gd_field(gd_msg: GDMessageType, descriptor: Descriptor, field: FieldD
is_map: bool = False
#create_field_func = lambda f: GDMapField() if is_map else GDRepeatedField() if f.label == FieldDescriptor.LABEL_REPEATED else GDField()

# First, detect and resolve real map fields (which compile to nested *Entry messages).
if field.type == FieldDescriptor.TYPE_MESSAGE and field.label == FieldDescriptor.LABEL_REPEATED and field.type_name:
type_name = field.type_name
if type_name.startswith("."):
type_name = type_name[1:]
parts = type_name.split(".")
if len(parts) > 1 and parts[-1].endswith("Entry"):
is_map = True
# Only treat as a map if we can resolve the nested Entry type on this descriptor.
map_type = None
for nested_type in descriptor.nested_type:
if nested_type.name == parts[-1]:
map_type = nested_type
break

if map_type and len(map_type.field) >= 2:
is_map = True
gd_field = GDMapField()
key_field = create_gd_field(gd_msg, descriptor, map_type.field[0])
value_field = create_gd_field(gd_msg, descriptor, map_type.field[1])
gd_field.create(field.name, field.number, field.type, key_field, value_field)
return gd_field

gd_field: GDField = None #create_field_func(field)
real_type = gd_msg.real_type_name(field.type_name)

if is_map:
map_type = None

type_name = field.type_name
if type_name.startswith("."):
type_name = type_name[1:]
parts = type_name.split(".")

for nested_type in descriptor.nested_type:
if nested_type.name == parts[-1]:
map_type = nested_type
break

if map_type and len(map_type.field) >= 2:
gd_field = GDMapField()
key_field = create_gd_field(gd_msg, descriptor, map_type.field[0])
value_field = create_gd_field(gd_msg, descriptor, map_type.field[1])
gd_field.create(field.name, field.number, field.type, key_field, value_field)
else:
gd_field = GDField()
gd_field.create(f"m_unknown_{field.name}", field.number, field.type, f"m_unknown_{map_type}", f"unknown_{na}", "unknown")
return gd_field
elif field.type == FieldDescriptor.TYPE_STRING:
if field.type == FieldDescriptor.TYPE_STRING:
gd_field = GDField()
default_value = default_value_func("")
gd_field.create(field.name, field.number, field.type, "String", f"\"{default_value}\"", "string")
Expand Down Expand Up @@ -503,19 +531,27 @@ def create_gd_field(gd_msg: GDMessageType, descriptor: Descriptor, field: FieldD
gd_field = GDMessageField()
gd_field.create(field.name, field.number, field.type, real_type, "null", "message")
else:
# Unknown / unsupported field type: fall back to a basic GDField so we
# never return or wrap a None gd_field.
gd_field = GDField()
gd_field.create("unknown", field.number, field.type, f"unknown_{field.type}", "unknown", "unknown")
return gd_field


if field.label == FieldDescriptor.LABEL_REPEATED and field.type != FieldDescriptor.TYPE_BYTES and is_map == False:
# Defensive: ensure we have a concrete sub_field to wrap.
if gd_field is None:
gd_field = GDField()
gd_field.create(field.name, field.number, field.type, "Variant", "null", "unknown")

repeated_field = GDRepeatedField()
repeated_field.create(field.name, field.number, field.type, gd_field)
return repeated_field

return gd_field

def init_message_type( descriptor: Descriptor, package_name: str) -> GDMessageType:
gd_message_type = GDMessageType(descriptor, package_name)
def init_message_type( descriptor: Descriptor, package_name: str, package_aliases: dict | None = None) -> GDMessageType:
gd_message_type = GDMessageType(descriptor, package_name, package_aliases)

for field in descriptor.field:
gd_field = create_gd_field(gd_message_type, descriptor, field)
Expand Down
82 changes: 53 additions & 29 deletions generate_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,16 @@ def get_import_path(proto_file_path: str, import_path: str) -> str:
"""
# Get the directory of the current proto file
proto_dir = os.path.dirname(proto_file_path)

# Calculate the absolute path of the import file
import_abs_path = os.path.normpath(os.path.join(proto_dir, import_path))

# Determine how to resolve the import:
# - If the import path starts with './' or '../', treat it as *relative* to
# the current proto file directory.
# - Otherwise, treat it as rooted relative to the proto include roots, as
# protoc normally does for imports like "ats2/types/identifier.proto".
if import_path.startswith("./") or import_path.startswith("../"):
import_abs_path = os.path.normpath(os.path.join(proto_dir, import_path))
else:
import_abs_path = os.path.normpath(import_path)

# Replace the .proto extension with .proto.gd
import_gd_path = os.path.splitext(import_abs_path)[0] + ".proto.gd"
Expand Down Expand Up @@ -151,27 +158,15 @@ def generate_imports(proto_file) -> str:
# Add imports for other proto files
for dependency in proto_file.dependency:
import_path = get_import_path(proto_file.name, dependency)
content += f'const {os.path.splitext(os.path.basename(dependency))[0]} = preload("{import_path}")\n'
alias = os.path.splitext(os.path.basename(dependency))[0]
content += f'const {alias} = preload("{import_path}")\n'

if content:
content += "\n"

return content

package_name = ""
def real_type_name(type_full_name: string):
if len(type_full_name) <= 0:
return type_full_name

if type_full_name[0] == '.':
type_full_name = type_full_name[1:]

if len(package_name) <= 0:
return type_full_name

# Remove package name
return type_full_name.replace(package_name + ".", "")

def generate_gdscript(request: plugin_pb2.CodeGeneratorRequest) -> plugin_pb2.CodeGeneratorResponse:
"""Generate GDScript code from the request."""
response = plugin_pb2.CodeGeneratorResponse()
Expand All @@ -186,14 +181,29 @@ def generate_gdscript(request: plugin_pb2.CodeGeneratorRequest) -> plugin_pb2.Co
global package_name
package_name = proto_file.package

proto_file_name = os.path.splitext(os.path.basename(proto_file.name))[0]
file_name = f"{proto_file_name}.proto.gd"
# Preserve the original proto file path (relative to the proto include roots)
# so that protoc will create the same folder structure under --gdscript_out.
proto_path_no_ext = os.path.splitext(proto_file.name)[0]
file_name = f"{proto_path_no_ext}.proto.gd"
file = response.file.add()
file.name = file_name

# Initialize content with package name
file.content = f"# Package: {package_name}\n\n"

# Build a mapping from imported protobuf package names to the preload
# aliases we use in this generated file. We derive the foreign package
# names from the request's proto_file descriptors.
package_aliases: dict[str, str] = {}
dep_pkg_by_name: dict[str, str] = {pf.name: pf.package for pf in request.proto_file}

for dependency in proto_file.dependency:
dep_pkg = dep_pkg_by_name.get(dependency, "")
if not dep_pkg:
continue
alias = os.path.splitext(os.path.basename(dependency))[0]
package_aliases[dep_pkg] = alias

# Add imports
file.content += generate_imports(proto_file)

Expand All @@ -207,18 +217,32 @@ def generate_gdscript(request: plugin_pb2.CodeGeneratorRequest) -> plugin_pb2.Co
if not message_type.name:
continue

# Generate message class line by line
file.content += generate_message_class(message_type, 0, package_name)
# Generate message class line by line, passing package_aliases so
# that field generation can resolve foreign types via preload
# aliases rather than package-qualified names.
file.content += generate_message_class(message_type, 0, package_name, package_aliases)
# Add separator between message types
file.content += "# =========================================\n\n"

return response

# Advertise support for proto3 optional fields so protoc accepts optional in proto3 syntax.
# This relies on the CodeGeneratorResponse.Feature.FEATURE_PROTO3_OPTIONAL flag
# being present in the version of google.protobuf used at runtime.
try:
response.supported_features |= plugin_pb2.CodeGeneratorResponse.Feature.FEATURE_PROTO3_OPTIONAL
except AttributeError:
# Older protobuf versions may not define FEATURE_PROTO3_OPTIONAL.
# In that case we simply don't set the flag; protoc will emit a warning
# or error, but the generator itself can still function for other cases.
pass

return response

def generate_message_class(message_type: MessageType, indent_level: int = 0, msg_package_name="") -> str:

# global package_name
gd_msg = gd_protobuf_info.init_message_type(message_type, msg_package_name)
def generate_message_class(message_type: MessageType, indent_level: int = 0, msg_package_name: str = "", package_aliases: dict | None = None) -> str:
# Build a GDMessageType enriched with package alias information so that
# field generation can resolve foreign message/enum types to the correct
# preload alias references (e.g. messages.MessageHeader).
gd_msg = gd_protobuf_info.init_message_type(message_type, msg_package_name, package_aliases)

"""Generate a message class."""
content = ""
Expand All @@ -241,7 +265,7 @@ def generate_message_class(message_type: MessageType, indent_level: int = 0, msg
if nested_type.options.map_entry:
# This is a map field
continue
content += generate_message_class(nested_type, indent_level + 1, msg_package_name)
content += generate_message_class(nested_type, indent_level + 1, msg_package_name, package_aliases)

# Generate Init method
content += generate_init_method(message_type, gd_msg, indent + "\t")
Expand Down Expand Up @@ -429,11 +453,11 @@ def generate_init_method(message_type: MessageType, gd_message_type: GDMessageTy

if len(message_type.field) <= 0:
content += f"{indent}\tpass"
return

for gd_field in gd_message_type.field_list:
if isinstance(gd_field, GDField):
content += f"{gd_field.field_clear(indent + '\t')}\n"
# Each clear call should be on its own indented line in GDScript.
content += gd_field.field_clear(indent + "\t") + "\n"
content += "\n"
return content

Expand Down