diff --git a/docs/src/internals/pn_systems.md b/docs/src/internals/pn_systems.md index d599d1cd..f2b30bb0 100644 --- a/docs/src/internals/pn_systems.md +++ b/docs/src/internals/pn_systems.md @@ -2,9 +2,9 @@ ```@docs PNSystem +Quasispherical BBH BHNS NSNS FDPNSystem -fd_pnsystem ``` diff --git a/src/PostNewtonian.jl b/src/PostNewtonian.jl index d4859708..00f09f87 100644 --- a/src/PostNewtonian.jl +++ b/src/PostNewtonian.jl @@ -1,12 +1,15 @@ module PostNewtonian -# Always explicitly address functions similar to functions defined in this package, -# which come from these packages: +# We must always explicitly qualify functions similar to functions defined in this package +# by the name of the package. We will use such functions from these packages: using MacroTools: MacroTools using FastDifferentiation: FastDifferentiation -using RuntimeGeneratedFunctions: RuntimeGeneratedFunctions -# Otherwise, we just explicitly import specific functions: +# Otherwise, we just explicitly import specific functions / types. Note that the difference +# between `using` and `import` in the following lines is that `using` will only allow us to +# call the functions, while `import` would also allow us to specialize them (define new +# methods of the imported functions). +using FastDifferentiation: Node as FDNode using DataInterpolations: CubicSpline using InteractiveUtils: methodswith using LinearAlgebra: mul! diff --git a/src/dynamics/right_hand_sides.jl b/src/dynamics/right_hand_sides.jl index b221da22..1379cd7e 100644 --- a/src/dynamics/right_hand_sides.jl +++ b/src/dynamics/right_hand_sides.jl @@ -21,6 +21,31 @@ function TaylorT5_v̇(p) return inv(truncated_series_ratio(v̇_denominator_coeffs(p), v̇_numerator_coeffs(p))) end +""" + causes_domain_error!(u̇, p) + +Ensure that these parameters correspond to a physically valid set of PN parameters. + +If the parameters are not valid, this function should modify `u̇` to indicate that the +current step is invalid. This is done by filling `u̇` with `NaN`s, which will be detected +by the ODE solver and cause it to try a different (smaller) step size. + +Currently, the only check that is done is to test that these parameters result in a PN +parameter v>0. In the future, this function may be expanded to include other checks, or it +may be specialized for specific `PNSystem` subtypes. +""" +function causes_domain_error!(u̇::ST, p::PNSystem{NT}) where {ST,NT} + if !ismutabletype(ST) + error("`causes_domain_error!` cannot modify input `u̇` because it is immutable") + end + if v(p) ≤ 0 # If this is expanded, document the change in the docstring. + u̇ .= convert(eltype(NT), NaN) + true + else + false + end +end + @pn_expression function TaylorTn!(pnsystem, u̇, TaylorTn_v̇::V̇) where {V̇} # If these parameters result in v≤0, fill u̇ with NaNs so that `solve` will # know that this was a bad step and try again. diff --git a/src/pn_systems/BBH.jl b/src/pn_systems/BBH.jl new file mode 100644 index 00000000..8221f5b6 --- /dev/null +++ b/src/pn_systems/BBH.jl @@ -0,0 +1,117 @@ +""" + BBH{NT, ST, PNOrder} <: PNSystem{NT, ST, PNOrder} + +The [`PNSystem`](@ref) subtype describing a binary black hole system. + +The `state` vector here holds the fundamental state variables characterizing the masses, +spins, orientation, velocity, and orbital phase of the system. The spins unpacked into +three components each. The orientation is described by the four components of the `Rotor` +`R`. This gives us a total of 14 elements: + + M₁, M₂, χ⃗₁ˣ, χ⃗₁ʸ, χ⃗₁ᶻ, χ⃗₂ˣ, χ⃗₂ʸ, χ⃗₂ᶻ, Rʷ, Rˣ, Rʸ, Rᶻ, v, Φ + +The "orbital phase" `Φ` is tracked as the 14th element of the `state` vector. This is just +the integral of the (scalar) orbital angular frequency `Ω`, and holds little interest for +general systems beyond a convenient description of how "far" the system has evolved. For +nonprecessing systems, `Φ` would be sufficient to describe the system's position, which is +more completely described by the `Rotor` `R`. However, for precessing systems, it is +difficult to extract this quantity from `R`. +""" +struct BBH{NT,ST,PNOrder} <: PNSystem{NT,ST,PNOrder} + state::ST + + function BBH{NT,ST,PNOrder}(state) where {NT,ST,PNOrder} + if eachindex(state) != Base.OneTo(14) + error( + "The `state` vector for `BBH` must be indexed from 1 to 14; " * + "input is indexed `$(eachindex(state))`.", + ) + end + new{NT,ST,PNOrder}(state) + end + function BBH(; M₁, M₂, χ⃗₁, χ⃗₂, v, R=Rotor(1), Φ=0, PNOrder=typemax(Int), kwargs...) + (NT, ST, PNOrder, state) = prepare_system(; M₁, M₂, χ⃗₁, χ⃗₂, R, v, Φ, PNOrder) + return new{NT,ST,PNOrder}(state) + end + function BBH(state; Λ₁=0, Λ₂=0, PNOrder=typemax(Int)) + if eachindex(state) != Base.OneTo(14) + error( + "The `state` vector for `BBH` must be indexed from 1 to 14; " * + "input is indexed `$(eachindex(state))`.", + ) + end + @assert Λ₁ == 0 + @assert Λ₂ == 0 + return new{eltype(state),typeof(state),prepare_pn_order(PNOrder)}(state) + end +end +const BHBH = BBH + +# The following are methods of functions defined in `state_variables.jl`, specialized for +# `BBH` systems. +state(pnsystem::BBH) = pnsystem.state +function symbols(::Type{<:BBH}) + (:M₁, :M₂, :χ⃗₁ˣ, :χ⃗₁ʸ, :χ⃗₁ᶻ, :χ⃗₂ˣ, :χ⃗₂ʸ, :χ⃗₂ᶻ, :Rʷ, :Rˣ, :Rʸ, :Rᶻ, :v, :Φ) +end +function ascii_symbols(::Type{<:BBH}) + (:M1, :M2, :chi1x, :chi1y, :chi1z, :chi2x, :chi2y, :chi2z, :Rw, :Rx, :Ry, :Rz, :v, :Phi) +end +for (i, symbol) ∈ enumerate(symbols(BBH)) + # This will define, e.g., `M₁(pnsystem::BBH) = pnsystem.state[1]`. We + # could do this manually, but this is more concise and less error-prone. + @eval begin + $(symbol)(pnsystem::BBH) = @inbounds pnsystem.state[$i] + function symbol_index(::Type{T}, ::Val{Symbol($symbol)}) where {T<:BBH} + $i + end + end +end + +Λ₁(pnsystem::BBH) = zero(pnsystem) +Λ₂(pnsystem::BBH) = zero(pnsystem) + +@testitem "BBH constructors" begin + using Quaternionic + + pnA = BBH(; + M₁=1.0f0, M₂=2.0f0, χ⃗₁=Float32[3.0, 4.0, 5.0], χ⃗₂=Float32[6.0, 7.0, 8.0], v=0.23f0 + ) + @test pnA.state == + Float32[1.0; 2.0; 3.0; 4.0; 5.0; 6.0; 7.0; 8.0; 1.0; 0.0; 0.0; 0.0; 0.23; 0.0] + + pnB = BBH(; + M₁=1.0f0, + M₂=2.0f0, + χ⃗₁=Float32[3.0, 4.0, 5.0], + χ⃗₂=Float32[6.0, 7.0, 8.0], + v=0.23f0, + Φ=9.0f0, + ) + @test pnB.state == + Float32[1.0; 2.0; 3.0; 4.0; 5.0; 6.0; 7.0; 8.0; 1.0; 0.0; 0.0; 0.0; 0.23; 9.0] + + R = randn(RotorF32) + pn1 = BBH(; + M₁=1.0f0, + M₂=2.0f0, + χ⃗₁=Float32[3.0, 4.0, 5.0], + χ⃗₂=Float32[6.0, 7.0, 8.0], + R=R, + v=0.23f0, + ) + @test pn1.state ≈ [1.0; 2.0; 3.0; 4.0; 5.0; 6.0; 7.0; 8.0; components(R)...; 0.23; 0.0] + + pn2 = BBH(; + M₁=1.0f0, + M₂=2.0f0, + χ⃗₁=Float32[3.0, 4.0, 5.0], + χ⃗₂=Float32[6.0, 7.0, 8.0], + R=R, + v=0.23f0, + Φ=9.0f0, + ) + @test pn2.state ≈ [1.0; 2.0; 3.0; 4.0; 5.0; 6.0; 7.0; 8.0; components(R)...; 0.23; 9.0] + + pn1.state[end] = 9.0f0 + @test pn1.state == pn2.state +end diff --git a/src/pn_systems/BHNS.jl b/src/pn_systems/BHNS.jl new file mode 100644 index 00000000..cccfb12f --- /dev/null +++ b/src/pn_systems/BHNS.jl @@ -0,0 +1,80 @@ +""" + BHNS{NT, ST, PNOrder} <: PNSystem{NT, ST, PNOrder} + +The [`PNSystem`](@ref) subtype describing a black-hole—neutron-star binary system. + +The `state` vector is the same as for a [`BBH`](@ref), with an additional field `Λ₂` holding +the (constant) tidal-coupling parameter of the neutron star. + +Note that the neutron star is *always* object 2 — meaning that `M₂`, `χ⃗₂`, and `Λ₂` always +refer to it; `M₁` and `χ⃗₁` always refer to the black hole. (It's "BHNS", not "NSBH".) See +also [`NSNS`](@ref). +""" +struct BHNS{NT,ST,PNOrder} <: PNSystem{NT,ST,PNOrder} + state::ST + + function BHNS{NT,ST,PNOrder}(state) where {NT,ST,PNOrder} + if eachindex(state) != Base.OneTo(15) + error( + "The `state` vector for `BHNS` must be indexed from 1 to 15; " * + "input is indexed `$(eachindex(state))`.", + ) + end + new{NT,ST,PNOrder}(state) + end + function BHNS(; + M₁, M₂, χ⃗₁, χ⃗₂, v, R=Rotor(1), Φ=0, Λ₂, PNOrder=typemax(Int), kwargs... + ) + NT, ST, PNOrder, state = prepare_system(; M₁, M₂, χ⃗₁, χ⃗₂, R, v, Φ, Λ₂, PNOrder) + return new{NT,ST,PNOrder}(state) + end + function BHNS(state; PNOrder=typemax(Int)) + if eachindex(state) != Base.OneTo(15) + error( + "The `state` vector for `BHNS` must be indexed from 1 to 15; " * + "input is indexed `$(eachindex(state))`.", + ) + end + NT, ST, PNOrder = eltype(state), typeof(state), prepare_pn_order(PNOrder) + return new{NT,ST,PNOrder}(state) + end +end + +# The following are methods of functions defined in `state_variables.jl`, specialized for +# `BHNS` systems. +state(pnsystem::BHNS) = pnsystem.state +function symbols(::Type{<:BHNS}) + (:M₁, :M₂, :χ⃗₁ˣ, :χ⃗₁ʸ, :χ⃗₁ᶻ, :χ⃗₂ˣ, :χ⃗₂ʸ, :χ⃗₂ᶻ, :Rʷ, :Rˣ, :Rʸ, :Rᶻ, :v, :Φ, :Λ₂) +end +function ascii_symbols(::Type{<:BHNS}) + ( + :M1, + :M2, + :chi1x, + :chi1y, + :chi1z, + :chi2x, + :chi2y, + :chi2z, + :Rw, + :Rx, + :Ry, + :Rz, + :v, + :Phi, + :Lambda2, + ) +end +for (i, symbol) ∈ enumerate(symbols(BHNS)) + # This will define, e.g., `M₁(pnsystem::BHNS) = pnsystem.state[1]`. We + # could do this manually, but this is more concise and less error-prone. + @eval begin + $(symbol)(pnsystem::BHNS) = @inbounds pnsystem.state[$i] + function symbol_index(::Type{T}, ::Val{Symbol($symbol)}) where {T<:BHNS} + $i + end + end +end + +Λ₁(pnsystem::BHNS) = zero(pnsystem) +Λ₂(pnsystem::BHNS) = @inbounds pnsystem.state[15] diff --git a/src/pn_systems/FDPNsystem.jl b/src/pn_systems/FDPNsystem.jl new file mode 100644 index 00000000..098aa415 --- /dev/null +++ b/src/pn_systems/FDPNsystem.jl @@ -0,0 +1,34 @@ +""" + FDPNSystem{NT, PN, PNOrder} <: PNSystem{FDNode, Vector{FDNode}, PNOrder} + +A `PNSystem` that contains information as variables from +[`FastDifferentiation.jl`](https://docs.juliahub.com/General/FastDifferentiation/stable/). + +Note that this type also involves the type parameter `PN`, which is actually the type of a +`PNSystem`, and its type parameter `NT`, which will be the number type of actual numbers +that eventually get fed into (and will be passed out from) functions that use this system. + +One important example of what this type is used for is computing the derivative of the +orbital binding energy, `𝓔′` — and in particular, for generating the corresponding function +method to apply to a given `PNSystem`. +""" +struct FDPNSystem{NT,PN<:PNSystem{NT},PNOrder} <: PNSystem{FDNode,Vector{FDNode},PNOrder} + state::Vector{FDNode} + + function FDPNSystem(::Type{PN}, PNOrder=typemax(Int)) where {NT,PN<:PNSystem{NT}} + return new{NT,prepare_pn_order(PNOrder)}([FDNode(s) for s ∈ symbols(PN)]) + end +end + +symbols(pnsystem::FDPNSystem{NT,PN}) where {NT,PN} = symbols(PN) + +function symbol_index(pnsystem::FDPNSystem{NT,PN}, s::Symbol) where {NT,PN} + symbol_index(PN, Val(s)) +end + +## TODO: See if this method is needed + +## The old code had this, but I think it would probably just cause errors. It might be +## relied upon in the functions where we take derivatives — 𝓔′code and γₚₙ₀′ — but even if +## so, maybe we could work around it with another function. +#Base.eltype(::FDPNSystem{FT}) where {FT} = FT diff --git a/src/pn_systems/NSNS.jl b/src/pn_systems/NSNS.jl new file mode 100644 index 00000000..ca40d3e7 --- /dev/null +++ b/src/pn_systems/NSNS.jl @@ -0,0 +1,98 @@ +""" + NSNS{NT, ST, PNOrder} <: PNSystem{NT, ST, PNOrder} + +The [`PNSystem`](@ref) subtype describing a neutron-star—neutron-star binary system. + +The `state` vector is the same as for a [`BBH`](@ref), with two additional fields `Λ₁` +and `Λ₂` holding the (constant) tidal-coupling parameters of the neutron stars. See also +[`BHNS`](@ref). +""" +struct NSNS{NT,ST,PNOrder} <: PNSystem{NT,ST,PNOrder} + state::ST + + function NSNS{NT,ST,PNOrder}(state) where {NT,ST,PNOrder} + if eachindex(state) != Base.OneTo(16) + error( + "The `state` vector for `NSNS` must be indexed from 1 to 16; " * + "input is indexed `$(eachindex(state))`.", + ) + end + new{NT,ST,PNOrder}(state) + end + function NSNS(; + M₁, M₂, χ⃗₁, χ⃗₂, v, R=Rotor(1), Φ=0, Λ₁, Λ₂, PNOrder=typemax(Int), kwargs... + ) + NT, ST, PNOrder, state = prepare_system(; + M₁, M₂, χ⃗₁, χ⃗₂, R, v, Φ, Λ₁, Λ₂, PNOrder + ) + return new{NT,ST,PNOrder}(state) + end + function NSNS(state; PNOrder=typemax(Int)) + if eachindex(state) != Base.OneTo(16) + error( + "The `state` vector for `NSNS` must be indexed from 1 to 16; " * + "input is indexed `$(eachindex(state))`.", + ) + end + NT, ST, PNOrder = eltype(state), typeof(state), prepare_pn_order(PNOrder) + return new{NT,ST,PNOrder}(state) + end +end +const BNS = NSNS + +# The following are methods of functions defined in `state_variables.jl`, specialized for +# `NSNS` systems. +state(pnsystem::NSNS) = pnsystem.state +function symbols(::Type{<:NSNS}) + ( + :M₁, + :M₂, + :χ⃗₁ˣ, + :χ⃗₁ʸ, + :χ⃗₁ᶻ, + :χ⃗₂ˣ, + :χ⃗₂ʸ, + :χ⃗₂ᶻ, + :Rʷ, + :Rˣ, + :Rʸ, + :Rᶻ, + :v, + :Φ, + :Λ₁, + :Λ₂, + ) +end +function ascii_symbols(::Type{<:NSNS}) + ( + :M1, + :M2, + :chi1x, + :chi1y, + :chi1z, + :chi2x, + :chi2y, + :chi2z, + :Rw, + :Rx, + :Ry, + :Rz, + :v, + :Phi, + :Lambda1, + :Lambda2, + ) +end +for (i, symbol) ∈ enumerate(symbols(NSNS)) + # This will define, e.g., `M₁(pnsystem::NSNS) = pnsystem.state[1]`. We + # could do this manually, but this is more concise and less error-prone. + @eval begin + $(symbol)(pnsystem::NSNS) = @inbounds pnsystem.state[$i] + function symbol_index(::Type{T}, ::Val{Symbol($symbol)}) where {T<:NSNS} + $i + end + end +end + +Λ₁(pnsystem::NSNS) = @inbounds pnsystem.state[15] +Λ₂(pnsystem::NSNS) = @inbounds pnsystem.state[16] diff --git a/src/pn_systems/PNSystem.jl b/src/pn_systems/PNSystem.jl new file mode 100644 index 00000000..8986cfcb --- /dev/null +++ b/src/pn_systems/PNSystem.jl @@ -0,0 +1,155 @@ +""" + PNSystem{NT, ST, PNOrder} + +Base type for all PN systems, such as `BBH`, `BHNS`, and `NSNS`. + +These objects encode all essential properties of the binary, including its current state. +As such, they can be used as inputs to the various [fundamental](@ref Fundamental-variables) +and [derived variables](@ref Derived-variables), as well as [PN expressions](@ref) and +[dynamics](@ref Dynamics) functions. + +The parameter `NT` is the number type of the system, such as `Float64` or `Dual{SomeTag, +Float64, 7}`. The parameter `ST <: DenseVector{NT}` is the type returned by the `state` +function, which probably just returns the `state` vector stored in the concrete subtype. As +such, this will probably be `MVector{N, NT}` or `SVector{N, NT}`, where `N` is the number of +elements in the state. `PNOrder` is a `Rational` giving the order to which PN expansions +should be carried out when using the given object. +""" +abstract type PNSystem{NT,ST<:DenseVector{NT},PNOrder} <: DenseVector{NT} end + +""" + state(pnsystem::PNSystem) + +Return the state vector of `pnsystem`, which is a vector of fundamental variables for the +given PN system. + +Note that the built-in `PNSystem` subtypes have a `state` field that is a vector, so this +function will just return that vector. However, that may not always be true for +user-defined subtypes. +""" +function state(::T) where {T<:PNSystem} + error("`state` is not yet defined for PNSystem subtype `$T`.") +end +Base.vec(pnsystem::PNSystem) = state(pnsystem) + +Base.eltype(::Type{PNT}) where {NT,PNT<:PNSystem{NT}} = NT +Base.one(::Type{PNT}) where {PNT<:PNSystem} = one(eltype(PNT)) +Base.one(x::T) where {T<:PNSystem} = one(T) +Base.zero(::Type{PNT}) where {PNT<:PNSystem} = zero(eltype(PNT)) +Base.zero(x::T) where {T<:PNSystem} = zero(T) +Base.float(::Type{PNT}) where {PNT<:PNSystem} = float(eltype(PNT)) +Base.float(x::T) where {T<:PNSystem} = float(T) + +""" + pn_order(pnsystem::PNSystem) + +Return the PN order of the given `pnsystem`. + +This is a `Rational{Int}` that indicates the order to which the PN expansions should be +carried out when using the given object. +""" +pn_order(::PNSystem{NT,ST,PNOrder}) where {NT,ST,PNOrder} = PNOrder + +""" + order_index(pnsystem::PNSystem) + +Return the order index of the given `pnsystem`. + +This is defined as the (one-based) index into an iterable of PN terms starting at 0pN, then +0.5pN, etc. Specifically, this is defined as `1 + Int(2pn_order(pnsystem))`. +""" +order_index(pn::PNSystem) = 1 + Int(2pn_order(pn)) + +""" + max_pn_order + +The maximum PN order that can be used without overflowing the `Int` type. +""" +const max_pn_order = (typemax(Int) - 2) // 2 + +""" + prepare_pn_order(PNOrder) + +Convert the input to a half-integer of type `Rational{Int}`. + +If `PNOrder` is larger than `max_pn_order`, it is set to `max_pn_order`, to avoid overflow +when computing the order index. +""" +function prepare_pn_order(PNOrder) + if PNOrder < max_pn_order + round(Int, 2PNOrder) // 2 + else + max_pn_order + end +end + +""" + symbols(pnsystem::PNSystem) + symbols(::Type{<:PNSystem}) + ascii_symbols(pnsystem::PNSystem) + ascii_symbols(::Type{<:PNSystem}) + +Return a Tuple of symbols corresponding to the variables tracked by `pnsystem`, in the order +in which they are stored in the `state` vector. + +The `ascii_symbols` function returns those symbols in ASCII form, enabling interaction with +external systems (e.g., Python) that do not support many Unicode symbols. + +```jldoctest +julia> using PostNewtonian: BBH + +julia> pnsystem = BBH(randn(14); PNOrder=7//2); + +julia> symbols(pnsystem) +(:M₁, :M₂, :χ⃗₁ˣ, :χ⃗₁ʸ, :χ⃗₁ᶻ, :χ⃗₂ˣ, :χ⃗₂ʸ, :χ⃗₂ᶻ, :Rʷ, :Rˣ, :Rʸ, :Rᶻ, :v, :Φ) + +julia> ascii_symbols(pnsystem) +(:M1, :M2, :chi1x, :chi1y, :chi1z, :chi2x, :chi2y, :chi2z, :Rw, :Rx, :Ry, :Rz, :v, :Phi) +``` +""" +symbols(pnsystem::PNSystem) = symbols(typeof(pnsystem)) +function symbols(::Type{T}) where {T<:PNSystem} + error("`symbols` is not yet defined for PNSystem subtype `$T`.") +end +ascii_symbols(pnsystem::PNSystem) = ascii_symbols(typeof(pnsystem)) +function ascii_symbols(::Type{T}) where {T<:PNSystem} + error("`ascii_symbols` is not yet defined for PNSystem subtype `$T`.") +end + +""" + pnsystem::PNSystem(; kwargs...) + +State-modifying copy constructor for `PNSystem` objects. + +Note that this cannot modify the type's parameters, including the number type `NT`, the +state type `ST`, or the `PNOrder` of the system. However, it can modify any of the state +variables by symbol or by ASCII symbol. This function will raise an AssertionError if +any of the keys in `kwargs` is not a valid symbol for the given `PNSystem` type. + +```jldoctest +julia> using PostNewtonian: BBH + +julia> pnsystem = BBH(ones(14)/2; PNOrder=7//2) +BBH{Vector{Float64}, 7//2}([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) + +julia> pnsystem2 = pnsystem(M₁=0.2, M₂=0.8, chi1x=0.1) +BBH{Vector{Float64}, 7//2}([0.2, 0.8, 0.1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) +``` +""" +function (pnsystem::PNSystem)(; kwargs...) + all_symbols = Set(symbols(pnsystem)) ∪ Set(ascii_symbols(pnsystem)) + @assert keys(kwargs) ⊆ all_symbols ( + "PNSystem of type $(typeof(pnsystem)) does not have these symbols which were input:\n" * + " $(setdiff(keys(kwargs), all_symbols))\n" * + "Maybe you passed `String`s instead of `Symbol`s?\n" * + "The available symbols for this type are\n" * + " $(symbols(pnsystem))\n" * + "and their ASCII equivalents:\n" * + " $(ascii_symbols(pnsystem))" + ) + state = Tuple( + get(kwargs, symbol, get(kwargs, ascii_symbol, pnsystem[symbol])) for + (symbol, ascii_symbol) ∈ zip(symbols(pnsystem), ascii_symbols(pnsystem)) + ) + typeof(pnsystem)(state) +end diff --git a/src/pn_systems/state_variables.jl b/src/pn_systems/state_variables.jl new file mode 100644 index 00000000..e06b2c44 --- /dev/null +++ b/src/pn_systems/state_variables.jl @@ -0,0 +1,342 @@ +""" + G(pnsystem) + +Return Newton's gravitational constant for the given `pnsystem`. + +By default, the value is one *with the same number type as `pnsystem`*. It can be +overridden for subtypes of `PNSystem` that use different units or conventions. + +However, note that this function should specialize on the number type of `pnsystem`, rather +than just returning the integer `1`, because there may be expressions with factors such as +`G/3` which will immediately convert to `Float64` if `G` is just `1`, so the result will not +have the expected precision. +""" +G(::PNSystem{NT}) where {NT} = one(NT) +G(::FDPNSystem{NT,PN}) where {NT,PN} = one(PN) + +""" + c(pnsystem) + +Return the speed of light for the given `pnsystem`. + +By default, the value is one *with the same number type as `pnsystem`*. It can be +overridden for subtypes of `PNSystem` that use different units or conventions. + +However, note that this function should specialize on the number type of `pnsystem`, rather +than just returning the integer `1`, because there may be expressions with factors such as +`c/3` which will immediately convert to `Float64` if `c` is just `1`, so the result will not +have the expected precision. +""" +c(::PNSystem{NT}) where {NT} = one(NT) +c(::FDPNSystem{NT,PN}) where {NT,PN} = one(PN) + +""" + M₁(pnsystem) + M1(pnsystem) + +Mass of object 1 in this system. +""" +function M₁(::T) where {T<:PNSystem} + error("M₁ is not (yet) defined for PNSystem subtype `$T`.") +end +M₁(fdpnsystem::FDPNSystem) = fdpnsystem[:M₁] +const M1 = M₁ + +""" + M₂(pnsystem) + M2(pnsystem) + +Mass of object 2 in this system. +""" +function M₂(::T) where {T<:PNSystem} + error("M₂ is not (yet) defined for PNSystem subtype `$T`.") +end +M₂(fdpnsystem::FDPNSystem) = fdpnsystem[:M₂] +const M2 = M₂ + +""" + χ⃗₁ˣ(pnsystem) + chi1x(pnsystem) + +`x`-component of dimensionless spin vector of object 1 in this system, as a `QuatVec`. + +See [`χ⃗₁`](@ref) for details. +""" +function χ⃗₁ˣ(::T) where {T<:PNSystem} + error("χ⃗₁ˣ is not (yet) defined for PNSystem subtype `$T`.") +end +χ⃗₁ˣ(fdpnsystem::FDPNSystem) = fdpnsystem[:χ⃗₁ˣ] +const chi1x = χ⃗₁ˣ + +""" + χ⃗₁ʸ(pnsystem) + chi1y(pnsystem) + +`y`-component of dimensionless spin vector of object 1 in this system, as a `QuatVec`. + +See [`χ⃗₁`](@ref) for details. +""" +function χ⃗₁ʸ(::T) where {T<:PNSystem} + error("χ⃗₁ʸ is not (yet) defined for PNSystem subtype `$T`.") +end +χ⃗₁ʸ(fdpnsystem::FDPNSystem) = fdpnsystem[:χ⃗₁ʸ] +const chi1y = χ⃗₁ʸ + +""" + χ⃗₁ᶻ(pnsystem) + chi1z(pnsystem) + +`z`-component of dimensionless spin vector of object 1 in this system, as a `QuatVec`. + +See [`χ⃗₁`](@ref) for details. +""" +function χ⃗₁ᶻ(::T) where {T<:PNSystem} + error("χ⃗₁ᶻ is not (yet) defined for PNSystem subtype `$T`.") +end +χ⃗₁ᶻ(fdpnsystem::FDPNSystem) = fdpnsystem[:χ⃗₁ᶻ] +const chi1z = χ⃗₁ᶻ + +""" + χ⃗₂ˣ(pnsystem) + chi2x(pnsystem) + +`x`-component of dimensionless spin vector of object 2 in this system, as a `QuatVec`. + +See [`χ⃗₂`](@ref) for details. +""" +function χ⃗₂ˣ(::T) where {T<:PNSystem} + error("χ⃗₂ˣ is not (yet) defined for PNSystem subtype `$T`.") +end +χ⃗₂ˣ(fdpnsystem::FDPNSystem) = fdpnsystem[:χ⃗₂ˣ] +const chi2x = χ⃗₂ˣ + +""" + χ⃗₂ʸ(pnsystem) + chi2y(pnsystem) + +`y`-component of dimensionless spin vector of object 2 in this system, as a `QuatVec`. + +See [`χ⃗₂`](@ref) for details. +""" +function χ⃗₂ʸ(::T) where {T<:PNSystem} + error("χ⃗₂ʸ is not (yet) defined for PNSystem subtype `$T`.") +end +χ⃗₂ʸ(fdpnsystem::FDPNSystem) = fdpnsystem[:χ⃗₂ʸ] +const chi2y = χ⃗₂ʸ + +""" + χ⃗₂ᶻ(pnsystem) + chi2z(pnsystem) + +`z`-component of dimensionless spin vector of object 2 in this system, as a `QuatVec`. + +See [`χ⃗₂`](@ref) for details. +""" +function χ⃗₂ᶻ(::T) where {T<:PNSystem} + error("χ⃗₂ᶻ is not (yet) defined for PNSystem subtype `$T`.") +end +χ⃗₂ᶻ(fdpnsystem::FDPNSystem) = fdpnsystem[:χ⃗₂ᶻ] +const chi2z = χ⃗₂ᶻ + +""" + Rʷ(pnsystem) + Rw(pnsystem) + +Scalar component of the orientation `Rotor` of the binary. + +See [`R`](@ref) for details. +""" +function Rʷ(::T) where {T<:PNSystem} + error("Rʷ is not (yet) defined for PNSystem subtype `$T`.") +end +Rʷ(fdpnsystem::FDPNSystem) = fdpnsystem[:Rʷ] +const Rw = Rʷ + +""" + Rˣ(pnsystem) + Rx(pnsystem) + +`x`-component of the orientation `Rotor` of the binary. + +See [`R`](@ref) for details. +""" +function Rˣ(::T) where {T<:PNSystem} + error("Rˣ is not (yet) defined for PNSystem subtype `$T`.") +end +Rˣ(fdpnsystem::FDPNSystem) = fdpnsystem[:Rˣ] +const Rx = Rˣ + +""" + Rʸ(pnsystem) + Ry(pnsystem) + +`y`-component of the orientation `Rotor` of the binary. + +See [`R`](@ref) for details. +""" +function Rʸ(::T) where {T<:PNSystem} + error("Rʸ is not (yet) defined for PNSystem subtype `$T`.") +end +Rʸ(fdpnsystem::FDPNSystem) = fdpnsystem[:Rʸ] +const Ry = Rʸ + +""" + Rᶻ(pnsystem) + Rz(pnsystem) + +`z`-component of the orientation `Rotor` of the binary. + +See [`R`](@ref) for details. +""" +function Rᶻ(::T) where {T<:PNSystem} + error("Rᶻ is not (yet) defined for PNSystem subtype `$T`.") +end +Rᶻ(fdpnsystem::FDPNSystem) = fdpnsystem[:Rᶻ] +const Rz = Rᶻ + +@doc raw""" + v(pnsystem) + v(;Ω, M=1) + +Post-Newtonian velocity parameter. This is related to the orbital angular frequency +``\Omega`` as +```math +v \colonequals (M\,\Omega)^{1/3}, +``` +where ``M`` is the total mass of the binary. + +Note that if you want to pass the value ``Ω`` (rather than a `PNSystem`), you must pass it +as a keyword argument — as in `v(Ω=0.1)`. + +See also [`Ω`](@ref). +""" +function v(::T) where {T<:PNSystem} + error("v is not (yet) defined for PNSystem subtype `$T`.") +end +v(fdpnsystem::FDPNSystem) = fdpnsystem[:v] +v(; Ω, M=1) = ∛(M * Ω) + +""" + Φ(pnsystem) + Phi(pnsystem) + +Integrated orbital phase of the system. It is computed as the integral of [`Ω`](@ref). +""" +function Φ(::T) where {T<:PNSystem} + error("Φ is not (yet) defined for PNSystem subtype `$T`.") +end +Φ(fdpnsystem::FDPNSystem) = fdpnsystem[:Φ] +const Phi = Φ + +@doc raw""" + Λ₁(pnsystem) + Lambda1(pnsystem) + +Quadrupolar tidal-coupling parameter of object 1 in this system. + +We imagine object 1 begin placed in an (adiabatic) external field with Newtonian potential +``\phi``, resulting in a tidal field measured by ``\partial_i \partial_j \phi`` evaluated at +the center of mass of the object. This induces a quadrupole moment ``Q_{ij}`` in object 1, +which can be related to the tidal field as +```math +Q_{ij} = -\frac{G^4}{c^{10}} \Lambda_1 M_1^5 \partial_i \partial_j \phi, +``` +where ``M_1`` is the mass of object 1. This tidal-coupling parameter ``\Lambda_1`` can be +related to the Love number ``k_2`` (where the subscript 2 refers to the fact that this is +for the ``\ell=2`` quadrupole, rather than object 2) as +```math +\Lambda_1 = \frac{2}{3} \frac{c^{10}}{G^5} \frac{R_1^5}{M_1^5} k_2, +``` +where ``R_1`` is the radius of object 1. Note that ``\Lambda_1`` is dimensionless. For +black holes, it is precisely zero; for neutron stars it may range up to 1; more exotic +objects may have significantly larger values. + +Note that — as of this writing — only `NSNS` systems can have a nonzero value for this +quantity. (`BHNS` systems can only have a nonzero value for ``\Lambda_2``.) All other +types return `0`, which Julia can use to eliminate code that would then be 0. Thus, it is +safe and efficient to use this quantity in any PN expression that specializes on the type of +`pnsystem`. + +See also [`Λ₂`](@ref) and [`Λ̃`](@ref). +""" +function Λ₁(::T) where {T<:PNSystem} + error("Λ₁ is not (yet) defined for PNSystem subtype `$T`.") +end +Λ₁(fdpnsystem::FDPNSystem) = fdpnsystem[:Λ₁] +const Lambda1 = Λ₁ + +@doc raw""" + Λ₂(pnsystem) + Lambda2(pnsystem) + +Quadrupolar tidal coupling parameter of object 2 in this system. + +See [`Λ₁`](@ref) for details about the definition, swapping "object 1" with "object 2". + +Note that — as of this writing — only `BHNS` and `NSNS` systems can have a nonzero value for +this quantity. All other types return `0`, which Julia can use to eliminate code that would +then be 0. Thus, it is safe and efficient to use this quantity in any PN expression that +specializes on the type of `pnsystem`. + +See also [`Λ₁`](@ref) and [`Λ̃`](@ref). +""" +function Λ₂(::T) where {T<:PNSystem} + error("Λ₂ is not (yet) defined for PNSystem subtype `$T`.") +end +Λ₂(fdpnsystem::FDPNSystem) = fdpnsystem[:Λ₂] +const Lambda2 = Λ₂ + +################################################################# +# Not actually state variables, but aggregates of state variables + +""" + χ⃗₁(pnsystem) + chi1(pnsystem) + +Dimensionless spin vector of object 1 in this system, as a `QuatVec`. + +See also [`χ⃗₁ˣ`](@ref), [`χ⃗₁ʸ`](@ref), and [`χ⃗₁ᶻ`](@ref) for the individual components. +""" +function χ⃗₁(::T) where {T<:PNSystem} + QuatVec(χ⃗₁ˣ(pnsystem), χ⃗₁ʸ(pnsystem), χ⃗₁ᶻ(pnsystem)) +end +const chi1 = χ⃗₁ + +""" + χ⃗₂(pnsystem) + chi2(pnsystem) + +Dimensionless spin vector of object 2 in this system, as a `QuatVec`. + +See also [`χ⃗₂ˣ`](@ref), [`χ⃗₂ʸ`](@ref), and [`χ⃗₂ᶻ`](@ref) for the individual components. +""" +function χ⃗₂(::T) where {T<:PNSystem} + QuatVec(χ⃗₂ˣ(pnsystem), χ⃗₂ʸ(pnsystem), χ⃗₂ᶻ(pnsystem)) +end +const chi2 = χ⃗₂ + +""" + R(pnsystem) + +Orientation of the binary, as a `Rotor`. + +At any instant, the binary is represented by the right-handed triad ``(n̂, λ̂, ℓ̂)``, where +[``n̂``](@ref PostNewtonian.n̂) is the unit vector pointing from object 2 to object 1, and +the instantaneous velocities of the binary's elements are in the ``n̂``-``λ̂`` plane. This +`Rotor` will rotate the ``x̂`` vector to be along ``n̂``, the ``ŷ`` vector to be along +``λ̂``, and the ``ẑ`` vector to be along ``ℓ̂``. + +Note that the angular velocity associated to `R` is given by ``Ω⃗ = 2 Ṙ R̄ = Ω ℓ̂ + ϖ n̂``. +(Any component of ``Ω⃗`` along ``λ̂`` would violate the condition that the velocities be in +the ``n̂``-``λ̂`` plane.) Here, the scalar quantity ``Ω`` is the orbital angular frequency, +and ``ϖ`` is the precession angular frequency. + +See also [`n̂`](@ref PostNewtonian.n̂), [`λ̂`](@ref PostNewtonian.λ̂), [`ℓ̂`](@ref +PostNewtonian.ℓ̂), [`Ω`](@ref PostNewtonian.Ω), and [`𝛡`](@ref PostNewtonian.𝛡)``=ϖ n̂``. +""" +function R(pnsystem::T) where {NT,T<:PNSystem{NT}} + # We use this explicit constructor (with type parameter) to avoid normalization + # that would probably just complicate derivatives. + Rotor{NT}(Rʷ(pnsystem), Rˣ(pnsystem), Rʸ(pnsystem), Rᶻ(pnsystem)) +end +R(fdpnsystem::FDPNSystem) = fdpnsystem[:R] diff --git a/src/pn_systems/vector_interface.jl b/src/pn_systems/vector_interface.jl new file mode 100644 index 00000000..03fd570c --- /dev/null +++ b/src/pn_systems/vector_interface.jl @@ -0,0 +1,249 @@ +# Base.ismutable(pnsystem::PNSystem{NT, ST}) where {NT, ST} = ismutable(state(pnsystem)) +# Base.ismutabletype(::Type{<:PNSystem{NT, ST}}) where {NT, ST} = ismutabletype(ST) + +""" + symbol_index(::Type{T}, s::Symbol) where {T<:PNSystem} + symbol_index(::Type{T}, ::Val{s}) where {T<:PNSystem} + +Return the index of the symbol `s` in the state vector of the given `PNSystem` type `T`. + +Note that the default implementation is slow; `symbol_index(::Type{T}, ::Val{s})` should be +overridden for every symbol (and ASCII equivalent, if desired) for concrete `PNSystem` +types. +""" +function symbol_index(::Type{T}, ::Val{S}) where {T<:PNSystem,S} + index = findfirst(y -> y == S, symbols(T)) + if isnothing(index) + index = findfirst(y -> y == S, ascii_symbols(T)) + end + if isnothing(index) + error( + "Type `$(T)` has no symbol `:$(S)`.\n" * + "Its symbols are `$(symbols(T))`.\n" * + "The ASCII equivalents are `$(ascii_symbols(T))`.\n", + ) + else + @warn "Please define `PostNewtonian.symbol_index(::Type{$T}, ::Val{$S})`" + index + end +end + +Base.getindex(pnsystem::PNSystem, s::Symbol) = getindex(pnsystem, Val(s)) +function Base.getindex(pnsystem::T, ::Val{S}) where {T<:PNSystem,S} + # If `S` is not actually a symbol in `pnsystem`, `symbol_index` will error, so we know + # that the `index` is inbounds if it returns. + index = symbol_index(T, Val(S)) + @inbounds state(pnsystem)[index] +end + +Base.setindex!(pnsystem::PNSystem, v, s::Symbol) = setindex!(pnsystem, v, Val(s)) +function Base.setindex!(pnsystem::T, v, ::Val{S}) where {NT,T<:PNSystem{NT},S} + index = symbol_index(T, Val(S)) + @inbounds setindex!(state(pnsystem), v, index) +end + +### Interfaces: https://docs.julialang.org/en/v1/manual/interfaces +# Iteration +Base.iterate(pnsystem::PNSystem) = iterate(state(pnsystem)) +Base.iterate(pnsystem::PNSystem, state) = iterate(state(pnsystem), state) +Base.IteratorSize(::Type{T}) where {T<:PNSystem} = Base.HasShape{1}() +Base.length(pnsystem::PNSystem) = length(state(pnsystem)) +Base.ndims(pnsystem::PNSystem) = ndims(state(pnsystem)) +Base.size(pnsystem::PNSystem) = size(state(pnsystem)) +Base.size(pnsystem::PNSystem, dim) = size(state(pnsystem), dim) +Base.IteratorEltype(::Type{T}) where {T<:PNSystem} = Base.HasEltype() +Base.eltype(::Type{<:PNSystem{NT}}) where {NT} = NT +Base.isdone(pnsystem::PNSystem) = isdone(state(pnsystem)) +Base.isdone(pnsystem::PNSystem, iterstate) = isdone(state(pnsystem), iterstate) +# Indexing +Base.getindex(pnsystem::PNSystem, i::Int) = @propagate_inbounds getindex(state(pnsystem), i) +Base.setindex!(pn::PNSystem, v, i::Int) = @propagate_inbounds setindex!(state(pn), v, i) +Base.firstindex(pnsystem::PNSystem) = firstindex(state(pnsystem)) +Base.lastindex(pnsystem::PNSystem) = lastindex(state(pnsystem)) +Base.eachindex(pnsystem::PNSystem) = eachindex(state(pnsystem)) +# Abstract arrays +Base.IndexStyle(::Type{T}) where {T<:PNSystem} = Base.IndexLinear() +Base.length(pnsystem::PNSystem) = length(state(pnsystem)) +# Base.similar(pnsystem::PNSystem) = similar(state(pnsystem)) +Base.axes(pnsystem::PNSystem) = axes(state(pnsystem)) +# Strided Arrays +Base.strides(pnsystem::PNSystem) = strides(state(pnsystem)) +function Base.unsafe_convert(::Type{Ptr{T}}, A::PNSystem) where {T} + Base.unsafe_convert(Ptr{T}, state(A)) +end +Base.elsize(::Type{<:PNSystem{T}}) where {T} = sizeof(T) +Base.stride(pnsystem::PNSystem, k::Int) = stride(state(pnsystem), k) + +function PreallocationTools.get_tmp( + dc::PreallocationTools.DiffCache, u::LArray{T,N,D,Syms} +) where {T<:ForwardDiff.Dual,N,D,Syms} + nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du) + if nelem > length(dc.dual_du) + PreallocationTools.enlargedualcache!(dc, nelem) + end + _x = ArrayInterface.restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem))) + LabelledArrays.LArray{T,N,D,Syms}(_x) +end + +function RecursiveArrayTools.recursive_unitless_eltype( + a::Type{LArray{T,N,D,Syms}} +) where {T,N,D,Syms} + LArray{typeof(one(T)),N,D,Syms} +end + +##################################### +# NamedTuple compatibility +##################################### +## SLArray to named tuple +function Base.convert(::Type{NamedTuple}, x::SLArray{S,T,N,L,Syms}) where {S,T,N,L,Syms} + tup = NTuple{length(Syms),T}(x.__x) + NamedTuple{Syms,typeof(tup)}(tup) +end +Base.keys(x::SLArray{S,T,N,L,Syms}) where {S,T,N,L,Syms} = Syms + +## pairs iterator +function Base.pairs(x::LArray{T,N,D,Syms}) where {T,N,D,Syms} + # (label => getproperty(x, label) for label in Syms) # not type stable? + (Syms[i] => x[i] for i ∈ 1:length(Syms)) +end + +function Base.iterate(x::SLArray, args...) + iterate(convert(NamedTuple, x), args...) +end + +##################################### +# Array Interface +##################################### +function Base.print_array(io::IO, w::WignerMatrix{NT,IT}) where {NT,IT<:Rational} + Base.print_array(io, parent(w)) +end + +Base.size(x::LArray) = size(getfield(x, :__x)) +Base.@propagate_inbounds Base.getindex(x::LArray, i...) = getfield(x, :__x)[i...] +Base.@propagate_inbounds function Base.setindex!(x::LArray, y, i...) + getfield(x, :__x)[i...] = y + return x +end + +Base.propertynames(::LArray{T,N,D,Syms}) where {T,N,D,Syms} = Syms +symnames(::Type{LArray{T,N,D,Syms}}) where {T,N,D,Syms} = Syms + +Base.@propagate_inbounds function Base.getproperty(x::LArray, s::Symbol) + if s == :__x + return getfield(x, :__x) + end + return getindex(x, Val(s)) +end + +Base.@propagate_inbounds function Base.setproperty!(x::LArray, s::Symbol, y) + if s == :__x + return setfield!(x, :__x, y) + end + setindex!(x, y, Val(s)) +end + +Base.@propagate_inbounds Base.getindex(x::LArray, s::Symbol) = getindex(x, Val(s)) +Base.@propagate_inbounds Base.getindex(x::LArray, s::Val) = __getindex(x, s) +Base.@propagate_inbounds Base.setindex!(x::LArray, v, s::Symbol) = setindex!(x, v, Val(s)) + +@generated function Base.setindex!(x::LArray, y, ::Val{s}) where {s} + syms = symnames(x) + if syms isa NamedTuple + idxs = syms[s] + return quote + Base.@_propagate_inbounds_meta + setindex!(getfield(x, :__x), y, $idxs) + return x + end + else # Tuple + idx = findfirst(y -> y == s, symnames(x)) + return quote + Base.@_propagate_inbounds_meta + setindex!(getfield(x, :__x), y, $idx) + return x + end + end +end + +Base.@propagate_inbounds function Base.getindex(x::LArray, s::AbstractArray{Symbol,1}) + [getindex(x, si) for si ∈ s] +end + +function Base.similar( + x::LArray{T,K,D,Syms}, ::Type{S}, dims::NTuple{N,Int} +) where {T,K,D,Syms,S,N} + tmp = similar(x.__x, S, dims) + LArray{S,N,typeof(tmp),Syms}(tmp) +end + +function StaticArrays.similar_type( + ::Type{SLArray{S,T,N,L,Syms}}, T2, ::Size{S} +) where {S,T,N,L,Syms} + SLArray{S,T2,N,L,Syms} +end + +# Allow copying LArray of uninitialized data, as with regular Array +Base.copy(x::LArray) = typeof(x)(copy(getfield(x, :__x))) +Base.copyto!(x::LArray, y::LArray) = copyto!(getfield(x, :__x), getfield(y, :__x)) + +# enable the usage of LAPACK +function Base.unsafe_convert(::Type{Ptr{T}}, a::LArray{T,N,D,S}) where {T,N,D,S} + Base.unsafe_convert(Ptr{T}, getfield(a, :__x)) +end + +Base.convert(::Type{T}, x) where {T<:LArray} = T(x) +Base.convert(::Type{T}, x::T) where {T<:LArray} = x +Base.convert(::Type{<:Array}, x::LArray) = convert(Array, getfield(x, :__x)) +function Base.convert( + ::Type{AbstractArray{T,N}}, x::LArray{S,N,<:Any,Syms} +) where {T,S,N,Syms} + LArray{Syms}(convert(AbstractArray{T,N}, getfield(x, :__x))) +end +Base.convert(::Type{AbstractArray{T,N}}, x::LArray{T,N}) where {T,N} = x + +function ArrayInterface.restructure( + x::LArray{T,N,D,Syms}, y::LArray{T2,N2,D2,Syms} +) where {T,N,D,T2,N2,D2,Syms} + reshape(y, size(x)...) +end + +##################################### +# Broadcast +##################################### +struct LAStyle{T,N,L} <: Broadcast.AbstractArrayStyle{N} end +LAStyle{T,N,L}(x::Val{i}) where {T,N,L,i} = LAStyle{T,N,L}() +Base.BroadcastStyle(::Type{LArray{T,N,D,L}}) where {T,N,D,L} = LAStyle{T,N,L}() +function Base.BroadcastStyle( + ::LabelledArrays.LAStyle{T,N,L}, ::LabelledArrays.LAStyle{E,N,L} +) where {T,E,N,L} + LAStyle{promote_type(T, E),N,L}() +end + +@generated function labels2axes(::Val{t}) where {t} + if t isa NamedTuple && all(x -> x isa Union{Integer,UnitRange}, values(t)) # range labelling + (Base.OneTo(maximum(Iterators.flatten(v for v ∈ values(t)))),) + elseif t isa NTuple{<:Any,Symbol} + axes(t) + else + error( + "$t label isn't supported for broadcasting. Try to formulate it in terms of linear indexing.", + ) + end +end +function Base.similar( + bc::Broadcast.Broadcasted{LAStyle{T,N,L}}, ::Type{ElType} +) where {T,N,L,ElType} + tmp = similar(Array{ElType}, axes(bc)) + if axes(bc) != labels2axes(Val(L)) + return tmp + else + return LArray{ElType,N,typeof(tmp),L}(tmp) + end +end + +# Broadcasting checks for aliasing with Base.dataids but the fallback +# for AbstractArrays is very slow. Instead, we just call dataids on the +# wrapped buffer +Base.dataids(pnsystem::PNSystem) = Base.dataids(state(pnsystem)) + +Base.elsize(::Type{<:LArray{T}}) where {T} = sizeof(T)