From f76a29aadac930609ab0053f391145a667969e83 Mon Sep 17 00:00:00 2001 From: Christopher Lupp Date: Wed, 8 Apr 2026 22:25:36 -0400 Subject: [PATCH 1/3] Add comprehensive input validation and error handling (#46) Introduce custom exception classes (PhiloteValidationError, PhiloteServerError), parameter validation in discipline base classes, gRPC error propagation via context.abort() in all server RPC methods, and client-side input validation with gRPC error wrapping. Fix missing space in RemoteImplicitComponent error message. --- CHANGELOG.md | 9 + philote_mdo/general/discipline.py | 45 +++++ philote_mdo/general/discipline_client.py | 20 +- philote_mdo/general/discipline_server.py | 85 ++++++--- philote_mdo/general/explicit_client.py | 35 ++-- philote_mdo/general/explicit_server.py | 130 +++++++------ philote_mdo/general/implicit_client.py | 77 +++++--- philote_mdo/general/implicit_server.py | 228 +++++++++++++---------- philote_mdo/openmdao/explicit.py | 4 + philote_mdo/openmdao/group.py | 39 ++++ philote_mdo/openmdao/implicit.py | 6 +- philote_mdo/openmdao/utils.py | 27 +-- philote_mdo/utils/__init__.py | 11 ++ philote_mdo/utils/helper.py | 18 ++ philote_mdo/utils/pair_dict.py | 12 ++ philote_mdo/utils/validation.py | 210 +++++++++++++++++++++ tests/test_discipline.py | 95 ++++++++++ tests/test_discipline_client.py | 19 ++ tests/test_discipline_server.py | 33 ++-- tests/test_edge_cases.py | 27 +-- tests/test_explicit_client.py | 32 ++++ tests/test_explicit_server.py | 68 +++++++ tests/test_implicit_client.py | 26 +++ tests/test_implicit_server.py | 79 ++++++++ tests/test_openmdao_group.py | 41 ++++ tests/test_openmdao_utils.py | 7 +- tests/test_utils.py | 36 +++- tests/test_validation.py | 184 ++++++++++++++++++ 28 files changed, 1331 insertions(+), 272 deletions(-) create mode 100644 philote_mdo/utils/validation.py create mode 100644 tests/test_validation.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e160b23..0426c32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,10 +20,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 the new `VariableMessage` wrapper. The OpenMDAO bindings (`RemoteExplicitComponent`, `RemoteImplicitComponent`) automatically discover and forward discrete variables. +- Added comprehensive input validation and error handling across the + framework. Introduces custom exception classes (`PhiloteValidationError`, + `PhiloteServerError`), parameter validation in discipline base classes + (`add_input`, `add_output`, `add_option`, `declare_partials`), proper + gRPC error propagation via `context.abort()` with appropriate status + codes in all server RPC methods, and client-side input validation with + gRPC error wrapping (#46). ### Bug Fixes - Fixed bare `except` to `except ImportError` in `examples/__init__.py`. +- Fixed missing space in `RemoteImplicitComponent` error message + ("will notbe" -> "will not be"). - Fixed `SellarMDA` promoted-input ambiguity that newer OpenMDAO releases reject during `final_setup`. The `x` and `z` defaults were being set on the inner `cycle` subgroup, but `obj_cmp` promoted the same variables diff --git a/philote_mdo/general/discipline.py b/philote_mdo/general/discipline.py index ae1f020..ed8414d 100644 --- a/philote_mdo/general/discipline.py +++ b/philote_mdo/general/discipline.py @@ -28,6 +28,13 @@ # therein. The DoD does not exercise any editorial, security, or other # control over the information you may find at these locations. import philote_mdo.generated.data_pb2 as data +from philote_mdo.utils.validation import ( + validate_name, + validate_shape, + validate_units, + validate_option_type, + PhiloteValidationError, +) class Discipline: @@ -71,6 +78,12 @@ def add_option(self, name, type): the data type of the option. acceptable types are 'bool', 'int', 'float', 'str', 'dict' """ + validate_name(name, "add_option") + validate_option_type(type, name) + if name in self.options_list: + raise PhiloteValidationError( + f"add_option: option '{name}' is already defined." + ) self.options_list[name] = type def add_input(self, name, shape=(1,), units=""): @@ -86,6 +99,13 @@ def add_input(self, name, shape=(1,), units=""): units : string the unit definition for the input variable """ + validate_name(name, "add_input") + validate_shape(shape, "add_input") + validate_units(units, "add_input") + if any(v.name == name and v.type == data.VariableType.kInput for v in self._var_meta): + raise PhiloteValidationError( + f"add_input: input '{name}' is already defined." + ) meta = data.VariableMetaData() meta.type = data.VariableType.kInput meta.name = name @@ -107,6 +127,14 @@ def add_discrete_input(self, name, default=None): default : object, optional the default value for the discrete input """ + validate_name(name, "add_discrete_input") + if any( + v.name == name and v.type == data.VariableType.kDiscreteInput + for v in self._discrete_var_meta + ): + raise PhiloteValidationError( + f"add_discrete_input: discrete input '{name}' is already defined." + ) meta = data.VariableMetaData() meta.type = data.VariableType.kDiscreteInput meta.name = name @@ -126,6 +154,14 @@ def add_discrete_output(self, name, default=None): default : object, optional the default value for the discrete output """ + validate_name(name, "add_discrete_output") + if any( + v.name == name and v.type == data.VariableType.kDiscreteOutput + for v in self._discrete_var_meta + ): + raise PhiloteValidationError( + f"add_discrete_output: discrete output '{name}' is already defined." + ) meta = data.VariableMetaData() meta.type = data.VariableType.kDiscreteOutput meta.name = name @@ -144,6 +180,13 @@ def add_output(self, name, shape=(1,), units=""): units : string the unit definition for the output variable """ + validate_name(name, "add_output") + validate_shape(shape, "add_output") + validate_units(units, "add_output") + if any(v.name == name and v.type == data.VariableType.kOutput for v in self._var_meta): + raise PhiloteValidationError( + f"add_output: output '{name}' is already defined." + ) out_meta = data.VariableMetaData() out_meta.type = data.VariableType.kOutput out_meta.name = name @@ -164,6 +207,8 @@ def declare_partials(self, func, var): """ Defines partials that will be determined using the analysis server. """ + validate_name(func, "declare_partials (func)") + validate_name(var, "declare_partials (var)") self._partials_meta += [data.PartialsMetaData(name=func, subname=var)] def initialize(self): diff --git a/philote_mdo/general/discipline_client.py b/philote_mdo/general/discipline_client.py index 13b206f..dfbed15 100644 --- a/philote_mdo/general/discipline_client.py +++ b/philote_mdo/general/discipline_client.py @@ -33,6 +33,11 @@ import philote_mdo.generated.disciplines_pb2_grpc as disc import philote_mdo.utils as utils from philote_mdo.general.discipline_server import _python_to_value, _value_to_python +from philote_mdo.utils.validation import ( + PhiloteValidationError, + validate_is_dict, + validate_numpy_array, +) class DisciplineClient: @@ -114,6 +119,7 @@ def send_options(self, options): ------- None """ + validate_is_dict(options, "send_options") proto_options = data.DisciplineOptions() proto_options.options.update(options) self._disc_stub.SetOptions(proto_options) @@ -158,6 +164,14 @@ def _assemble_input_messages( Both continuous and discrete inputs are wrapped in ``VariableMessage`` envelopes. """ + validate_is_dict(inputs, "_assemble_input_messages (inputs)") + for input_name, value in inputs.items(): + validate_numpy_array(value, input_name) + if outputs is not None: + validate_is_dict(outputs, "_assemble_input_messages (outputs)") + for output_name, value in outputs.items(): + validate_numpy_array(value, output_name) + messages = [] # Continuous inputs @@ -251,7 +265,7 @@ def _recover_outputs(self, responses): if len(arr.data) > 0: flat_outputs[arr.name][b:e] = arr.data else: - raise ValueError( + raise PhiloteValidationError( "Expected continuous variables, but array is empty." ) @@ -289,7 +303,7 @@ def _recover_residuals(self, responses): if len(arr.data) > 0: flat_residuals[arr.name][b:e] = arr.data else: - raise ValueError( + raise PhiloteValidationError( "Expected continuous variables, but array is empty." ) @@ -336,7 +350,7 @@ def _recover_partials(self, responses): if len(arr.data) > 0: flat_p[(arr.name, arr.subname)][b:e] = arr.data else: - raise ValueError( + raise PhiloteValidationError( "Expected continuous outputs for the " "partials, but array was empty." ) diff --git a/philote_mdo/general/discipline_server.py b/philote_mdo/general/discipline_server.py index d784b18..a7baee6 100644 --- a/philote_mdo/general/discipline_server.py +++ b/philote_mdo/general/discipline_server.py @@ -27,6 +27,7 @@ # the linked websites, of the information, products, or services contained # therein. The DoD does not exercise any editorial, security, or other # control over the information you may find at these locations. +import grpc import numpy as np import philote_mdo.generated.data_pb2 as data @@ -34,6 +35,7 @@ from google.protobuf.empty_pb2 import Empty from google.protobuf import struct_pb2 from philote_mdo.utils import PairDict, get_flattened_view +from philote_mdo.utils.validation import PhiloteValidationError class DisciplineServer(disc.DisciplineService): @@ -85,48 +87,69 @@ def GetAvailableOptions(self, request, context): """ RPC that gets the names and types of all available discipline options. """ - opts_dict = self._discipline.options_list - opts = data.OptionsList() - - for name, val in opts_dict.items(): - opts.options.append(name) - - # assign the correct data type - if val == "bool": - type = data.kBool - elif val == "int": - type = data.kInt - elif val == "float": - type = data.kDouble - elif val == "str": - type = data.kString - elif val == "dict": - type = data.kStruct - else: - raise ValueError( - "Invalid value for discipline option '{}'".format(name) - ) + try: + opts_dict = self._discipline.options_list + opts = data.OptionsList() + + for name, val in opts_dict.items(): + opts.options.append(name) + + # assign the correct data type + if val == "bool": + type = data.kBool + elif val == "int": + type = data.kInt + elif val == "float": + type = data.kDouble + elif val == "str": + type = data.kString + elif val == "dict": + type = data.kStruct + else: + raise PhiloteValidationError( + "Invalid value for discipline option '{}'".format(name) + ) - opts.type.append(type) + opts.type.append(type) - return opts + return opts + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, f"GetAvailableOptions failed: {e}" + ) def SetOptions(self, request, context): """ RPC that sets the discipline options. """ - options = request.options - self._discipline.set_options(options) - return Empty() + try: + options = request.options + self._discipline.set_options(options) + return Empty() + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, f"SetOptions failed: {e}" + ) def Setup(self, request, context): """ RPC that runs the setup function """ - self._discipline._clear_data() - self._discipline.setup() - self._discipline.setup_partials() - return Empty() + try: + self._discipline._clear_data() + self._discipline.setup() + self._discipline.setup_partials() + return Empty() + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, f"Setup failed: {e}" + ) def GetVariableDefinitions(self, request, context): """ @@ -237,7 +260,7 @@ def process_inputs( elif arr.type == data.VariableType.kOutput: flat_outputs[arr.name][b : e + 1] = arr.data else: - raise ValueError( + raise PhiloteValidationError( "Expected continuous variables but arrays were" " empty for variable %s." % (arr.name) ) diff --git a/philote_mdo/general/explicit_client.py b/philote_mdo/general/explicit_client.py index 5b278cb..b42f8e7 100644 --- a/philote_mdo/general/explicit_client.py +++ b/philote_mdo/general/explicit_client.py @@ -29,6 +29,7 @@ # control over the information you may find at these locations. import grpc from philote_mdo.general.discipline_client import DisciplineClient +from philote_mdo.utils.validation import PhiloteServerError, validate_is_dict import philote_mdo.generated.disciplines_pb2_grpc as disc @@ -59,11 +60,17 @@ def run_compute(self, inputs, discrete_inputs=None): Continuous outputs, or (continuous outputs, discrete outputs) when the server returns discrete output data. """ - messages = self._assemble_input_messages( - inputs, discrete_inputs=discrete_inputs - ) - responses = self._expl_stub.ComputeFunction(iter(messages)) - return self._recover_outputs(responses) + validate_is_dict(inputs, "run_compute (inputs)") + try: + messages = self._assemble_input_messages( + inputs, discrete_inputs=discrete_inputs + ) + responses = self._expl_stub.ComputeFunction(iter(messages)) + return self._recover_outputs(responses) + except grpc.RpcError as e: + raise PhiloteServerError( + f"Server error during run_compute: {e.details()}" + ) from e def run_compute_partials(self, inputs, discrete_inputs=None): """ @@ -77,10 +84,16 @@ def run_compute_partials(self, inputs, discrete_inputs=None): discrete_inputs : dict, optional Discrete input values. """ - messages = self._assemble_input_messages( - inputs, discrete_inputs=discrete_inputs - ) - responses = self._expl_stub.ComputeGradient(iter(messages)) - partials = self._recover_partials(responses) + validate_is_dict(inputs, "run_compute_partials (inputs)") + try: + messages = self._assemble_input_messages( + inputs, discrete_inputs=discrete_inputs + ) + responses = self._expl_stub.ComputeGradient(iter(messages)) + partials = self._recover_partials(responses) - return partials + return partials + except grpc.RpcError as e: + raise PhiloteServerError( + f"Server error during run_compute_partials: {e.details()}" + ) from e diff --git a/philote_mdo/general/explicit_server.py b/philote_mdo/general/explicit_server.py index 049e0a2..e8b8fff 100644 --- a/philote_mdo/general/explicit_server.py +++ b/philote_mdo/general/explicit_server.py @@ -27,6 +27,7 @@ # the linked websites, of the information, products, or services contained # therein. The DoD does not exercise any editorial, security, or other # control over the information you may find at these locations. +import grpc import philote_mdo.generated.disciplines_pb2_grpc as disc import philote_mdo.generated.data_pb2 as data from philote_mdo.general.discipline_server import ( @@ -34,6 +35,7 @@ _python_to_value, ) from philote_mdo.utils import get_chunk_indices +from philote_mdo.utils.validation import PhiloteValidationError class ExplicitServer(DisciplineServer, disc.ExplicitServiceServicer): @@ -55,76 +57,90 @@ def ComputeFunction(self, request_iterator, context): """ Computes the function evaluation and sends the result to the client. """ - inputs = {} - flat_inputs = {} - outputs = {} - discrete_inputs = {} - discrete_outputs = {} + try: + inputs = {} + flat_inputs = {} + outputs = {} + discrete_inputs = {} + discrete_outputs = {} - self.preallocate_inputs(inputs, flat_inputs) - discrete_inputs, _ = self.process_inputs( - request_iterator, flat_inputs, discrete_inputs=discrete_inputs - ) - - # Call compute with discrete data when discrete variables are present - if discrete_inputs or self._discipline._discrete_var_meta: - self._discipline.compute( - inputs, outputs, discrete_inputs, discrete_outputs + self.preallocate_inputs(inputs, flat_inputs) + discrete_inputs, _ = self.process_inputs( + request_iterator, flat_inputs, discrete_inputs=discrete_inputs ) - else: - self._discipline.compute(inputs, outputs) - # Stream continuous outputs - for output_name, value in outputs.items(): - for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): - yield data.VariableMessage( - continuous=data.Array( - name=output_name, - type=data.kOutput, - start=b, - end=e - 1, - data=value.ravel()[b:e], - ) + # Call compute with discrete data when discrete variables are present + if discrete_inputs or self._discipline._discrete_var_meta: + self._discipline.compute( + inputs, outputs, discrete_inputs, discrete_outputs ) + else: + self._discipline.compute(inputs, outputs) - # Stream discrete outputs - for name, value in discrete_outputs.items(): - yield data.VariableMessage( - discrete=data.DiscreteVariable( - name=name, - type=data.VariableType.kDiscreteOutput, - value=_python_to_value(value), + # Stream continuous outputs + for output_name, value in outputs.items(): + for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): + yield data.VariableMessage( + continuous=data.Array( + name=output_name, + type=data.kOutput, + start=b, + end=e - 1, + data=value.ravel()[b:e], + ) + ) + + # Stream discrete outputs + for name, value in discrete_outputs.items(): + yield data.VariableMessage( + discrete=data.DiscreteVariable( + name=name, + type=data.VariableType.kDiscreteOutput, + value=_python_to_value(value), + ) ) + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, f"ComputeFunction failed: {e}" ) def ComputeGradient(self, request_iterator, context): """ Computes the gradient evaluation and sends the result to the client. """ - inputs = {} - flat_inputs = {} - discrete_inputs = {} + try: + inputs = {} + flat_inputs = {} + discrete_inputs = {} - self.preallocate_inputs(inputs, flat_inputs) - jac = self.preallocate_partials() - discrete_inputs, _ = self.process_inputs( - request_iterator, flat_inputs, discrete_inputs=discrete_inputs - ) + self.preallocate_inputs(inputs, flat_inputs) + jac = self.preallocate_partials() + discrete_inputs, _ = self.process_inputs( + request_iterator, flat_inputs, discrete_inputs=discrete_inputs + ) - if discrete_inputs or self._discipline._discrete_var_meta: - self._discipline.compute_partials(inputs, jac, discrete_inputs) - else: - self._discipline.compute_partials(inputs, jac) + if discrete_inputs or self._discipline._discrete_var_meta: + self._discipline.compute_partials(inputs, jac, discrete_inputs) + else: + self._discipline.compute_partials(inputs, jac) - for jac, value in jac.items(): - for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): - yield data.VariableMessage( - continuous=data.Array( - name=jac[0], - subname=jac[1], - type=data.kPartial, - start=b, - end=e - 1, - data=value.ravel()[b:e], + for jac, value in jac.items(): + for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): + yield data.VariableMessage( + continuous=data.Array( + name=jac[0], + subname=jac[1], + type=data.kPartial, + start=b, + end=e - 1, + data=value.ravel()[b:e], + ) ) - ) + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, f"ComputeGradient failed: {e}" + ) diff --git a/philote_mdo/general/implicit_client.py b/philote_mdo/general/implicit_client.py index 81fb6f3..9ef8156 100644 --- a/philote_mdo/general/implicit_client.py +++ b/philote_mdo/general/implicit_client.py @@ -29,6 +29,7 @@ # control over the information you may find at these locations. import grpc from philote_mdo.general.discipline_client import DisciplineClient +from philote_mdo.utils.validation import PhiloteServerError, validate_is_dict import philote_mdo.generated.data_pb2 as data import philote_mdo.generated.disciplines_pb2_grpc as disc @@ -168,17 +169,24 @@ def run_compute_residuals( - Large arrays are automatically streamed for efficiency - This is typically used for residual evaluation during Newton iterations """ - # Assemble input messages and call server - messages = self._assemble_input_messages( - inputs, - outputs, - discrete_inputs=discrete_inputs, - discrete_outputs=discrete_outputs, - ) - responses = self._impl_stub.ComputeResiduals(iter(messages)) - residuals = self._recover_residuals(responses) - - return residuals + validate_is_dict(inputs, "run_compute_residuals (inputs)") + validate_is_dict(outputs, "run_compute_residuals (outputs)") + try: + # Assemble input messages and call server + messages = self._assemble_input_messages( + inputs, + outputs, + discrete_inputs=discrete_inputs, + discrete_outputs=discrete_outputs, + ) + responses = self._impl_stub.ComputeResiduals(iter(messages)) + residuals = self._recover_residuals(responses) + + return residuals + except grpc.RpcError as e: + raise PhiloteServerError( + f"Server error during run_compute_residuals: {e.details()}" + ) from e def run_solve_residuals(self, inputs, discrete_inputs=None): """ @@ -231,13 +239,19 @@ def run_solve_residuals(self, inputs, discrete_inputs=None): - May raise exceptions for ill-conditioned or non-convergent problems - Solution quality depends on the server's implementation and input conditioning """ - # Assemble input messages and call server - messages = self._assemble_input_messages( - inputs, discrete_inputs=discrete_inputs - ) - responses = self._impl_stub.SolveResiduals(iter(messages)) - outputs = self._recover_outputs(responses) - return outputs + validate_is_dict(inputs, "run_solve_residuals (inputs)") + try: + # Assemble input messages and call server + messages = self._assemble_input_messages( + inputs, discrete_inputs=discrete_inputs + ) + responses = self._impl_stub.SolveResiduals(iter(messages)) + outputs = self._recover_outputs(responses) + return outputs + except grpc.RpcError as e: + raise PhiloteServerError( + f"Server error during run_solve_residuals: {e.details()}" + ) from e def run_residual_gradients( self, inputs, outputs, discrete_inputs=None, discrete_outputs=None @@ -299,13 +313,20 @@ def run_residual_gradients( - Used by optimization algorithms and sensitivity analysis tools - For large problems, consider matrix-free methods if available """ - # Assemble input messages and call server - messages = self._assemble_input_messages( - inputs, - outputs, - discrete_inputs=discrete_inputs, - discrete_outputs=discrete_outputs, - ) - responses = self._impl_stub.ComputeResidualGradients(iter(messages)) - partials = self._recover_partials(responses) - return partials + validate_is_dict(inputs, "run_residual_gradients (inputs)") + validate_is_dict(outputs, "run_residual_gradients (outputs)") + try: + # Assemble input messages and call server + messages = self._assemble_input_messages( + inputs, + outputs, + discrete_inputs=discrete_inputs, + discrete_outputs=discrete_outputs, + ) + responses = self._impl_stub.ComputeResidualGradients(iter(messages)) + partials = self._recover_partials(responses) + return partials + except grpc.RpcError as e: + raise PhiloteServerError( + f"Server error during run_residual_gradients: {e.details()}" + ) from e diff --git a/philote_mdo/general/implicit_server.py b/philote_mdo/general/implicit_server.py index 973cfc8..ac7c1f7 100644 --- a/philote_mdo/general/implicit_server.py +++ b/philote_mdo/general/implicit_server.py @@ -27,11 +27,13 @@ # the linked websites, of the information, products, or services contained # therein. The DoD does not exercise any editorial, security, or other # control over the information you may find at these locations. +import grpc import philote_mdo.generated.disciplines_pb2_grpc as disc import philote_mdo.generated.data_pb2 as data import philote_mdo.general as pmdo from philote_mdo.general.discipline_server import _python_to_value from philote_mdo.utils import get_chunk_indices +from philote_mdo.utils.validation import PhiloteValidationError class ImplicitServer(pmdo.DisciplineServer, disc.ImplicitServiceServicer): @@ -154,43 +156,50 @@ def ComputeResiduals(self, request_iterator, context): - Streams results back in chunks for efficiency - This method is called automatically by the gRPC framework """ - # inputs and outputs - inputs = {} - flat_inputs = {} - outputs = {} - flat_outputs = {} - residuals = {} - discrete_inputs = {} - discrete_outputs = {} - - self.preallocate_inputs(inputs, flat_inputs, outputs, flat_outputs) - discrete_inputs, discrete_outputs = self.process_inputs( - request_iterator, - flat_inputs, - flat_outputs, - discrete_inputs=discrete_inputs, - discrete_outputs=discrete_outputs, - ) - - # Call the user-defined compute_residuals function - if discrete_inputs or self._discipline._discrete_var_meta: - self._discipline.compute_residuals( - inputs, outputs, residuals, discrete_inputs, discrete_outputs + try: + # inputs and outputs + inputs = {} + flat_inputs = {} + outputs = {} + flat_outputs = {} + residuals = {} + discrete_inputs = {} + discrete_outputs = {} + + self.preallocate_inputs(inputs, flat_inputs, outputs, flat_outputs) + discrete_inputs, discrete_outputs = self.process_inputs( + request_iterator, + flat_inputs, + flat_outputs, + discrete_inputs=discrete_inputs, + discrete_outputs=discrete_outputs, ) - else: - self._discipline.compute_residuals(inputs, outputs, residuals) - - for res_name, value in residuals.items(): - for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): - yield data.VariableMessage( - continuous=data.Array( - name=res_name, - start=b, - end=e, - type=data.kResidual, - data=value.ravel()[b:e], - ) + + # Call the user-defined compute_residuals function + if discrete_inputs or self._discipline._discrete_var_meta: + self._discipline.compute_residuals( + inputs, outputs, residuals, discrete_inputs, discrete_outputs ) + else: + self._discipline.compute_residuals(inputs, outputs, residuals) + + for res_name, value in residuals.items(): + for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): + yield data.VariableMessage( + continuous=data.Array( + name=res_name, + start=b, + end=e, + type=data.kResidual, + data=value.ravel()[b:e], + ) + ) + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, f"ComputeResiduals failed: {e}" + ) def SolveResiduals(self, request_iterator, context): """ @@ -219,38 +228,45 @@ def SolveResiduals(self, request_iterator, context): - Outputs are streamed back in chunks for large arrays - This method is called automatically by the gRPC framework """ - # inputs and outputs - inputs = {} - flat_inputs = {} - outputs = {} - flat_outputs = {} - discrete_inputs = {} - - self.preallocate_inputs(inputs, flat_inputs, outputs, flat_outputs) - discrete_inputs, _ = self.process_inputs( - request_iterator, - flat_inputs, - flat_outputs, - discrete_inputs=discrete_inputs, - ) - - # Call the user-defined solve function - if discrete_inputs or self._discipline._discrete_var_meta: - self._discipline.solve_residuals(inputs, outputs, discrete_inputs) - else: - self._discipline.solve_residuals(inputs, outputs) - - for output_name, value in outputs.items(): - for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): - yield data.VariableMessage( - continuous=data.Array( - name=output_name, - start=b, - end=e, - type=data.kOutput, - data=value.ravel()[b:e], + try: + # inputs and outputs + inputs = {} + flat_inputs = {} + outputs = {} + flat_outputs = {} + discrete_inputs = {} + + self.preallocate_inputs(inputs, flat_inputs, outputs, flat_outputs) + discrete_inputs, _ = self.process_inputs( + request_iterator, + flat_inputs, + flat_outputs, + discrete_inputs=discrete_inputs, + ) + + # Call the user-defined solve function + if discrete_inputs or self._discipline._discrete_var_meta: + self._discipline.solve_residuals(inputs, outputs, discrete_inputs) + else: + self._discipline.solve_residuals(inputs, outputs) + + for output_name, value in outputs.items(): + for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): + yield data.VariableMessage( + continuous=data.Array( + name=output_name, + start=b, + end=e, + type=data.kOutput, + data=value.ravel()[b:e], + ) ) - ) + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, f"SolveResiduals failed: {e}" + ) def ComputeResidualGradients(self, request_iterator, context): """ @@ -279,44 +295,52 @@ def ComputeResidualGradients(self, request_iterator, context): - Used for gradient-based optimization and sensitivity analysis - This method is called automatically by the gRPC framework """ - # inputs and outputs - inputs = {} - flat_inputs = {} - outputs = {} - flat_outputs = {} - discrete_inputs = {} - discrete_outputs = {} - - self.preallocate_inputs(inputs, flat_inputs, outputs, flat_outputs) - jac = self.preallocate_partials() - discrete_inputs, discrete_outputs = self.process_inputs( - request_iterator, - flat_inputs, - flat_outputs, - discrete_inputs=discrete_inputs, - discrete_outputs=discrete_outputs, - ) - - # Call the user-defined residual partials function - if discrete_inputs or self._discipline._discrete_var_meta: - self._discipline.residual_partials( - inputs, outputs, jac, discrete_inputs, discrete_outputs + try: + # inputs and outputs + inputs = {} + flat_inputs = {} + outputs = {} + flat_outputs = {} + discrete_inputs = {} + discrete_outputs = {} + + self.preallocate_inputs(inputs, flat_inputs, outputs, flat_outputs) + jac = self.preallocate_partials() + discrete_inputs, discrete_outputs = self.process_inputs( + request_iterator, + flat_inputs, + flat_outputs, + discrete_inputs=discrete_inputs, + discrete_outputs=discrete_outputs, ) - else: - self._discipline.residual_partials(inputs, outputs, jac) - - for jac, value in jac.items(): - for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): - yield data.VariableMessage( - continuous=data.Array( - name=jac[0], - subname=jac[1], - type=data.kPartial, - start=b, - end=e, - data=value.ravel()[b:e], - ) + + # Call the user-defined residual partials function + if discrete_inputs or self._discipline._discrete_var_meta: + self._discipline.residual_partials( + inputs, outputs, jac, discrete_inputs, discrete_outputs ) + else: + self._discipline.residual_partials(inputs, outputs, jac) + + for jac, value in jac.items(): + for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): + yield data.VariableMessage( + continuous=data.Array( + name=jac[0], + subname=jac[1], + type=data.kPartial, + start=b, + end=e, + data=value.ravel()[b:e], + ) + ) + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, + f"ComputeResidualGradients failed: {e}", + ) # def MatrixFreeGradients(self, request_iterator, context): # """ diff --git a/philote_mdo/openmdao/explicit.py b/philote_mdo/openmdao/explicit.py index 458b7e6..f80c9ef 100644 --- a/philote_mdo/openmdao/explicit.py +++ b/philote_mdo/openmdao/explicit.py @@ -128,6 +128,10 @@ def __init__(self, channel=None, num_par_fd=1, **kwargs): raise ValueError( "No channel provided, the Philote client will not be able to connect." ) + if not isinstance(num_par_fd, int) or num_par_fd < 1: + raise ValueError( + f"num_par_fd must be a positive integer, got {num_par_fd!r}." + ) # generic Philote client # The setting of OpenMDAO options requires the list of available diff --git a/philote_mdo/openmdao/group.py b/philote_mdo/openmdao/group.py index 4ca6c54..f27f9d9 100644 --- a/philote_mdo/openmdao/group.py +++ b/philote_mdo/openmdao/group.py @@ -29,6 +29,12 @@ # control over the information you may find at these locations. import openmdao.api as om import philote_mdo.general as pm +from philote_mdo.utils.validation import ( + PhiloteValidationError, + validate_name, + validate_shape, + validate_units, +) class OpenMdaoSubProblem(pm.ExplicitDiscipline): @@ -115,6 +121,11 @@ def add_group(self, group): >>> subprob = OpenMdaoSubProblem() >>> subprob.add_group(MyGroup()) """ + if not isinstance(group, om.Group): + raise PhiloteValidationError( + f"add_group: expected an om.Group instance, " + f"got {type(group).__name__}." + ) self._prob = om.Problem(model=group) self._model = self._prob.model @@ -142,6 +153,14 @@ def add_mapped_input(self, local_var, subprob_var, shape=(1,), units=""): >>> subprob.add_mapped_input('x_local', 'x', shape=(1,), units='m') >>> subprob.add_mapped_input('design_vars', 'z', shape=(2,), units='') """ + validate_name(local_var, "add_mapped_input (local_var)") + validate_name(subprob_var, "add_mapped_input (subprob_var)") + validate_shape(shape, "add_mapped_input") + validate_units(units, "add_mapped_input") + if local_var in self._input_map: + raise PhiloteValidationError( + f"add_mapped_input: '{local_var}' is already mapped." + ) self._input_map[local_var] = { "sub_prob_name": subprob_var, "shape": shape, @@ -172,6 +191,14 @@ def add_mapped_output(self, local_var, subprob_var, shape=(1,), units=""): >>> subprob.add_mapped_output('objective', 'obj', shape=(1,), units='') >>> subprob.add_mapped_output('constraint1', 'con1', shape=(1,), units='N') """ + validate_name(local_var, "add_mapped_output (local_var)") + validate_name(subprob_var, "add_mapped_output (subprob_var)") + validate_shape(shape, "add_mapped_output") + validate_units(units, "add_mapped_output") + if local_var in self._output_map: + raise PhiloteValidationError( + f"add_mapped_output: '{local_var}' is already mapped." + ) self._output_map[local_var] = { "sub_prob_name": subprob_var, "shape": shape, @@ -222,6 +249,18 @@ def declare_subproblem_partial(self, local_func, local_var): >>> subprob.add_mapped_output('y_local', 'y') >>> subprob.declare_subproblem_partial('y_local', 'x_local') """ + validate_name(local_func, "declare_subproblem_partial (local_func)") + validate_name(local_var, "declare_subproblem_partial (local_var)") + if local_func not in self._output_map: + raise PhiloteValidationError( + f"declare_subproblem_partial: output '{local_func}' has not " + f"been mapped. Call add_mapped_output first." + ) + if local_var not in self._input_map: + raise PhiloteValidationError( + f"declare_subproblem_partial: input '{local_var}' has not " + f"been mapped. Call add_mapped_input first." + ) self._partials_map[(local_func, local_var)] = ( self._output_map[local_func]["sub_prob_name"], self._input_map[local_var]["sub_prob_name"], diff --git a/philote_mdo/openmdao/implicit.py b/philote_mdo/openmdao/implicit.py index 1aa8843..2176bdc 100644 --- a/philote_mdo/openmdao/implicit.py +++ b/philote_mdo/openmdao/implicit.py @@ -132,9 +132,13 @@ def __init__(self, channel=None, num_par_fd=1, **kwargs): """ if not channel: raise ValueError( - "No channel provided, the Philote client will not" + "No channel provided, the Philote client will not " "be able to connect." ) + if not isinstance(num_par_fd, int) or num_par_fd < 1: + raise ValueError( + f"num_par_fd must be a positive integer, got {num_par_fd!r}." + ) # generic Philote client # The setting of OpenMDAO options requires the list of available diff --git a/philote_mdo/openmdao/utils.py b/philote_mdo/openmdao/utils.py index 3b95870..ea11503 100644 --- a/philote_mdo/openmdao/utils.py +++ b/philote_mdo/openmdao/utils.py @@ -28,6 +28,16 @@ # therein. The DoD does not exercise any editorial, security, or other # control over the information you may find at these locations. import philote_mdo.generated.data_pb2 as data +from philote_mdo.utils.validation import PhiloteValidationError + + +_TYPE_MAP = { + "bool": bool, + "int": int, + "float": float, + "str": str, + "dict": dict, +} def declare_options(opt_list, options): @@ -35,17 +45,12 @@ def declare_options(opt_list, options): Declares the options from the client options list. """ for name, type_str in opt_list: - opt_type = None - if type_str == "bool": - opt_type = bool - elif type_str == "int": - opt_type = int - elif type_str == "float": - opt_type = float - elif type_str == "str": - opt_type = str - elif type_str == "dict": - opt_type = dict + opt_type = _TYPE_MAP.get(type_str) + if opt_type is None: + raise PhiloteValidationError( + f"declare_options: unknown option type '{type_str}' " + f"for option '{name}'." + ) options.declare(name, types=opt_type) diff --git a/philote_mdo/utils/__init__.py b/philote_mdo/utils/__init__.py index 2e99b77..db801d8 100644 --- a/philote_mdo/utils/__init__.py +++ b/philote_mdo/utils/__init__.py @@ -29,3 +29,14 @@ # control over the information you may find at these locations. from .pair_dict import PairDict from .helper import get_chunk_indices, get_flattened_view +from .validation import ( + PhiloteError, + PhiloteValidationError, + PhiloteServerError, + validate_name, + validate_shape, + validate_units, + validate_option_type, + validate_is_dict, + validate_numpy_array, +) diff --git a/philote_mdo/utils/helper.py b/philote_mdo/utils/helper.py index 67bb39f..eb2a147 100644 --- a/philote_mdo/utils/helper.py +++ b/philote_mdo/utils/helper.py @@ -29,8 +29,21 @@ # control over the information you may find at these locations. import numpy as np +from philote_mdo.utils.validation import PhiloteValidationError + def get_chunk_indices(num_values, chunk_size): + if not isinstance(num_values, (int, np.integer)) or num_values < 0: + raise PhiloteValidationError( + f"get_chunk_indices: num_values must be a non-negative integer, " + f"got {num_values!r}." + ) + if not isinstance(chunk_size, (int, np.integer)) or chunk_size < 1: + raise PhiloteValidationError( + f"get_chunk_indices: chunk_size must be a positive integer, " + f"got {chunk_size!r}." + ) + beg_i = np.arange(0, num_values, chunk_size) if beg_i.size == 1: @@ -48,6 +61,11 @@ def get_flattened_view(arr): :param arr: Array to get a flattened view :return: A view of the input array, guaranteed to not be a copy """ + if not isinstance(arr, np.ndarray): + raise PhiloteValidationError( + f"get_flattened_view: expected a numpy ndarray, " + f"got {type(arr).__name__}." + ) flat_view = arr.view() flat_view.shape = -1 return flat_view diff --git a/philote_mdo/utils/pair_dict.py b/philote_mdo/utils/pair_dict.py index d72fdfc..bae5b9d 100644 --- a/philote_mdo/utils/pair_dict.py +++ b/philote_mdo/utils/pair_dict.py @@ -29,15 +29,27 @@ # control over the information you may find at these locations. +from philote_mdo.utils.validation import PhiloteValidationError + + class PairDict(dict): """ Jacobian dictionary for storing values with respect to two keys. """ + @staticmethod + def _validate_key(keys): + if not isinstance(keys, tuple) or len(keys) != 2: + raise PhiloteValidationError( + f"PairDict keys must be a 2-tuple, got {keys!r}." + ) + def __setitem__(self, keys, value): + self._validate_key(keys) key1, key2 = keys super().__setitem__((key1, key2), value) def __getitem__(self, keys): + self._validate_key(keys) key1, key2 = keys return super().__getitem__((key1, key2)) diff --git a/philote_mdo/utils/validation.py b/philote_mdo/utils/validation.py new file mode 100644 index 0000000..75395f2 --- /dev/null +++ b/philote_mdo/utils/validation.py @@ -0,0 +1,210 @@ +# Philote-Python +# +# Copyright 2022-2025 Christopher A. Lupp +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# This work has been cleared for public release, distribution unlimited, case +# number: AFRL-2023-5713. +# +# The views expressed are those of the authors and do not reflect the +# official guidance or position of the United States Government, the +# Department of Defense or of the United States Air Force. +# +# Statement from DoD: The Appearance of external hyperlinks does not +# constitute endorsement by the United States Department of Defense (DoD) of +# the linked websites, of the information, products, or services contained +# therein. The DoD does not exercise any editorial, security, or other +# control over the information you may find at these locations. +import numpy as np + + +# --------------------------------------------------------------------------- +# Custom exception hierarchy +# --------------------------------------------------------------------------- + + +class PhiloteError(Exception): + """Base class for all Philote-specific errors.""" + + pass + + +class PhiloteValidationError(PhiloteError, ValueError): + """Raised when an input fails validation. + + Inherits from ``ValueError`` so that existing ``except ValueError`` + handlers in user code continue to work. + """ + + pass + + +class PhiloteServerError(PhiloteError, RuntimeError): + """Raised on the client side when a gRPC server call fails. + + Wraps the gRPC error details into a framework-specific exception so + that users do not need to catch ``grpc.RpcError`` directly. + """ + + pass + + +# --------------------------------------------------------------------------- +# Validation helpers +# --------------------------------------------------------------------------- + +VALID_OPTION_TYPES = {"bool", "int", "float", "str", "dict"} + + +def validate_name(name, context): + """Validate that *name* is a non-empty string. + + Parameters + ---------- + name : object + The name to validate. + context : str + Human-readable context for error messages (e.g. ``"add_input"``). + + Raises + ------ + PhiloteValidationError + If *name* is not a string or is empty. + """ + if not isinstance(name, str): + raise PhiloteValidationError( + f"{context}: 'name' must be a string, got {type(name).__name__}." + ) + if not name: + raise PhiloteValidationError(f"{context}: 'name' must not be empty.") + + +def validate_shape(shape, context): + """Validate that *shape* is a tuple of positive integers. + + Parameters + ---------- + shape : object + The shape to validate. + context : str + Human-readable context for error messages. + + Raises + ------ + PhiloteValidationError + If *shape* is not a tuple, or contains non-positive / non-integer + elements. + """ + if not isinstance(shape, tuple): + raise PhiloteValidationError( + f"{context}: 'shape' must be a tuple, got {type(shape).__name__}." + ) + for i, dim in enumerate(shape): + if not isinstance(dim, int): + raise PhiloteValidationError( + f"{context}: all elements of 'shape' must be integers, " + f"but element {i} is {type(dim).__name__}." + ) + if dim <= 0: + raise PhiloteValidationError( + f"{context}: all elements of 'shape' must be positive, " + f"but element {i} is {dim}." + ) + + +def validate_units(units, context): + """Validate that *units* is a string. + + Parameters + ---------- + units : object + The units string to validate. + context : str + Human-readable context for error messages. + + Raises + ------ + PhiloteValidationError + If *units* is not a string. + """ + if not isinstance(units, str): + raise PhiloteValidationError( + f"{context}: 'units' must be a string, got {type(units).__name__}." + ) + + +def validate_option_type(type_str, name): + """Validate that *type_str* is one of the allowed option types. + + Parameters + ---------- + type_str : object + The option type string to validate. + name : str + The option name (for error messages). + + Raises + ------ + PhiloteValidationError + If *type_str* is not in the allowed set. + """ + if type_str not in VALID_OPTION_TYPES: + raise PhiloteValidationError( + f"Invalid type '{type_str}' for option '{name}'. " + f"Allowed types are: {sorted(VALID_OPTION_TYPES)}." + ) + + +def validate_is_dict(obj, context): + """Validate that *obj* is a dictionary. + + Parameters + ---------- + obj : object + The object to validate. + context : str + Human-readable context for error messages. + + Raises + ------ + PhiloteValidationError + If *obj* is not a ``dict``. + """ + if not isinstance(obj, dict): + raise PhiloteValidationError( + f"{context}: expected a dict, got {type(obj).__name__}." + ) + + +def validate_numpy_array(value, name): + """Validate that *value* is a NumPy ndarray. + + Parameters + ---------- + value : object + The value to validate. + name : str + Variable name (for error messages). + + Raises + ------ + PhiloteValidationError + If *value* is not a ``numpy.ndarray``. + """ + if not isinstance(value, np.ndarray): + raise PhiloteValidationError( + f"Variable '{name}' must be a numpy ndarray, " + f"got {type(value).__name__}." + ) diff --git a/tests/test_discipline.py b/tests/test_discipline.py index 6dbf5a6..a1b0e8e 100644 --- a/tests/test_discipline.py +++ b/tests/test_discipline.py @@ -31,6 +31,7 @@ from unittest.mock import Mock from philote_mdo.general import Discipline +from philote_mdo.utils.validation import PhiloteValidationError import philote_mdo.generated.data_pb2 as data @@ -172,6 +173,100 @@ def test_clear_data(self): self.assertEqual(len(disc._var_meta), 0) self.assertEqual(len(disc._partials_meta), 0) + # ------------------------------------------------------------------ + # Validation tests + # ------------------------------------------------------------------ + + def test_add_input_invalid_name_type(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_input(123) + + def test_add_input_empty_name(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_input("") + + def test_add_input_invalid_shape(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_input("x", shape=[2, 3]) + + def test_add_input_invalid_units(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_input("x", units=42) + + def test_add_input_duplicate(self): + disc = Discipline() + disc.add_input("x", shape=(2,)) + with self.assertRaises(PhiloteValidationError): + disc.add_input("x", shape=(3,)) + + def test_add_output_invalid_name(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_output(None) + + def test_add_output_invalid_shape(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_output("y", shape=(-1,)) + + def test_add_output_duplicate(self): + disc = Discipline() + disc.add_output("y") + with self.assertRaises(PhiloteValidationError): + disc.add_output("y") + + def test_add_option_invalid_name(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_option(42, "int") + + def test_add_option_invalid_type(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_option("opt", "unknown") + + def test_add_option_duplicate(self): + disc = Discipline() + disc.add_option("opt", "int") + with self.assertRaises(PhiloteValidationError): + disc.add_option("opt", "float") + + def test_add_discrete_input_invalid_name(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_discrete_input("") + + def test_add_discrete_input_duplicate(self): + disc = Discipline() + disc.add_discrete_input("d") + with self.assertRaises(PhiloteValidationError): + disc.add_discrete_input("d") + + def test_add_discrete_output_invalid_name(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_discrete_output(123) + + def test_add_discrete_output_duplicate(self): + disc = Discipline() + disc.add_discrete_output("d") + with self.assertRaises(PhiloteValidationError): + disc.add_discrete_output("d") + + def test_declare_partials_invalid_func(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.declare_partials("", "x") + + def test_declare_partials_invalid_var(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.declare_partials("f", 123) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_discipline_client.py b/tests/test_discipline_client.py index d7be559..62d2bae 100644 --- a/tests/test_discipline_client.py +++ b/tests/test_discipline_client.py @@ -33,6 +33,7 @@ from google.protobuf.empty_pb2 import Empty from google.protobuf.struct_pb2 import Struct from philote_mdo.general import DisciplineClient +from philote_mdo.utils.validation import PhiloteValidationError import philote_mdo.generated.data_pb2 as data import philote_mdo.generated.disciplines_pb2_grpc as disc import philote_mdo.utils as utils @@ -556,6 +557,24 @@ def test_recover_partials_empty_array_raises_error(self): self.assertIn("Expected continuous outputs for the partials, but array was empty", str(context.exception)) + def test_send_options_non_dict_raises(self): + mock_channel = Mock() + client = DisciplineClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client.send_options("not a dict") + + def test_assemble_input_messages_non_dict_raises(self): + mock_channel = Mock() + client = DisciplineClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client._assemble_input_messages("not a dict") + + def test_assemble_input_messages_non_array_value_raises(self): + mock_channel = Mock() + client = DisciplineClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client._assemble_input_messages({"x": [1.0, 2.0]}) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_discipline_server.py b/tests/test_discipline_server.py index 2f30542..907b8d6 100644 --- a/tests/test_discipline_server.py +++ b/tests/test_discipline_server.py @@ -30,11 +30,13 @@ import unittest from unittest.mock import Mock +import grpc import numpy as np from google.protobuf.empty_pb2 import Empty from philote_mdo.general import Discipline, DisciplineServer +from philote_mdo.utils.validation import PhiloteValidationError import philote_mdo.generated.data_pb2 as data @@ -411,27 +413,32 @@ def test_set_options_with_nested_dict(self): } ) - def test_get_available_options_invalid_type_raises_error(self): + def test_get_available_options_invalid_type_aborts(self): """ - Tests that GetAvailableOptions raises ValueError for invalid option types. + Tests that GetAvailableOptions calls context.abort for invalid option + types. """ server = DisciplineServer() discipline = server._discipline = Discipline() - - # Add option with invalid type - discipline.add_option("invalid_option", "unknown_type") - + + # Add option with invalid type (bypasses add_option validation by + # writing directly to options_list) + discipline.options_list["invalid_option"] = "unknown_type" + request = Empty() context = Mock() - - with self.assertRaises(ValueError) as error_context: - server.GetAvailableOptions(request, context) - - self.assertIn("Invalid value for discipline option 'invalid_option'", str(error_context.exception)) + + server.GetAvailableOptions(request, context) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("Invalid value for discipline option 'invalid_option'", args[0][1]) def test_process_inputs_empty_array_raises_error(self): """ - Tests that process_inputs raises ValueError when array data is empty. + Tests that process_inputs raises PhiloteValidationError when array + data is empty. """ server = DisciplineServer() @@ -451,7 +458,7 @@ def test_process_inputs_empty_array_raises_error(self): flat_inputs = {"x": np.zeros(3)} flat_outputs = {} - with self.assertRaises(ValueError) as context: + with self.assertRaises(PhiloteValidationError) as context: server.process_inputs(request_iterator, flat_inputs, flat_outputs) self.assertIn("Expected continuous variables but arrays were empty for variable x", str(context.exception)) diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index 0afc37b..87d70d7 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -29,7 +29,11 @@ # control over the information you may find at these locations. import unittest from unittest.mock import Mock, MagicMock + +import grpc + from philote_mdo.general import DisciplineServer, DisciplineClient, ExplicitDiscipline +from philote_mdo.utils.validation import PhiloteValidationError import philote_mdo.generated.data_pb2 as data @@ -95,26 +99,27 @@ def test_get_available_options_with_dict_type(self): def test_get_available_options_with_invalid_type(self): """ - Test GetAvailableOptions with invalid option type (covers lines 100-103). + Test GetAvailableOptions with invalid option type aborts with + INVALID_ARGUMENT. """ server = DisciplineServer() discipline = Mock() - + # Mock the options_list attribute to return a dict with invalid type discipline.options_list = {"invalid_option": "invalid_type"} - + server.attach_discipline(discipline) - + # Create a mock request and context request = Mock() context = Mock() - - # This should raise a ValueError - with self.assertRaises(ValueError) as context_err: - server.GetAvailableOptions(request, context) - - self.assertIn("Invalid value for discipline option", str(context_err.exception)) - self.assertIn("invalid_option", str(context_err.exception)) + + server.GetAvailableOptions(request, context) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("Invalid value for discipline option", args[0][1]) def test_process_inputs_with_empty_continuous_data(self): """ diff --git a/tests/test_explicit_client.py b/tests/test_explicit_client.py index 59bd35b..4a4421c 100644 --- a/tests/test_explicit_client.py +++ b/tests/test_explicit_client.py @@ -30,9 +30,11 @@ import unittest from unittest.mock import Mock, patch +import grpc import numpy as np from philote_mdo.general import ExplicitClient +from philote_mdo.utils.validation import PhiloteValidationError, PhiloteServerError import philote_mdo.generated.data_pb2 as data import philote_mdo.utils as utils @@ -135,3 +137,33 @@ def test_compute_partials(self, mock_explicit_stub): for output_name, expected_data in expected_outputs.items(): self.assertTrue(output_name in outputs) np.testing.assert_array_equal(outputs[output_name], expected_data) + + def test_run_compute_non_dict_raises(self): + mock_channel = Mock() + client = ExplicitClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client.run_compute("not a dict") + + @patch("philote_mdo.generated.disciplines_pb2_grpc.ExplicitServiceStub") + def test_run_compute_grpc_error_wraps(self, mock_explicit_stub): + mock_channel = Mock() + mock_stub = mock_explicit_stub.return_value + client = ExplicitClient(mock_channel) + client._var_meta = [ + data.VariableMetaData(name="x", type=data.kInput, shape=(1,)), + ] + + rpc_error = grpc.RpcError() + rpc_error.details = lambda: "server crashed" + rpc_error.code = lambda: grpc.StatusCode.INTERNAL + mock_stub.ComputeFunction.side_effect = rpc_error + + with self.assertRaises(PhiloteServerError) as ctx: + client.run_compute({"x": np.array([1.0])}) + self.assertIn("server crashed", str(ctx.exception)) + + def test_run_compute_partials_non_dict_raises(self): + mock_channel = Mock() + client = ExplicitClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client.run_compute_partials(42) diff --git a/tests/test_explicit_server.py b/tests/test_explicit_server.py index 0416ec0..cb1c5e1 100644 --- a/tests/test_explicit_server.py +++ b/tests/test_explicit_server.py @@ -30,6 +30,7 @@ import unittest from unittest.mock import Mock +import grpc import numpy as np from scipy.optimize import rosen, rosen_der @@ -146,5 +147,72 @@ def compute_partials(inputs, jac): ) + def test_compute_function_aborts_on_discipline_error(self): + """ + Tests that ComputeFunction calls context.abort when the discipline's + compute raises an exception. + """ + server = ExplicitServer() + discipline = server._discipline = ExplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + ] + + def bad_compute(inputs, outputs): + raise RuntimeError("division by zero in compute") + + server._discipline.compute = bad_compute + + # Exhaust the generator + list(server.ComputeFunction(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INTERNAL) + self.assertIn("ComputeFunction failed", args[0][1]) + + def test_compute_gradient_aborts_on_discipline_error(self): + """ + Tests that ComputeGradient calls context.abort when the discipline's + compute_partials raises an exception. + """ + server = ExplicitServer() + discipline = server._discipline = ExplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + discipline.declare_partials("f", "x") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + ] + + def bad_partials(inputs, jac): + raise RuntimeError("singular matrix") + + server._discipline.compute_partials = bad_partials + + list(server.ComputeGradient(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INTERNAL) + self.assertIn("ComputeGradient failed", args[0][1]) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_implicit_client.py b/tests/test_implicit_client.py index 3ec34a9..a652872 100644 --- a/tests/test_implicit_client.py +++ b/tests/test_implicit_client.py @@ -33,6 +33,7 @@ import numpy as np from philote_mdo.general import ImplicitClient +from philote_mdo.utils.validation import PhiloteValidationError, PhiloteServerError import philote_mdo.generated.data_pb2 as data import philote_mdo.utils as utils @@ -202,5 +203,30 @@ def test_residual_partials(self, mock_implicit_stub): np.testing.assert_array_equal(partials[key], expected_data) + def test_run_compute_residuals_non_dict_inputs_raises(self): + mock_channel = Mock() + client = ImplicitClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client.run_compute_residuals("not a dict", {"f": np.array([1.0])}) + + def test_run_compute_residuals_non_dict_outputs_raises(self): + mock_channel = Mock() + client = ImplicitClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client.run_compute_residuals({"x": np.array([1.0])}, "not a dict") + + def test_run_solve_residuals_non_dict_raises(self): + mock_channel = Mock() + client = ImplicitClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client.run_solve_residuals(42) + + def test_run_residual_gradients_non_dict_raises(self): + mock_channel = Mock() + client = ImplicitClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client.run_residual_gradients("bad", {"f": np.array([1.0])}) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_implicit_server.py b/tests/test_implicit_server.py index 21e7c82..901a588 100644 --- a/tests/test_implicit_server.py +++ b/tests/test_implicit_server.py @@ -29,6 +29,8 @@ # control over the information you may find at these locations. import unittest from unittest.mock import Mock + +import grpc import numpy as np import numpy.testing as np_testing from google.protobuf.empty_pb2 import Empty @@ -206,5 +208,82 @@ def residual_partials(inputs, residuals, jac): ) + def test_compute_residuals_aborts_on_discipline_error(self): + """ + Tests that ComputeResiduals calls context.abort when the discipline's + compute_residuals raises an exception. + """ + server = ImplicitServer() + discipline = server._discipline = ImplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kOutput, name="f", + ) + ), + ] + + def bad_residuals(inputs, outputs, residuals): + raise RuntimeError("residual computation failed") + + server._discipline.compute_residuals = bad_residuals + + list(server.ComputeResiduals(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INTERNAL) + self.assertIn("ComputeResiduals failed", args[0][1]) + + def test_solve_residuals_aborts_on_discipline_error(self): + """ + Tests that SolveResiduals calls context.abort when the discipline's + solve_residuals raises an exception. + """ + server = ImplicitServer() + discipline = server._discipline = ImplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kOutput, name="f", + ) + ), + ] + + def bad_solve(inputs, outputs): + raise RuntimeError("solver did not converge") + + server._discipline.solve_residuals = bad_solve + + list(server.SolveResiduals(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INTERNAL) + self.assertIn("SolveResiduals failed", args[0][1]) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_openmdao_group.py b/tests/test_openmdao_group.py index 016831a..c78b0fa 100644 --- a/tests/test_openmdao_group.py +++ b/tests/test_openmdao_group.py @@ -31,6 +31,7 @@ import numpy as np import openmdao.api as om from philote_mdo.openmdao.group import OpenMdaoSubProblem +from philote_mdo.utils.validation import PhiloteValidationError class SimpleGroup(om.Group): @@ -257,5 +258,45 @@ def test_units_and_shapes(self): self.assertEqual(subprob._output_map['vector_out']['units'], 'kg') + def test_add_group_non_group_raises(self): + subprob = OpenMdaoSubProblem() + with self.assertRaises(PhiloteValidationError): + subprob.add_group("not a group") + + def test_add_mapped_input_invalid_name_raises(self): + subprob = OpenMdaoSubProblem() + with self.assertRaises(PhiloteValidationError): + subprob.add_mapped_input("", "x") + + def test_add_mapped_input_invalid_shape_raises(self): + subprob = OpenMdaoSubProblem() + with self.assertRaises(PhiloteValidationError): + subprob.add_mapped_input("x", "sub_x", shape=[2]) + + def test_add_mapped_input_duplicate_raises(self): + subprob = OpenMdaoSubProblem() + subprob.add_mapped_input("x", "sub_x") + with self.assertRaises(PhiloteValidationError): + subprob.add_mapped_input("x", "sub_x2") + + def test_add_mapped_output_duplicate_raises(self): + subprob = OpenMdaoSubProblem() + subprob.add_mapped_output("y", "sub_y") + with self.assertRaises(PhiloteValidationError): + subprob.add_mapped_output("y", "sub_y2") + + def test_declare_subproblem_partial_unmapped_output_raises(self): + subprob = OpenMdaoSubProblem() + subprob.add_mapped_input("x", "sub_x") + with self.assertRaises(PhiloteValidationError): + subprob.declare_subproblem_partial("unmapped_y", "x") + + def test_declare_subproblem_partial_unmapped_input_raises(self): + subprob = OpenMdaoSubProblem() + subprob.add_mapped_output("y", "sub_y") + with self.assertRaises(PhiloteValidationError): + subprob.declare_subproblem_partial("y", "unmapped_x") + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_openmdao_utils.py b/tests/test_openmdao_utils.py index 95a5acd..763504e 100644 --- a/tests/test_openmdao_utils.py +++ b/tests/test_openmdao_utils.py @@ -32,6 +32,7 @@ import numpy as np from philote_mdo.generated.data_pb2 import kInput, kOutput +from philote_mdo.utils.validation import PhiloteValidationError import philote_mdo.openmdao.utils as utils @@ -206,11 +207,11 @@ def test_declare_options(self): self.assertEqual(options_mock.declare.call_count, 5) - # Test case 3: Unknown type (should result in None) + # Test case 3: Unknown type now raises PhiloteValidationError options_mock.reset_mock() opt_list = [("unknown_param", "unknown_type")] - declare_options(opt_list, options_mock) - options_mock.declare.assert_called_once_with("unknown_param", types=None) + with self.assertRaises(PhiloteValidationError): + declare_options(opt_list, options_mock) if __name__ == "__main__": diff --git a/tests/test_utils.py b/tests/test_utils.py index ab65ff2..d262bb8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -29,7 +29,8 @@ # control over the information you may find at these locations. import unittest import numpy as np -from philote_mdo.utils import get_chunk_indices, get_flattened_view +from philote_mdo.utils import get_chunk_indices, get_flattened_view, PairDict +from philote_mdo.utils.validation import PhiloteValidationError class TestUtils(unittest.TestCase): @@ -76,5 +77,38 @@ def test_get_flattened_view(self): self.assertEqual(result_empty.shape, (0,)) + def test_get_chunk_indices_negative_num_values_raises(self): + with self.assertRaises(PhiloteValidationError): + list(get_chunk_indices(-1, 3)) + + def test_get_chunk_indices_zero_chunk_size_raises(self): + with self.assertRaises(PhiloteValidationError): + list(get_chunk_indices(10, 0)) + + def test_get_chunk_indices_non_int_raises(self): + with self.assertRaises(PhiloteValidationError): + list(get_chunk_indices(10.5, 3)) + + def test_get_flattened_view_non_array_raises(self): + with self.assertRaises(PhiloteValidationError): + get_flattened_view([1, 2, 3]) + + def test_pair_dict_invalid_key_raises(self): + pd = PairDict() + with self.assertRaises(PhiloteValidationError): + pd["single_key"] = 1.0 + + def test_pair_dict_three_tuple_raises(self): + pd = PairDict() + with self.assertRaises(PhiloteValidationError): + pd[("a", "b", "c")] = 1.0 + + def test_pair_dict_get_invalid_key_raises(self): + pd = PairDict() + pd[("a", "b")] = 1.0 + with self.assertRaises(PhiloteValidationError): + _ = pd["single_key"] + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 0000000..8d1d09c --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,184 @@ +# Philote-Python +# +# Copyright 2022-2025 Christopher A. Lupp +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# This work has been cleared for public release, distribution unlimited, case +# number: AFRL-2023-5713. +# +# The views expressed are those of the authors and do not reflect the +# official guidance or position of the United States Government, the +# Department of Defense or of the United States Air Force. +# +# Statement from DoD: The Appearance of external hyperlinks does not +# constitute endorsement by the United States Department of Defense (DoD) of +# the linked websites, of the information, products, or services contained +# therein. The DoD does not exercise any editorial, security, or other +# control over the information you may find at these locations. +import unittest + +import numpy as np + +from philote_mdo.utils.validation import ( + PhiloteError, + PhiloteValidationError, + PhiloteServerError, + validate_name, + validate_shape, + validate_units, + validate_option_type, + validate_is_dict, + validate_numpy_array, +) + + +class TestExceptionHierarchy(unittest.TestCase): + """Tests for the custom exception class hierarchy.""" + + def test_philote_validation_error_is_value_error(self): + with self.assertRaises(ValueError): + raise PhiloteValidationError("test") + + def test_philote_validation_error_is_philote_error(self): + with self.assertRaises(PhiloteError): + raise PhiloteValidationError("test") + + def test_philote_server_error_is_runtime_error(self): + with self.assertRaises(RuntimeError): + raise PhiloteServerError("test") + + def test_philote_server_error_is_philote_error(self): + with self.assertRaises(PhiloteError): + raise PhiloteServerError("test") + + def test_philote_error_is_exception(self): + with self.assertRaises(Exception): + raise PhiloteError("test") + + +class TestValidateName(unittest.TestCase): + """Tests for validate_name.""" + + def test_valid_name(self): + validate_name("x", "test") + + def test_non_string_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_name(123, "add_input") + self.assertIn("must be a string", str(ctx.exception)) + self.assertIn("add_input", str(ctx.exception)) + + def test_empty_string_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_name("", "add_output") + self.assertIn("must not be empty", str(ctx.exception)) + + def test_none_raises(self): + with self.assertRaises(PhiloteValidationError): + validate_name(None, "test") + + +class TestValidateShape(unittest.TestCase): + """Tests for validate_shape.""" + + def test_valid_shape_1d(self): + validate_shape((3,), "test") + + def test_valid_shape_2d(self): + validate_shape((2, 4), "test") + + def test_non_tuple_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_shape([2, 3], "add_input") + self.assertIn("must be a tuple", str(ctx.exception)) + + def test_int_raises(self): + with self.assertRaises(PhiloteValidationError): + validate_shape(3, "test") + + def test_non_integer_element_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_shape((2.0,), "test") + self.assertIn("must be integers", str(ctx.exception)) + + def test_zero_element_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_shape((0,), "test") + self.assertIn("must be positive", str(ctx.exception)) + + def test_negative_element_raises(self): + with self.assertRaises(PhiloteValidationError): + validate_shape((-1, 3), "test") + + +class TestValidateUnits(unittest.TestCase): + """Tests for validate_units.""" + + def test_valid_units(self): + validate_units("m**2", "test") + + def test_empty_string_is_valid(self): + validate_units("", "test") + + def test_non_string_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_units(42, "add_input") + self.assertIn("must be a string", str(ctx.exception)) + + +class TestValidateOptionType(unittest.TestCase): + """Tests for validate_option_type.""" + + def test_valid_types(self): + for t in ("bool", "int", "float", "str", "dict"): + validate_option_type(t, "opt") + + def test_invalid_type_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_option_type("unknown", "my_opt") + self.assertIn("Invalid type", str(ctx.exception)) + self.assertIn("my_opt", str(ctx.exception)) + + +class TestValidateIsDict(unittest.TestCase): + """Tests for validate_is_dict.""" + + def test_valid_dict(self): + validate_is_dict({"a": 1}, "test") + + def test_non_dict_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_is_dict([1, 2], "send_options") + self.assertIn("expected a dict", str(ctx.exception)) + + +class TestValidateNumpyArray(unittest.TestCase): + """Tests for validate_numpy_array.""" + + def test_valid_array(self): + validate_numpy_array(np.array([1.0, 2.0]), "x") + + def test_list_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_numpy_array([1.0, 2.0], "x") + self.assertIn("must be a numpy ndarray", str(ctx.exception)) + + def test_scalar_raises(self): + with self.assertRaises(PhiloteValidationError): + validate_numpy_array(1.0, "x") + + +if __name__ == "__main__": + unittest.main(verbosity=2) From c78f53ccb42f35d65771702833607a887fd767cd Mon Sep 17 00:00:00 2001 From: Christopher Lupp Date: Wed, 8 Apr 2026 22:31:51 -0400 Subject: [PATCH 2/3] Add tests to restore 100% coverage Cover all previously uncovered error handling paths: PhiloteValidationError and general Exception branches in server RPC methods, gRPC error wrapping in all client methods, and num_par_fd validation in OpenMDAO components. --- tests/test_discipline_server.py | 105 ++++++++++++++++++ tests/test_explicit_client.py | 20 ++++ tests/test_explicit_server.py | 66 ++++++++++++ tests/test_implicit_client.py | 65 +++++++++++ tests/test_implicit_server.py | 143 +++++++++++++++++++++++++ tests/test_openmdao_explicit_client.py | 9 ++ tests/test_openmdao_implicit_client.py | 9 ++ 7 files changed, 417 insertions(+) diff --git a/tests/test_discipline_server.py b/tests/test_discipline_server.py index 907b8d6..9b5c930 100644 --- a/tests/test_discipline_server.py +++ b/tests/test_discipline_server.py @@ -463,6 +463,111 @@ def test_process_inputs_empty_array_raises_error(self): self.assertIn("Expected continuous variables but arrays were empty for variable x", str(context.exception)) + def test_get_available_options_general_exception_aborts(self): + """ + Tests that GetAvailableOptions calls context.abort with INTERNAL + for unexpected exceptions. + """ + server = DisciplineServer() + discipline = Mock() + # options_list property raises an unexpected error + type(discipline).options_list = property( + lambda self: (_ for _ in ()).throw(RuntimeError("unexpected")) + ) + server._discipline = discipline + + request = Mock() + context = Mock() + + server.GetAvailableOptions(request, context) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INTERNAL) + self.assertIn("GetAvailableOptions failed", args[0][1]) + + def test_set_options_validation_error_aborts(self): + """ + Tests that SetOptions calls context.abort with INVALID_ARGUMENT + for PhiloteValidationError. + """ + server = DisciplineServer() + discipline = Mock() + discipline.set_options.side_effect = PhiloteValidationError("bad option") + server._discipline = discipline + + request = Mock() + request.options = {} + context = Mock() + + server.SetOptions(request, context) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("bad option", args[0][1]) + + def test_set_options_general_exception_aborts(self): + """ + Tests that SetOptions calls context.abort with INTERNAL for + unexpected exceptions. + """ + server = DisciplineServer() + discipline = Mock() + discipline.set_options.side_effect = RuntimeError("boom") + server._discipline = discipline + + request = Mock() + request.options = {} + context = Mock() + + server.SetOptions(request, context) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INTERNAL) + self.assertIn("SetOptions failed", args[0][1]) + + def test_setup_validation_error_aborts(self): + """ + Tests that Setup calls context.abort with INVALID_ARGUMENT + for PhiloteValidationError. + """ + server = DisciplineServer() + discipline = Mock() + discipline.setup.side_effect = PhiloteValidationError("bad setup") + server._discipline = discipline + + request = Mock() + context = Mock() + + server.Setup(request, context) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("bad setup", args[0][1]) + + def test_setup_general_exception_aborts(self): + """ + Tests that Setup calls context.abort with INTERNAL for + unexpected exceptions. + """ + server = DisciplineServer() + discipline = Mock() + discipline._clear_data.side_effect = RuntimeError("crash") + server._discipline = discipline + + request = Mock() + context = Mock() + + server.Setup(request, context) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INTERNAL) + self.assertIn("Setup failed", args[0][1]) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_explicit_client.py b/tests/test_explicit_client.py index 4a4421c..558ff01 100644 --- a/tests/test_explicit_client.py +++ b/tests/test_explicit_client.py @@ -162,6 +162,26 @@ def test_run_compute_grpc_error_wraps(self, mock_explicit_stub): client.run_compute({"x": np.array([1.0])}) self.assertIn("server crashed", str(ctx.exception)) + @patch("philote_mdo.generated.disciplines_pb2_grpc.ExplicitServiceStub") + def test_run_compute_partials_grpc_error_wraps(self, mock_explicit_stub): + mock_channel = Mock() + mock_stub = mock_explicit_stub.return_value + client = ExplicitClient(mock_channel) + client._var_meta = [ + data.VariableMetaData(name="f", type=data.kOutput, shape=(1,)), + data.VariableMetaData(name="x", type=data.kInput, shape=(1,)), + ] + client._partials_meta = [data.PartialsMetaData(name="f", subname="x")] + + rpc_error = grpc.RpcError() + rpc_error.details = lambda: "gradient computation failed" + rpc_error.code = lambda: grpc.StatusCode.INTERNAL + mock_stub.ComputeGradient.side_effect = rpc_error + + with self.assertRaises(PhiloteServerError) as ctx: + client.run_compute_partials({"x": np.array([1.0])}) + self.assertIn("gradient computation failed", str(ctx.exception)) + def test_run_compute_partials_non_dict_raises(self): mock_channel = Mock() client = ExplicitClient(mock_channel) diff --git a/tests/test_explicit_server.py b/tests/test_explicit_server.py index cb1c5e1..f4110b7 100644 --- a/tests/test_explicit_server.py +++ b/tests/test_explicit_server.py @@ -37,6 +37,7 @@ from google.protobuf.empty_pb2 import Empty from philote_mdo.general import ExplicitDiscipline, ExplicitServer +from philote_mdo.utils.validation import PhiloteValidationError import philote_mdo.generated.data_pb2 as data @@ -147,6 +148,71 @@ def compute_partials(inputs, jac): ) + def test_compute_function_aborts_on_validation_error(self): + """ + Tests that ComputeFunction calls context.abort with INVALID_ARGUMENT + when a PhiloteValidationError is raised. + """ + server = ExplicitServer() + discipline = server._discipline = ExplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + ] + + def bad_compute(inputs, outputs): + raise PhiloteValidationError("bad input data") + + server._discipline.compute = bad_compute + + list(server.ComputeFunction(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("bad input data", args[0][1]) + + def test_compute_gradient_aborts_on_validation_error(self): + """ + Tests that ComputeGradient calls context.abort with INVALID_ARGUMENT + when a PhiloteValidationError is raised. + """ + server = ExplicitServer() + discipline = server._discipline = ExplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + discipline.declare_partials("f", "x") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + ] + + def bad_partials(inputs, jac): + raise PhiloteValidationError("invalid partials") + + server._discipline.compute_partials = bad_partials + + list(server.ComputeGradient(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("invalid partials", args[0][1]) + def test_compute_function_aborts_on_discipline_error(self): """ Tests that ComputeFunction calls context.abort when the discipline's diff --git a/tests/test_implicit_client.py b/tests/test_implicit_client.py index a652872..d9333c9 100644 --- a/tests/test_implicit_client.py +++ b/tests/test_implicit_client.py @@ -30,6 +30,7 @@ import unittest from unittest.mock import Mock, patch +import grpc import numpy as np from philote_mdo.general import ImplicitClient @@ -203,6 +204,70 @@ def test_residual_partials(self, mock_implicit_stub): np.testing.assert_array_equal(partials[key], expected_data) + @patch("philote_mdo.generated.disciplines_pb2_grpc.ImplicitServiceStub") + def test_run_compute_residuals_grpc_error_wraps(self, mock_implicit_stub): + mock_channel = Mock() + mock_stub = mock_implicit_stub.return_value + client = ImplicitClient(mock_channel) + client._var_meta = [ + data.VariableMetaData(name="x", type=data.kInput, shape=(1,)), + data.VariableMetaData(name="f", type=data.kOutput, shape=(1,)), + data.VariableMetaData(name="f", type=data.kResidual, shape=(1,)), + ] + + rpc_error = grpc.RpcError() + rpc_error.details = lambda: "residual failed" + rpc_error.code = lambda: grpc.StatusCode.INTERNAL + mock_stub.ComputeResiduals.side_effect = rpc_error + + with self.assertRaises(PhiloteServerError) as ctx: + client.run_compute_residuals( + {"x": np.array([1.0])}, {"f": np.array([1.0])} + ) + self.assertIn("residual failed", str(ctx.exception)) + + @patch("philote_mdo.generated.disciplines_pb2_grpc.ImplicitServiceStub") + def test_run_solve_residuals_grpc_error_wraps(self, mock_implicit_stub): + mock_channel = Mock() + mock_stub = mock_implicit_stub.return_value + client = ImplicitClient(mock_channel) + client._var_meta = [ + data.VariableMetaData(name="x", type=data.kInput, shape=(1,)), + data.VariableMetaData(name="f", type=data.kOutput, shape=(1,)), + ] + + rpc_error = grpc.RpcError() + rpc_error.details = lambda: "solve failed" + rpc_error.code = lambda: grpc.StatusCode.INTERNAL + mock_stub.SolveResiduals.side_effect = rpc_error + + with self.assertRaises(PhiloteServerError) as ctx: + client.run_solve_residuals({"x": np.array([1.0])}) + self.assertIn("solve failed", str(ctx.exception)) + + @patch("philote_mdo.generated.disciplines_pb2_grpc.ImplicitServiceStub") + def test_run_residual_gradients_grpc_error_wraps(self, mock_implicit_stub): + mock_channel = Mock() + mock_stub = mock_implicit_stub.return_value + client = ImplicitClient(mock_channel) + client._var_meta = [ + data.VariableMetaData(name="x", type=data.kInput, shape=(1,)), + data.VariableMetaData(name="f", type=data.kOutput, shape=(1,)), + data.VariableMetaData(name="f", type=data.kResidual, shape=(1,)), + ] + client._partials_meta = [data.PartialsMetaData(name="f", subname="x")] + + rpc_error = grpc.RpcError() + rpc_error.details = lambda: "gradient failed" + rpc_error.code = lambda: grpc.StatusCode.INTERNAL + mock_stub.ComputeResidualGradients.side_effect = rpc_error + + with self.assertRaises(PhiloteServerError) as ctx: + client.run_residual_gradients( + {"x": np.array([1.0])}, {"f": np.array([1.0])} + ) + self.assertIn("gradient failed", str(ctx.exception)) + def test_run_compute_residuals_non_dict_inputs_raises(self): mock_channel = Mock() client = ImplicitClient(mock_channel) diff --git a/tests/test_implicit_server.py b/tests/test_implicit_server.py index 901a588..6f21d68 100644 --- a/tests/test_implicit_server.py +++ b/tests/test_implicit_server.py @@ -35,6 +35,7 @@ import numpy.testing as np_testing from google.protobuf.empty_pb2 import Empty from philote_mdo.general import ImplicitDiscipline, ImplicitServer +from philote_mdo.utils.validation import PhiloteValidationError import philote_mdo.generated.data_pb2 as data @@ -208,6 +209,148 @@ def residual_partials(inputs, residuals, jac): ) + def test_compute_residuals_aborts_on_validation_error(self): + """ + Tests that ComputeResiduals calls context.abort with INVALID_ARGUMENT + when a PhiloteValidationError is raised. + """ + server = ImplicitServer() + discipline = server._discipline = ImplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kOutput, name="f", + ) + ), + ] + + def bad_residuals(inputs, outputs, residuals): + raise PhiloteValidationError("bad residual input") + + server._discipline.compute_residuals = bad_residuals + + list(server.ComputeResiduals(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("bad residual input", args[0][1]) + + def test_solve_residuals_aborts_on_validation_error(self): + """ + Tests that SolveResiduals calls context.abort with INVALID_ARGUMENT + when a PhiloteValidationError is raised. + """ + server = ImplicitServer() + discipline = server._discipline = ImplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kOutput, name="f", + ) + ), + ] + + def bad_solve(inputs, outputs): + raise PhiloteValidationError("bad solve input") + + server._discipline.solve_residuals = bad_solve + + list(server.SolveResiduals(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("bad solve input", args[0][1]) + + def test_compute_residual_gradients_aborts_on_validation_error(self): + """ + Tests that ComputeResidualGradients calls context.abort with + INVALID_ARGUMENT when a PhiloteValidationError is raised. + """ + server = ImplicitServer() + discipline = server._discipline = ImplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + discipline.declare_partials("f", "x") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + ] + + def bad_partials(inputs, outputs, jac): + raise PhiloteValidationError("bad partials input") + + server._discipline.residual_partials = bad_partials + + list(server.ComputeResidualGradients(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("bad partials input", args[0][1]) + + def test_compute_residual_gradients_aborts_on_discipline_error(self): + """ + Tests that ComputeResidualGradients calls context.abort with INTERNAL + when an unexpected exception is raised. + """ + server = ImplicitServer() + discipline = server._discipline = ImplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + discipline.declare_partials("f", "x") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + ] + + def bad_partials(inputs, outputs, jac): + raise RuntimeError("unexpected crash") + + server._discipline.residual_partials = bad_partials + + list(server.ComputeResidualGradients(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INTERNAL) + self.assertIn("ComputeResidualGradients failed", args[0][1]) + def test_compute_residuals_aborts_on_discipline_error(self): """ Tests that ComputeResiduals calls context.abort when the discipline's diff --git a/tests/test_openmdao_explicit_client.py b/tests/test_openmdao_explicit_client.py index c8339e5..b97551f 100644 --- a/tests/test_openmdao_explicit_client.py +++ b/tests/test_openmdao_explicit_client.py @@ -305,6 +305,15 @@ def test_constructor_no_channel_raises_error(self, om_explicit_component_patch): self.assertIn("No channel provided", str(context.exception)) + def test_invalid_num_par_fd_raises(self, mock_explicit_component): + """ + Tests that an invalid num_par_fd raises a ValueError. + """ + mock_channel = Mock() + with self.assertRaises(ValueError) as context: + RemoteExplicitComponent(channel=mock_channel, num_par_fd=0) + self.assertIn("num_par_fd must be a positive integer", str(context.exception)) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_openmdao_implicit_client.py b/tests/test_openmdao_implicit_client.py index 2895e7f..045ddcb 100644 --- a/tests/test_openmdao_implicit_client.py +++ b/tests/test_openmdao_implicit_client.py @@ -367,6 +367,15 @@ def test_constructor_no_channel_raises_error(self, om_implicit_component_patch): self.assertIn("No channel provided", str(context.exception)) + def test_invalid_num_par_fd_raises(self, mock_implicit_component): + """ + Tests that an invalid num_par_fd raises a ValueError. + """ + mock_channel = Mock() + with self.assertRaises(ValueError) as context: + RemoteImplicitComponent(channel=mock_channel, num_par_fd=-1) + self.assertIn("num_par_fd must be a positive integer", str(context.exception)) + if __name__ == "__main__": unittest.main(verbosity=2) From 9da99d72915ccef4d359ff3193fe3753c089394e Mon Sep 17 00:00:00 2001 From: Christopher Lupp Date: Thu, 9 Apr 2026 13:39:58 -0400 Subject: [PATCH 3/3] Add dynamic shapes for inputs and outputs (MDO-Standards/Philote-MDO#6) Disciplines can now declare variables with `dynamic_shape=True` in `add_input`/`add_output`, indicating that the client is allowed to set the variable's shape at runtime. A new `SetVariableShapes` gRPC RPC lets clients send resolved shapes after querying variable definitions. Protocol changes: - Added `bool dynamic_shape` field to `VariableMetaData` in data.proto - Added `SetVariableShapes` streaming RPC to DisciplineService Server: - `add_input`/`add_output` accept `dynamic_shape` parameter - `SetVariableShapes` handler validates and updates shapes, including implicit residual entries - `preallocate_inputs` raises if a dynamic variable has no shape set Client: - `set_variable_shape`, `send_variable_shapes`, `get_dynamic_variables` helper methods OpenMDAO bindings: - Dynamic-shape variables map to `shape_by_conn=True` automatically - Resolved shapes are sent to the server before partials setup Includes FlexibleDiscipline example and 22 new tests (100% coverage). --- CHANGELOG.md | 8 + philote_mdo/examples/__init__.py | 1 + philote_mdo/examples/flexible.py | 55 ++ philote_mdo/general/discipline.py | 32 +- philote_mdo/general/discipline_client.py | 65 +++ philote_mdo/general/discipline_server.py | 61 ++- philote_mdo/generated/data_pb2.py | 28 +- philote_mdo/generated/data_pb2.pyi | 6 +- philote_mdo/generated/disciplines_pb2.py | 12 +- philote_mdo/generated/disciplines_pb2_grpc.py | 15 +- philote_mdo/openmdao/explicit.py | 5 + philote_mdo/openmdao/implicit.py | 5 + philote_mdo/openmdao/utils.py | 40 +- tests/test_dynamic_shapes.py | 509 ++++++++++++++++++ tests/test_openmdao_explicit_client.py | 1 + tests/test_openmdao_implicit_client.py | 1 + tests/test_openmdao_utils.py | 2 + 17 files changed, 809 insertions(+), 37 deletions(-) create mode 100644 philote_mdo/examples/flexible.py create mode 100644 tests/test_dynamic_shapes.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0426c32..f86ac6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Features +- Added dynamic shapes for inputs and outputs. Disciplines can now + declare variables with `dynamic_shape=True` in `add_input` / + `add_output`, indicating that the client is allowed to set the + variable's shape at runtime. A new `SetVariableShapes` gRPC RPC + lets clients send resolved shapes after querying variable definitions. + The OpenMDAO bindings automatically map dynamic-shape variables to + `shape_by_conn=True` and send resolved shapes back to the server + (MDO-Standards/Philote-MDO#6). - Added support for struct (dict) options via the new `kStruct` DataType enum value, enabling complex nested data to be declared and passed as discipline options (#49). diff --git a/philote_mdo/examples/__init__.py b/philote_mdo/examples/__init__.py index 414b3e5..ae2f477 100644 --- a/philote_mdo/examples/__init__.py +++ b/philote_mdo/examples/__init__.py @@ -27,6 +27,7 @@ # the linked websites, of the information, products, or services contained # therein. The DoD does not exercise any editorial, security, or other # control over the information you may find at these locations. +from .flexible import FlexibleDiscipline from .paraboloid import Paraboloid from .quadratic import QuadradicImplicit from .rosenbrock import Rosenbrock diff --git a/philote_mdo/examples/flexible.py b/philote_mdo/examples/flexible.py new file mode 100644 index 0000000..a58c59e --- /dev/null +++ b/philote_mdo/examples/flexible.py @@ -0,0 +1,55 @@ +# Philote-Python +# +# Copyright 2022-2025 Christopher A. Lupp +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# This work has been cleared for public release, distribution unlimited, case +# number: AFRL-2023-5713. +# +# The views expressed are those of the authors and do not reflect the +# official guidance or position of the United States Government, the +# Department of Defense or of the United States Air Force. +# +# Statement from DoD: The Appearance of external hyperlinks does not +# constitute endorsement by the United States Department of Defense (DoD) of +# the linked websites, of the information, products, or services contained +# therein. The DoD does not exercise any editorial, security, or other +# control over the information you may find at these locations. +import numpy as np +import philote_mdo.general as pmdo + + +class FlexibleDiscipline(pmdo.ExplicitDiscipline): + """ + Example explicit discipline with dynamic shapes. + + This discipline doubles every element of the input vector. The input + and output shapes are not fixed by the server — the client is + expected to set them via ``SetVariableShapes`` before computation. + """ + + def setup(self): + self.add_input("x", dynamic_shape=True, units="m") + self.add_output("y", dynamic_shape=True, units="m") + + def setup_partials(self): + self.declare_partials("y", "x") + + def compute(self, inputs, outputs): + outputs["y"] = 2.0 * inputs["x"] + + def compute_partials(self, inputs, partials): + n = inputs["x"].size + partials["y", "x"] = 2.0 * np.eye(n) diff --git a/philote_mdo/general/discipline.py b/philote_mdo/general/discipline.py index ed8414d..010975f 100644 --- a/philote_mdo/general/discipline.py +++ b/philote_mdo/general/discipline.py @@ -86,7 +86,7 @@ def add_option(self, name, type): ) self.options_list[name] = type - def add_input(self, name, shape=(1,), units=""): + def add_input(self, name, shape=(1,), units="", dynamic_shape=False): """ Define a continuous input. @@ -95,12 +95,16 @@ def add_input(self, name, shape=(1,), units=""): name : string the name of the input variable shape : tuple - the shape of the input variable + the shape of the input variable (ignored when dynamic_shape + is True) units : string the unit definition for the input variable + dynamic_shape : bool + when True, the client is allowed to set this variable's shape """ validate_name(name, "add_input") - validate_shape(shape, "add_input") + if not dynamic_shape: + validate_shape(shape, "add_input") validate_units(units, "add_input") if any(v.name == name and v.type == data.VariableType.kInput for v in self._var_meta): raise PhiloteValidationError( @@ -109,8 +113,10 @@ def add_input(self, name, shape=(1,), units=""): meta = data.VariableMetaData() meta.type = data.VariableType.kInput meta.name = name - meta.shape.extend(shape) + if not dynamic_shape: + meta.shape.extend(shape) meta.units = units + meta.dynamic_shape = dynamic_shape self._var_meta += [meta] def add_discrete_input(self, name, default=None): @@ -167,7 +173,7 @@ def add_discrete_output(self, name, default=None): meta.name = name self._discrete_var_meta += [meta] - def add_output(self, name, shape=(1,), units=""): + def add_output(self, name, shape=(1,), units="", dynamic_shape=False): """ Defines a continuous output. @@ -176,12 +182,16 @@ def add_output(self, name, shape=(1,), units=""): name : string the name of the output variable shape : tuple - the shape of the output variable + the shape of the output variable (ignored when dynamic_shape + is True) units : string the unit definition for the output variable + dynamic_shape : bool + when True, the client is allowed to set this variable's shape """ validate_name(name, "add_output") - validate_shape(shape, "add_output") + if not dynamic_shape: + validate_shape(shape, "add_output") validate_units(units, "add_output") if any(v.name == name and v.type == data.VariableType.kOutput for v in self._var_meta): raise PhiloteValidationError( @@ -190,17 +200,21 @@ def add_output(self, name, shape=(1,), units=""): out_meta = data.VariableMetaData() out_meta.type = data.VariableType.kOutput out_meta.name = name - out_meta.shape.extend(shape) + if not dynamic_shape: + out_meta.shape.extend(shape) out_meta.units = units + out_meta.dynamic_shape = dynamic_shape self._var_meta += [out_meta] if self._is_implicit: res_meta = data.VariableMetaData() res_meta.type = data.VariableType.kOutput res_meta.name = name - res_meta.shape.extend(shape) + if not dynamic_shape: + res_meta.shape.extend(shape) res_meta.units = units res_meta.type = data.VariableType.kResidual + res_meta.dynamic_shape = dynamic_shape self._var_meta += [res_meta] def declare_partials(self, func, var): diff --git a/philote_mdo/general/discipline_client.py b/philote_mdo/general/discipline_client.py index dfbed15..684f5e2 100644 --- a/philote_mdo/general/discipline_client.py +++ b/philote_mdo/general/discipline_client.py @@ -154,6 +154,71 @@ def get_partials_definitions(self): if message.name not in self._partials_meta: self._partials_meta += [message] + def get_dynamic_variables(self): + """ + Returns a list of variable metadata entries that have + ``dynamic_shape`` set to ``True``. + """ + return [v for v in self._var_meta if v.dynamic_shape] + + def set_variable_shape(self, name, shape, var_type=data.VariableType.kInput): + """ + Creates a ``VariableMetaData`` message for setting a dynamic + variable's shape. + + Parameters + ---------- + name : str + the name of the variable + shape : tuple + the desired shape + var_type : VariableType + the variable type (kInput or kOutput) + + Returns + ------- + VariableMetaData + protobuf message ready for ``send_variable_shapes`` + """ + meta = data.VariableMetaData() + meta.type = var_type + meta.name = name + meta.shape.extend(shape) + return meta + + def send_variable_shapes(self, variable_metadata): + """ + Sends shapes for variables flagged as ``dynamic_shape``. + + Call after ``get_variable_definitions()`` and before compute + calls. + + Parameters + ---------- + variable_metadata : list of VariableMetaData + shapes for dynamic variables + """ + self._disc_stub.SetVariableShapes(iter(variable_metadata)) + + # update local metadata to reflect the new shapes + for meta in variable_metadata: + for var in self._var_meta: + if var.name == meta.name and var.type == meta.type: + var.shape[:] = [] + var.shape.extend(meta.shape) + break + + # for implicit outputs, also update the matching residual + if meta.type == data.VariableType.kOutput: + for var in self._var_meta: + if ( + var.name == meta.name + and var.type == data.VariableType.kResidual + ): + var.shape[:] = [] + var.shape.extend(meta.shape) + break + def _assemble_input_messages( self, inputs, outputs=None, discrete_inputs=None, discrete_outputs=None ): diff --git a/philote_mdo/general/discipline_server.py b/philote_mdo/general/discipline_server.py index a7baee6..2a0d956 100644 --- a/philote_mdo/general/discipline_server.py +++ b/philote_mdo/general/discipline_server.py @@ -35,7 +35,7 @@ from google.protobuf.empty_pb2 import Empty from google.protobuf import struct_pb2 from philote_mdo.utils import PairDict, get_flattened_view -from philote_mdo.utils.validation import PhiloteValidationError +from philote_mdo.utils.validation import PhiloteValidationError, validate_shape class DisciplineServer(disc.DisciplineService): @@ -170,6 +170,57 @@ def GetPartialDefinitions(self, request, context): for jac in self._discipline._partials_meta: yield jac + def SetVariableShapes(self, request_iterator, context): + """ + Receives client-defined shapes for variables flagged as + dynamic_shape. + + The client must call this RPC after GetVariableDefinitions and + before any compute RPCs for disciplines that contain variables + with dynamic shapes. + """ + try: + for meta in request_iterator: + validate_shape(tuple(meta.shape), "SetVariableShapes") + + # find the matching variable and update its shape + for var in self._discipline._var_meta: + if var.name == meta.name and var.type == meta.type: + if not var.dynamic_shape: + raise PhiloteValidationError( + f"Variable '{meta.name}' does not allow " + f"dynamic shapes." + ) + var.shape[:] = [] + var.shape.extend(meta.shape) + break + else: + raise PhiloteValidationError( + f"SetVariableShapes: variable '{meta.name}' " + f"not found." + ) + + # if the variable is an output on an implicit discipline, + # also update the matching residual entry + if meta.type == data.VariableType.kOutput: + for var in self._discipline._var_meta: + if ( + var.name == meta.name + and var.type == data.VariableType.kResidual + and var.dynamic_shape + ): + var.shape[:] = [] + var.shape.extend(meta.shape) + break + + return Empty() + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, f"SetVariableShapes failed: {e}" + ) + def preallocate_inputs(self, inputs, flat_inputs, outputs=None, flat_outputs=None): """ Preallocates the inputs before receiving data from the client. @@ -178,6 +229,14 @@ def preallocate_inputs(self, inputs, flat_inputs, outputs=None, flat_outputs=Non inputs to evaluate the residuals and the partials of the residuals. """ for var in self._discipline._var_meta: + # validate that dynamic-shape variables have been resolved + if var.dynamic_shape and len(var.shape) == 0: + raise PhiloteValidationError( + f"Variable '{var.name}' has dynamic_shape=True but " + f"no shape has been set. Call SetVariableShapes " + f"before computing." + ) + if var.type == data.kInput: inputs[var.name] = np.zeros(var.shape) flat_inputs[var.name] = get_flattened_view(inputs[var.name]) diff --git a/philote_mdo/generated/data_pb2.py b/philote_mdo/generated/data_pb2.py index b6029b2..668e1ab 100644 --- a/philote_mdo/generated/data_pb2.py +++ b/philote_mdo/generated/data_pb2.py @@ -7,17 +7,17 @@ _runtime_version.ValidateProtobufRuntimeVersion(_runtime_version.Domain.PUBLIC, 5, 27, 2, '', 'data.proto') _sym_db = _symbol_database.Default() from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x07philote\x1a\x1cgoogle/protobuf/struct.proto"}\n\x14DisciplineProperties\x12\x12\n\ncontinuous\x18\x01 \x01(\x08\x12\x16\n\x0edifferentiable\x18\x02 \x01(\x08\x12\x1a\n\x12provides_gradients\x18\x03 \x01(\x08\x12\x0c\n\x04name\x18\x04 \x01(\t\x12\x0f\n\x07version\x18\x05 \x01(\t"#\n\rStreamOptions\x12\x12\n\nnum_double\x18\x01 \x01(\x03"?\n\x0bOptionsList\x12\x0f\n\x07options\x18\x01 \x03(\t\x12\x1f\n\x04type\x18\x02 \x03(\x0e2\x11.philote.DataType"=\n\x11DisciplineOptions\x12(\n\x07options\x18\x01 \x01(\x0b2\x17.google.protobuf.Struct"c\n\x10VariableMetaData\x12#\n\x04type\x18\x01 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\r\n\x05shape\x18\x04 \x03(\x03\x12\r\n\x05units\x18\x05 \x01(\t"@\n\x10PartialsMetaData\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03"u\n\x05Array\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05start\x18\x03 \x01(\x03\x12\x0b\n\x03end\x18\x04 \x01(\x03\x12#\n\x04type\x18\x05 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04data\x18\x06 \x03(\x01"l\n\x10DiscreteVariable\x12\x0c\n\x04name\x18\x01 \x01(\t\x12#\n\x04type\x18\x02 \x01(\x0e2\x15.philote.VariableType\x12%\n\x05value\x18\x03 \x01(\x0b2\x16.google.protobuf.Value"q\n\x0fVariableMessage\x12$\n\ncontinuous\x18\x01 \x01(\x0b2\x0e.philote.ArrayH\x00\x12-\n\x08discrete\x18\x02 \x01(\x0b2\x19.philote.DiscreteVariableH\x00B\t\n\x07payload*F\n\x08DataType\x12\t\n\x05kBool\x10\x00\x12\x08\n\x04kInt\x10\x01\x12\x0b\n\x07kDouble\x10\x02\x12\x0b\n\x07kString\x10\x03\x12\x0b\n\x07kStruct\x10\x04*m\n\x0cVariableType\x12\n\n\x06kInput\x10\x00\x12\x12\n\x0ekDiscreteInput\x10\x01\x12\r\n\tkResidual\x10\x02\x12\x0b\n\x07kOutput\x10\x03\x12\x13\n\x0fkDiscreteOutput\x10\x04\x12\x0c\n\x08kPartial\x10\x05B\x11\n\x0forg.philote.mdob\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x07philote\x1a\x1cgoogle/protobuf/struct.proto"}\n\x14DisciplineProperties\x12\x12\n\ncontinuous\x18\x01 \x01(\x08\x12\x16\n\x0edifferentiable\x18\x02 \x01(\x08\x12\x1a\n\x12provides_gradients\x18\x03 \x01(\x08\x12\x0c\n\x04name\x18\x04 \x01(\t\x12\x0f\n\x07version\x18\x05 \x01(\t"#\n\rStreamOptions\x12\x12\n\nnum_double\x18\x01 \x01(\x03"?\n\x0bOptionsList\x12\x0f\n\x07options\x18\x01 \x03(\t\x12\x1f\n\x04type\x18\x02 \x03(\x0e2\x11.philote.DataType"=\n\x11DisciplineOptions\x12(\n\x07options\x18\x01 \x01(\x0b2\x17.google.protobuf.Struct"z\n\x10VariableMetaData\x12#\n\x04type\x18\x01 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\r\n\x05shape\x18\x04 \x03(\x03\x12\r\n\x05units\x18\x05 \x01(\t\x12\x15\n\rdynamic_shape\x18\x06 \x01(\x08"@\n\x10PartialsMetaData\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03"u\n\x05Array\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07subname\x18\x02 \x01(\t\x12\r\n\x05start\x18\x03 \x01(\x03\x12\x0b\n\x03end\x18\x04 \x01(\x03\x12#\n\x04type\x18\x05 \x01(\x0e2\x15.philote.VariableType\x12\x0c\n\x04data\x18\x06 \x03(\x01"l\n\x10DiscreteVariable\x12\x0c\n\x04name\x18\x01 \x01(\t\x12#\n\x04type\x18\x02 \x01(\x0e2\x15.philote.VariableType\x12%\n\x05value\x18\x03 \x01(\x0b2\x16.google.protobuf.Value"q\n\x0fVariableMessage\x12$\n\ncontinuous\x18\x01 \x01(\x0b2\x0e.philote.ArrayH\x00\x12-\n\x08discrete\x18\x02 \x01(\x0b2\x19.philote.DiscreteVariableH\x00B\t\n\x07payload*F\n\x08DataType\x12\t\n\x05kBool\x10\x00\x12\x08\n\x04kInt\x10\x01\x12\x0b\n\x07kDouble\x10\x02\x12\x0b\n\x07kString\x10\x03\x12\x0b\n\x07kStruct\x10\x04*m\n\x0cVariableType\x12\n\n\x06kInput\x10\x00\x12\x12\n\x0ekDiscreteInput\x10\x01\x12\r\n\tkResidual\x10\x02\x12\x0b\n\x07kOutput\x10\x03\x12\x13\n\x0fkDiscreteOutput\x10\x04\x12\x0c\n\x08kPartial\x10\x05B\x11\n\x0forg.philote.mdob\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'data_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: _globals['DESCRIPTOR']._loaded_options = None _globals['DESCRIPTOR']._serialized_options = b'\n\x0forg.philote.mdo' - _globals['_DATATYPE']._serialized_start = 856 - _globals['_DATATYPE']._serialized_end = 926 - _globals['_VARIABLETYPE']._serialized_start = 928 - _globals['_VARIABLETYPE']._serialized_end = 1037 + _globals['_DATATYPE']._serialized_start = 879 + _globals['_DATATYPE']._serialized_end = 949 + _globals['_VARIABLETYPE']._serialized_start = 951 + _globals['_VARIABLETYPE']._serialized_end = 1060 _globals['_DISCIPLINEPROPERTIES']._serialized_start = 53 _globals['_DISCIPLINEPROPERTIES']._serialized_end = 178 _globals['_STREAMOPTIONS']._serialized_start = 180 @@ -27,12 +27,12 @@ _globals['_DISCIPLINEOPTIONS']._serialized_start = 282 _globals['_DISCIPLINEOPTIONS']._serialized_end = 343 _globals['_VARIABLEMETADATA']._serialized_start = 345 - _globals['_VARIABLEMETADATA']._serialized_end = 444 - _globals['_PARTIALSMETADATA']._serialized_start = 446 - _globals['_PARTIALSMETADATA']._serialized_end = 510 - _globals['_ARRAY']._serialized_start = 512 - _globals['_ARRAY']._serialized_end = 629 - _globals['_DISCRETEVARIABLE']._serialized_start = 631 - _globals['_DISCRETEVARIABLE']._serialized_end = 739 - _globals['_VARIABLEMESSAGE']._serialized_start = 741 - _globals['_VARIABLEMESSAGE']._serialized_end = 854 \ No newline at end of file + _globals['_VARIABLEMETADATA']._serialized_end = 467 + _globals['_PARTIALSMETADATA']._serialized_start = 469 + _globals['_PARTIALSMETADATA']._serialized_end = 533 + _globals['_ARRAY']._serialized_start = 535 + _globals['_ARRAY']._serialized_end = 652 + _globals['_DISCRETEVARIABLE']._serialized_start = 654 + _globals['_DISCRETEVARIABLE']._serialized_end = 762 + _globals['_VARIABLEMESSAGE']._serialized_start = 764 + _globals['_VARIABLEMESSAGE']._serialized_end = 877 \ No newline at end of file diff --git a/philote_mdo/generated/data_pb2.pyi b/philote_mdo/generated/data_pb2.pyi index 13c800f..500c23f 100644 --- a/philote_mdo/generated/data_pb2.pyi +++ b/philote_mdo/generated/data_pb2.pyi @@ -77,17 +77,19 @@ class DisciplineOptions(_message.Message): ... class VariableMetaData(_message.Message): - __slots__ = ('type', 'name', 'shape', 'units') + __slots__ = ('type', 'name', 'shape', 'units', 'dynamic_shape') TYPE_FIELD_NUMBER: _ClassVar[int] NAME_FIELD_NUMBER: _ClassVar[int] SHAPE_FIELD_NUMBER: _ClassVar[int] UNITS_FIELD_NUMBER: _ClassVar[int] + DYNAMIC_SHAPE_FIELD_NUMBER: _ClassVar[int] type: VariableType name: str shape: _containers.RepeatedScalarFieldContainer[int] units: str + dynamic_shape: bool - def __init__(self, type: _Optional[_Union[VariableType, str]]=..., name: _Optional[str]=..., shape: _Optional[_Iterable[int]]=..., units: _Optional[str]=...) -> None: + def __init__(self, type: _Optional[_Union[VariableType, str]]=..., name: _Optional[str]=..., shape: _Optional[_Iterable[int]]=..., units: _Optional[str]=..., dynamic_shape: bool=...) -> None: ... class PartialsMetaData(_message.Message): diff --git a/philote_mdo/generated/disciplines_pb2.py b/philote_mdo/generated/disciplines_pb2.py index a78cc88..1425ba1 100644 --- a/philote_mdo/generated/disciplines_pb2.py +++ b/philote_mdo/generated/disciplines_pb2.py @@ -8,7 +8,7 @@ _sym_db = _symbol_database.Default() from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 from . import data_pb2 as data__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11disciplines.proto\x12\x07philote\x1a\x1bgoogle/protobuf/empty.proto\x1a\ndata.proto2\x84\x04\n\x11DisciplineService\x12B\n\x07GetInfo\x12\x16.google.protobuf.Empty\x1a\x1d.philote.DisciplineProperties"\x00\x12D\n\x10SetStreamOptions\x12\x16.philote.StreamOptions\x1a\x16.google.protobuf.Empty"\x00\x12E\n\x13GetAvailableOptions\x12\x16.google.protobuf.Empty\x1a\x14.philote.OptionsList"\x00\x12B\n\nSetOptions\x12\x1a.philote.DisciplineOptions\x1a\x16.google.protobuf.Empty"\x00\x129\n\x05Setup\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty"\x00\x12O\n\x16GetVariableDefinitions\x12\x16.google.protobuf.Empty\x1a\x19.philote.VariableMetaData"\x000\x01\x12N\n\x15GetPartialDefinitions\x12\x16.google.protobuf.Empty\x1a\x19.philote.PartialsMetaData"\x000\x012\xab\x01\n\x0fExplicitService\x12K\n\x0fComputeFunction\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01\x12K\n\x0fComputeGradient\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x012\x81\x02\n\x0fImplicitService\x12L\n\x10ComputeResiduals\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01\x12J\n\x0eSolveResiduals\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01\x12T\n\x18ComputeResidualGradients\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01B\x11\n\x0forg.philote.mdob\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11disciplines.proto\x12\x07philote\x1a\x1bgoogle/protobuf/empty.proto\x1a\ndata.proto2\xd0\x04\n\x11DisciplineService\x12B\n\x07GetInfo\x12\x16.google.protobuf.Empty\x1a\x1d.philote.DisciplineProperties"\x00\x12D\n\x10SetStreamOptions\x12\x16.philote.StreamOptions\x1a\x16.google.protobuf.Empty"\x00\x12E\n\x13GetAvailableOptions\x12\x16.google.protobuf.Empty\x1a\x14.philote.OptionsList"\x00\x12B\n\nSetOptions\x12\x1a.philote.DisciplineOptions\x1a\x16.google.protobuf.Empty"\x00\x129\n\x05Setup\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty"\x00\x12O\n\x16GetVariableDefinitions\x12\x16.google.protobuf.Empty\x1a\x19.philote.VariableMetaData"\x000\x01\x12N\n\x15GetPartialDefinitions\x12\x16.google.protobuf.Empty\x1a\x19.philote.PartialsMetaData"\x000\x01\x12J\n\x11SetVariableShapes\x12\x19.philote.VariableMetaData\x1a\x16.google.protobuf.Empty"\x00(\x012\xab\x01\n\x0fExplicitService\x12K\n\x0fComputeFunction\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01\x12K\n\x0fComputeGradient\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x012\x81\x02\n\x0fImplicitService\x12L\n\x10ComputeResiduals\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01\x12J\n\x0eSolveResiduals\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01\x12T\n\x18ComputeResidualGradients\x12\x18.philote.VariableMessage\x1a\x18.philote.VariableMessage"\x00(\x010\x01B\x11\n\x0forg.philote.mdob\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'disciplines_pb2', _globals) @@ -16,8 +16,8 @@ _globals['DESCRIPTOR']._loaded_options = None _globals['DESCRIPTOR']._serialized_options = b'\n\x0forg.philote.mdo' _globals['_DISCIPLINESERVICE']._serialized_start = 72 - _globals['_DISCIPLINESERVICE']._serialized_end = 588 - _globals['_EXPLICITSERVICE']._serialized_start = 591 - _globals['_EXPLICITSERVICE']._serialized_end = 762 - _globals['_IMPLICITSERVICE']._serialized_start = 765 - _globals['_IMPLICITSERVICE']._serialized_end = 1022 \ No newline at end of file + _globals['_DISCIPLINESERVICE']._serialized_end = 664 + _globals['_EXPLICITSERVICE']._serialized_start = 667 + _globals['_EXPLICITSERVICE']._serialized_end = 838 + _globals['_IMPLICITSERVICE']._serialized_start = 841 + _globals['_IMPLICITSERVICE']._serialized_end = 1098 \ No newline at end of file diff --git a/philote_mdo/generated/disciplines_pb2_grpc.py b/philote_mdo/generated/disciplines_pb2_grpc.py index 1b23738..638bd69 100644 --- a/philote_mdo/generated/disciplines_pb2_grpc.py +++ b/philote_mdo/generated/disciplines_pb2_grpc.py @@ -34,6 +34,7 @@ def __init__(self, channel): self.Setup = channel.unary_unary('/philote.DisciplineService/Setup', request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, _registered_method=True) self.GetVariableDefinitions = channel.unary_stream('/philote.DisciplineService/GetVariableDefinitions', request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, response_deserializer=data__pb2.VariableMetaData.FromString, _registered_method=True) self.GetPartialDefinitions = channel.unary_stream('/philote.DisciplineService/GetPartialDefinitions', request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, response_deserializer=data__pb2.PartialsMetaData.FromString, _registered_method=True) + self.SetVariableShapes = channel.stream_unary('/philote.DisciplineService/SetVariableShapes', request_serializer=data__pb2.VariableMetaData.SerializeToString, response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, _registered_method=True) class DisciplineServiceServicer(object): """Generic Discipline Definition @@ -91,8 +92,16 @@ def GetPartialDefinitions(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def SetVariableShapes(self, request_iterator, context): + """Sets shapes for variables flagged as dynamic_shape. + Must be called after GetVariableDefinitions and before compute RPCs. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_DisciplineServiceServicer_to_server(servicer, server): - rpc_method_handlers = {'GetInfo': grpc.unary_unary_rpc_method_handler(servicer.GetInfo, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=data__pb2.DisciplineProperties.SerializeToString), 'SetStreamOptions': grpc.unary_unary_rpc_method_handler(servicer.SetStreamOptions, request_deserializer=data__pb2.StreamOptions.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString), 'GetAvailableOptions': grpc.unary_unary_rpc_method_handler(servicer.GetAvailableOptions, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=data__pb2.OptionsList.SerializeToString), 'SetOptions': grpc.unary_unary_rpc_method_handler(servicer.SetOptions, request_deserializer=data__pb2.DisciplineOptions.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString), 'Setup': grpc.unary_unary_rpc_method_handler(servicer.Setup, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString), 'GetVariableDefinitions': grpc.unary_stream_rpc_method_handler(servicer.GetVariableDefinitions, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=data__pb2.VariableMetaData.SerializeToString), 'GetPartialDefinitions': grpc.unary_stream_rpc_method_handler(servicer.GetPartialDefinitions, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=data__pb2.PartialsMetaData.SerializeToString)} + rpc_method_handlers = {'GetInfo': grpc.unary_unary_rpc_method_handler(servicer.GetInfo, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=data__pb2.DisciplineProperties.SerializeToString), 'SetStreamOptions': grpc.unary_unary_rpc_method_handler(servicer.SetStreamOptions, request_deserializer=data__pb2.StreamOptions.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString), 'GetAvailableOptions': grpc.unary_unary_rpc_method_handler(servicer.GetAvailableOptions, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=data__pb2.OptionsList.SerializeToString), 'SetOptions': grpc.unary_unary_rpc_method_handler(servicer.SetOptions, request_deserializer=data__pb2.DisciplineOptions.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString), 'Setup': grpc.unary_unary_rpc_method_handler(servicer.Setup, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString), 'GetVariableDefinitions': grpc.unary_stream_rpc_method_handler(servicer.GetVariableDefinitions, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=data__pb2.VariableMetaData.SerializeToString), 'GetPartialDefinitions': grpc.unary_stream_rpc_method_handler(servicer.GetPartialDefinitions, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=data__pb2.PartialsMetaData.SerializeToString), 'SetVariableShapes': grpc.stream_unary_rpc_method_handler(servicer.SetVariableShapes, request_deserializer=data__pb2.VariableMetaData.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString)} generic_handler = grpc.method_handlers_generic_handler('philote.DisciplineService', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) server.add_registered_method_handlers('philote.DisciplineService', rpc_method_handlers) @@ -132,6 +141,10 @@ def GetVariableDefinitions(request, target, options=(), channel_credentials=None def GetPartialDefinitions(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): return grpc.experimental.unary_stream(request, target, '/philote.DisciplineService/GetPartialDefinitions', google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, data__pb2.PartialsMetaData.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, _registered_method=True) + @staticmethod + def SetVariableShapes(request_iterator, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): + return grpc.experimental.stream_unary(request_iterator, target, '/philote.DisciplineService/SetVariableShapes', data__pb2.VariableMetaData.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, _registered_method=True) + class ExplicitServiceStub(object): """Definition of the generic Explicit Component RPC """ diff --git a/philote_mdo/openmdao/explicit.py b/philote_mdo/openmdao/explicit.py index f80c9ef..2e02b9f 100644 --- a/philote_mdo/openmdao/explicit.py +++ b/philote_mdo/openmdao/explicit.py @@ -200,6 +200,10 @@ def setup_partials(self): the server's partial derivative metadata. The component can compute partials either analytically (if supported by the server) or via finite differencing. + If any variables were declared with ``dynamic_shape=True`` on the server, + the shapes resolved by OpenMDAO (e.g. via ``shape_by_conn``) are sent + back to the server before querying partial definitions. + The method is called automatically by OpenMDAO during component setup and should not be called directly by users. @@ -209,6 +213,7 @@ def setup_partials(self): - Both analytic and finite difference methods are supported - Sparsity patterns are preserved when available from server metadata """ + utils.send_resolved_shapes(self) utils.client_setup_partials(self) def compute(self, inputs, outputs, discrete_inputs=None, discrete_outputs=None): diff --git a/philote_mdo/openmdao/implicit.py b/philote_mdo/openmdao/implicit.py index 2176bdc..70c8c17 100644 --- a/philote_mdo/openmdao/implicit.py +++ b/philote_mdo/openmdao/implicit.py @@ -209,6 +209,10 @@ def setup_partials(self): pairs based on the server's partial derivative metadata. For implicit components, this includes both dR/dinputs and dR/doutputs terms needed for Newton-type solvers. + If any variables were declared with ``dynamic_shape=True`` on the server, + the shapes resolved by OpenMDAO (e.g. via ``shape_by_conn``) are sent + back to the server before querying partial definitions. + The method is called automatically by OpenMDAO during component setup and should not be called directly by users. @@ -218,6 +222,7 @@ def setup_partials(self): - Both dR/dinputs and dR/doutputs partials are declared - Sparsity patterns are preserved when available from server metadata """ + utils.send_resolved_shapes(self) utils.client_setup_partials(self) def apply_nonlinear(self, inputs, outputs, residuals, discrete_inputs=None, discrete_outputs=None): diff --git a/philote_mdo/openmdao/utils.py b/philote_mdo/openmdao/utils.py index ea11503..d7a599e 100644 --- a/philote_mdo/openmdao/utils.py +++ b/philote_mdo/openmdao/utils.py @@ -74,11 +74,19 @@ def client_setup(comp): else: units = var.units - if var.type == data.kInput: - comp.add_input(var.name, shape=tuple(var.shape), units=units) + if var.dynamic_shape: + # let OpenMDAO resolve the shape from connections + if var.type == data.kInput: + comp.add_input(var.name, shape_by_conn=True, units=units) - if var.type == data.kOutput: - comp.add_output(var.name, shape=tuple(var.shape), units=units) + if var.type == data.kOutput: + comp.add_output(var.name, shape_by_conn=True, units=units) + else: + if var.type == data.kInput: + comp.add_input(var.name, shape=tuple(var.shape), units=units) + + if var.type == data.kOutput: + comp.add_output(var.name, shape=tuple(var.shape), units=units) # define discrete inputs and outputs for var in comp._client._discrete_var_meta: @@ -89,6 +97,30 @@ def client_setup(comp): comp.add_discrete_output(var.name, val=None) +def send_resolved_shapes(comp): + """ + Sends resolved shapes for dynamic-shape variables back to the server. + + After OpenMDAO resolves shapes (e.g. via ``shape_by_conn``), this + function reads the resolved metadata from the component and transmits + the shapes to the remote discipline server. + """ + dynamic_shapes = [] + for var in comp._client._var_meta: + if not var.dynamic_shape: + continue + + resolved_meta = comp._var_rel2meta[var.name] + dynamic_shapes.append( + comp._client.set_variable_shape( + var.name, resolved_meta["shape"], var.type + ) + ) + + if dynamic_shapes: + comp._client.send_variable_shapes(dynamic_shapes) + + def client_setup_partials(comp): """ Sets up the partials for the OpenMDAO component. diff --git a/tests/test_dynamic_shapes.py b/tests/test_dynamic_shapes.py new file mode 100644 index 0000000..6ec15f5 --- /dev/null +++ b/tests/test_dynamic_shapes.py @@ -0,0 +1,509 @@ +# Philote-Python +# +# Copyright 2022-2025 Christopher A. Lupp +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# This work has been cleared for public release, distribution unlimited, case +# number: AFRL-2023-5713. +# +# The views expressed are those of the authors and do not reflect the +# official guidance or position of the United States Government, the +# Department of Defense or of the United States Air Force. +# +# Statement from DoD: The Appearance of external hyperlinks does not +# constitute endorsement by the United States Department of Defense (DoD) of +# the linked websites, of the information, products, or services contained +# therein. The DoD does not exercise any editorial, security, or other +# control over the information you may find at these locations. +from concurrent import futures +import unittest +from unittest.mock import Mock + +import grpc +import numpy as np +import numpy.testing as npt + +from philote_mdo.general import ( + Discipline, + DisciplineServer, + ExplicitDiscipline, + ExplicitServer, + ExplicitClient, + ImplicitDiscipline, + ImplicitServer, + ImplicitClient, +) +from philote_mdo.utils.validation import PhiloteValidationError +from philote_mdo.examples import FlexibleDiscipline +import philote_mdo.generated.data_pb2 as data + + +# --------------------------------------------------------------- +# Unit tests: Discipline base class +# --------------------------------------------------------------- +class TestDynamicShapeDiscipline(unittest.TestCase): + """Unit tests for dynamic_shape flag on the Discipline base class.""" + + def test_add_input_dynamic_shape(self): + """add_input with dynamic_shape=True stores the flag and omits shape.""" + disc = Discipline() + disc.add_input("x", dynamic_shape=True) + + self.assertEqual(len(disc._var_meta), 1) + meta = disc._var_meta[0] + self.assertEqual(meta.name, "x") + self.assertTrue(meta.dynamic_shape) + self.assertEqual(list(meta.shape), []) + + def test_add_output_dynamic_shape(self): + """add_output with dynamic_shape=True stores the flag and omits shape.""" + disc = Discipline() + disc.add_output("y", dynamic_shape=True) + + self.assertEqual(len(disc._var_meta), 1) + meta = disc._var_meta[0] + self.assertEqual(meta.name, "y") + self.assertTrue(meta.dynamic_shape) + self.assertEqual(list(meta.shape), []) + + def test_add_input_static_shape_unchanged(self): + """add_input without dynamic_shape behaves as before.""" + disc = Discipline() + disc.add_input("x", shape=(3, 2), units="m") + + meta = disc._var_meta[0] + self.assertFalse(meta.dynamic_shape) + self.assertEqual(list(meta.shape), [3, 2]) + + def test_add_output_static_shape_unchanged(self): + """add_output without dynamic_shape behaves as before.""" + disc = Discipline() + disc.add_output("y", shape=(4,)) + + meta = disc._var_meta[0] + self.assertFalse(meta.dynamic_shape) + self.assertEqual(list(meta.shape), [4]) + + def test_dynamic_shape_skips_shape_validation(self): + """dynamic_shape=True should not validate the default shape arg.""" + disc = Discipline() + # Should not raise even though we pass no valid shape + disc.add_input("x", dynamic_shape=True) + disc.add_output("y", dynamic_shape=True) + self.assertEqual(len(disc._var_meta), 2) + + def test_implicit_output_dynamic_shape_creates_residual(self): + """For implicit disciplines, dynamic output also creates a residual entry.""" + disc = Discipline() + disc._is_implicit = True + disc.add_output("y", dynamic_shape=True, units="m") + + # Should have output and residual + self.assertEqual(len(disc._var_meta), 2) + out = disc._var_meta[0] + res = disc._var_meta[1] + self.assertEqual(out.type, data.VariableType.kOutput) + self.assertTrue(out.dynamic_shape) + self.assertEqual(res.type, data.VariableType.kResidual) + self.assertTrue(res.dynamic_shape) + + +# --------------------------------------------------------------- +# Unit tests: DisciplineServer.SetVariableShapes +# --------------------------------------------------------------- +class TestSetVariableShapesRPC(unittest.TestCase): + """Unit tests for the SetVariableShapes RPC handler.""" + + def _make_server_with_dynamic_disc(self): + server = DisciplineServer() + disc = Discipline() + disc.add_input("x", dynamic_shape=True) + disc.add_output("y", dynamic_shape=True) + disc.add_input("z", shape=(2,)) # static + server._discipline = disc + return server + + def test_set_shapes_for_dynamic_variables(self): + """SetVariableShapes updates shapes on dynamic variables.""" + server = self._make_server_with_dynamic_disc() + context = Mock() + + x_meta = data.VariableMetaData( + name="x", type=data.VariableType.kInput, shape=[5] + ) + y_meta = data.VariableMetaData( + name="y", type=data.VariableType.kOutput, shape=[5] + ) + server.SetVariableShapes(iter([x_meta, y_meta]), context) + + # verify shapes were updated + for var in server._discipline._var_meta: + if var.name == "x": + self.assertEqual(list(var.shape), [5]) + if var.name == "y" and var.type == data.VariableType.kOutput: + self.assertEqual(list(var.shape), [5]) + + def test_reject_shape_for_static_variable(self): + """SetVariableShapes aborts when targeting a non-dynamic variable.""" + server = self._make_server_with_dynamic_disc() + context = Mock() + + z_meta = data.VariableMetaData( + name="z", type=data.VariableType.kInput, shape=[10] + ) + server.SetVariableShapes(iter([z_meta]), context) + context.abort.assert_called_once() + + def test_reject_unknown_variable(self): + """SetVariableShapes aborts when the variable name is not found.""" + server = self._make_server_with_dynamic_disc() + context = Mock() + + meta = data.VariableMetaData( + name="nope", type=data.VariableType.kInput, shape=[3] + ) + server.SetVariableShapes(iter([meta]), context) + context.abort.assert_called_once() + + def test_reject_invalid_shape(self): + """SetVariableShapes aborts on invalid (non-positive) shape.""" + server = self._make_server_with_dynamic_disc() + context = Mock() + + meta = data.VariableMetaData( + name="x", type=data.VariableType.kInput, shape=[-1] + ) + server.SetVariableShapes(iter([meta]), context) + context.abort.assert_called_once() + + def test_preallocate_raises_when_shape_unset(self): + """preallocate_inputs raises if a dynamic variable has no shape.""" + server = self._make_server_with_dynamic_disc() + + with self.assertRaises(PhiloteValidationError): + server.preallocate_inputs({}, {}) + + +# --------------------------------------------------------------- +# Unit tests: DisciplineClient helpers +# --------------------------------------------------------------- +class TestDynamicShapeClient(unittest.TestCase): + """Unit tests for client-side dynamic shape helpers.""" + + def test_set_variable_shape(self): + """set_variable_shape creates correct VariableMetaData.""" + client = ExplicitClient.__new__(ExplicitClient) + client._var_meta = [] + + meta = client.set_variable_shape("x", (3, 2), data.VariableType.kInput) + + self.assertEqual(meta.name, "x") + self.assertEqual(list(meta.shape), [3, 2]) + self.assertEqual(meta.type, data.VariableType.kInput) + + def test_get_dynamic_variables(self): + """get_dynamic_variables filters correctly.""" + client = ExplicitClient.__new__(ExplicitClient) + + static_var = data.VariableMetaData( + name="a", type=data.kInput, shape=[1], dynamic_shape=False + ) + dynamic_var = data.VariableMetaData( + name="b", type=data.kInput, dynamic_shape=True + ) + client._var_meta = [static_var, dynamic_var] + + result = client.get_dynamic_variables() + self.assertEqual(len(result), 1) + self.assertEqual(result[0].name, "b") + + +# --------------------------------------------------------------- +# Integration test: FlexibleDiscipline round-trip +# --------------------------------------------------------------- +class TestFlexibleDisciplineIntegration(unittest.TestCase): + """End-to-end test: dynamic shapes over gRPC.""" + + def test_flexible_compute(self): + """Client sets shapes, then computes with the FlexibleDiscipline.""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + + discipline = ExplicitServer(discipline=FlexibleDiscipline()) + discipline.attach_to_server(server) + + server.add_insecure_port("[::]:50051") + server.start() + + try: + client = ExplicitClient( + channel=grpc.insecure_channel("localhost:50051") + ) + client.send_stream_options() + client.run_setup() + client.get_variable_definitions() + + # verify dynamic_shape flags came through + dynamic = client.get_dynamic_variables() + self.assertEqual(len(dynamic), 2) + + # set shapes + shapes = [ + client.set_variable_shape("x", (4,), data.VariableType.kInput), + client.set_variable_shape( + "y", (4,), data.VariableType.kOutput + ), + ] + client.send_variable_shapes(shapes) + client.get_partials_definitions() + + # compute + inputs = {"x": np.array([1.0, 2.0, 3.0, 4.0])} + outputs = client.run_compute(inputs) + + npt.assert_array_almost_equal( + outputs["y"], [2.0, 4.0, 6.0, 8.0] + ) + finally: + server.stop(0) + + def test_flexible_compute_partials(self): + """Client sets shapes, then computes partials.""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + + discipline = ExplicitServer(discipline=FlexibleDiscipline()) + discipline.attach_to_server(server) + + server.add_insecure_port("[::]:50051") + server.start() + + try: + client = ExplicitClient( + channel=grpc.insecure_channel("localhost:50051") + ) + client.send_stream_options() + client.run_setup() + client.get_variable_definitions() + + shapes = [ + client.set_variable_shape("x", (3,), data.VariableType.kInput), + client.set_variable_shape( + "y", (3,), data.VariableType.kOutput + ), + ] + client.send_variable_shapes(shapes) + client.get_partials_definitions() + + inputs = {"x": np.array([1.0, 2.0, 3.0])} + jac = client.run_compute_partials(inputs) + + expected = 2.0 * np.eye(3) + npt.assert_array_almost_equal(jac["y", "x"], expected) + finally: + server.stop(0) + + def test_backward_compat_static_shapes(self): + """Existing static-shape disciplines work without calling SetVariableShapes.""" + from philote_mdo.examples import Paraboloid + + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + + discipline = ExplicitServer(discipline=Paraboloid()) + discipline.attach_to_server(server) + + server.add_insecure_port("[::]:50051") + server.start() + + try: + client = ExplicitClient( + channel=grpc.insecure_channel("localhost:50051") + ) + client.send_stream_options() + client.run_setup() + client.get_variable_definitions() + client.get_partials_definitions() + + # no dynamic variables + self.assertEqual(len(client.get_dynamic_variables()), 0) + + inputs = {"x": np.array([1.0]), "y": np.array([2.0])} + outputs = client.run_compute(inputs) + self.assertAlmostEqual(outputs["f_xy"][0], 39.0) + finally: + server.stop(0) + + +# --------------------------------------------------------------- +# Integration test: implicit discipline with dynamic shapes +# --------------------------------------------------------------- +class DynamicImplicit(ImplicitDiscipline): + """Implicit discipline with dynamic shapes for residual testing.""" + + def setup(self): + self.add_input("a", shape=(1,)) + self.add_output("x", dynamic_shape=True) + + def setup_partials(self): + self.declare_partials("x", "a") + self.declare_partials("x", "x") + + def compute_residuals(self, inputs, outputs, residuals): + residuals["x"] = outputs["x"] ** 2 - inputs["a"] + + def solve_residuals(self, inputs, outputs): + outputs["x"] = np.sqrt(np.abs(inputs["a"])) + + def compute_residual_partials(self, inputs, outputs, partials): + partials["x", "a"] = -np.ones(inputs["a"].shape) + partials["x", "x"] = 2.0 * np.diag(outputs["x"].ravel()) + + +class TestDynamicImplicitIntegration(unittest.TestCase): + """Integration test for implicit discipline with dynamic shapes.""" + + def test_implicit_dynamic_shape_residual(self): + """SetVariableShapes updates residual entries for implicit disciplines.""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + + discipline = ImplicitServer(discipline=DynamicImplicit()) + discipline.attach_to_server(server) + + server.add_insecure_port("[::]:50051") + server.start() + + try: + client = ImplicitClient( + channel=grpc.insecure_channel("localhost:50051") + ) + client.send_stream_options() + client.run_setup() + client.get_variable_definitions() + + # set shape for the dynamic output (and its residual) + shapes = [ + client.set_variable_shape( + "x", (2,), data.VariableType.kOutput + ), + ] + client.send_variable_shapes(shapes) + client.get_partials_definitions() + + inputs = {"a": np.array([4.0])} + outputs = {"x": np.array([3.0, 1.0])} + residuals = client.run_compute_residuals(inputs, outputs) + + # x**2 - a = [9-4, 1-4] = [5, -3] + npt.assert_array_almost_equal(residuals["x"], [5.0, -3.0]) + finally: + server.stop(0) + + +# --------------------------------------------------------------- +# Unit test: SetVariableShapes generic exception path +# --------------------------------------------------------------- +class TestSetVariableShapesGenericException(unittest.TestCase): + """Tests the generic exception handler in SetVariableShapes.""" + + def test_generic_exception_aborts(self): + server = DisciplineServer() + disc = Discipline() + disc.add_input("x", dynamic_shape=True) + server._discipline = disc + + context = Mock() + + # Craft an iterator that raises a non-validation exception + def bad_iterator(): + raise RuntimeError("unexpected error") + yield # pragma: no cover + + server.SetVariableShapes(bad_iterator(), context) + context.abort.assert_called_once() + self.assertEqual( + context.abort.call_args[0][0], grpc.StatusCode.INTERNAL + ) + + +# --------------------------------------------------------------- +# Unit test: OpenMDAO utils dynamic shape paths +# --------------------------------------------------------------- +class TestOpenMdaoUtilsDynamicShapes(unittest.TestCase): + """Tests the OpenMDAO utils functions with dynamic-shape variables.""" + + def test_client_setup_with_dynamic_input(self): + """client_setup passes shape_by_conn=True for dynamic inputs.""" + comp = Mock() + var = Mock() + var.name = "x" + var.units = "m" + var.type = data.kInput + var.shape = [] + var.dynamic_shape = True + + comp._client._var_meta = [var] + comp._client._discrete_var_meta = [] + + from philote_mdo.openmdao.utils import client_setup + + client_setup(comp) + + comp.add_input.assert_called_once_with( + "x", shape_by_conn=True, units="m" + ) + + def test_client_setup_with_dynamic_output(self): + """client_setup passes shape_by_conn=True for dynamic outputs.""" + comp = Mock() + var = Mock() + var.name = "y" + var.units = "" + var.type = data.kOutput + var.shape = [] + var.dynamic_shape = True + + comp._client._var_meta = [var] + comp._client._discrete_var_meta = [] + + from philote_mdo.openmdao.utils import client_setup + + client_setup(comp) + + comp.add_output.assert_called_once_with( + "y", shape_by_conn=True, units=None + ) + + def test_send_resolved_shapes_with_dynamic_vars(self): + """send_resolved_shapes reads OpenMDAO metadata and sends shapes.""" + comp = Mock() + var = data.VariableMetaData( + name="x", type=data.kInput, dynamic_shape=True + ) + comp._client._var_meta = [var] + comp._var_rel2meta = {"x": {"shape": (5,)}} + comp._client.set_variable_shape.return_value = data.VariableMetaData( + name="x", type=data.kInput, shape=[5] + ) + + from philote_mdo.openmdao.utils import send_resolved_shapes + + send_resolved_shapes(comp) + + comp._client.set_variable_shape.assert_called_once_with( + "x", (5,), data.kInput + ) + comp._client.send_variable_shapes.assert_called_once() + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_openmdao_explicit_client.py b/tests/test_openmdao_explicit_client.py index b97551f..b50bc59 100644 --- a/tests/test_openmdao_explicit_client.py +++ b/tests/test_openmdao_explicit_client.py @@ -167,6 +167,7 @@ def test_setup_partials( component = RemoteExplicitComponent(channel=mock_channel) component._client = Mock() component._client._partials_meta = [par1, par2] + component._client._var_meta = [] # call the function component.setup_partials() diff --git a/tests/test_openmdao_implicit_client.py b/tests/test_openmdao_implicit_client.py index 045ddcb..82255c1 100644 --- a/tests/test_openmdao_implicit_client.py +++ b/tests/test_openmdao_implicit_client.py @@ -171,6 +171,7 @@ def test_setup_partials( component = RemoteImplicitComponent(channel=mock_channel) component._client = Mock() component._client._partials_meta = [par1, par2] + component._client._var_meta = [] # call the function component.setup_partials() diff --git a/tests/test_openmdao_utils.py b/tests/test_openmdao_utils.py index 763504e..bc65eba 100644 --- a/tests/test_openmdao_utils.py +++ b/tests/test_openmdao_utils.py @@ -48,12 +48,14 @@ def test_openmdao_client_setup(self): var1.units = "m" var1.type = kInput var1.shape = [2] + var1.dynamic_shape = False var2 = Mock() var2.name = "var2" var2.units = None var2.type = kOutput var2.shape = [1] + var2.dynamic_shape = False comp._client._var_meta = [var1, var2] comp._client._discrete_var_meta = []