diff --git a/CHANGELOG.md b/CHANGELOG.md index 0426c32..f86ac6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Features +- Added dynamic shapes for inputs and outputs. Disciplines can now + declare variables with `dynamic_shape=True` in `add_input` / + `add_output`, indicating that the client is allowed to set the + variable's shape at runtime. A new `SetVariableShapes` gRPC RPC + lets clients send resolved shapes after querying variable definitions. + The OpenMDAO bindings automatically map dynamic-shape variables to + `shape_by_conn=True` and send resolved shapes back to the server + (MDO-Standards/Philote-MDO#6). - Added support for struct (dict) options via the new `kStruct` DataType enum value, enabling complex nested data to be declared and passed as discipline options (#49). diff --git a/philote_mdo/examples/__init__.py b/philote_mdo/examples/__init__.py index 414b3e5..ae2f477 100644 --- a/philote_mdo/examples/__init__.py +++ b/philote_mdo/examples/__init__.py @@ -27,6 +27,7 @@ # the linked websites, of the information, products, or services contained # therein. The DoD does not exercise any editorial, security, or other # control over the information you may find at these locations. +from .flexible import FlexibleDiscipline from .paraboloid import Paraboloid from .quadratic import QuadradicImplicit from .rosenbrock import Rosenbrock diff --git a/philote_mdo/examples/flexible.py b/philote_mdo/examples/flexible.py new file mode 100644 index 0000000..a58c59e --- /dev/null +++ b/philote_mdo/examples/flexible.py @@ -0,0 +1,55 @@ +# Philote-Python +# +# Copyright 2022-2025 Christopher A. Lupp +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# This work has been cleared for public release, distribution unlimited, case +# number: AFRL-2023-5713. +# +# The views expressed are those of the authors and do not reflect the +# official guidance or position of the United States Government, the +# Department of Defense or of the United States Air Force. +# +# Statement from DoD: The Appearance of external hyperlinks does not +# constitute endorsement by the United States Department of Defense (DoD) of +# the linked websites, of the information, products, or services contained +# therein. The DoD does not exercise any editorial, security, or other +# control over the information you may find at these locations. +import numpy as np +import philote_mdo.general as pmdo + + +class FlexibleDiscipline(pmdo.ExplicitDiscipline): + """ + Example explicit discipline with dynamic shapes. + + This discipline doubles every element of the input vector. The input + and output shapes are not fixed by the server — the client is + expected to set them via ``SetVariableShapes`` before computation. + """ + + def setup(self): + self.add_input("x", dynamic_shape=True, units="m") + self.add_output("y", dynamic_shape=True, units="m") + + def setup_partials(self): + self.declare_partials("y", "x") + + def compute(self, inputs, outputs): + outputs["y"] = 2.0 * inputs["x"] + + def compute_partials(self, inputs, partials): + n = inputs["x"].size + partials["y", "x"] = 2.0 * np.eye(n) diff --git a/philote_mdo/general/discipline.py b/philote_mdo/general/discipline.py index ed8414d..010975f 100644 --- a/philote_mdo/general/discipline.py +++ b/philote_mdo/general/discipline.py @@ -86,7 +86,7 @@ def add_option(self, name, type): ) self.options_list[name] = type - def add_input(self, name, shape=(1,), units=""): + def add_input(self, name, shape=(1,), units="", dynamic_shape=False): """ Define a continuous input. @@ -95,12 +95,16 @@ def add_input(self, name, shape=(1,), units=""): name : string the name of the input variable shape : tuple - the shape of the input variable + the shape of the input variable (ignored when dynamic_shape + is True) units : string the unit definition for the input variable + dynamic_shape : bool + when True, the client is allowed to set this variable's shape """ validate_name(name, "add_input") - validate_shape(shape, "add_input") + if not dynamic_shape: + validate_shape(shape, "add_input") validate_units(units, "add_input") if any(v.name == name and v.type == data.VariableType.kInput for v in self._var_meta): raise PhiloteValidationError( @@ -109,8 +113,10 @@ def add_input(self, name, shape=(1,), units=""): meta = data.VariableMetaData() meta.type = data.VariableType.kInput meta.name = name - meta.shape.extend(shape) + if not dynamic_shape: + meta.shape.extend(shape) meta.units = units + meta.dynamic_shape = dynamic_shape self._var_meta += [meta] def add_discrete_input(self, name, default=None): @@ -167,7 +173,7 @@ def add_discrete_output(self, name, default=None): meta.name = name self._discrete_var_meta += [meta] - def add_output(self, name, shape=(1,), units=""): + def add_output(self, name, shape=(1,), units="", dynamic_shape=False): """ Defines a continuous output. @@ -176,12 +182,16 @@ def add_output(self, name, shape=(1,), units=""): name : string the name of the output variable shape : tuple - the shape of the output variable + the shape of the output variable (ignored when dynamic_shape + is True) units : string the unit definition for the output variable + dynamic_shape : bool + when True, the client is allowed to set this variable's shape """ validate_name(name, "add_output") - validate_shape(shape, "add_output") + if not dynamic_shape: + validate_shape(shape, "add_output") validate_units(units, "add_output") if any(v.name == name and v.type == data.VariableType.kOutput for v in self._var_meta): raise PhiloteValidationError( @@ -190,17 +200,21 @@ def add_output(self, name, shape=(1,), units=""): out_meta = data.VariableMetaData() out_meta.type = data.VariableType.kOutput out_meta.name = name - out_meta.shape.extend(shape) + if not dynamic_shape: + out_meta.shape.extend(shape) out_meta.units = units + out_meta.dynamic_shape = dynamic_shape self._var_meta += [out_meta] if self._is_implicit: res_meta = data.VariableMetaData() res_meta.type = data.VariableType.kOutput res_meta.name = name - res_meta.shape.extend(shape) + if not dynamic_shape: + res_meta.shape.extend(shape) res_meta.units = units res_meta.type = data.VariableType.kResidual + res_meta.dynamic_shape = dynamic_shape self._var_meta += [res_meta] def declare_partials(self, func, var): diff --git a/philote_mdo/general/discipline_client.py b/philote_mdo/general/discipline_client.py index dfbed15..684f5e2 100644 --- a/philote_mdo/general/discipline_client.py +++ b/philote_mdo/general/discipline_client.py @@ -154,6 +154,71 @@ def get_partials_definitions(self): if message.name not in self._partials_meta: self._partials_meta += [message] + def get_dynamic_variables(self): + """ + Returns a list of variable metadata entries that have + ``dynamic_shape`` set to ``True``. + """ + return [v for v in self._var_meta if v.dynamic_shape] + + def set_variable_shape(self, name, shape, var_type=data.VariableType.kInput): + """ + Creates a ``VariableMetaData`` message for setting a dynamic + variable's shape. + + Parameters + ---------- + name : str + the name of the variable + shape : tuple + the desired shape + var_type : VariableType + the variable type (kInput or kOutput) + + Returns + ------- + VariableMetaData + protobuf message ready for ``send_variable_shapes`` + """ + meta = data.VariableMetaData() + meta.type = var_type + meta.name = name + meta.shape.extend(shape) + return meta + + def send_variable_shapes(self, variable_metadata): + """ + Sends shapes for variables flagged as ``dynamic_shape``. + + Call after ``get_variable_definitions()`` and before compute + calls. + + Parameters + ---------- + variable_metadata : list of VariableMetaData + shapes for dynamic variables + """ + self._disc_stub.SetVariableShapes(iter(variable_metadata)) + + # update local metadata to reflect the new shapes + for meta in variable_metadata: + for var in self._var_meta: + if var.name == meta.name and var.type == meta.type: + var.shape[:] = [] + var.shape.extend(meta.shape) + break + + # for implicit outputs, also update the matching residual + if meta.type == data.VariableType.kOutput: + for var in self._var_meta: + if ( + var.name == meta.name + and var.type == data.VariableType.kResidual + ): + var.shape[:] = [] + var.shape.extend(meta.shape) + break + def _assemble_input_messages( self, inputs, outputs=None, discrete_inputs=None, discrete_outputs=None ): diff --git a/philote_mdo/general/discipline_server.py b/philote_mdo/general/discipline_server.py index a7baee6..2a0d956 100644 --- a/philote_mdo/general/discipline_server.py +++ b/philote_mdo/general/discipline_server.py @@ -35,7 +35,7 @@ from google.protobuf.empty_pb2 import Empty from google.protobuf import struct_pb2 from philote_mdo.utils import PairDict, get_flattened_view -from philote_mdo.utils.validation import PhiloteValidationError +from philote_mdo.utils.validation import PhiloteValidationError, validate_shape class DisciplineServer(disc.DisciplineService): @@ -170,6 +170,57 @@ def GetPartialDefinitions(self, request, context): for jac in self._discipline._partials_meta: yield jac + def SetVariableShapes(self, request_iterator, context): + """ + Receives client-defined shapes for variables flagged as + dynamic_shape. + + The client must call this RPC after GetVariableDefinitions and + before any compute RPCs for disciplines that contain variables + with dynamic shapes. + """ + try: + for meta in request_iterator: + validate_shape(tuple(meta.shape), "SetVariableShapes") + + # find the matching variable and update its shape + for var in self._discipline._var_meta: + if var.name == meta.name and var.type == meta.type: + if not var.dynamic_shape: + raise PhiloteValidationError( + f"Variable '{meta.name}' does not allow " + f"dynamic shapes." + ) + var.shape[:] = [] + var.shape.extend(meta.shape) + break + else: + raise PhiloteValidationError( + f"SetVariableShapes: variable '{meta.name}' " + f"not found." + ) + + # if the variable is an output on an implicit discipline, + # also update the matching residual entry + if meta.type == data.VariableType.kOutput: + for var in self._discipline._var_meta: + if ( + var.name == meta.name + and var.type == data.VariableType.kResidual + and var.dynamic_shape + ): + var.shape[:] = [] + var.shape.extend(meta.shape) + break + + return Empty() + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, f"SetVariableShapes failed: {e}" + ) + def preallocate_inputs(self, inputs, flat_inputs, outputs=None, flat_outputs=None): """ Preallocates the inputs before receiving data from the client. @@ -178,6 +229,14 @@ def preallocate_inputs(self, inputs, flat_inputs, outputs=None, flat_outputs=Non inputs to evaluate the residuals and the partials of the residuals. """ for var in self._discipline._var_meta: + # validate that dynamic-shape variables have been resolved + if var.dynamic_shape and len(var.shape) == 0: + raise PhiloteValidationError( + f"Variable '{var.name}' has dynamic_shape=True but " + f"no shape has been set. Call SetVariableShapes " + f"before computing." + ) + if var.type == data.kInput: inputs[var.name] = np.zeros(var.shape) flat_inputs[var.name] = get_flattened_view(inputs[var.name]) diff --git a/philote_mdo/generated/data_pb2.py b/philote_mdo/generated/data_pb2.py index b6029b2..668e1ab 100644 --- a/philote_mdo/generated/data_pb2.py +++ b/philote_mdo/generated/data_pb2.py @@ -7,17 +7,17 @@ _runtime_version.ValidateProtobufRuntimeVersion(_runtime_version.Domain.PUBLIC, 5, 27, 2, '', 'data.proto') _sym_db = _symbol_database.Default() from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x07philote\x1a\x1cgoogle/protobuf/struct.proto"}\n\x14DisciplineProperties\x12\x12\n\ncontinuous\x18\x01 \x01(\x08\x12\x16\n\x0edifferentiable\x18\x02 \x01(\x08\x12\x1a\n\x12provides_gradients\x18\x03 \x01(\x08\x12\x0c\n\x04name\x18\x04 \x01(\t\x12\x0f\n\x07version\x18\x05 \x01(\t"#\n\rStreamOptions\x12\x12\n\nnum_double\x18\x01 \x01(\x03"?\n\x0bOptionsList\x12\x0f\n\x07options\x18\x01 \x03(\t\x12\x1f\n\x04type\x18\x02 \x03(\x0e2\x11.philote.DataType"=\n\x11DisciplineOptions\x12(\n\x07options\x18\x01 \x01(\x0b2\x17.google.protobuf.Struct"c\n\x10VariableMetaData\x12#\n\x04type\x18\x01 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\r\n\x05shape\x18\x04 \x03(\x03\x12\r\n\x05units\x18\x05 \x01(\t"@\n\x10PartialsMetaData\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03"u\n\x05Array\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05start\x18\x03 \x01(\x03\x12\x0b\n\x03end\x18\x04 \x01(\x03\x12#\n\x04type\x18\x05 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04data\x18\x06 \x03(\x01"l\n\x10DiscreteVariable\x12\x0c\n\x04name\x18\x01 \x01(\t\x12#\n\x04type\x18\x02 \x01(\x0e2\x15.philote.VariableType\x12%\n\x05value\x18\x03 \x01(\x0b2\x16.google.protobuf.Value"q\n\x0fVariableMessage\x12$\n\ncontinuous\x18\x01 \x01(\x0b2\x0e.philote.ArrayH\x00\x12-\n\x08discrete\x18\x02 \x01(\x0b2\x19.philote.DiscreteVariableH\x00B\t\n\x07payload*F\n\x08DataType\x12\t\n\x05kBool\x10\x00\x12\x08\n\x04kInt\x10\x01\x12\x0b\n\x07kDouble\x10\x02\x12\x0b\n\x07kString\x10\x03\x12\x0b\n\x07kStruct\x10\x04*m\n\x0cVariableType\x12\n\n\x06kInput\x10\x00\x12\x12\n\x0ekDiscreteInput\x10\x01\x12\r\n\tkResidual\x10\x02\x12\x0b\n\x07kOutput\x10\x03\x12\x13\n\x0fkDiscreteOutput\x10\x04\x12\x0c\n\x08kPartial\x10\x05B\x11\n\x0forg.philote.mdob\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x07philote\x1a\x1cgoogle/protobuf/struct.proto"}\n\x14DisciplineProperties\x12\x12\n\ncontinuous\x18\x01 \x01(\x08\x12\x16\n\x0edifferentiable\x18\x02 \x01(\x08\x12\x1a\n\x12provides_gradients\x18\x03 \x01(\x08\x12\x0c\n\x04name\x18\x04 \x01(\t\x12\x0f\n\x07version\x18\x05 \x01(\t"#\n\rStreamOptions\x12\x12\n\nnum_double\x18\x01 \x01(\x03"?\n\x0bOptionsList\x12\x0f\n\x07options\x18\x01 \x03(\t\x12\x1f\n\x04type\x18\x02 \x03(\x0e2\x11.philote.DataType"=\n\x11DisciplineOptions\x12(\n\x07options\x18\x01 \x01(\x0b2\x17.google.protobuf.Struct"z\n\x10VariableMetaData\x12#\n\x04type\x18\x01 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\r\n\x05shape\x18\x04 \x03(\x03\x12\r\n\x05units\x18\x05 \x01(\t\x12\x15\n\rdynamic_shape\x18\x06 \x01(\x08"@\n\x10PartialsMetaData\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03"u\n\x05Array\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05start\x18\x03 \x01(\x03\x12\x0b\n\x03end\x18\x04 \x01(\x03\x12#\n\x04type\x18\x05 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04data\x18\x06 \x03(\x01"l\n\x10DiscreteVariable\x12\x0c\n\x04name\x18\x01 \x01(\t\x12#\n\x04type\x18\x02 \x01(\x0e2\x15.philote.VariableType\x12%\n\x05value\x18\x03 \x01(\x0b2\x16.google.protobuf.Value"q\n\x0fVariableMessage\x12$\n\ncontinuous\x18\x01 \x01(\x0b2\x0e.philote.ArrayH\x00\x12-\n\x08discrete\x18\x02 \x01(\x0b2\x19.philote.DiscreteVariableH\x00B\t\n\x07payload*F\n\x08DataType\x12\t\n\x05kBool\x10\x00\x12\x08\n\x04kInt\x10\x01\x12\x0b\n\x07kDouble\x10\x02\x12\x0b\n\x07kString\x10\x03\x12\x0b\n\x07kStruct\x10\x04*m\n\x0cVariableType\x12\n\n\x06kInput\x10\x00\x12\x12\n\x0ekDiscreteInput\x10\x01\x12\r\n\tkResidual\x10\x02\x12\x0b\n\x07kOutput\x10\x03\x12\x13\n\x0fkDiscreteOutput\x10\x04\x12\x0c\n\x08kPartial\x10\x05B\x11\n\x0forg.philote.mdob\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'data_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: _globals['DESCRIPTOR']._loaded_options = None _globals['DESCRIPTOR']._serialized_options = b'\n\x0forg.philote.mdo' - _globals['_DATATYPE']._serialized_start = 856 - _globals['_DATATYPE']._serialized_end = 926 - _globals['_VARIABLETYPE']._serialized_start = 928 - _globals['_VARIABLETYPE']._serialized_end = 1037 + _globals['_DATATYPE']._serialized_start = 879 + _globals['_DATATYPE']._serialized_end = 949 + _globals['_VARIABLETYPE']._serialized_start = 951 + _globals['_VARIABLETYPE']._serialized_end = 1060 _globals['_DISCIPLINEPROPERTIES']._serialized_start = 53 _globals['_DISCIPLINEPROPERTIES']._serialized_end = 178 _globals['_STREAMOPTIONS']._serialized_start = 180 @@ -27,12 +27,12 @@ _globals['_DISCIPLINEOPTIONS']._serialized_start = 282 _globals['_DISCIPLINEOPTIONS']._serialized_end = 343 _globals['_VARIABLEMETADATA']._serialized_start = 345 - _globals['_VARIABLEMETADATA']._serialized_end = 444 - _globals['_PARTIALSMETADATA']._serialized_start = 446 - _globals['_PARTIALSMETADATA']._serialized_end = 510 - _globals['_ARRAY']._serialized_start = 512 - _globals['_ARRAY']._serialized_end = 629 - _globals['_DISCRETEVARIABLE']._serialized_start = 631 - _globals['_DISCRETEVARIABLE']._serialized_end = 739 - _globals['_VARIABLEMESSAGE']._serialized_start = 741 - _globals['_VARIABLEMESSAGE']._serialized_end = 854 \ No newline at end of file + _globals['_VARIABLEMETADATA']._serialized_end = 467 + _globals['_PARTIALSMETADATA']._serialized_start = 469 + _globals['_PARTIALSMETADATA']._serialized_end = 533 + _globals['_ARRAY']._serialized_start = 535 + _globals['_ARRAY']._serialized_end = 652 + _globals['_DISCRETEVARIABLE']._serialized_start = 654 + _globals['_DISCRETEVARIABLE']._serialized_end = 762 + _globals['_VARIABLEMESSAGE']._serialized_start = 764 + _globals['_VARIABLEMESSAGE']._serialized_end = 877 \ No newline at end of file diff --git a/philote_mdo/generated/data_pb2.pyi b/philote_mdo/generated/data_pb2.pyi index 13c800f..500c23f 100644 --- a/philote_mdo/generated/data_pb2.pyi +++ b/philote_mdo/generated/data_pb2.pyi @@ -77,17 +77,19 @@ class DisciplineOptions(_message.Message): ... class VariableMetaData(_message.Message): - __slots__ = ('type', 'name', 'shape', 'units') + __slots__ = ('type', 'name', 'shape', 'units', 'dynamic_shape') TYPE_FIELD_NUMBER: _ClassVar[int] NAME_FIELD_NUMBER: _ClassVar[int] SHAPE_FIELD_NUMBER: _ClassVar[int] UNITS_FIELD_NUMBER: _ClassVar[int] + DYNAMIC_SHAPE_FIELD_NUMBER: _ClassVar[int] type: VariableType name: str shape: _containers.RepeatedScalarFieldContainer[int] units: str + dynamic_shape: bool - def __init__(self, type: _Optional[_Union[VariableType, str]]=..., name: _Optional[str]=..., shape: _Optional[_Iterable[int]]=..., units: _Optional[str]=...) -> None: + def __init__(self, type: _Optional[_Union[VariableType, str]]=..., name: _Optional[str]=..., shape: _Optional[_Iterable[int]]=..., units: _Optional[str]=..., dynamic_shape: bool=...) -> None: ... class PartialsMetaData(_message.Message): diff --git a/philote_mdo/generated/disciplines_pb2.py b/philote_mdo/generated/disciplines_pb2.py index a78cc88..1425ba1 100644 --- a/philote_mdo/generated/disciplines_pb2.py +++ b/philote_mdo/generated/disciplines_pb2.py @@ -8,7 +8,7 @@ _sym_db = _symbol_database.Default() from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 from . import data_pb2 as data__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11disciplines.proto\x12\x07philote\x1a\x1bgoogle/protobuf/empty.proto\x1a\ndata.proto2\x84\x04\n\x11DisciplineService\x12B\n\x07GetInfo\x12\x16.google.protobuf.Empty\x1a\x1d.philote.DisciplineProperties"\x00\x12D\n\x10SetStreamOptions\x12\x16.philote.StreamOptions\x1a\x16.google.protobuf.Empty"\x00\x12E\n\x13GetAvailableOptions\x12\x16.google.protobuf.Empty\x1a\x14.philote.OptionsList"\x00\x12B\n\nSetOptions\x12\x1a.philote.DisciplineOptions\x1a\x16.google.protobuf.Empty"\x00\x129\n\x05Setup\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty"\x00\x12O\n\x16GetVariableDefinitions\x12\x16.google.protobuf.Empty\x1a\x19.philote.VariableMetaData"\x000\x01\x12N\n\x15GetPartialDefinitions\x12\x16.google.protobuf.Empty\x1a\x19.philote.PartialsMetaData"\x000\x012\xab\x01\n\x0fExplicitService\x12K\n\x0fComputeFunction\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01\x12K\n\x0fComputeGradient\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x012\x81\x02\n\x0fImplicitService\x12L\n\x10ComputeResiduals\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01\x12J\n\x0eSolveResiduals\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01\x12T\n\x18ComputeResidualGradients\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01B\x11\n\x0forg.philote.mdob\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11disciplines.proto\x12\x07philote\x1a\x1bgoogle/protobuf/empty.proto\x1a\ndata.proto2\xd0\x04\n\x11DisciplineService\x12B\n\x07GetInfo\x12\x16.google.protobuf.Empty\x1a\x1d.philote.DisciplineProperties"\x00\x12D\n\x10SetStreamOptions\x12\x16.philote.StreamOptions\x1a\x16.google.protobuf.Empty"\x00\x12E\n\x13GetAvailableOptions\x12\x16.google.protobuf.Empty\x1a\x14.philote.OptionsList"\x00\x12B\n\nSetOptions\x12\x1a.philote.DisciplineOptions\x1a\x16.google.protobuf.Empty"\x00\x129\n\x05Setup\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty"\x00\x12O\n\x16GetVariableDefinitions\x12\x16.google.protobuf.Empty\x1a\x19.philote.VariableMetaData"\x000\x01\x12N\n\x15GetPartialDefinitions\x12\x16.google.protobuf.Empty\x1a\x19.philote.PartialsMetaData"\x000\x01\x12J\n\x11SetVariableShapes\x12\x19.philote.VariableMetaData\x1a\x16.google.protobuf.Empty"\x00(\x012\xab\x01\n\x0fExplicitService\x12K\n\x0fComputeFunction\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01\x12K\n\x0fComputeGradient\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x012\x81\x02\n\x0fImplicitService\x12L\n\x10ComputeResiduals\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01\x12J\n\x0eSolveResiduals\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01\x12T\n\x18ComputeResidualGradients\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01B\x11\n\x0forg.philote.mdob\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'disciplines_pb2', _globals) @@ -16,8 +16,8 @@ _globals['DESCRIPTOR']._loaded_options = None _globals['DESCRIPTOR']._serialized_options = b'\n\x0forg.philote.mdo' _globals['_DISCIPLINESERVICE']._serialized_start = 72 - _globals['_DISCIPLINESERVICE']._serialized_end = 588 - _globals['_EXPLICITSERVICE']._serialized_start = 591 - _globals['_EXPLICITSERVICE']._serialized_end = 762 - _globals['_IMPLICITSERVICE']._serialized_start = 765 - _globals['_IMPLICITSERVICE']._serialized_end = 1022 \ No newline at end of file + _globals['_DISCIPLINESERVICE']._serialized_end = 664 + _globals['_EXPLICITSERVICE']._serialized_start = 667 + _globals['_EXPLICITSERVICE']._serialized_end = 838 + _globals['_IMPLICITSERVICE']._serialized_start = 841 + _globals['_IMPLICITSERVICE']._serialized_end = 1098 \ No newline at end of file diff --git a/philote_mdo/generated/disciplines_pb2_grpc.py b/philote_mdo/generated/disciplines_pb2_grpc.py index 1b23738..638bd69 100644 --- a/philote_mdo/generated/disciplines_pb2_grpc.py +++ b/philote_mdo/generated/disciplines_pb2_grpc.py @@ -34,6 +34,7 @@ def __init__(self, channel): self.Setup = channel.unary_unary('/philote.DisciplineService/Setup', request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, _registered_method=True) self.GetVariableDefinitions = channel.unary_stream('/philote.DisciplineService/GetVariableDefinitions', request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, response_deserializer=data__pb2.VariableMetaData.FromString, _registered_method=True) self.GetPartialDefinitions = channel.unary_stream('/philote.DisciplineService/GetPartialDefinitions', request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, response_deserializer=data__pb2.PartialsMetaData.FromString, _registered_method=True) + self.SetVariableShapes = channel.stream_unary('/philote.DisciplineService/SetVariableShapes', request_serializer=data__pb2.VariableMetaData.SerializeToString, response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, _registered_method=True) class DisciplineServiceServicer(object): """Generic Discipline Definition @@ -91,8 +92,16 @@ def GetPartialDefinitions(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def SetVariableShapes(self, request_iterator, context): + """Sets shapes for variables flagged as dynamic_shape. + Must be called after GetVariableDefinitions and before compute RPCs. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_DisciplineServiceServicer_to_server(servicer, server): - rpc_method_handlers = {'GetInfo': grpc.unary_unary_rpc_method_handler(servicer.GetInfo, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=data__pb2.DisciplineProperties.SerializeToString), 'SetStreamOptions': grpc.unary_unary_rpc_method_handler(servicer.SetStreamOptions, request_deserializer=data__pb2.StreamOptions.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString), 'GetAvailableOptions': grpc.unary_unary_rpc_method_handler(servicer.GetAvailableOptions, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=data__pb2.OptionsList.SerializeToString), 'SetOptions': grpc.unary_unary_rpc_method_handler(servicer.SetOptions, request_deserializer=data__pb2.DisciplineOptions.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString), 'Setup': grpc.unary_unary_rpc_method_handler(servicer.Setup, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString), 'GetVariableDefinitions': grpc.unary_stream_rpc_method_handler(servicer.GetVariableDefinitions, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=data__pb2.VariableMetaData.SerializeToString), 'GetPartialDefinitions': grpc.unary_stream_rpc_method_handler(servicer.GetPartialDefinitions, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=data__pb2.PartialsMetaData.SerializeToString)} + rpc_method_handlers = {'GetInfo': grpc.unary_unary_rpc_method_handler(servicer.GetInfo, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=data__pb2.DisciplineProperties.SerializeToString), 'SetStreamOptions': grpc.unary_unary_rpc_method_handler(servicer.SetStreamOptions, request_deserializer=data__pb2.StreamOptions.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString), 'GetAvailableOptions': grpc.unary_unary_rpc_method_handler(servicer.GetAvailableOptions, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=data__pb2.OptionsList.SerializeToString), 'SetOptions': grpc.unary_unary_rpc_method_handler(servicer.SetOptions, request_deserializer=data__pb2.DisciplineOptions.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString), 'Setup': grpc.unary_unary_rpc_method_handler(servicer.Setup, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString), 'GetVariableDefinitions': grpc.unary_stream_rpc_method_handler(servicer.GetVariableDefinitions, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=data__pb2.VariableMetaData.SerializeToString), 'GetPartialDefinitions': grpc.unary_stream_rpc_method_handler(servicer.GetPartialDefinitions, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=data__pb2.PartialsMetaData.SerializeToString), 'SetVariableShapes': grpc.stream_unary_rpc_method_handler(servicer.SetVariableShapes, request_deserializer=data__pb2.VariableMetaData.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString)} generic_handler = grpc.method_handlers_generic_handler('philote.DisciplineService', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) server.add_registered_method_handlers('philote.DisciplineService', rpc_method_handlers) @@ -132,6 +141,10 @@ def GetVariableDefinitions(request, target, options=(), channel_credentials=None def GetPartialDefinitions(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): return grpc.experimental.unary_stream(request, target, '/philote.DisciplineService/GetPartialDefinitions', google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, data__pb2.PartialsMetaData.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, _registered_method=True) + @staticmethod + def SetVariableShapes(request_iterator, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): + return grpc.experimental.stream_unary(request_iterator, target, '/philote.DisciplineService/SetVariableShapes', data__pb2.VariableMetaData.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, _registered_method=True) + class ExplicitServiceStub(object): """Definition of the generic Explicit Component RPC """ diff --git a/philote_mdo/openmdao/explicit.py b/philote_mdo/openmdao/explicit.py index f80c9ef..2e02b9f 100644 --- a/philote_mdo/openmdao/explicit.py +++ b/philote_mdo/openmdao/explicit.py @@ -200,6 +200,10 @@ def setup_partials(self): the server's partial derivative metadata. The component can compute partials either analytically (if supported by the server) or via finite differencing. + If any variables were declared with ``dynamic_shape=True`` on the server, + the shapes resolved by OpenMDAO (e.g. via ``shape_by_conn``) are sent + back to the server before querying partial definitions. + The method is called automatically by OpenMDAO during component setup and should not be called directly by users. @@ -209,6 +213,7 @@ def setup_partials(self): - Both analytic and finite difference methods are supported - Sparsity patterns are preserved when available from server metadata """ + utils.send_resolved_shapes(self) utils.client_setup_partials(self) def compute(self, inputs, outputs, discrete_inputs=None, discrete_outputs=None): diff --git a/philote_mdo/openmdao/implicit.py b/philote_mdo/openmdao/implicit.py index 2176bdc..70c8c17 100644 --- a/philote_mdo/openmdao/implicit.py +++ b/philote_mdo/openmdao/implicit.py @@ -209,6 +209,10 @@ def setup_partials(self): pairs based on the server's partial derivative metadata. For implicit components, this includes both dR/dinputs and dR/doutputs terms needed for Newton-type solvers. + If any variables were declared with ``dynamic_shape=True`` on the server, + the shapes resolved by OpenMDAO (e.g. via ``shape_by_conn``) are sent + back to the server before querying partial definitions. + The method is called automatically by OpenMDAO during component setup and should not be called directly by users. @@ -218,6 +222,7 @@ def setup_partials(self): - Both dR/dinputs and dR/doutputs partials are declared - Sparsity patterns are preserved when available from server metadata """ + utils.send_resolved_shapes(self) utils.client_setup_partials(self) def apply_nonlinear(self, inputs, outputs, residuals, discrete_inputs=None, discrete_outputs=None): diff --git a/philote_mdo/openmdao/utils.py b/philote_mdo/openmdao/utils.py index ea11503..d7a599e 100644 --- a/philote_mdo/openmdao/utils.py +++ b/philote_mdo/openmdao/utils.py @@ -74,11 +74,19 @@ def client_setup(comp): else: units = var.units - if var.type == data.kInput: - comp.add_input(var.name, shape=tuple(var.shape), units=units) + if var.dynamic_shape: + # let OpenMDAO resolve the shape from connections + if var.type == data.kInput: + comp.add_input(var.name, shape_by_conn=True, units=units) - if var.type == data.kOutput: - comp.add_output(var.name, shape=tuple(var.shape), units=units) + if var.type == data.kOutput: + comp.add_output(var.name, shape_by_conn=True, units=units) + else: + if var.type == data.kInput: + comp.add_input(var.name, shape=tuple(var.shape), units=units) + + if var.type == data.kOutput: + comp.add_output(var.name, shape=tuple(var.shape), units=units) # define discrete inputs and outputs for var in comp._client._discrete_var_meta: @@ -89,6 +97,30 @@ def client_setup(comp): comp.add_discrete_output(var.name, val=None) +def send_resolved_shapes(comp): + """ + Sends resolved shapes for dynamic-shape variables back to the server. + + After OpenMDAO resolves shapes (e.g. via ``shape_by_conn``), this + function reads the resolved metadata from the component and transmits + the shapes to the remote discipline server. + """ + dynamic_shapes = [] + for var in comp._client._var_meta: + if not var.dynamic_shape: + continue + + resolved_meta = comp._var_rel2meta[var.name] + dynamic_shapes.append( + comp._client.set_variable_shape( + var.name, resolved_meta["shape"], var.type + ) + ) + + if dynamic_shapes: + comp._client.send_variable_shapes(dynamic_shapes) + + def client_setup_partials(comp): """ Sets up the partials for the OpenMDAO component. diff --git a/tests/test_dynamic_shapes.py b/tests/test_dynamic_shapes.py new file mode 100644 index 0000000..6ec15f5 --- /dev/null +++ b/tests/test_dynamic_shapes.py @@ -0,0 +1,509 @@ +# Philote-Python +# +# Copyright 2022-2025 Christopher A. Lupp +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# This work has been cleared for public release, distribution unlimited, case +# number: AFRL-2023-5713. +# +# The views expressed are those of the authors and do not reflect the +# official guidance or position of the United States Government, the +# Department of Defense or of the United States Air Force. +# +# Statement from DoD: The Appearance of external hyperlinks does not +# constitute endorsement by the United States Department of Defense (DoD) of +# the linked websites, of the information, products, or services contained +# therein. The DoD does not exercise any editorial, security, or other +# control over the information you may find at these locations. +from concurrent import futures +import unittest +from unittest.mock import Mock + +import grpc +import numpy as np +import numpy.testing as npt + +from philote_mdo.general import ( + Discipline, + DisciplineServer, + ExplicitDiscipline, + ExplicitServer, + ExplicitClient, + ImplicitDiscipline, + ImplicitServer, + ImplicitClient, +) +from philote_mdo.utils.validation import PhiloteValidationError +from philote_mdo.examples import FlexibleDiscipline +import philote_mdo.generated.data_pb2 as data + + +# --------------------------------------------------------------- +# Unit tests: Discipline base class +# --------------------------------------------------------------- +class TestDynamicShapeDiscipline(unittest.TestCase): + """Unit tests for dynamic_shape flag on the Discipline base class.""" + + def test_add_input_dynamic_shape(self): + """add_input with dynamic_shape=True stores the flag and omits shape.""" + disc = Discipline() + disc.add_input("x", dynamic_shape=True) + + self.assertEqual(len(disc._var_meta), 1) + meta = disc._var_meta[0] + self.assertEqual(meta.name, "x") + self.assertTrue(meta.dynamic_shape) + self.assertEqual(list(meta.shape), []) + + def test_add_output_dynamic_shape(self): + """add_output with dynamic_shape=True stores the flag and omits shape.""" + disc = Discipline() + disc.add_output("y", dynamic_shape=True) + + self.assertEqual(len(disc._var_meta), 1) + meta = disc._var_meta[0] + self.assertEqual(meta.name, "y") + self.assertTrue(meta.dynamic_shape) + self.assertEqual(list(meta.shape), []) + + def test_add_input_static_shape_unchanged(self): + """add_input without dynamic_shape behaves as before.""" + disc = Discipline() + disc.add_input("x", shape=(3, 2), units="m") + + meta = disc._var_meta[0] + self.assertFalse(meta.dynamic_shape) + self.assertEqual(list(meta.shape), [3, 2]) + + def test_add_output_static_shape_unchanged(self): + """add_output without dynamic_shape behaves as before.""" + disc = Discipline() + disc.add_output("y", shape=(4,)) + + meta = disc._var_meta[0] + self.assertFalse(meta.dynamic_shape) + self.assertEqual(list(meta.shape), [4]) + + def test_dynamic_shape_skips_shape_validation(self): + """dynamic_shape=True should not validate the default shape arg.""" + disc = Discipline() + # Should not raise even though we pass no valid shape + disc.add_input("x", dynamic_shape=True) + disc.add_output("y", dynamic_shape=True) + self.assertEqual(len(disc._var_meta), 2) + + def test_implicit_output_dynamic_shape_creates_residual(self): + """For implicit disciplines, dynamic output also creates a residual entry.""" + disc = Discipline() + disc._is_implicit = True + disc.add_output("y", dynamic_shape=True, units="m") + + # Should have output and residual + self.assertEqual(len(disc._var_meta), 2) + out = disc._var_meta[0] + res = disc._var_meta[1] + self.assertEqual(out.type, data.VariableType.kOutput) + self.assertTrue(out.dynamic_shape) + self.assertEqual(res.type, data.VariableType.kResidual) + self.assertTrue(res.dynamic_shape) + + +# --------------------------------------------------------------- +# Unit tests: DisciplineServer.SetVariableShapes +# --------------------------------------------------------------- +class TestSetVariableShapesRPC(unittest.TestCase): + """Unit tests for the SetVariableShapes RPC handler.""" + + def _make_server_with_dynamic_disc(self): + server = DisciplineServer() + disc = Discipline() + disc.add_input("x", dynamic_shape=True) + disc.add_output("y", dynamic_shape=True) + disc.add_input("z", shape=(2,)) # static + server._discipline = disc + return server + + def test_set_shapes_for_dynamic_variables(self): + """SetVariableShapes updates shapes on dynamic variables.""" + server = self._make_server_with_dynamic_disc() + context = Mock() + + x_meta = data.VariableMetaData( + name="x", type=data.VariableType.kInput, shape=[5] + ) + y_meta = data.VariableMetaData( + name="y", type=data.VariableType.kOutput, shape=[5] + ) + server.SetVariableShapes(iter([x_meta, y_meta]), context) + + # verify shapes were updated + for var in server._discipline._var_meta: + if var.name == "x": + self.assertEqual(list(var.shape), [5]) + if var.name == "y" and var.type == data.VariableType.kOutput: + self.assertEqual(list(var.shape), [5]) + + def test_reject_shape_for_static_variable(self): + """SetVariableShapes aborts when targeting a non-dynamic variable.""" + server = self._make_server_with_dynamic_disc() + context = Mock() + + z_meta = data.VariableMetaData( + name="z", type=data.VariableType.kInput, shape=[10] + ) + server.SetVariableShapes(iter([z_meta]), context) + context.abort.assert_called_once() + + def test_reject_unknown_variable(self): + """SetVariableShapes aborts when the variable name is not found.""" + server = self._make_server_with_dynamic_disc() + context = Mock() + + meta = data.VariableMetaData( + name="nope", type=data.VariableType.kInput, shape=[3] + ) + server.SetVariableShapes(iter([meta]), context) + context.abort.assert_called_once() + + def test_reject_invalid_shape(self): + """SetVariableShapes aborts on invalid (non-positive) shape.""" + server = self._make_server_with_dynamic_disc() + context = Mock() + + meta = data.VariableMetaData( + name="x", type=data.VariableType.kInput, shape=[-1] + ) + server.SetVariableShapes(iter([meta]), context) + context.abort.assert_called_once() + + def test_preallocate_raises_when_shape_unset(self): + """preallocate_inputs raises if a dynamic variable has no shape.""" + server = self._make_server_with_dynamic_disc() + + with self.assertRaises(PhiloteValidationError): + server.preallocate_inputs({}, {}) + + +# --------------------------------------------------------------- +# Unit tests: DisciplineClient helpers +# --------------------------------------------------------------- +class TestDynamicShapeClient(unittest.TestCase): + """Unit tests for client-side dynamic shape helpers.""" + + def test_set_variable_shape(self): + """set_variable_shape creates correct VariableMetaData.""" + client = ExplicitClient.__new__(ExplicitClient) + client._var_meta = [] + + meta = client.set_variable_shape("x", (3, 2), data.VariableType.kInput) + + self.assertEqual(meta.name, "x") + self.assertEqual(list(meta.shape), [3, 2]) + self.assertEqual(meta.type, data.VariableType.kInput) + + def test_get_dynamic_variables(self): + """get_dynamic_variables filters correctly.""" + client = ExplicitClient.__new__(ExplicitClient) + + static_var = data.VariableMetaData( + name="a", type=data.kInput, shape=[1], dynamic_shape=False + ) + dynamic_var = data.VariableMetaData( + name="b", type=data.kInput, dynamic_shape=True + ) + client._var_meta = [static_var, dynamic_var] + + result = client.get_dynamic_variables() + self.assertEqual(len(result), 1) + self.assertEqual(result[0].name, "b") + + +# --------------------------------------------------------------- +# Integration test: FlexibleDiscipline round-trip +# --------------------------------------------------------------- +class TestFlexibleDisciplineIntegration(unittest.TestCase): + """End-to-end test: dynamic shapes over gRPC.""" + + def test_flexible_compute(self): + """Client sets shapes, then computes with the FlexibleDiscipline.""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + + discipline = ExplicitServer(discipline=FlexibleDiscipline()) + discipline.attach_to_server(server) + + server.add_insecure_port("[::]:50051") + server.start() + + try: + client = ExplicitClient( + channel=grpc.insecure_channel("localhost:50051") + ) + client.send_stream_options() + client.run_setup() + client.get_variable_definitions() + + # verify dynamic_shape flags came through + dynamic = client.get_dynamic_variables() + self.assertEqual(len(dynamic), 2) + + # set shapes + shapes = [ + client.set_variable_shape("x", (4,), data.VariableType.kInput), + client.set_variable_shape( + "y", (4,), data.VariableType.kOutput + ), + ] + client.send_variable_shapes(shapes) + client.get_partials_definitions() + + # compute + inputs = {"x": np.array([1.0, 2.0, 3.0, 4.0])} + outputs = client.run_compute(inputs) + + npt.assert_array_almost_equal( + outputs["y"], [2.0, 4.0, 6.0, 8.0] + ) + finally: + server.stop(0) + + def test_flexible_compute_partials(self): + """Client sets shapes, then computes partials.""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + + discipline = ExplicitServer(discipline=FlexibleDiscipline()) + discipline.attach_to_server(server) + + server.add_insecure_port("[::]:50051") + server.start() + + try: + client = ExplicitClient( + channel=grpc.insecure_channel("localhost:50051") + ) + client.send_stream_options() + client.run_setup() + client.get_variable_definitions() + + shapes = [ + client.set_variable_shape("x", (3,), data.VariableType.kInput), + client.set_variable_shape( + "y", (3,), data.VariableType.kOutput + ), + ] + client.send_variable_shapes(shapes) + client.get_partials_definitions() + + inputs = {"x": np.array([1.0, 2.0, 3.0])} + jac = client.run_compute_partials(inputs) + + expected = 2.0 * np.eye(3) + npt.assert_array_almost_equal(jac["y", "x"], expected) + finally: + server.stop(0) + + def test_backward_compat_static_shapes(self): + """Existing static-shape disciplines work without calling SetVariableShapes.""" + from philote_mdo.examples import Paraboloid + + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + + discipline = ExplicitServer(discipline=Paraboloid()) + discipline.attach_to_server(server) + + server.add_insecure_port("[::]:50051") + server.start() + + try: + client = ExplicitClient( + channel=grpc.insecure_channel("localhost:50051") + ) + client.send_stream_options() + client.run_setup() + client.get_variable_definitions() + client.get_partials_definitions() + + # no dynamic variables + self.assertEqual(len(client.get_dynamic_variables()), 0) + + inputs = {"x": np.array([1.0]), "y": np.array([2.0])} + outputs = client.run_compute(inputs) + self.assertAlmostEqual(outputs["f_xy"][0], 39.0) + finally: + server.stop(0) + + +# --------------------------------------------------------------- +# Integration test: implicit discipline with dynamic shapes +# --------------------------------------------------------------- +class DynamicImplicit(ImplicitDiscipline): + """Implicit discipline with dynamic shapes for residual testing.""" + + def setup(self): + self.add_input("a", shape=(1,)) + self.add_output("x", dynamic_shape=True) + + def setup_partials(self): + self.declare_partials("x", "a") + self.declare_partials("x", "x") + + def compute_residuals(self, inputs, outputs, residuals): + residuals["x"] = outputs["x"] ** 2 - inputs["a"] + + def solve_residuals(self, inputs, outputs): + outputs["x"] = np.sqrt(np.abs(inputs["a"])) + + def compute_residual_partials(self, inputs, outputs, partials): + partials["x", "a"] = -np.ones(inputs["a"].shape) + partials["x", "x"] = 2.0 * np.diag(outputs["x"].ravel()) + + +class TestDynamicImplicitIntegration(unittest.TestCase): + """Integration test for implicit discipline with dynamic shapes.""" + + def test_implicit_dynamic_shape_residual(self): + """SetVariableShapes updates residual entries for implicit disciplines.""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + + discipline = ImplicitServer(discipline=DynamicImplicit()) + discipline.attach_to_server(server) + + server.add_insecure_port("[::]:50051") + server.start() + + try: + client = ImplicitClient( + channel=grpc.insecure_channel("localhost:50051") + ) + client.send_stream_options() + client.run_setup() + client.get_variable_definitions() + + # set shape for the dynamic output (and its residual) + shapes = [ + client.set_variable_shape( + "x", (2,), data.VariableType.kOutput + ), + ] + client.send_variable_shapes(shapes) + client.get_partials_definitions() + + inputs = {"a": np.array([4.0])} + outputs = {"x": np.array([3.0, 1.0])} + residuals = client.run_compute_residuals(inputs, outputs) + + # x**2 - a = [9-4, 1-4] = [5, -3] + npt.assert_array_almost_equal(residuals["x"], [5.0, -3.0]) + finally: + server.stop(0) + + +# --------------------------------------------------------------- +# Unit test: SetVariableShapes generic exception path +# --------------------------------------------------------------- +class TestSetVariableShapesGenericException(unittest.TestCase): + """Tests the generic exception handler in SetVariableShapes.""" + + def test_generic_exception_aborts(self): + server = DisciplineServer() + disc = Discipline() + disc.add_input("x", dynamic_shape=True) + server._discipline = disc + + context = Mock() + + # Craft an iterator that raises a non-validation exception + def bad_iterator(): + raise RuntimeError("unexpected error") + yield # pragma: no cover + + server.SetVariableShapes(bad_iterator(), context) + context.abort.assert_called_once() + self.assertEqual( + context.abort.call_args[0][0], grpc.StatusCode.INTERNAL + ) + + +# --------------------------------------------------------------- +# Unit test: OpenMDAO utils dynamic shape paths +# --------------------------------------------------------------- +class TestOpenMdaoUtilsDynamicShapes(unittest.TestCase): + """Tests the OpenMDAO utils functions with dynamic-shape variables.""" + + def test_client_setup_with_dynamic_input(self): + """client_setup passes shape_by_conn=True for dynamic inputs.""" + comp = Mock() + var = Mock() + var.name = "x" + var.units = "m" + var.type = data.kInput + var.shape = [] + var.dynamic_shape = True + + comp._client._var_meta = [var] + comp._client._discrete_var_meta = [] + + from philote_mdo.openmdao.utils import client_setup + + client_setup(comp) + + comp.add_input.assert_called_once_with( + "x", shape_by_conn=True, units="m" + ) + + def test_client_setup_with_dynamic_output(self): + """client_setup passes shape_by_conn=True for dynamic outputs.""" + comp = Mock() + var = Mock() + var.name = "y" + var.units = "" + var.type = data.kOutput + var.shape = [] + var.dynamic_shape = True + + comp._client._var_meta = [var] + comp._client._discrete_var_meta = [] + + from philote_mdo.openmdao.utils import client_setup + + client_setup(comp) + + comp.add_output.assert_called_once_with( + "y", shape_by_conn=True, units=None + ) + + def test_send_resolved_shapes_with_dynamic_vars(self): + """send_resolved_shapes reads OpenMDAO metadata and sends shapes.""" + comp = Mock() + var = data.VariableMetaData( + name="x", type=data.kInput, dynamic_shape=True + ) + comp._client._var_meta = [var] + comp._var_rel2meta = {"x": {"shape": (5,)}} + comp._client.set_variable_shape.return_value = data.VariableMetaData( + name="x", type=data.kInput, shape=[5] + ) + + from philote_mdo.openmdao.utils import send_resolved_shapes + + send_resolved_shapes(comp) + + comp._client.set_variable_shape.assert_called_once_with( + "x", (5,), data.kInput + ) + comp._client.send_variable_shapes.assert_called_once() + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_openmdao_explicit_client.py b/tests/test_openmdao_explicit_client.py index b97551f..b50bc59 100644 --- a/tests/test_openmdao_explicit_client.py +++ b/tests/test_openmdao_explicit_client.py @@ -167,6 +167,7 @@ def test_setup_partials( component = RemoteExplicitComponent(channel=mock_channel) component._client = Mock() component._client._partials_meta = [par1, par2] + component._client._var_meta = [] # call the function component.setup_partials() diff --git a/tests/test_openmdao_implicit_client.py b/tests/test_openmdao_implicit_client.py index 045ddcb..82255c1 100644 --- a/tests/test_openmdao_implicit_client.py +++ b/tests/test_openmdao_implicit_client.py @@ -171,6 +171,7 @@ def test_setup_partials( component = RemoteImplicitComponent(channel=mock_channel) component._client = Mock() component._client._partials_meta = [par1, par2] + component._client._var_meta = [] # call the function component.setup_partials() diff --git a/tests/test_openmdao_utils.py b/tests/test_openmdao_utils.py index 763504e..bc65eba 100644 --- a/tests/test_openmdao_utils.py +++ b/tests/test_openmdao_utils.py @@ -48,12 +48,14 @@ def test_openmdao_client_setup(self): var1.units = "m" var1.type = kInput var1.shape = [2] + var1.dynamic_shape = False var2 = Mock() var2.name = "var2" var2.units = None var2.type = kOutput var2.shape = [1] + var2.dynamic_shape = False comp._client._var_meta = [var1, var2] comp._client._discrete_var_meta = []