diff --git a/CHANGELOG.md b/CHANGELOG.md index e160b23..0426c32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,10 +20,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 the new `VariableMessage` wrapper. The OpenMDAO bindings (`RemoteExplicitComponent`, `RemoteImplicitComponent`) automatically discover and forward discrete variables. +- Added comprehensive input validation and error handling across the + framework. Introduces custom exception classes (`PhiloteValidationError`, + `PhiloteServerError`), parameter validation in discipline base classes + (`add_input`, `add_output`, `add_option`, `declare_partials`), proper + gRPC error propagation via `context.abort()` with appropriate status + codes in all server RPC methods, and client-side input validation with + gRPC error wrapping (#46). ### Bug Fixes - Fixed bare `except` to `except ImportError` in `examples/__init__.py`. +- Fixed missing space in `RemoteImplicitComponent` error message + ("will notbe" -> "will not be"). - Fixed `SellarMDA` promoted-input ambiguity that newer OpenMDAO releases reject during `final_setup`. The `x` and `z` defaults were being set on the inner `cycle` subgroup, but `obj_cmp` promoted the same variables diff --git a/philote_mdo/general/discipline.py b/philote_mdo/general/discipline.py index ae1f020..ed8414d 100644 --- a/philote_mdo/general/discipline.py +++ b/philote_mdo/general/discipline.py @@ -28,6 +28,13 @@ # therein. The DoD does not exercise any editorial, security, or other # control over the information you may find at these locations. import philote_mdo.generated.data_pb2 as data +from philote_mdo.utils.validation import ( + validate_name, + validate_shape, + validate_units, + validate_option_type, + PhiloteValidationError, +) class Discipline: @@ -71,6 +78,12 @@ def add_option(self, name, type): the data type of the option. acceptable types are 'bool', 'int', 'float', 'str', 'dict' """ + validate_name(name, "add_option") + validate_option_type(type, name) + if name in self.options_list: + raise PhiloteValidationError( + f"add_option: option '{name}' is already defined." + ) self.options_list[name] = type def add_input(self, name, shape=(1,), units=""): @@ -86,6 +99,13 @@ def add_input(self, name, shape=(1,), units=""): units : string the unit definition for the input variable """ + validate_name(name, "add_input") + validate_shape(shape, "add_input") + validate_units(units, "add_input") + if any(v.name == name and v.type == data.VariableType.kInput for v in self._var_meta): + raise PhiloteValidationError( + f"add_input: input '{name}' is already defined." + ) meta = data.VariableMetaData() meta.type = data.VariableType.kInput meta.name = name @@ -107,6 +127,14 @@ def add_discrete_input(self, name, default=None): default : object, optional the default value for the discrete input """ + validate_name(name, "add_discrete_input") + if any( + v.name == name and v.type == data.VariableType.kDiscreteInput + for v in self._discrete_var_meta + ): + raise PhiloteValidationError( + f"add_discrete_input: discrete input '{name}' is already defined." + ) meta = data.VariableMetaData() meta.type = data.VariableType.kDiscreteInput meta.name = name @@ -126,6 +154,14 @@ def add_discrete_output(self, name, default=None): default : object, optional the default value for the discrete output """ + validate_name(name, "add_discrete_output") + if any( + v.name == name and v.type == data.VariableType.kDiscreteOutput + for v in self._discrete_var_meta + ): + raise PhiloteValidationError( + f"add_discrete_output: discrete output '{name}' is already defined." + ) meta = data.VariableMetaData() meta.type = data.VariableType.kDiscreteOutput meta.name = name @@ -144,6 +180,13 @@ def add_output(self, name, shape=(1,), units=""): units : string the unit definition for the output variable """ + validate_name(name, "add_output") + validate_shape(shape, "add_output") + validate_units(units, "add_output") + if any(v.name == name and v.type == data.VariableType.kOutput for v in self._var_meta): + raise PhiloteValidationError( + f"add_output: output '{name}' is already defined." + ) out_meta = data.VariableMetaData() out_meta.type = data.VariableType.kOutput out_meta.name = name @@ -164,6 +207,8 @@ def declare_partials(self, func, var): """ Defines partials that will be determined using the analysis server. """ + validate_name(func, "declare_partials (func)") + validate_name(var, "declare_partials (var)") self._partials_meta += [data.PartialsMetaData(name=func, subname=var)] def initialize(self): diff --git a/philote_mdo/general/discipline_client.py b/philote_mdo/general/discipline_client.py index 13b206f..dfbed15 100644 --- a/philote_mdo/general/discipline_client.py +++ b/philote_mdo/general/discipline_client.py @@ -33,6 +33,11 @@ import philote_mdo.generated.disciplines_pb2_grpc as disc import philote_mdo.utils as utils from philote_mdo.general.discipline_server import _python_to_value, _value_to_python +from philote_mdo.utils.validation import ( + PhiloteValidationError, + validate_is_dict, + validate_numpy_array, +) class DisciplineClient: @@ -114,6 +119,7 @@ def send_options(self, options): ------- None """ + validate_is_dict(options, "send_options") proto_options = data.DisciplineOptions() proto_options.options.update(options) self._disc_stub.SetOptions(proto_options) @@ -158,6 +164,14 @@ def _assemble_input_messages( Both continuous and discrete inputs are wrapped in ``VariableMessage`` envelopes. """ + validate_is_dict(inputs, "_assemble_input_messages (inputs)") + for input_name, value in inputs.items(): + validate_numpy_array(value, input_name) + if outputs is not None: + validate_is_dict(outputs, "_assemble_input_messages (outputs)") + for output_name, value in outputs.items(): + validate_numpy_array(value, output_name) + messages = [] # Continuous inputs @@ -251,7 +265,7 @@ def _recover_outputs(self, responses): if len(arr.data) > 0: flat_outputs[arr.name][b:e] = arr.data else: - raise ValueError( + raise PhiloteValidationError( "Expected continuous variables, but array is empty." ) @@ -289,7 +303,7 @@ def _recover_residuals(self, responses): if len(arr.data) > 0: flat_residuals[arr.name][b:e] = arr.data else: - raise ValueError( + raise PhiloteValidationError( "Expected continuous variables, but array is empty." ) @@ -336,7 +350,7 @@ def _recover_partials(self, responses): if len(arr.data) > 0: flat_p[(arr.name, arr.subname)][b:e] = arr.data else: - raise ValueError( + raise PhiloteValidationError( "Expected continuous outputs for the " "partials, but array was empty." ) diff --git a/philote_mdo/general/discipline_server.py b/philote_mdo/general/discipline_server.py index d784b18..a7baee6 100644 --- a/philote_mdo/general/discipline_server.py +++ b/philote_mdo/general/discipline_server.py @@ -27,6 +27,7 @@ # the linked websites, of the information, products, or services contained # therein. The DoD does not exercise any editorial, security, or other # control over the information you may find at these locations. +import grpc import numpy as np import philote_mdo.generated.data_pb2 as data @@ -34,6 +35,7 @@ from google.protobuf.empty_pb2 import Empty from google.protobuf import struct_pb2 from philote_mdo.utils import PairDict, get_flattened_view +from philote_mdo.utils.validation import PhiloteValidationError class DisciplineServer(disc.DisciplineService): @@ -85,48 +87,69 @@ def GetAvailableOptions(self, request, context): """ RPC that gets the names and types of all available discipline options. """ - opts_dict = self._discipline.options_list - opts = data.OptionsList() - - for name, val in opts_dict.items(): - opts.options.append(name) - - # assign the correct data type - if val == "bool": - type = data.kBool - elif val == "int": - type = data.kInt - elif val == "float": - type = data.kDouble - elif val == "str": - type = data.kString - elif val == "dict": - type = data.kStruct - else: - raise ValueError( - "Invalid value for discipline option '{}'".format(name) - ) + try: + opts_dict = self._discipline.options_list + opts = data.OptionsList() + + for name, val in opts_dict.items(): + opts.options.append(name) + + # assign the correct data type + if val == "bool": + type = data.kBool + elif val == "int": + type = data.kInt + elif val == "float": + type = data.kDouble + elif val == "str": + type = data.kString + elif val == "dict": + type = data.kStruct + else: + raise PhiloteValidationError( + "Invalid value for discipline option '{}'".format(name) + ) - opts.type.append(type) + opts.type.append(type) - return opts + return opts + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, f"GetAvailableOptions failed: {e}" + ) def SetOptions(self, request, context): """ RPC that sets the discipline options. """ - options = request.options - self._discipline.set_options(options) - return Empty() + try: + options = request.options + self._discipline.set_options(options) + return Empty() + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, f"SetOptions failed: {e}" + ) def Setup(self, request, context): """ RPC that runs the setup function """ - self._discipline._clear_data() - self._discipline.setup() - self._discipline.setup_partials() - return Empty() + try: + self._discipline._clear_data() + self._discipline.setup() + self._discipline.setup_partials() + return Empty() + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, f"Setup failed: {e}" + ) def GetVariableDefinitions(self, request, context): """ @@ -237,7 +260,7 @@ def process_inputs( elif arr.type == data.VariableType.kOutput: flat_outputs[arr.name][b : e + 1] = arr.data else: - raise ValueError( + raise PhiloteValidationError( "Expected continuous variables but arrays were" " empty for variable %s." % (arr.name) ) diff --git a/philote_mdo/general/explicit_client.py b/philote_mdo/general/explicit_client.py index 5b278cb..b42f8e7 100644 --- a/philote_mdo/general/explicit_client.py +++ b/philote_mdo/general/explicit_client.py @@ -29,6 +29,7 @@ # control over the information you may find at these locations. import grpc from philote_mdo.general.discipline_client import DisciplineClient +from philote_mdo.utils.validation import PhiloteServerError, validate_is_dict import philote_mdo.generated.disciplines_pb2_grpc as disc @@ -59,11 +60,17 @@ def run_compute(self, inputs, discrete_inputs=None): Continuous outputs, or (continuous outputs, discrete outputs) when the server returns discrete output data. """ - messages = self._assemble_input_messages( - inputs, discrete_inputs=discrete_inputs - ) - responses = self._expl_stub.ComputeFunction(iter(messages)) - return self._recover_outputs(responses) + validate_is_dict(inputs, "run_compute (inputs)") + try: + messages = self._assemble_input_messages( + inputs, discrete_inputs=discrete_inputs + ) + responses = self._expl_stub.ComputeFunction(iter(messages)) + return self._recover_outputs(responses) + except grpc.RpcError as e: + raise PhiloteServerError( + f"Server error during run_compute: {e.details()}" + ) from e def run_compute_partials(self, inputs, discrete_inputs=None): """ @@ -77,10 +84,16 @@ def run_compute_partials(self, inputs, discrete_inputs=None): discrete_inputs : dict, optional Discrete input values. """ - messages = self._assemble_input_messages( - inputs, discrete_inputs=discrete_inputs - ) - responses = self._expl_stub.ComputeGradient(iter(messages)) - partials = self._recover_partials(responses) + validate_is_dict(inputs, "run_compute_partials (inputs)") + try: + messages = self._assemble_input_messages( + inputs, discrete_inputs=discrete_inputs + ) + responses = self._expl_stub.ComputeGradient(iter(messages)) + partials = self._recover_partials(responses) - return partials + return partials + except grpc.RpcError as e: + raise PhiloteServerError( + f"Server error during run_compute_partials: {e.details()}" + ) from e diff --git a/philote_mdo/general/explicit_server.py b/philote_mdo/general/explicit_server.py index 049e0a2..e8b8fff 100644 --- a/philote_mdo/general/explicit_server.py +++ b/philote_mdo/general/explicit_server.py @@ -27,6 +27,7 @@ # the linked websites, of the information, products, or services contained # therein. The DoD does not exercise any editorial, security, or other # control over the information you may find at these locations. +import grpc import philote_mdo.generated.disciplines_pb2_grpc as disc import philote_mdo.generated.data_pb2 as data from philote_mdo.general.discipline_server import ( @@ -34,6 +35,7 @@ _python_to_value, ) from philote_mdo.utils import get_chunk_indices +from philote_mdo.utils.validation import PhiloteValidationError class ExplicitServer(DisciplineServer, disc.ExplicitServiceServicer): @@ -55,76 +57,90 @@ def ComputeFunction(self, request_iterator, context): """ Computes the function evaluation and sends the result to the client. """ - inputs = {} - flat_inputs = {} - outputs = {} - discrete_inputs = {} - discrete_outputs = {} + try: + inputs = {} + flat_inputs = {} + outputs = {} + discrete_inputs = {} + discrete_outputs = {} - self.preallocate_inputs(inputs, flat_inputs) - discrete_inputs, _ = self.process_inputs( - request_iterator, flat_inputs, discrete_inputs=discrete_inputs - ) - - # Call compute with discrete data when discrete variables are present - if discrete_inputs or self._discipline._discrete_var_meta: - self._discipline.compute( - inputs, outputs, discrete_inputs, discrete_outputs + self.preallocate_inputs(inputs, flat_inputs) + discrete_inputs, _ = self.process_inputs( + request_iterator, flat_inputs, discrete_inputs=discrete_inputs ) - else: - self._discipline.compute(inputs, outputs) - # Stream continuous outputs - for output_name, value in outputs.items(): - for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): - yield data.VariableMessage( - continuous=data.Array( - name=output_name, - type=data.kOutput, - start=b, - end=e - 1, - data=value.ravel()[b:e], - ) + # Call compute with discrete data when discrete variables are present + if discrete_inputs or self._discipline._discrete_var_meta: + self._discipline.compute( + inputs, outputs, discrete_inputs, discrete_outputs ) + else: + self._discipline.compute(inputs, outputs) - # Stream discrete outputs - for name, value in discrete_outputs.items(): - yield data.VariableMessage( - discrete=data.DiscreteVariable( - name=name, - type=data.VariableType.kDiscreteOutput, - value=_python_to_value(value), + # Stream continuous outputs + for output_name, value in outputs.items(): + for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): + yield data.VariableMessage( + continuous=data.Array( + name=output_name, + type=data.kOutput, + start=b, + end=e - 1, + data=value.ravel()[b:e], + ) + ) + + # Stream discrete outputs + for name, value in discrete_outputs.items(): + yield data.VariableMessage( + discrete=data.DiscreteVariable( + name=name, + type=data.VariableType.kDiscreteOutput, + value=_python_to_value(value), + ) ) + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, f"ComputeFunction failed: {e}" ) def ComputeGradient(self, request_iterator, context): """ Computes the gradient evaluation and sends the result to the client. """ - inputs = {} - flat_inputs = {} - discrete_inputs = {} + try: + inputs = {} + flat_inputs = {} + discrete_inputs = {} - self.preallocate_inputs(inputs, flat_inputs) - jac = self.preallocate_partials() - discrete_inputs, _ = self.process_inputs( - request_iterator, flat_inputs, discrete_inputs=discrete_inputs - ) + self.preallocate_inputs(inputs, flat_inputs) + jac = self.preallocate_partials() + discrete_inputs, _ = self.process_inputs( + request_iterator, flat_inputs, discrete_inputs=discrete_inputs + ) - if discrete_inputs or self._discipline._discrete_var_meta: - self._discipline.compute_partials(inputs, jac, discrete_inputs) - else: - self._discipline.compute_partials(inputs, jac) + if discrete_inputs or self._discipline._discrete_var_meta: + self._discipline.compute_partials(inputs, jac, discrete_inputs) + else: + self._discipline.compute_partials(inputs, jac) - for jac, value in jac.items(): - for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): - yield data.VariableMessage( - continuous=data.Array( - name=jac[0], - subname=jac[1], - type=data.kPartial, - start=b, - end=e - 1, - data=value.ravel()[b:e], + for jac, value in jac.items(): + for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): + yield data.VariableMessage( + continuous=data.Array( + name=jac[0], + subname=jac[1], + type=data.kPartial, + start=b, + end=e - 1, + data=value.ravel()[b:e], + ) ) - ) + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, f"ComputeGradient failed: {e}" + ) diff --git a/philote_mdo/general/implicit_client.py b/philote_mdo/general/implicit_client.py index 81fb6f3..9ef8156 100644 --- a/philote_mdo/general/implicit_client.py +++ b/philote_mdo/general/implicit_client.py @@ -29,6 +29,7 @@ # control over the information you may find at these locations. import grpc from philote_mdo.general.discipline_client import DisciplineClient +from philote_mdo.utils.validation import PhiloteServerError, validate_is_dict import philote_mdo.generated.data_pb2 as data import philote_mdo.generated.disciplines_pb2_grpc as disc @@ -168,17 +169,24 @@ def run_compute_residuals( - Large arrays are automatically streamed for efficiency - This is typically used for residual evaluation during Newton iterations """ - # Assemble input messages and call server - messages = self._assemble_input_messages( - inputs, - outputs, - discrete_inputs=discrete_inputs, - discrete_outputs=discrete_outputs, - ) - responses = self._impl_stub.ComputeResiduals(iter(messages)) - residuals = self._recover_residuals(responses) - - return residuals + validate_is_dict(inputs, "run_compute_residuals (inputs)") + validate_is_dict(outputs, "run_compute_residuals (outputs)") + try: + # Assemble input messages and call server + messages = self._assemble_input_messages( + inputs, + outputs, + discrete_inputs=discrete_inputs, + discrete_outputs=discrete_outputs, + ) + responses = self._impl_stub.ComputeResiduals(iter(messages)) + residuals = self._recover_residuals(responses) + + return residuals + except grpc.RpcError as e: + raise PhiloteServerError( + f"Server error during run_compute_residuals: {e.details()}" + ) from e def run_solve_residuals(self, inputs, discrete_inputs=None): """ @@ -231,13 +239,19 @@ def run_solve_residuals(self, inputs, discrete_inputs=None): - May raise exceptions for ill-conditioned or non-convergent problems - Solution quality depends on the server's implementation and input conditioning """ - # Assemble input messages and call server - messages = self._assemble_input_messages( - inputs, discrete_inputs=discrete_inputs - ) - responses = self._impl_stub.SolveResiduals(iter(messages)) - outputs = self._recover_outputs(responses) - return outputs + validate_is_dict(inputs, "run_solve_residuals (inputs)") + try: + # Assemble input messages and call server + messages = self._assemble_input_messages( + inputs, discrete_inputs=discrete_inputs + ) + responses = self._impl_stub.SolveResiduals(iter(messages)) + outputs = self._recover_outputs(responses) + return outputs + except grpc.RpcError as e: + raise PhiloteServerError( + f"Server error during run_solve_residuals: {e.details()}" + ) from e def run_residual_gradients( self, inputs, outputs, discrete_inputs=None, discrete_outputs=None @@ -299,13 +313,20 @@ def run_residual_gradients( - Used by optimization algorithms and sensitivity analysis tools - For large problems, consider matrix-free methods if available """ - # Assemble input messages and call server - messages = self._assemble_input_messages( - inputs, - outputs, - discrete_inputs=discrete_inputs, - discrete_outputs=discrete_outputs, - ) - responses = self._impl_stub.ComputeResidualGradients(iter(messages)) - partials = self._recover_partials(responses) - return partials + validate_is_dict(inputs, "run_residual_gradients (inputs)") + validate_is_dict(outputs, "run_residual_gradients (outputs)") + try: + # Assemble input messages and call server + messages = self._assemble_input_messages( + inputs, + outputs, + discrete_inputs=discrete_inputs, + discrete_outputs=discrete_outputs, + ) + responses = self._impl_stub.ComputeResidualGradients(iter(messages)) + partials = self._recover_partials(responses) + return partials + except grpc.RpcError as e: + raise PhiloteServerError( + f"Server error during run_residual_gradients: {e.details()}" + ) from e diff --git a/philote_mdo/general/implicit_server.py b/philote_mdo/general/implicit_server.py index 973cfc8..ac7c1f7 100644 --- a/philote_mdo/general/implicit_server.py +++ b/philote_mdo/general/implicit_server.py @@ -27,11 +27,13 @@ # the linked websites, of the information, products, or services contained # therein. The DoD does not exercise any editorial, security, or other # control over the information you may find at these locations. +import grpc import philote_mdo.generated.disciplines_pb2_grpc as disc import philote_mdo.generated.data_pb2 as data import philote_mdo.general as pmdo from philote_mdo.general.discipline_server import _python_to_value from philote_mdo.utils import get_chunk_indices +from philote_mdo.utils.validation import PhiloteValidationError class ImplicitServer(pmdo.DisciplineServer, disc.ImplicitServiceServicer): @@ -154,43 +156,50 @@ def ComputeResiduals(self, request_iterator, context): - Streams results back in chunks for efficiency - This method is called automatically by the gRPC framework """ - # inputs and outputs - inputs = {} - flat_inputs = {} - outputs = {} - flat_outputs = {} - residuals = {} - discrete_inputs = {} - discrete_outputs = {} - - self.preallocate_inputs(inputs, flat_inputs, outputs, flat_outputs) - discrete_inputs, discrete_outputs = self.process_inputs( - request_iterator, - flat_inputs, - flat_outputs, - discrete_inputs=discrete_inputs, - discrete_outputs=discrete_outputs, - ) - - # Call the user-defined compute_residuals function - if discrete_inputs or self._discipline._discrete_var_meta: - self._discipline.compute_residuals( - inputs, outputs, residuals, discrete_inputs, discrete_outputs + try: + # inputs and outputs + inputs = {} + flat_inputs = {} + outputs = {} + flat_outputs = {} + residuals = {} + discrete_inputs = {} + discrete_outputs = {} + + self.preallocate_inputs(inputs, flat_inputs, outputs, flat_outputs) + discrete_inputs, discrete_outputs = self.process_inputs( + request_iterator, + flat_inputs, + flat_outputs, + discrete_inputs=discrete_inputs, + discrete_outputs=discrete_outputs, ) - else: - self._discipline.compute_residuals(inputs, outputs, residuals) - - for res_name, value in residuals.items(): - for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): - yield data.VariableMessage( - continuous=data.Array( - name=res_name, - start=b, - end=e, - type=data.kResidual, - data=value.ravel()[b:e], - ) + + # Call the user-defined compute_residuals function + if discrete_inputs or self._discipline._discrete_var_meta: + self._discipline.compute_residuals( + inputs, outputs, residuals, discrete_inputs, discrete_outputs ) + else: + self._discipline.compute_residuals(inputs, outputs, residuals) + + for res_name, value in residuals.items(): + for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): + yield data.VariableMessage( + continuous=data.Array( + name=res_name, + start=b, + end=e, + type=data.kResidual, + data=value.ravel()[b:e], + ) + ) + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, f"ComputeResiduals failed: {e}" + ) def SolveResiduals(self, request_iterator, context): """ @@ -219,38 +228,45 @@ def SolveResiduals(self, request_iterator, context): - Outputs are streamed back in chunks for large arrays - This method is called automatically by the gRPC framework """ - # inputs and outputs - inputs = {} - flat_inputs = {} - outputs = {} - flat_outputs = {} - discrete_inputs = {} - - self.preallocate_inputs(inputs, flat_inputs, outputs, flat_outputs) - discrete_inputs, _ = self.process_inputs( - request_iterator, - flat_inputs, - flat_outputs, - discrete_inputs=discrete_inputs, - ) - - # Call the user-defined solve function - if discrete_inputs or self._discipline._discrete_var_meta: - self._discipline.solve_residuals(inputs, outputs, discrete_inputs) - else: - self._discipline.solve_residuals(inputs, outputs) - - for output_name, value in outputs.items(): - for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): - yield data.VariableMessage( - continuous=data.Array( - name=output_name, - start=b, - end=e, - type=data.kOutput, - data=value.ravel()[b:e], + try: + # inputs and outputs + inputs = {} + flat_inputs = {} + outputs = {} + flat_outputs = {} + discrete_inputs = {} + + self.preallocate_inputs(inputs, flat_inputs, outputs, flat_outputs) + discrete_inputs, _ = self.process_inputs( + request_iterator, + flat_inputs, + flat_outputs, + discrete_inputs=discrete_inputs, + ) + + # Call the user-defined solve function + if discrete_inputs or self._discipline._discrete_var_meta: + self._discipline.solve_residuals(inputs, outputs, discrete_inputs) + else: + self._discipline.solve_residuals(inputs, outputs) + + for output_name, value in outputs.items(): + for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): + yield data.VariableMessage( + continuous=data.Array( + name=output_name, + start=b, + end=e, + type=data.kOutput, + data=value.ravel()[b:e], + ) ) - ) + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, f"SolveResiduals failed: {e}" + ) def ComputeResidualGradients(self, request_iterator, context): """ @@ -279,44 +295,52 @@ def ComputeResidualGradients(self, request_iterator, context): - Used for gradient-based optimization and sensitivity analysis - This method is called automatically by the gRPC framework """ - # inputs and outputs - inputs = {} - flat_inputs = {} - outputs = {} - flat_outputs = {} - discrete_inputs = {} - discrete_outputs = {} - - self.preallocate_inputs(inputs, flat_inputs, outputs, flat_outputs) - jac = self.preallocate_partials() - discrete_inputs, discrete_outputs = self.process_inputs( - request_iterator, - flat_inputs, - flat_outputs, - discrete_inputs=discrete_inputs, - discrete_outputs=discrete_outputs, - ) - - # Call the user-defined residual partials function - if discrete_inputs or self._discipline._discrete_var_meta: - self._discipline.residual_partials( - inputs, outputs, jac, discrete_inputs, discrete_outputs + try: + # inputs and outputs + inputs = {} + flat_inputs = {} + outputs = {} + flat_outputs = {} + discrete_inputs = {} + discrete_outputs = {} + + self.preallocate_inputs(inputs, flat_inputs, outputs, flat_outputs) + jac = self.preallocate_partials() + discrete_inputs, discrete_outputs = self.process_inputs( + request_iterator, + flat_inputs, + flat_outputs, + discrete_inputs=discrete_inputs, + discrete_outputs=discrete_outputs, ) - else: - self._discipline.residual_partials(inputs, outputs, jac) - - for jac, value in jac.items(): - for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): - yield data.VariableMessage( - continuous=data.Array( - name=jac[0], - subname=jac[1], - type=data.kPartial, - start=b, - end=e, - data=value.ravel()[b:e], - ) + + # Call the user-defined residual partials function + if discrete_inputs or self._discipline._discrete_var_meta: + self._discipline.residual_partials( + inputs, outputs, jac, discrete_inputs, discrete_outputs ) + else: + self._discipline.residual_partials(inputs, outputs, jac) + + for jac, value in jac.items(): + for b, e in get_chunk_indices(value.size, self._stream_opts.num_double): + yield data.VariableMessage( + continuous=data.Array( + name=jac[0], + subname=jac[1], + type=data.kPartial, + start=b, + end=e, + data=value.ravel()[b:e], + ) + ) + except PhiloteValidationError as e: + context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + context.abort( + grpc.StatusCode.INTERNAL, + f"ComputeResidualGradients failed: {e}", + ) # def MatrixFreeGradients(self, request_iterator, context): # """ diff --git a/philote_mdo/openmdao/explicit.py b/philote_mdo/openmdao/explicit.py index 458b7e6..f80c9ef 100644 --- a/philote_mdo/openmdao/explicit.py +++ b/philote_mdo/openmdao/explicit.py @@ -128,6 +128,10 @@ def __init__(self, channel=None, num_par_fd=1, **kwargs): raise ValueError( "No channel provided, the Philote client will not be able to connect." ) + if not isinstance(num_par_fd, int) or num_par_fd < 1: + raise ValueError( + f"num_par_fd must be a positive integer, got {num_par_fd!r}." + ) # generic Philote client # The setting of OpenMDAO options requires the list of available diff --git a/philote_mdo/openmdao/group.py b/philote_mdo/openmdao/group.py index 4ca6c54..f27f9d9 100644 --- a/philote_mdo/openmdao/group.py +++ b/philote_mdo/openmdao/group.py @@ -29,6 +29,12 @@ # control over the information you may find at these locations. import openmdao.api as om import philote_mdo.general as pm +from philote_mdo.utils.validation import ( + PhiloteValidationError, + validate_name, + validate_shape, + validate_units, +) class OpenMdaoSubProblem(pm.ExplicitDiscipline): @@ -115,6 +121,11 @@ def add_group(self, group): >>> subprob = OpenMdaoSubProblem() >>> subprob.add_group(MyGroup()) """ + if not isinstance(group, om.Group): + raise PhiloteValidationError( + f"add_group: expected an om.Group instance, " + f"got {type(group).__name__}." + ) self._prob = om.Problem(model=group) self._model = self._prob.model @@ -142,6 +153,14 @@ def add_mapped_input(self, local_var, subprob_var, shape=(1,), units=""): >>> subprob.add_mapped_input('x_local', 'x', shape=(1,), units='m') >>> subprob.add_mapped_input('design_vars', 'z', shape=(2,), units='') """ + validate_name(local_var, "add_mapped_input (local_var)") + validate_name(subprob_var, "add_mapped_input (subprob_var)") + validate_shape(shape, "add_mapped_input") + validate_units(units, "add_mapped_input") + if local_var in self._input_map: + raise PhiloteValidationError( + f"add_mapped_input: '{local_var}' is already mapped." + ) self._input_map[local_var] = { "sub_prob_name": subprob_var, "shape": shape, @@ -172,6 +191,14 @@ def add_mapped_output(self, local_var, subprob_var, shape=(1,), units=""): >>> subprob.add_mapped_output('objective', 'obj', shape=(1,), units='') >>> subprob.add_mapped_output('constraint1', 'con1', shape=(1,), units='N') """ + validate_name(local_var, "add_mapped_output (local_var)") + validate_name(subprob_var, "add_mapped_output (subprob_var)") + validate_shape(shape, "add_mapped_output") + validate_units(units, "add_mapped_output") + if local_var in self._output_map: + raise PhiloteValidationError( + f"add_mapped_output: '{local_var}' is already mapped." + ) self._output_map[local_var] = { "sub_prob_name": subprob_var, "shape": shape, @@ -222,6 +249,18 @@ def declare_subproblem_partial(self, local_func, local_var): >>> subprob.add_mapped_output('y_local', 'y') >>> subprob.declare_subproblem_partial('y_local', 'x_local') """ + validate_name(local_func, "declare_subproblem_partial (local_func)") + validate_name(local_var, "declare_subproblem_partial (local_var)") + if local_func not in self._output_map: + raise PhiloteValidationError( + f"declare_subproblem_partial: output '{local_func}' has not " + f"been mapped. Call add_mapped_output first." + ) + if local_var not in self._input_map: + raise PhiloteValidationError( + f"declare_subproblem_partial: input '{local_var}' has not " + f"been mapped. Call add_mapped_input first." + ) self._partials_map[(local_func, local_var)] = ( self._output_map[local_func]["sub_prob_name"], self._input_map[local_var]["sub_prob_name"], diff --git a/philote_mdo/openmdao/implicit.py b/philote_mdo/openmdao/implicit.py index 1aa8843..2176bdc 100644 --- a/philote_mdo/openmdao/implicit.py +++ b/philote_mdo/openmdao/implicit.py @@ -132,9 +132,13 @@ def __init__(self, channel=None, num_par_fd=1, **kwargs): """ if not channel: raise ValueError( - "No channel provided, the Philote client will not" + "No channel provided, the Philote client will not " "be able to connect." ) + if not isinstance(num_par_fd, int) or num_par_fd < 1: + raise ValueError( + f"num_par_fd must be a positive integer, got {num_par_fd!r}." + ) # generic Philote client # The setting of OpenMDAO options requires the list of available diff --git a/philote_mdo/openmdao/utils.py b/philote_mdo/openmdao/utils.py index 3b95870..ea11503 100644 --- a/philote_mdo/openmdao/utils.py +++ b/philote_mdo/openmdao/utils.py @@ -28,6 +28,16 @@ # therein. The DoD does not exercise any editorial, security, or other # control over the information you may find at these locations. import philote_mdo.generated.data_pb2 as data +from philote_mdo.utils.validation import PhiloteValidationError + + +_TYPE_MAP = { + "bool": bool, + "int": int, + "float": float, + "str": str, + "dict": dict, +} def declare_options(opt_list, options): @@ -35,17 +45,12 @@ def declare_options(opt_list, options): Declares the options from the client options list. """ for name, type_str in opt_list: - opt_type = None - if type_str == "bool": - opt_type = bool - elif type_str == "int": - opt_type = int - elif type_str == "float": - opt_type = float - elif type_str == "str": - opt_type = str - elif type_str == "dict": - opt_type = dict + opt_type = _TYPE_MAP.get(type_str) + if opt_type is None: + raise PhiloteValidationError( + f"declare_options: unknown option type '{type_str}' " + f"for option '{name}'." + ) options.declare(name, types=opt_type) diff --git a/philote_mdo/utils/__init__.py b/philote_mdo/utils/__init__.py index 2e99b77..db801d8 100644 --- a/philote_mdo/utils/__init__.py +++ b/philote_mdo/utils/__init__.py @@ -29,3 +29,14 @@ # control over the information you may find at these locations. from .pair_dict import PairDict from .helper import get_chunk_indices, get_flattened_view +from .validation import ( + PhiloteError, + PhiloteValidationError, + PhiloteServerError, + validate_name, + validate_shape, + validate_units, + validate_option_type, + validate_is_dict, + validate_numpy_array, +) diff --git a/philote_mdo/utils/helper.py b/philote_mdo/utils/helper.py index 67bb39f..eb2a147 100644 --- a/philote_mdo/utils/helper.py +++ b/philote_mdo/utils/helper.py @@ -29,8 +29,21 @@ # control over the information you may find at these locations. import numpy as np +from philote_mdo.utils.validation import PhiloteValidationError + def get_chunk_indices(num_values, chunk_size): + if not isinstance(num_values, (int, np.integer)) or num_values < 0: + raise PhiloteValidationError( + f"get_chunk_indices: num_values must be a non-negative integer, " + f"got {num_values!r}." + ) + if not isinstance(chunk_size, (int, np.integer)) or chunk_size < 1: + raise PhiloteValidationError( + f"get_chunk_indices: chunk_size must be a positive integer, " + f"got {chunk_size!r}." + ) + beg_i = np.arange(0, num_values, chunk_size) if beg_i.size == 1: @@ -48,6 +61,11 @@ def get_flattened_view(arr): :param arr: Array to get a flattened view :return: A view of the input array, guaranteed to not be a copy """ + if not isinstance(arr, np.ndarray): + raise PhiloteValidationError( + f"get_flattened_view: expected a numpy ndarray, " + f"got {type(arr).__name__}." + ) flat_view = arr.view() flat_view.shape = -1 return flat_view diff --git a/philote_mdo/utils/pair_dict.py b/philote_mdo/utils/pair_dict.py index d72fdfc..bae5b9d 100644 --- a/philote_mdo/utils/pair_dict.py +++ b/philote_mdo/utils/pair_dict.py @@ -29,15 +29,27 @@ # control over the information you may find at these locations. +from philote_mdo.utils.validation import PhiloteValidationError + + class PairDict(dict): """ Jacobian dictionary for storing values with respect to two keys. """ + @staticmethod + def _validate_key(keys): + if not isinstance(keys, tuple) or len(keys) != 2: + raise PhiloteValidationError( + f"PairDict keys must be a 2-tuple, got {keys!r}." + ) + def __setitem__(self, keys, value): + self._validate_key(keys) key1, key2 = keys super().__setitem__((key1, key2), value) def __getitem__(self, keys): + self._validate_key(keys) key1, key2 = keys return super().__getitem__((key1, key2)) diff --git a/philote_mdo/utils/validation.py b/philote_mdo/utils/validation.py new file mode 100644 index 0000000..75395f2 --- /dev/null +++ b/philote_mdo/utils/validation.py @@ -0,0 +1,210 @@ +# Philote-Python +# +# Copyright 2022-2025 Christopher A. Lupp +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# This work has been cleared for public release, distribution unlimited, case +# number: AFRL-2023-5713. +# +# The views expressed are those of the authors and do not reflect the +# official guidance or position of the United States Government, the +# Department of Defense or of the United States Air Force. +# +# Statement from DoD: The Appearance of external hyperlinks does not +# constitute endorsement by the United States Department of Defense (DoD) of +# the linked websites, of the information, products, or services contained +# therein. The DoD does not exercise any editorial, security, or other +# control over the information you may find at these locations. +import numpy as np + + +# --------------------------------------------------------------------------- +# Custom exception hierarchy +# --------------------------------------------------------------------------- + + +class PhiloteError(Exception): + """Base class for all Philote-specific errors.""" + + pass + + +class PhiloteValidationError(PhiloteError, ValueError): + """Raised when an input fails validation. + + Inherits from ``ValueError`` so that existing ``except ValueError`` + handlers in user code continue to work. + """ + + pass + + +class PhiloteServerError(PhiloteError, RuntimeError): + """Raised on the client side when a gRPC server call fails. + + Wraps the gRPC error details into a framework-specific exception so + that users do not need to catch ``grpc.RpcError`` directly. + """ + + pass + + +# --------------------------------------------------------------------------- +# Validation helpers +# --------------------------------------------------------------------------- + +VALID_OPTION_TYPES = {"bool", "int", "float", "str", "dict"} + + +def validate_name(name, context): + """Validate that *name* is a non-empty string. + + Parameters + ---------- + name : object + The name to validate. + context : str + Human-readable context for error messages (e.g. ``"add_input"``). + + Raises + ------ + PhiloteValidationError + If *name* is not a string or is empty. + """ + if not isinstance(name, str): + raise PhiloteValidationError( + f"{context}: 'name' must be a string, got {type(name).__name__}." + ) + if not name: + raise PhiloteValidationError(f"{context}: 'name' must not be empty.") + + +def validate_shape(shape, context): + """Validate that *shape* is a tuple of positive integers. + + Parameters + ---------- + shape : object + The shape to validate. + context : str + Human-readable context for error messages. + + Raises + ------ + PhiloteValidationError + If *shape* is not a tuple, or contains non-positive / non-integer + elements. + """ + if not isinstance(shape, tuple): + raise PhiloteValidationError( + f"{context}: 'shape' must be a tuple, got {type(shape).__name__}." + ) + for i, dim in enumerate(shape): + if not isinstance(dim, int): + raise PhiloteValidationError( + f"{context}: all elements of 'shape' must be integers, " + f"but element {i} is {type(dim).__name__}." + ) + if dim <= 0: + raise PhiloteValidationError( + f"{context}: all elements of 'shape' must be positive, " + f"but element {i} is {dim}." + ) + + +def validate_units(units, context): + """Validate that *units* is a string. + + Parameters + ---------- + units : object + The units string to validate. + context : str + Human-readable context for error messages. + + Raises + ------ + PhiloteValidationError + If *units* is not a string. + """ + if not isinstance(units, str): + raise PhiloteValidationError( + f"{context}: 'units' must be a string, got {type(units).__name__}." + ) + + +def validate_option_type(type_str, name): + """Validate that *type_str* is one of the allowed option types. + + Parameters + ---------- + type_str : object + The option type string to validate. + name : str + The option name (for error messages). + + Raises + ------ + PhiloteValidationError + If *type_str* is not in the allowed set. + """ + if type_str not in VALID_OPTION_TYPES: + raise PhiloteValidationError( + f"Invalid type '{type_str}' for option '{name}'. " + f"Allowed types are: {sorted(VALID_OPTION_TYPES)}." + ) + + +def validate_is_dict(obj, context): + """Validate that *obj* is a dictionary. + + Parameters + ---------- + obj : object + The object to validate. + context : str + Human-readable context for error messages. + + Raises + ------ + PhiloteValidationError + If *obj* is not a ``dict``. + """ + if not isinstance(obj, dict): + raise PhiloteValidationError( + f"{context}: expected a dict, got {type(obj).__name__}." + ) + + +def validate_numpy_array(value, name): + """Validate that *value* is a NumPy ndarray. + + Parameters + ---------- + value : object + The value to validate. + name : str + Variable name (for error messages). + + Raises + ------ + PhiloteValidationError + If *value* is not a ``numpy.ndarray``. + """ + if not isinstance(value, np.ndarray): + raise PhiloteValidationError( + f"Variable '{name}' must be a numpy ndarray, " + f"got {type(value).__name__}." + ) diff --git a/tests/test_discipline.py b/tests/test_discipline.py index 6dbf5a6..a1b0e8e 100644 --- a/tests/test_discipline.py +++ b/tests/test_discipline.py @@ -31,6 +31,7 @@ from unittest.mock import Mock from philote_mdo.general import Discipline +from philote_mdo.utils.validation import PhiloteValidationError import philote_mdo.generated.data_pb2 as data @@ -172,6 +173,100 @@ def test_clear_data(self): self.assertEqual(len(disc._var_meta), 0) self.assertEqual(len(disc._partials_meta), 0) + # ------------------------------------------------------------------ + # Validation tests + # ------------------------------------------------------------------ + + def test_add_input_invalid_name_type(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_input(123) + + def test_add_input_empty_name(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_input("") + + def test_add_input_invalid_shape(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_input("x", shape=[2, 3]) + + def test_add_input_invalid_units(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_input("x", units=42) + + def test_add_input_duplicate(self): + disc = Discipline() + disc.add_input("x", shape=(2,)) + with self.assertRaises(PhiloteValidationError): + disc.add_input("x", shape=(3,)) + + def test_add_output_invalid_name(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_output(None) + + def test_add_output_invalid_shape(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_output("y", shape=(-1,)) + + def test_add_output_duplicate(self): + disc = Discipline() + disc.add_output("y") + with self.assertRaises(PhiloteValidationError): + disc.add_output("y") + + def test_add_option_invalid_name(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_option(42, "int") + + def test_add_option_invalid_type(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_option("opt", "unknown") + + def test_add_option_duplicate(self): + disc = Discipline() + disc.add_option("opt", "int") + with self.assertRaises(PhiloteValidationError): + disc.add_option("opt", "float") + + def test_add_discrete_input_invalid_name(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_discrete_input("") + + def test_add_discrete_input_duplicate(self): + disc = Discipline() + disc.add_discrete_input("d") + with self.assertRaises(PhiloteValidationError): + disc.add_discrete_input("d") + + def test_add_discrete_output_invalid_name(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.add_discrete_output(123) + + def test_add_discrete_output_duplicate(self): + disc = Discipline() + disc.add_discrete_output("d") + with self.assertRaises(PhiloteValidationError): + disc.add_discrete_output("d") + + def test_declare_partials_invalid_func(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.declare_partials("", "x") + + def test_declare_partials_invalid_var(self): + disc = Discipline() + with self.assertRaises(PhiloteValidationError): + disc.declare_partials("f", 123) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_discipline_client.py b/tests/test_discipline_client.py index d7be559..62d2bae 100644 --- a/tests/test_discipline_client.py +++ b/tests/test_discipline_client.py @@ -33,6 +33,7 @@ from google.protobuf.empty_pb2 import Empty from google.protobuf.struct_pb2 import Struct from philote_mdo.general import DisciplineClient +from philote_mdo.utils.validation import PhiloteValidationError import philote_mdo.generated.data_pb2 as data import philote_mdo.generated.disciplines_pb2_grpc as disc import philote_mdo.utils as utils @@ -556,6 +557,24 @@ def test_recover_partials_empty_array_raises_error(self): self.assertIn("Expected continuous outputs for the partials, but array was empty", str(context.exception)) + def test_send_options_non_dict_raises(self): + mock_channel = Mock() + client = DisciplineClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client.send_options("not a dict") + + def test_assemble_input_messages_non_dict_raises(self): + mock_channel = Mock() + client = DisciplineClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client._assemble_input_messages("not a dict") + + def test_assemble_input_messages_non_array_value_raises(self): + mock_channel = Mock() + client = DisciplineClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client._assemble_input_messages({"x": [1.0, 2.0]}) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_discipline_server.py b/tests/test_discipline_server.py index 2f30542..9b5c930 100644 --- a/tests/test_discipline_server.py +++ b/tests/test_discipline_server.py @@ -30,11 +30,13 @@ import unittest from unittest.mock import Mock +import grpc import numpy as np from google.protobuf.empty_pb2 import Empty from philote_mdo.general import Discipline, DisciplineServer +from philote_mdo.utils.validation import PhiloteValidationError import philote_mdo.generated.data_pb2 as data @@ -411,27 +413,32 @@ def test_set_options_with_nested_dict(self): } ) - def test_get_available_options_invalid_type_raises_error(self): + def test_get_available_options_invalid_type_aborts(self): """ - Tests that GetAvailableOptions raises ValueError for invalid option types. + Tests that GetAvailableOptions calls context.abort for invalid option + types. """ server = DisciplineServer() discipline = server._discipline = Discipline() - - # Add option with invalid type - discipline.add_option("invalid_option", "unknown_type") - + + # Add option with invalid type (bypasses add_option validation by + # writing directly to options_list) + discipline.options_list["invalid_option"] = "unknown_type" + request = Empty() context = Mock() - - with self.assertRaises(ValueError) as error_context: - server.GetAvailableOptions(request, context) - - self.assertIn("Invalid value for discipline option 'invalid_option'", str(error_context.exception)) + + server.GetAvailableOptions(request, context) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("Invalid value for discipline option 'invalid_option'", args[0][1]) def test_process_inputs_empty_array_raises_error(self): """ - Tests that process_inputs raises ValueError when array data is empty. + Tests that process_inputs raises PhiloteValidationError when array + data is empty. """ server = DisciplineServer() @@ -451,11 +458,116 @@ def test_process_inputs_empty_array_raises_error(self): flat_inputs = {"x": np.zeros(3)} flat_outputs = {} - with self.assertRaises(ValueError) as context: + with self.assertRaises(PhiloteValidationError) as context: server.process_inputs(request_iterator, flat_inputs, flat_outputs) self.assertIn("Expected continuous variables but arrays were empty for variable x", str(context.exception)) + def test_get_available_options_general_exception_aborts(self): + """ + Tests that GetAvailableOptions calls context.abort with INTERNAL + for unexpected exceptions. + """ + server = DisciplineServer() + discipline = Mock() + # options_list property raises an unexpected error + type(discipline).options_list = property( + lambda self: (_ for _ in ()).throw(RuntimeError("unexpected")) + ) + server._discipline = discipline + + request = Mock() + context = Mock() + + server.GetAvailableOptions(request, context) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INTERNAL) + self.assertIn("GetAvailableOptions failed", args[0][1]) + + def test_set_options_validation_error_aborts(self): + """ + Tests that SetOptions calls context.abort with INVALID_ARGUMENT + for PhiloteValidationError. + """ + server = DisciplineServer() + discipline = Mock() + discipline.set_options.side_effect = PhiloteValidationError("bad option") + server._discipline = discipline + + request = Mock() + request.options = {} + context = Mock() + + server.SetOptions(request, context) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("bad option", args[0][1]) + + def test_set_options_general_exception_aborts(self): + """ + Tests that SetOptions calls context.abort with INTERNAL for + unexpected exceptions. + """ + server = DisciplineServer() + discipline = Mock() + discipline.set_options.side_effect = RuntimeError("boom") + server._discipline = discipline + + request = Mock() + request.options = {} + context = Mock() + + server.SetOptions(request, context) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INTERNAL) + self.assertIn("SetOptions failed", args[0][1]) + + def test_setup_validation_error_aborts(self): + """ + Tests that Setup calls context.abort with INVALID_ARGUMENT + for PhiloteValidationError. + """ + server = DisciplineServer() + discipline = Mock() + discipline.setup.side_effect = PhiloteValidationError("bad setup") + server._discipline = discipline + + request = Mock() + context = Mock() + + server.Setup(request, context) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("bad setup", args[0][1]) + + def test_setup_general_exception_aborts(self): + """ + Tests that Setup calls context.abort with INTERNAL for + unexpected exceptions. + """ + server = DisciplineServer() + discipline = Mock() + discipline._clear_data.side_effect = RuntimeError("crash") + server._discipline = discipline + + request = Mock() + context = Mock() + + server.Setup(request, context) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INTERNAL) + self.assertIn("Setup failed", args[0][1]) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index 0afc37b..87d70d7 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -29,7 +29,11 @@ # control over the information you may find at these locations. import unittest from unittest.mock import Mock, MagicMock + +import grpc + from philote_mdo.general import DisciplineServer, DisciplineClient, ExplicitDiscipline +from philote_mdo.utils.validation import PhiloteValidationError import philote_mdo.generated.data_pb2 as data @@ -95,26 +99,27 @@ def test_get_available_options_with_dict_type(self): def test_get_available_options_with_invalid_type(self): """ - Test GetAvailableOptions with invalid option type (covers lines 100-103). + Test GetAvailableOptions with invalid option type aborts with + INVALID_ARGUMENT. """ server = DisciplineServer() discipline = Mock() - + # Mock the options_list attribute to return a dict with invalid type discipline.options_list = {"invalid_option": "invalid_type"} - + server.attach_discipline(discipline) - + # Create a mock request and context request = Mock() context = Mock() - - # This should raise a ValueError - with self.assertRaises(ValueError) as context_err: - server.GetAvailableOptions(request, context) - - self.assertIn("Invalid value for discipline option", str(context_err.exception)) - self.assertIn("invalid_option", str(context_err.exception)) + + server.GetAvailableOptions(request, context) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("Invalid value for discipline option", args[0][1]) def test_process_inputs_with_empty_continuous_data(self): """ diff --git a/tests/test_explicit_client.py b/tests/test_explicit_client.py index 59bd35b..558ff01 100644 --- a/tests/test_explicit_client.py +++ b/tests/test_explicit_client.py @@ -30,9 +30,11 @@ import unittest from unittest.mock import Mock, patch +import grpc import numpy as np from philote_mdo.general import ExplicitClient +from philote_mdo.utils.validation import PhiloteValidationError, PhiloteServerError import philote_mdo.generated.data_pb2 as data import philote_mdo.utils as utils @@ -135,3 +137,53 @@ def test_compute_partials(self, mock_explicit_stub): for output_name, expected_data in expected_outputs.items(): self.assertTrue(output_name in outputs) np.testing.assert_array_equal(outputs[output_name], expected_data) + + def test_run_compute_non_dict_raises(self): + mock_channel = Mock() + client = ExplicitClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client.run_compute("not a dict") + + @patch("philote_mdo.generated.disciplines_pb2_grpc.ExplicitServiceStub") + def test_run_compute_grpc_error_wraps(self, mock_explicit_stub): + mock_channel = Mock() + mock_stub = mock_explicit_stub.return_value + client = ExplicitClient(mock_channel) + client._var_meta = [ + data.VariableMetaData(name="x", type=data.kInput, shape=(1,)), + ] + + rpc_error = grpc.RpcError() + rpc_error.details = lambda: "server crashed" + rpc_error.code = lambda: grpc.StatusCode.INTERNAL + mock_stub.ComputeFunction.side_effect = rpc_error + + with self.assertRaises(PhiloteServerError) as ctx: + client.run_compute({"x": np.array([1.0])}) + self.assertIn("server crashed", str(ctx.exception)) + + @patch("philote_mdo.generated.disciplines_pb2_grpc.ExplicitServiceStub") + def test_run_compute_partials_grpc_error_wraps(self, mock_explicit_stub): + mock_channel = Mock() + mock_stub = mock_explicit_stub.return_value + client = ExplicitClient(mock_channel) + client._var_meta = [ + data.VariableMetaData(name="f", type=data.kOutput, shape=(1,)), + data.VariableMetaData(name="x", type=data.kInput, shape=(1,)), + ] + client._partials_meta = [data.PartialsMetaData(name="f", subname="x")] + + rpc_error = grpc.RpcError() + rpc_error.details = lambda: "gradient computation failed" + rpc_error.code = lambda: grpc.StatusCode.INTERNAL + mock_stub.ComputeGradient.side_effect = rpc_error + + with self.assertRaises(PhiloteServerError) as ctx: + client.run_compute_partials({"x": np.array([1.0])}) + self.assertIn("gradient computation failed", str(ctx.exception)) + + def test_run_compute_partials_non_dict_raises(self): + mock_channel = Mock() + client = ExplicitClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client.run_compute_partials(42) diff --git a/tests/test_explicit_server.py b/tests/test_explicit_server.py index 0416ec0..f4110b7 100644 --- a/tests/test_explicit_server.py +++ b/tests/test_explicit_server.py @@ -30,12 +30,14 @@ import unittest from unittest.mock import Mock +import grpc import numpy as np from scipy.optimize import rosen, rosen_der from google.protobuf.empty_pb2 import Empty from philote_mdo.general import ExplicitDiscipline, ExplicitServer +from philote_mdo.utils.validation import PhiloteValidationError import philote_mdo.generated.data_pb2 as data @@ -146,5 +148,137 @@ def compute_partials(inputs, jac): ) + def test_compute_function_aborts_on_validation_error(self): + """ + Tests that ComputeFunction calls context.abort with INVALID_ARGUMENT + when a PhiloteValidationError is raised. + """ + server = ExplicitServer() + discipline = server._discipline = ExplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + ] + + def bad_compute(inputs, outputs): + raise PhiloteValidationError("bad input data") + + server._discipline.compute = bad_compute + + list(server.ComputeFunction(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("bad input data", args[0][1]) + + def test_compute_gradient_aborts_on_validation_error(self): + """ + Tests that ComputeGradient calls context.abort with INVALID_ARGUMENT + when a PhiloteValidationError is raised. + """ + server = ExplicitServer() + discipline = server._discipline = ExplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + discipline.declare_partials("f", "x") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + ] + + def bad_partials(inputs, jac): + raise PhiloteValidationError("invalid partials") + + server._discipline.compute_partials = bad_partials + + list(server.ComputeGradient(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("invalid partials", args[0][1]) + + def test_compute_function_aborts_on_discipline_error(self): + """ + Tests that ComputeFunction calls context.abort when the discipline's + compute raises an exception. + """ + server = ExplicitServer() + discipline = server._discipline = ExplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + ] + + def bad_compute(inputs, outputs): + raise RuntimeError("division by zero in compute") + + server._discipline.compute = bad_compute + + # Exhaust the generator + list(server.ComputeFunction(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INTERNAL) + self.assertIn("ComputeFunction failed", args[0][1]) + + def test_compute_gradient_aborts_on_discipline_error(self): + """ + Tests that ComputeGradient calls context.abort when the discipline's + compute_partials raises an exception. + """ + server = ExplicitServer() + discipline = server._discipline = ExplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + discipline.declare_partials("f", "x") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + ] + + def bad_partials(inputs, jac): + raise RuntimeError("singular matrix") + + server._discipline.compute_partials = bad_partials + + list(server.ComputeGradient(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INTERNAL) + self.assertIn("ComputeGradient failed", args[0][1]) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_implicit_client.py b/tests/test_implicit_client.py index 3ec34a9..d9333c9 100644 --- a/tests/test_implicit_client.py +++ b/tests/test_implicit_client.py @@ -30,9 +30,11 @@ import unittest from unittest.mock import Mock, patch +import grpc import numpy as np from philote_mdo.general import ImplicitClient +from philote_mdo.utils.validation import PhiloteValidationError, PhiloteServerError import philote_mdo.generated.data_pb2 as data import philote_mdo.utils as utils @@ -202,5 +204,94 @@ def test_residual_partials(self, mock_implicit_stub): np.testing.assert_array_equal(partials[key], expected_data) + @patch("philote_mdo.generated.disciplines_pb2_grpc.ImplicitServiceStub") + def test_run_compute_residuals_grpc_error_wraps(self, mock_implicit_stub): + mock_channel = Mock() + mock_stub = mock_implicit_stub.return_value + client = ImplicitClient(mock_channel) + client._var_meta = [ + data.VariableMetaData(name="x", type=data.kInput, shape=(1,)), + data.VariableMetaData(name="f", type=data.kOutput, shape=(1,)), + data.VariableMetaData(name="f", type=data.kResidual, shape=(1,)), + ] + + rpc_error = grpc.RpcError() + rpc_error.details = lambda: "residual failed" + rpc_error.code = lambda: grpc.StatusCode.INTERNAL + mock_stub.ComputeResiduals.side_effect = rpc_error + + with self.assertRaises(PhiloteServerError) as ctx: + client.run_compute_residuals( + {"x": np.array([1.0])}, {"f": np.array([1.0])} + ) + self.assertIn("residual failed", str(ctx.exception)) + + @patch("philote_mdo.generated.disciplines_pb2_grpc.ImplicitServiceStub") + def test_run_solve_residuals_grpc_error_wraps(self, mock_implicit_stub): + mock_channel = Mock() + mock_stub = mock_implicit_stub.return_value + client = ImplicitClient(mock_channel) + client._var_meta = [ + data.VariableMetaData(name="x", type=data.kInput, shape=(1,)), + data.VariableMetaData(name="f", type=data.kOutput, shape=(1,)), + ] + + rpc_error = grpc.RpcError() + rpc_error.details = lambda: "solve failed" + rpc_error.code = lambda: grpc.StatusCode.INTERNAL + mock_stub.SolveResiduals.side_effect = rpc_error + + with self.assertRaises(PhiloteServerError) as ctx: + client.run_solve_residuals({"x": np.array([1.0])}) + self.assertIn("solve failed", str(ctx.exception)) + + @patch("philote_mdo.generated.disciplines_pb2_grpc.ImplicitServiceStub") + def test_run_residual_gradients_grpc_error_wraps(self, mock_implicit_stub): + mock_channel = Mock() + mock_stub = mock_implicit_stub.return_value + client = ImplicitClient(mock_channel) + client._var_meta = [ + data.VariableMetaData(name="x", type=data.kInput, shape=(1,)), + data.VariableMetaData(name="f", type=data.kOutput, shape=(1,)), + data.VariableMetaData(name="f", type=data.kResidual, shape=(1,)), + ] + client._partials_meta = [data.PartialsMetaData(name="f", subname="x")] + + rpc_error = grpc.RpcError() + rpc_error.details = lambda: "gradient failed" + rpc_error.code = lambda: grpc.StatusCode.INTERNAL + mock_stub.ComputeResidualGradients.side_effect = rpc_error + + with self.assertRaises(PhiloteServerError) as ctx: + client.run_residual_gradients( + {"x": np.array([1.0])}, {"f": np.array([1.0])} + ) + self.assertIn("gradient failed", str(ctx.exception)) + + def test_run_compute_residuals_non_dict_inputs_raises(self): + mock_channel = Mock() + client = ImplicitClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client.run_compute_residuals("not a dict", {"f": np.array([1.0])}) + + def test_run_compute_residuals_non_dict_outputs_raises(self): + mock_channel = Mock() + client = ImplicitClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client.run_compute_residuals({"x": np.array([1.0])}, "not a dict") + + def test_run_solve_residuals_non_dict_raises(self): + mock_channel = Mock() + client = ImplicitClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client.run_solve_residuals(42) + + def test_run_residual_gradients_non_dict_raises(self): + mock_channel = Mock() + client = ImplicitClient(mock_channel) + with self.assertRaises(PhiloteValidationError): + client.run_residual_gradients("bad", {"f": np.array([1.0])}) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_implicit_server.py b/tests/test_implicit_server.py index 21e7c82..6f21d68 100644 --- a/tests/test_implicit_server.py +++ b/tests/test_implicit_server.py @@ -29,10 +29,13 @@ # control over the information you may find at these locations. import unittest from unittest.mock import Mock + +import grpc import numpy as np import numpy.testing as np_testing from google.protobuf.empty_pb2 import Empty from philote_mdo.general import ImplicitDiscipline, ImplicitServer +from philote_mdo.utils.validation import PhiloteValidationError import philote_mdo.generated.data_pb2 as data @@ -206,5 +209,224 @@ def residual_partials(inputs, residuals, jac): ) + def test_compute_residuals_aborts_on_validation_error(self): + """ + Tests that ComputeResiduals calls context.abort with INVALID_ARGUMENT + when a PhiloteValidationError is raised. + """ + server = ImplicitServer() + discipline = server._discipline = ImplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kOutput, name="f", + ) + ), + ] + + def bad_residuals(inputs, outputs, residuals): + raise PhiloteValidationError("bad residual input") + + server._discipline.compute_residuals = bad_residuals + + list(server.ComputeResiduals(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("bad residual input", args[0][1]) + + def test_solve_residuals_aborts_on_validation_error(self): + """ + Tests that SolveResiduals calls context.abort with INVALID_ARGUMENT + when a PhiloteValidationError is raised. + """ + server = ImplicitServer() + discipline = server._discipline = ImplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kOutput, name="f", + ) + ), + ] + + def bad_solve(inputs, outputs): + raise PhiloteValidationError("bad solve input") + + server._discipline.solve_residuals = bad_solve + + list(server.SolveResiduals(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("bad solve input", args[0][1]) + + def test_compute_residual_gradients_aborts_on_validation_error(self): + """ + Tests that ComputeResidualGradients calls context.abort with + INVALID_ARGUMENT when a PhiloteValidationError is raised. + """ + server = ImplicitServer() + discipline = server._discipline = ImplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + discipline.declare_partials("f", "x") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + ] + + def bad_partials(inputs, outputs, jac): + raise PhiloteValidationError("bad partials input") + + server._discipline.residual_partials = bad_partials + + list(server.ComputeResidualGradients(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INVALID_ARGUMENT) + self.assertIn("bad partials input", args[0][1]) + + def test_compute_residual_gradients_aborts_on_discipline_error(self): + """ + Tests that ComputeResidualGradients calls context.abort with INTERNAL + when an unexpected exception is raised. + """ + server = ImplicitServer() + discipline = server._discipline = ImplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + discipline.declare_partials("f", "x") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + ] + + def bad_partials(inputs, outputs, jac): + raise RuntimeError("unexpected crash") + + server._discipline.residual_partials = bad_partials + + list(server.ComputeResidualGradients(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INTERNAL) + self.assertIn("ComputeResidualGradients failed", args[0][1]) + + def test_compute_residuals_aborts_on_discipline_error(self): + """ + Tests that ComputeResiduals calls context.abort when the discipline's + compute_residuals raises an exception. + """ + server = ImplicitServer() + discipline = server._discipline = ImplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kOutput, name="f", + ) + ), + ] + + def bad_residuals(inputs, outputs, residuals): + raise RuntimeError("residual computation failed") + + server._discipline.compute_residuals = bad_residuals + + list(server.ComputeResiduals(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INTERNAL) + self.assertIn("ComputeResiduals failed", args[0][1]) + + def test_solve_residuals_aborts_on_discipline_error(self): + """ + Tests that SolveResiduals calls context.abort when the discipline's + solve_residuals raises an exception. + """ + server = ImplicitServer() + discipline = server._discipline = ImplicitDiscipline() + discipline.add_input("x", shape=(1,), units="") + discipline.add_output("f", shape=(1,), units="") + + context = Mock() + request_iterator = [ + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kInput, name="x", + ) + ), + data.VariableMessage( + continuous=data.Array( + start=0, end=0, data=[1.0], + type=data.VariableType.kOutput, name="f", + ) + ), + ] + + def bad_solve(inputs, outputs): + raise RuntimeError("solver did not converge") + + server._discipline.solve_residuals = bad_solve + + list(server.SolveResiduals(request_iterator, context)) + + context.abort.assert_called_once() + args = context.abort.call_args + self.assertEqual(args[0][0], grpc.StatusCode.INTERNAL) + self.assertIn("SolveResiduals failed", args[0][1]) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_openmdao_explicit_client.py b/tests/test_openmdao_explicit_client.py index c8339e5..b97551f 100644 --- a/tests/test_openmdao_explicit_client.py +++ b/tests/test_openmdao_explicit_client.py @@ -305,6 +305,15 @@ def test_constructor_no_channel_raises_error(self, om_explicit_component_patch): self.assertIn("No channel provided", str(context.exception)) + def test_invalid_num_par_fd_raises(self, mock_explicit_component): + """ + Tests that an invalid num_par_fd raises a ValueError. + """ + mock_channel = Mock() + with self.assertRaises(ValueError) as context: + RemoteExplicitComponent(channel=mock_channel, num_par_fd=0) + self.assertIn("num_par_fd must be a positive integer", str(context.exception)) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_openmdao_group.py b/tests/test_openmdao_group.py index 016831a..c78b0fa 100644 --- a/tests/test_openmdao_group.py +++ b/tests/test_openmdao_group.py @@ -31,6 +31,7 @@ import numpy as np import openmdao.api as om from philote_mdo.openmdao.group import OpenMdaoSubProblem +from philote_mdo.utils.validation import PhiloteValidationError class SimpleGroup(om.Group): @@ -257,5 +258,45 @@ def test_units_and_shapes(self): self.assertEqual(subprob._output_map['vector_out']['units'], 'kg') + def test_add_group_non_group_raises(self): + subprob = OpenMdaoSubProblem() + with self.assertRaises(PhiloteValidationError): + subprob.add_group("not a group") + + def test_add_mapped_input_invalid_name_raises(self): + subprob = OpenMdaoSubProblem() + with self.assertRaises(PhiloteValidationError): + subprob.add_mapped_input("", "x") + + def test_add_mapped_input_invalid_shape_raises(self): + subprob = OpenMdaoSubProblem() + with self.assertRaises(PhiloteValidationError): + subprob.add_mapped_input("x", "sub_x", shape=[2]) + + def test_add_mapped_input_duplicate_raises(self): + subprob = OpenMdaoSubProblem() + subprob.add_mapped_input("x", "sub_x") + with self.assertRaises(PhiloteValidationError): + subprob.add_mapped_input("x", "sub_x2") + + def test_add_mapped_output_duplicate_raises(self): + subprob = OpenMdaoSubProblem() + subprob.add_mapped_output("y", "sub_y") + with self.assertRaises(PhiloteValidationError): + subprob.add_mapped_output("y", "sub_y2") + + def test_declare_subproblem_partial_unmapped_output_raises(self): + subprob = OpenMdaoSubProblem() + subprob.add_mapped_input("x", "sub_x") + with self.assertRaises(PhiloteValidationError): + subprob.declare_subproblem_partial("unmapped_y", "x") + + def test_declare_subproblem_partial_unmapped_input_raises(self): + subprob = OpenMdaoSubProblem() + subprob.add_mapped_output("y", "sub_y") + with self.assertRaises(PhiloteValidationError): + subprob.declare_subproblem_partial("y", "unmapped_x") + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_openmdao_implicit_client.py b/tests/test_openmdao_implicit_client.py index 2895e7f..045ddcb 100644 --- a/tests/test_openmdao_implicit_client.py +++ b/tests/test_openmdao_implicit_client.py @@ -367,6 +367,15 @@ def test_constructor_no_channel_raises_error(self, om_implicit_component_patch): self.assertIn("No channel provided", str(context.exception)) + def test_invalid_num_par_fd_raises(self, mock_implicit_component): + """ + Tests that an invalid num_par_fd raises a ValueError. + """ + mock_channel = Mock() + with self.assertRaises(ValueError) as context: + RemoteImplicitComponent(channel=mock_channel, num_par_fd=-1) + self.assertIn("num_par_fd must be a positive integer", str(context.exception)) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_openmdao_utils.py b/tests/test_openmdao_utils.py index 95a5acd..763504e 100644 --- a/tests/test_openmdao_utils.py +++ b/tests/test_openmdao_utils.py @@ -32,6 +32,7 @@ import numpy as np from philote_mdo.generated.data_pb2 import kInput, kOutput +from philote_mdo.utils.validation import PhiloteValidationError import philote_mdo.openmdao.utils as utils @@ -206,11 +207,11 @@ def test_declare_options(self): self.assertEqual(options_mock.declare.call_count, 5) - # Test case 3: Unknown type (should result in None) + # Test case 3: Unknown type now raises PhiloteValidationError options_mock.reset_mock() opt_list = [("unknown_param", "unknown_type")] - declare_options(opt_list, options_mock) - options_mock.declare.assert_called_once_with("unknown_param", types=None) + with self.assertRaises(PhiloteValidationError): + declare_options(opt_list, options_mock) if __name__ == "__main__": diff --git a/tests/test_utils.py b/tests/test_utils.py index ab65ff2..d262bb8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -29,7 +29,8 @@ # control over the information you may find at these locations. import unittest import numpy as np -from philote_mdo.utils import get_chunk_indices, get_flattened_view +from philote_mdo.utils import get_chunk_indices, get_flattened_view, PairDict +from philote_mdo.utils.validation import PhiloteValidationError class TestUtils(unittest.TestCase): @@ -76,5 +77,38 @@ def test_get_flattened_view(self): self.assertEqual(result_empty.shape, (0,)) + def test_get_chunk_indices_negative_num_values_raises(self): + with self.assertRaises(PhiloteValidationError): + list(get_chunk_indices(-1, 3)) + + def test_get_chunk_indices_zero_chunk_size_raises(self): + with self.assertRaises(PhiloteValidationError): + list(get_chunk_indices(10, 0)) + + def test_get_chunk_indices_non_int_raises(self): + with self.assertRaises(PhiloteValidationError): + list(get_chunk_indices(10.5, 3)) + + def test_get_flattened_view_non_array_raises(self): + with self.assertRaises(PhiloteValidationError): + get_flattened_view([1, 2, 3]) + + def test_pair_dict_invalid_key_raises(self): + pd = PairDict() + with self.assertRaises(PhiloteValidationError): + pd["single_key"] = 1.0 + + def test_pair_dict_three_tuple_raises(self): + pd = PairDict() + with self.assertRaises(PhiloteValidationError): + pd[("a", "b", "c")] = 1.0 + + def test_pair_dict_get_invalid_key_raises(self): + pd = PairDict() + pd[("a", "b")] = 1.0 + with self.assertRaises(PhiloteValidationError): + _ = pd["single_key"] + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 0000000..8d1d09c --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,184 @@ +# Philote-Python +# +# Copyright 2022-2025 Christopher A. Lupp +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# This work has been cleared for public release, distribution unlimited, case +# number: AFRL-2023-5713. +# +# The views expressed are those of the authors and do not reflect the +# official guidance or position of the United States Government, the +# Department of Defense or of the United States Air Force. +# +# Statement from DoD: The Appearance of external hyperlinks does not +# constitute endorsement by the United States Department of Defense (DoD) of +# the linked websites, of the information, products, or services contained +# therein. The DoD does not exercise any editorial, security, or other +# control over the information you may find at these locations. +import unittest + +import numpy as np + +from philote_mdo.utils.validation import ( + PhiloteError, + PhiloteValidationError, + PhiloteServerError, + validate_name, + validate_shape, + validate_units, + validate_option_type, + validate_is_dict, + validate_numpy_array, +) + + +class TestExceptionHierarchy(unittest.TestCase): + """Tests for the custom exception class hierarchy.""" + + def test_philote_validation_error_is_value_error(self): + with self.assertRaises(ValueError): + raise PhiloteValidationError("test") + + def test_philote_validation_error_is_philote_error(self): + with self.assertRaises(PhiloteError): + raise PhiloteValidationError("test") + + def test_philote_server_error_is_runtime_error(self): + with self.assertRaises(RuntimeError): + raise PhiloteServerError("test") + + def test_philote_server_error_is_philote_error(self): + with self.assertRaises(PhiloteError): + raise PhiloteServerError("test") + + def test_philote_error_is_exception(self): + with self.assertRaises(Exception): + raise PhiloteError("test") + + +class TestValidateName(unittest.TestCase): + """Tests for validate_name.""" + + def test_valid_name(self): + validate_name("x", "test") + + def test_non_string_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_name(123, "add_input") + self.assertIn("must be a string", str(ctx.exception)) + self.assertIn("add_input", str(ctx.exception)) + + def test_empty_string_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_name("", "add_output") + self.assertIn("must not be empty", str(ctx.exception)) + + def test_none_raises(self): + with self.assertRaises(PhiloteValidationError): + validate_name(None, "test") + + +class TestValidateShape(unittest.TestCase): + """Tests for validate_shape.""" + + def test_valid_shape_1d(self): + validate_shape((3,), "test") + + def test_valid_shape_2d(self): + validate_shape((2, 4), "test") + + def test_non_tuple_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_shape([2, 3], "add_input") + self.assertIn("must be a tuple", str(ctx.exception)) + + def test_int_raises(self): + with self.assertRaises(PhiloteValidationError): + validate_shape(3, "test") + + def test_non_integer_element_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_shape((2.0,), "test") + self.assertIn("must be integers", str(ctx.exception)) + + def test_zero_element_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_shape((0,), "test") + self.assertIn("must be positive", str(ctx.exception)) + + def test_negative_element_raises(self): + with self.assertRaises(PhiloteValidationError): + validate_shape((-1, 3), "test") + + +class TestValidateUnits(unittest.TestCase): + """Tests for validate_units.""" + + def test_valid_units(self): + validate_units("m**2", "test") + + def test_empty_string_is_valid(self): + validate_units("", "test") + + def test_non_string_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_units(42, "add_input") + self.assertIn("must be a string", str(ctx.exception)) + + +class TestValidateOptionType(unittest.TestCase): + """Tests for validate_option_type.""" + + def test_valid_types(self): + for t in ("bool", "int", "float", "str", "dict"): + validate_option_type(t, "opt") + + def test_invalid_type_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_option_type("unknown", "my_opt") + self.assertIn("Invalid type", str(ctx.exception)) + self.assertIn("my_opt", str(ctx.exception)) + + +class TestValidateIsDict(unittest.TestCase): + """Tests for validate_is_dict.""" + + def test_valid_dict(self): + validate_is_dict({"a": 1}, "test") + + def test_non_dict_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_is_dict([1, 2], "send_options") + self.assertIn("expected a dict", str(ctx.exception)) + + +class TestValidateNumpyArray(unittest.TestCase): + """Tests for validate_numpy_array.""" + + def test_valid_array(self): + validate_numpy_array(np.array([1.0, 2.0]), "x") + + def test_list_raises(self): + with self.assertRaises(PhiloteValidationError) as ctx: + validate_numpy_array([1.0, 2.0], "x") + self.assertIn("must be a numpy ndarray", str(ctx.exception)) + + def test_scalar_raises(self): + with self.assertRaises(PhiloteValidationError): + validate_numpy_array(1.0, "x") + + +if __name__ == "__main__": + unittest.main(verbosity=2)