diff --git a/gd_protobuf_info.py b/gd_protobuf_info.py index b23e96b..a9e7f0a 100644 --- a/gd_protobuf_info.py +++ b/gd_protobuf_info.py @@ -30,7 +30,9 @@ def field_type_name(self)->str: def field_define(self, indent: str, define_name: str = None, flag: int = 0)->str: if define_name is None: define_name = self.field_name() - return f"{indent}var {define_name}: {self.field_type_name()} = {self.default_value}" + # Use an untyped variable declaration to avoid invalid or overly + # specific type hints (especially for message/enum types). + return f"{indent}var {define_name} = {self.default_value}" def field_clear(self, indent: str) -> str: return f"{indent}self.{self.field_name()} = {self.default_value}" @@ -95,13 +97,16 @@ class GDMessageField(GDField): def field_define(self, indent: str, define_name: str = None, flag: int = 0)->str: if define_name is None: define_name = self.field_name() + # For message fields, declare without a static type hint so that + # Godot's type system doesn't have to resolve package-qualified names. if flag == 0: - return f"{indent}var {define_name}: {self.field_type_name()} = {self.default_value}" + return f"{indent}var {define_name} = {self.default_value}" else: - return f"{indent}var {define_name}: {self.field_type_name()} = {self.field_type_name()}.new()" + return f"{indent}var {define_name} = {self.field_type_name()}.new()" def field_clear(self, indent: str) -> str: - content = f"{indent}if self.{self.field_name()} != null:" + # Clear nested message fields with a properly formatted multiline if-block. + content = f"{indent}if self.{self.field_name()} != null:\n" content += f"{indent}\tself.{self.field_name()}.clear()" return content @@ -179,10 +184,13 @@ def create(self, name: str, number: int, field_type: int, field: GDField): self.sub_field = field # Make field name private # self.name = "_" + name - super().create(name, number, field_type, f"Array[{self.sub_field.field_type_name()}]", "[]", "") + # Use a generic Array container type; element type hints will be kept + # loose (Variant) to avoid invalid qualified type names. + super().create(name, number, field_type, "Array", "[]", "") def field_clear(self, indent: str) -> str: - return f"{indent}self.{self.method_field_clear_name()}" + # Call the generated clear_* method to reset the repeated field. + return f"{indent}self.{self.method_field_clear_name()}()" def field_merge(self, indent: str, other: str) -> str: @@ -231,15 +239,18 @@ def field_define(self, indent: str, define_name: str = None, flag: int = 0) -> s content += f"{indent}func {self.method_field_size_name()}() -> int:\n" content += f"{indent}\treturn self.{self.field_size_name()}\n" content += f"{indent}## Get {self.field_name()}\n" - content += f"{indent}func {self.method_field_get_array_name()}() -> {self.field_type_name()}:\n" + # Expose the underlying Array without an element type annotation. + content += f"{indent}func {self.method_field_get_array_name()}() -> Array:\n" content += f"{indent}\treturn self.{self.field_name()}.slice(0, self.{self.field_size_name()})\n" content += f"{indent}## Get {self.field_name()} item \n" - content += f"{indent}func {self.method_field_get_name()}(index: int) -> {self.sub_field.field_type_name()}: # index begin from 1\n" + # Individual items are treated as Variant to avoid invalid qualified + # type names in function signatures. + content += f"{indent}func {self.method_field_get_name()}(index: int) -> Variant: # index begin from 1\n" content += f"{indent}\tif {self._index_check_content()}:\n" content += f"{indent}\t\treturn self.{self.field_name()}[index - 1]\n" content += f"{indent}\treturn {self.sub_field.default_value}\n" content += f"{indent}## Add {self.field_name()}\n" - content += f"{indent}func {self.method_field_add_name()}(item: {self.sub_field.field_type_name()}) -> {self.sub_field.field_type_name()}:\n" + content += f"{indent}func {self.method_field_add_name()}(item: Variant) -> Variant:\n" content += f"{indent}\tif self.{self.field_size_name()} >= 0 and self.{self.field_size_name()} < self.{self.field_name()}.size():\n" content += f"{indent}\t\tself.{self.field_name()}[self.{self.field_size_name()}] = item\n" content += f"{indent}\telse:\n" @@ -375,10 +386,15 @@ def field_parse(self, return content class GDMessageType: - def __init__(self, descriptor: Descriptor, package_name: str = ""): + def __init__(self, descriptor: Descriptor, package_name: str = "", package_aliases: dict | None = None): self.descriptor = descriptor self.package_name = package_name -# self.field_dic = {} + # Mapping from foreign protobuf package names to the local preload + # aliases used in the generated GDScript (e.g. "ats2.types.messages" + # -> "messages"). This allows us to resolve fully-qualified + # protobuf type names to the correct alias-based GDScript references. + self.package_aliases = package_aliases or {} + # self.field_dic = {} self.field_list = [] def add_field(self, field: GDField): @@ -393,17 +409,37 @@ def get_field(self, number: int)->GDField: # return self.field_list[number] def real_type_name(self, type_full_name: str)->str: - if len(type_full_name) <= 0: + """Resolve a protobuf type name to the GDScript reference to use. + + This prefers local (same-package) names, but for foreign packages it + uses the preload aliases passed in from the generator so that we emit + expressions like `messages.MessageHeader` instead of a + package-qualified protobuf name such as + `ats2.types.messages.MessageHeader`, which GDScript cannot resolve. + """ + if not type_full_name: return type_full_name + # Strip leading dot from fully-qualified type names. if type_full_name[0] == '.': type_full_name = type_full_name[1:] - if len(self.package_name) <= 0: - return type_full_name + # Same-package types: drop the package prefix entirely so we refer to + # the local class name (and nested names, if any). + if self.package_name and type_full_name.startswith(self.package_name + "."): + return type_full_name[len(self.package_name) + 1 :] + + # Imported types: look for a matching foreign package and rewrite to + # use its preload alias plus the remaining type path. + for foreign_pkg, alias in self.package_aliases.items(): + if type_full_name.startswith(foreign_pkg + "."): + remainder = type_full_name[len(foreign_pkg) + 1 :] + return f"{alias}.{remainder}" - # Remove package name - return type_full_name.replace(self.package_name + ".", "") + # Fallback: return the original (minus any leading dot). This keeps + # behaviour unchanged for types we cannot resolve via aliases, though + # it may still produce package-qualified names. + return type_full_name def create_gd_field(gd_msg: GDMessageType, descriptor: Descriptor, field: FieldDescriptor) ->GDField: # default_value_func = lambda default : field.default_value if hasattr(field, 'default_value') else default @@ -412,40 +448,32 @@ def create_gd_field(gd_msg: GDMessageType, descriptor: Descriptor, field: FieldD is_map: bool = False #create_field_func = lambda f: GDMapField() if is_map else GDRepeatedField() if f.label == FieldDescriptor.LABEL_REPEATED else GDField() + # First, detect and resolve real map fields (which compile to nested *Entry messages). if field.type == FieldDescriptor.TYPE_MESSAGE and field.label == FieldDescriptor.LABEL_REPEATED and field.type_name: type_name = field.type_name if type_name.startswith("."): type_name = type_name[1:] parts = type_name.split(".") if len(parts) > 1 and parts[-1].endswith("Entry"): - is_map = True + # Only treat as a map if we can resolve the nested Entry type on this descriptor. + map_type = None + for nested_type in descriptor.nested_type: + if nested_type.name == parts[-1]: + map_type = nested_type + break + + if map_type and len(map_type.field) >= 2: + is_map = True + gd_field = GDMapField() + key_field = create_gd_field(gd_msg, descriptor, map_type.field[0]) + value_field = create_gd_field(gd_msg, descriptor, map_type.field[1]) + gd_field.create(field.name, field.number, field.type, key_field, value_field) + return gd_field gd_field: GDField = None #create_field_func(field) real_type = gd_msg.real_type_name(field.type_name) - if is_map: - map_type = None - - type_name = field.type_name - if type_name.startswith("."): - type_name = type_name[1:] - parts = type_name.split(".") - - for nested_type in descriptor.nested_type: - if nested_type.name == parts[-1]: - map_type = nested_type - break - - if map_type and len(map_type.field) >= 2: - gd_field = GDMapField() - key_field = create_gd_field(gd_msg, descriptor, map_type.field[0]) - value_field = create_gd_field(gd_msg, descriptor, map_type.field[1]) - gd_field.create(field.name, field.number, field.type, key_field, value_field) - else: - gd_field = GDField() - gd_field.create(f"m_unknown_{field.name}", field.number, field.type, f"m_unknown_{map_type}", f"unknown_{na}", "unknown") - return gd_field - elif field.type == FieldDescriptor.TYPE_STRING: + if field.type == FieldDescriptor.TYPE_STRING: gd_field = GDField() default_value = default_value_func("") gd_field.create(field.name, field.number, field.type, "String", f"\"{default_value}\"", "string") @@ -503,19 +531,27 @@ def create_gd_field(gd_msg: GDMessageType, descriptor: Descriptor, field: FieldD gd_field = GDMessageField() gd_field.create(field.name, field.number, field.type, real_type, "null", "message") else: + # Unknown / unsupported field type: fall back to a basic GDField so we + # never return or wrap a None gd_field. + gd_field = GDField() gd_field.create("unknown", field.number, field.type, f"unknown_{field.type}", "unknown", "unknown") return gd_field if field.label == FieldDescriptor.LABEL_REPEATED and field.type != FieldDescriptor.TYPE_BYTES and is_map == False: + # Defensive: ensure we have a concrete sub_field to wrap. + if gd_field is None: + gd_field = GDField() + gd_field.create(field.name, field.number, field.type, "Variant", "null", "unknown") + repeated_field = GDRepeatedField() repeated_field.create(field.name, field.number, field.type, gd_field) return repeated_field return gd_field -def init_message_type( descriptor: Descriptor, package_name: str) -> GDMessageType: - gd_message_type = GDMessageType(descriptor, package_name) +def init_message_type( descriptor: Descriptor, package_name: str, package_aliases: dict | None = None) -> GDMessageType: + gd_message_type = GDMessageType(descriptor, package_name, package_aliases) for field in descriptor.field: gd_field = create_gd_field(gd_message_type, descriptor, field) diff --git a/generate_message.py b/generate_message.py index c9af543..94ae96f 100644 --- a/generate_message.py +++ b/generate_message.py @@ -117,9 +117,16 @@ def get_import_path(proto_file_path: str, import_path: str) -> str: """ # Get the directory of the current proto file proto_dir = os.path.dirname(proto_file_path) - - # Calculate the absolute path of the import file - import_abs_path = os.path.normpath(os.path.join(proto_dir, import_path)) + + # Determine how to resolve the import: + # - If the import path starts with './' or '../', treat it as *relative* to + # the current proto file directory. + # - Otherwise, treat it as rooted relative to the proto include roots, as + # protoc normally does for imports like "ats2/types/identifier.proto". + if import_path.startswith("./") or import_path.startswith("../"): + import_abs_path = os.path.normpath(os.path.join(proto_dir, import_path)) + else: + import_abs_path = os.path.normpath(import_path) # Replace the .proto extension with .proto.gd import_gd_path = os.path.splitext(import_abs_path)[0] + ".proto.gd" @@ -151,7 +158,8 @@ def generate_imports(proto_file) -> str: # Add imports for other proto files for dependency in proto_file.dependency: import_path = get_import_path(proto_file.name, dependency) - content += f'const {os.path.splitext(os.path.basename(dependency))[0]} = preload("{import_path}")\n' + alias = os.path.splitext(os.path.basename(dependency))[0] + content += f'const {alias} = preload("{import_path}")\n' if content: content += "\n" @@ -159,19 +167,6 @@ def generate_imports(proto_file) -> str: return content package_name = "" -def real_type_name(type_full_name: string): - if len(type_full_name) <= 0: - return type_full_name - - if type_full_name[0] == '.': - type_full_name = type_full_name[1:] - - if len(package_name) <= 0: - return type_full_name - - # Remove package name - return type_full_name.replace(package_name + ".", "") - def generate_gdscript(request: plugin_pb2.CodeGeneratorRequest) -> plugin_pb2.CodeGeneratorResponse: """Generate GDScript code from the request.""" response = plugin_pb2.CodeGeneratorResponse() @@ -186,14 +181,29 @@ def generate_gdscript(request: plugin_pb2.CodeGeneratorRequest) -> plugin_pb2.Co global package_name package_name = proto_file.package - proto_file_name = os.path.splitext(os.path.basename(proto_file.name))[0] - file_name = f"{proto_file_name}.proto.gd" + # Preserve the original proto file path (relative to the proto include roots) + # so that protoc will create the same folder structure under --gdscript_out. + proto_path_no_ext = os.path.splitext(proto_file.name)[0] + file_name = f"{proto_path_no_ext}.proto.gd" file = response.file.add() file.name = file_name # Initialize content with package name file.content = f"# Package: {package_name}\n\n" + # Build a mapping from imported protobuf package names to the preload + # aliases we use in this generated file. We derive the foreign package + # names from the request's proto_file descriptors. + package_aliases: dict[str, str] = {} + dep_pkg_by_name: dict[str, str] = {pf.name: pf.package for pf in request.proto_file} + + for dependency in proto_file.dependency: + dep_pkg = dep_pkg_by_name.get(dependency, "") + if not dep_pkg: + continue + alias = os.path.splitext(os.path.basename(dependency))[0] + package_aliases[dep_pkg] = alias + # Add imports file.content += generate_imports(proto_file) @@ -207,18 +217,32 @@ def generate_gdscript(request: plugin_pb2.CodeGeneratorRequest) -> plugin_pb2.Co if not message_type.name: continue - # Generate message class line by line - file.content += generate_message_class(message_type, 0, package_name) + # Generate message class line by line, passing package_aliases so + # that field generation can resolve foreign types via preload + # aliases rather than package-qualified names. + file.content += generate_message_class(message_type, 0, package_name, package_aliases) # Add separator between message types file.content += "# =========================================\n\n" - - return response + # Advertise support for proto3 optional fields so protoc accepts optional in proto3 syntax. + # This relies on the CodeGeneratorResponse.Feature.FEATURE_PROTO3_OPTIONAL flag + # being present in the version of google.protobuf used at runtime. + try: + response.supported_features |= plugin_pb2.CodeGeneratorResponse.Feature.FEATURE_PROTO3_OPTIONAL + except AttributeError: + # Older protobuf versions may not define FEATURE_PROTO3_OPTIONAL. + # In that case we simply don't set the flag; protoc will emit a warning + # or error, but the generator itself can still function for other cases. + pass + + return response -def generate_message_class(message_type: MessageType, indent_level: int = 0, msg_package_name="") -> str: -# global package_name - gd_msg = gd_protobuf_info.init_message_type(message_type, msg_package_name) +def generate_message_class(message_type: MessageType, indent_level: int = 0, msg_package_name: str = "", package_aliases: dict | None = None) -> str: + # Build a GDMessageType enriched with package alias information so that + # field generation can resolve foreign message/enum types to the correct + # preload alias references (e.g. messages.MessageHeader). + gd_msg = gd_protobuf_info.init_message_type(message_type, msg_package_name, package_aliases) """Generate a message class.""" content = "" @@ -241,7 +265,7 @@ def generate_message_class(message_type: MessageType, indent_level: int = 0, msg if nested_type.options.map_entry: # This is a map field continue - content += generate_message_class(nested_type, indent_level + 1, msg_package_name) + content += generate_message_class(nested_type, indent_level + 1, msg_package_name, package_aliases) # Generate Init method content += generate_init_method(message_type, gd_msg, indent + "\t") @@ -429,11 +453,11 @@ def generate_init_method(message_type: MessageType, gd_message_type: GDMessageTy if len(message_type.field) <= 0: content += f"{indent}\tpass" - return for gd_field in gd_message_type.field_list: if isinstance(gd_field, GDField): - content += f"{gd_field.field_clear(indent + '\t')}\n" + # Each clear call should be on its own indented line in GDScript. + content += gd_field.field_clear(indent + "\t") + "\n" content += "\n" return content