diff --git a/ext/PolynomialsFFTWExt.jl b/ext/PolynomialsFFTWExt.jl index e8adab0e..0797f86a 100644 --- a/ext/PolynomialsFFTWExt.jl +++ b/ext/PolynomialsFFTWExt.jl @@ -1,22 +1,9 @@ module PolynomialsFFTWExt using Polynomials -import Polynomials: MutableDensePolynomial, StandardBasis +import Polynomials: MutableDensePolynomial, StandardBasis, Pad import FFTW import FFTW: fft, ifft - -struct Pad{T} <: AbstractVector{T} - a::Vector{T} - n::Int -end -Base.length(a::Pad) = a.n -Base.size(a::Pad) = (a.n,) -function Base.getindex(a::Pad, i) - u = length(a.a) - i ≤ u && return a.a[i] - return zero(first(a.a)) -end - function Polynomials.poly_multiplication_fft(p::P, q::Q) where {B <: StandardBasis,X, T <: AbstractFloat, P<:MutableDensePolynomial{B,T,X}, S <: AbstractFloat, Q<:MutableDensePolynomial{B,S,X}} diff --git a/src/polynomials/standard-basis/standard-basis.jl b/src/polynomials/standard-basis/standard-basis.jl index ff5337cf..7d55c242 100644 --- a/src/polynomials/standard-basis/standard-basis.jl +++ b/src/polynomials/standard-basis/standard-basis.jl @@ -874,6 +874,22 @@ end ## This assumes length(as) = 2^k for some k ## ωₙ is an nth root of unity, for example `exp(-2pi*im/n)` (also available with `sincos(2pi/n)`) for floating point ## or Cyclotomics.E(n), the latter much slower but non-lossy. +## +## Should implement NTT https://www.nayuki.io/page/number-theoretic-transform-integer-dft to close #519 + +struct Pad{T} <: AbstractVector{T} + a::Vector{T} + n::Int +end +Base.length(a::Pad) = a.n +Base.size(a::Pad) = (a.n,) +function Base.getindex(a::Pad, i) + u = length(a.a) + i ≤ u && return a.a[i] + return zero(first(a.a)) +end + + function recursive_fft(as, ωₙ = nothing) n = length(as) N = 2^ceil(Int, log2(n)) @@ -920,24 +936,19 @@ end # This *should* be faster -- (O(nlog(n)), but this version is definitely not so. # when `ωₙ = Cyclotomics.E` and T,S are integer, this can be exact -# using `FFTW.jl` over `Float64` types is much faster and is +# using `FFTW.jl` over `Float64` types is much better and is # implemented in an extension function poly_multiplication_fft(p::P, q::Q, ωₙ=nothing) where {T,P<:StandardBasisPolynomial{T}, S,Q<:StandardBasisPolynomial{S}} as, bs = coeffs0(p), coeffs0(q) n = maximum(length, (as, bs)) N = 2^ceil(Int, log2(n)) - - as′ = zeros(promote_type(T,S), 2N) - copy!(view(as′, 1:length(as)), as) + as′ = Pad(as, 2N) + bs′ = Pad(bs, 2N) ω = something(ωₙ, n -> exp(-2im*pi/n))(2N) âs = recursive_fft(as′, ω) - - as′ .= 0 - copy!(view(as′, 1:length(bs)), bs) - b̂s = recursive_fft(as′, ω) - + b̂s = recursive_fft(bs′, ω) âb̂s = âs .* b̂s PP = promote_type(P,Q)