Skip to content

Commit

Permalink
more comments on JuliaMath#519
Browse files Browse the repository at this point in the history
  • Loading branch information
jverzani committed Aug 23, 2023
1 parent fce5ee6 commit e547b6c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 23 deletions.
15 changes: 1 addition & 14 deletions ext/PolynomialsFFTWExt.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
module PolynomialsFFTWExt

using Polynomials
import Polynomials: MutableDensePolynomial, StandardBasis
import Polynomials: MutableDensePolynomial, StandardBasis, Pad
import FFTW
import FFTW: fft, ifft

struct Pad{T} <: AbstractVector{T}
a::Vector{T}
n::Int
end
Base.length(a::Pad) = a.n
Base.size(a::Pad) = (a.n,)
function Base.getindex(a::Pad, i)
u = length(a.a)
i u && return a.a[i]
return zero(first(a.a))
end

function Polynomials.poly_multiplication_fft(p::P, q::Q) where {B <: StandardBasis,X,
T <: AbstractFloat, P<:MutableDensePolynomial{B,T,X},
S <: AbstractFloat, Q<:MutableDensePolynomial{B,S,X}}
Expand Down
29 changes: 20 additions & 9 deletions src/polynomials/standard-basis/standard-basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,22 @@ end
## This assumes length(as) = 2^k for some k
## ωₙ is an nth root of unity, for example `exp(-2pi*im/n)` (also available with `sincos(2pi/n)`) for floating point
## or Cyclotomics.E(n), the latter much slower but non-lossy.
##
## Should implement NTT https://www.nayuki.io/page/number-theoretic-transform-integer-dft to close #519

struct Pad{T} <: AbstractVector{T}
a::Vector{T}
n::Int
end
Base.length(a::Pad) = a.n
Base.size(a::Pad) = (a.n,)
function Base.getindex(a::Pad, i)
u = length(a.a)
i u && return a.a[i]
return zero(first(a.a))
end


function recursive_fft(as, ωₙ = nothing)
n = length(as)
N = 2^ceil(Int, log2(n))
Expand Down Expand Up @@ -920,24 +936,19 @@ end

# This *should* be faster -- (O(nlog(n)), but this version is definitely not so.
# when `ωₙ = Cyclotomics.E` and T,S are integer, this can be exact
# using `FFTW.jl` over `Float64` types is much faster and is
# using `FFTW.jl` over `Float64` types is much better and is
# implemented in an extension
function poly_multiplication_fft(p::P, q::Q, ωₙ=nothing) where {T,P<:StandardBasisPolynomial{T},
S,Q<:StandardBasisPolynomial{S}}
as, bs = coeffs0(p), coeffs0(q)
n = maximum(length, (as, bs))
N = 2^ceil(Int, log2(n))

as′ = zeros(promote_type(T,S), 2N)
copy!(view(as′, 1:length(as)), as)
as′ = Pad(as, 2N)
bs′ = Pad(bs, 2N)

ω = something(ωₙ, n -> exp(-2im*pi/n))(2N)
âs = recursive_fft(as′, ω)

as′ .= 0
copy!(view(as′, 1:length(bs)), bs)
b̂s = recursive_fft(as′, ω)

b̂s = recursive_fft(bs′, ω)
âb̂s = âs .* b̂s

PP = promote_type(P,Q)
Expand Down

0 comments on commit e547b6c

Please sign in to comment.