diff --git a/src/pytblis/einsum_impl.py b/src/pytblis/einsum_impl.py index 27f5eaf..abf4804 100644 --- a/src/pytblis/einsum_impl.py +++ b/src/pytblis/einsum_impl.py @@ -55,8 +55,12 @@ def einsum(*operands, out=None, optimize=True, complex_real_contractions=True, * 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise. 'K' is ignored, for now. Default is 'C'. - optimize : {False, True, 'greedy', 'optimal'}, default True + optimize : {bool, list, tuple, 'greedy', 'optimal'}, default True Controls the optimization strategy used to compute the contraction. + If a tuple is provided, the second argument is assumed to be + the maximum intermediate size created. + Also accepts an explicit contraction list from the ``np.einsum_path`` + function. See ``np.einsum_path`` for more details. complex_real_contractions : bool, default True If True, handle contractions between complex and real tensors by performing separate contractions for the real and imaginary parts of the complex tensor. @@ -69,8 +73,6 @@ def einsum(*operands, out=None, optimize=True, complex_real_contractions=True, * The calculation based on the Einstein summation convention. """ specified_out = out is not None - if optimize not in (False, True, "greedy", "optimal"): - raise ValueError("optimize must be one of False, True, 'greedy', or 'optimal'") # Check the kwargs to avoid a more cryptic error later, without having to # repeat default values here