Skip to content

Add Enzyme as a normal test dependency #583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,9 @@ steps:
command: |
julia -e 'println("--- :julia: Instantiating project")
using Pkg
try
Pkg.develop([PackageSpec(; path=pwd()), PackageSpec("Enzyme"), PackageSpec("EnzymeCore"), PackageSpec("CUDA")])
catch err
Pkg.develop(; path=pwd())
Pkg.add(["CUDA", "Enzyme"])
end' || exit 3
Pkg.develop(; path=pwd())
Pkg.add(["CUDA", "Enzyme"])
' || exit 3

julia -e 'println("+++ :julia: Running tests")
using CUDA
Expand Down
18 changes: 0 additions & 18 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,24 +128,6 @@ jobs:
end'
echo '[pocl_jll]' > test/LocalPreferences.toml
echo 'libpocl_path="${{ github.workspace }}/target/lib/libpocl.so"' >> test/LocalPreferences.toml
- name: "Co-develop Enzyme and KA"
run: |
julia -e '
using Pkg
withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do
Pkg.activate("test")
Pkg.add(["Enzyme", "EnzymeCore"])

# to check compatibility, also add Enzyme to the main environment
# (or Pkg.test, which merges both environments, could fail)
Pkg.activate(".")
# Try to co-develop Enzyme and KA
try
Pkg.develop([PackageSpec("Enzyme"), PackageSpec("EnzymeCore")])
catch err
end
end
'
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
with:
Expand Down
126 changes: 11 additions & 115 deletions ext/EnzymeCore07Ext.jl
Original file line number Diff line number Diff line change
@@ -1,81 +1,30 @@
# https://github.com/EnzymeAD/Enzyme.jl/issues/1516
# On the CPU `autodiff_deferred` can deadlock.
# Hence a specialized CPU version
function cpu_fwd(ctx, f, args...)
EnzymeCore.autodiff(Forward, Const(f), Const{Nothing}, Const(ctx), args...)
return nothing
end

function gpu_fwd(ctx, f, args...)
function fwd(ctx, f, args...)

Check warning on line 1 in ext/EnzymeCore07Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore07Ext.jl#L1

Added line #L1 was not covered by tests
EnzymeCore.autodiff_deferred(Forward, Const(f), Const{Nothing}, Const(ctx), args...)
return nothing
end

function EnzymeRules.forward(
func::Const{<:Kernel{CPU}},
::Type{Const{Nothing}},
args...;
ndrange = nothing,
workgroupsize = nothing,
)
kernel = func.val
f = kernel.f
fwd_kernel = similar(kernel, cpu_fwd)

return fwd_kernel(f, args...; ndrange, workgroupsize)
end

function EnzymeRules.forward(
func::Const{<:Kernel{<:GPU}},
func::Const{<:Kernel},
::Type{Const{Nothing}},
args...;
ndrange = nothing,
workgroupsize = nothing,
)
kernel = func.val
f = kernel.f
fwd_kernel = similar(kernel, gpu_fwd)
fwd_kernel = similar(kernel, fwd)

Check warning on line 15 in ext/EnzymeCore07Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore07Ext.jl#L15

Added line #L15 was not covered by tests

return fwd_kernel(f, args...; ndrange, workgroupsize)
end

_enzyme_mkcontext(kernel::Kernel{CPU}, ndrange, iterspace, dynamic) =
mkcontext(kernel, first(blocks(iterspace)), ndrange, iterspace, dynamic)
_enzyme_mkcontext(kernel::Kernel{<:GPU}, ndrange, iterspace, dynamic) =
_enzyme_mkcontext(kernel::Kernel, ndrange, iterspace, dynamic) =

Check warning on line 20 in ext/EnzymeCore07Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore07Ext.jl#L20

Added line #L20 was not covered by tests
mkcontext(kernel, ndrange, iterspace)

_augmented_return(::Kernel{CPU}, subtape, arg_refs, tape_type) =
AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs), typeof(tape_type)}}(
nothing,
nothing,
(subtape, arg_refs, tape_type),
)
_augmented_return(::Kernel{<:GPU}, subtape, arg_refs, tape_type) =
_augmented_return(::Kernel, subtape, arg_refs, tape_type) =

Check warning on line 23 in ext/EnzymeCore07Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore07Ext.jl#L23

Added line #L23 was not covered by tests
AugmentedReturn{Nothing, Nothing, Any}(nothing, nothing, (subtape, arg_refs, tape_type))

function _create_tape_kernel(
kernel::Kernel{CPU},
ModifiedBetween,
FT,
ctxTy,
ndrange,
iterspace,
args2...,
)
TapeType = EnzymeCore.tape_type(
ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween),
FT,
Const{Nothing},
Const{ctxTy},
map(Core.Typeof, args2)...,
)
subtape = Array{TapeType}(undef, size(blocks(iterspace)))
aug_kernel = similar(kernel, cpu_aug_fwd)
return TapeType, subtape, aug_kernel
end

function _create_tape_kernel(
kernel::Kernel{<:GPU},
kernel::Kernel,
ModifiedBetween,
FT,
ctxTy,
Expand Down Expand Up @@ -104,60 +53,11 @@
# Allocate per thread
subtape = allocate(backend(kernel), TapeType, prod(ndrange))

aug_kernel = similar(kernel, gpu_aug_fwd)
aug_kernel = similar(kernel, aug_fwd)

Check warning on line 56 in ext/EnzymeCore07Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore07Ext.jl#L56

Added line #L56 was not covered by tests
return TapeType, subtape, aug_kernel
end

_create_rev_kernel(kernel::Kernel{CPU}) = similar(kernel, cpu_rev)
_create_rev_kernel(kernel::Kernel{<:GPU}) = similar(kernel, gpu_rev)

function cpu_aug_fwd(
ctx,
f::FT,
::Val{ModifiedBetween},
subtape,
::Val{TapeType},
args...,
) where {ModifiedBetween, FT, TapeType}
# A2 = Const{Nothing} -- since f->Nothing
forward, _ = EnzymeCore.autodiff_thunk(
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
Const{Core.Typeof(f)},
Const{Nothing},
Const{Core.Typeof(ctx)},
map(Core.Typeof, args)...,
)

# On the CPU: F is a per block function
# On the CPU: subtape::Vector{Vector}
I = __index_Group_Cartesian(ctx, CartesianIndex(1, 1)) #=fake=#
subtape[I] = forward(Const(f), Const(ctx), args...)[1]
return nothing
end

function cpu_rev(
ctx,
f::FT,
::Val{ModifiedBetween},
subtape,
::Val{TapeType},
args...,
) where {ModifiedBetween, FT, TapeType}
_, reverse = EnzymeCore.autodiff_thunk(
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
Const{Core.Typeof(f)},
Const{Nothing},
Const{Core.Typeof(ctx)},
map(Core.Typeof, args)...,
)
I = __index_Group_Cartesian(ctx, CartesianIndex(1, 1)) #=fake=#
tp = subtape[I]
reverse(Const(f), Const(ctx), args..., tp)
return nothing
end

# GPU support
function gpu_aug_fwd(
function aug_fwd(

Check warning on line 60 in ext/EnzymeCore07Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore07Ext.jl#L60

Added line #L60 was not covered by tests
ctx,
f::FT,
::Val{ModifiedBetween},
Expand All @@ -184,7 +84,7 @@
return nothing
end

function gpu_rev(
function rev(

Check warning on line 87 in ext/EnzymeCore07Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore07Ext.jl#L87

Added line #L87 was not covered by tests
ctx,
f::FT,
::Val{ModifiedBetween},
Expand Down Expand Up @@ -232,11 +132,7 @@
arg_refs = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
if func.val isa Kernel{<:GPU}
error("Active kernel arguments not supported on GPU")
else
Ref(EnzymeCore.make_zero(args[i].val))
end
error("Active kernel arguments not supported")

Check warning on line 135 in ext/EnzymeCore07Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore07Ext.jl#L135

Added line #L135 was not covered by tests
else
nothing
end
Expand Down Expand Up @@ -292,7 +188,7 @@

ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))

rev_kernel = _create_rev_kernel(kernel)
rev_kernel = similar(kernel, rev)

Check warning on line 191 in ext/EnzymeCore07Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore07Ext.jl#L191

Added line #L191 was not covered by tests
rev_kernel(
f,
ModifiedBetween,
Expand Down
123 changes: 9 additions & 114 deletions ext/EnzymeCore08Ext.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,11 @@
# https://github.com/EnzymeAD/Enzyme.jl/issues/1516
# On the CPU `autodiff_deferred` can deadlock.
# Hence a specialized CPU version
function cpu_fwd(ctx, config, f, args...)
EnzymeCore.autodiff(EnzymeCore.set_runtime_activity(Forward, config), Const(f), Const{Nothing}, Const(ctx), args...)
return nothing
end

function gpu_fwd(ctx, config, f, args...)
function fwd(ctx, config, f, args...)

Check warning on line 1 in ext/EnzymeCore08Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore08Ext.jl#L1

Added line #L1 was not covered by tests
EnzymeCore.autodiff_deferred(EnzymeCore.set_runtime_activity(Forward, config), Const(f), Const{Nothing}, Const(ctx), args...)
return nothing
end

function EnzymeRules.forward(
config,
func::Const{<:Kernel{CPU}},
::Type{Const{Nothing}},
args...;
ndrange = nothing,
workgroupsize = nothing,
)
kernel = func.val
f = kernel.f
fwd_kernel = similar(kernel, cpu_fwd)

return fwd_kernel(config, f, args...; ndrange, workgroupsize)
end

function EnzymeRules.forward(
config,
func::Const{<:Kernel{<:GPU}},
func::Const{<:Kernel},
::Type{Const{Nothing}},
args...;
ndrange = nothing,
Expand All @@ -41,41 +18,12 @@
return fwd_kernel(config, f, args...; ndrange, workgroupsize)
end

_enzyme_mkcontext(kernel::Kernel{CPU}, ndrange, iterspace, dynamic) =
mkcontext(kernel, first(blocks(iterspace)), ndrange, iterspace, dynamic)
_enzyme_mkcontext(kernel::Kernel{<:GPU}, ndrange, iterspace, dynamic) =
_enzyme_mkcontext(kernel::Kernel, ndrange, iterspace, dynamic) =

Check warning on line 21 in ext/EnzymeCore08Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore08Ext.jl#L21

Added line #L21 was not covered by tests
mkcontext(kernel, ndrange, iterspace)

_augmented_return(::Kernel{CPU}, subtape, arg_refs, tape_type) =
AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs), typeof(tape_type)}}(
nothing,
nothing,
(subtape, arg_refs, tape_type),
)
_augmented_return(::Kernel{<:GPU}, subtape, arg_refs, tape_type) =
_augmented_return(::Kernel, subtape, arg_refs, tape_type) =

Check warning on line 24 in ext/EnzymeCore08Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore08Ext.jl#L24

Added line #L24 was not covered by tests
AugmentedReturn{Nothing, Nothing, Any}(nothing, nothing, (subtape, arg_refs, tape_type))

function _create_tape_kernel(
kernel::Kernel{CPU},
Mode,
FT,
ctxTy,
ndrange,
iterspace,
args2...,
)
TapeType = EnzymeCore.tape_type(
Mode,
FT,
Const{Nothing},
Const{ctxTy},
map(Core.Typeof, args2)...,
)
subtape = Array{TapeType}(undef, size(blocks(iterspace)))
aug_kernel = similar(kernel, cpu_aug_fwd)
return TapeType, subtape, aug_kernel
end

function _create_tape_kernel(
kernel::Kernel{<:GPU},
Mode,
Expand Down Expand Up @@ -106,60 +54,11 @@
# Allocate per thread
subtape = allocate(backend(kernel), TapeType, prod(ndrange))

aug_kernel = similar(kernel, gpu_aug_fwd)
aug_kernel = similar(kernel, aug_fwd)

Check warning on line 57 in ext/EnzymeCore08Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore08Ext.jl#L57

Added line #L57 was not covered by tests
return TapeType, subtape, aug_kernel
end

_create_rev_kernel(kernel::Kernel{CPU}) = similar(kernel, cpu_rev)
_create_rev_kernel(kernel::Kernel{<:GPU}) = similar(kernel, gpu_rev)

function cpu_aug_fwd(
ctx,
f::FT,
mode::Mode,
subtape,
::Val{TapeType},
args...,
) where {Mode, FT, TapeType}
# A2 = Const{Nothing} -- since f->Nothing
forward, _ = EnzymeCore.autodiff_thunk(
mode,
Const{Core.Typeof(f)},
Const{Nothing},
Const{Core.Typeof(ctx)},
map(Core.Typeof, args)...,
)

# On the CPU: F is a per block function
# On the CPU: subtape::Vector{Vector}
I = __index_Group_Cartesian(ctx, CartesianIndex(1, 1)) #=fake=#
subtape[I] = forward(Const(f), Const(ctx), args...)[1]
return nothing
end

function cpu_rev(
ctx,
f::FT,
mode::Mode,
subtape,
::Val{TapeType},
args...,
) where {Mode, FT, TapeType}
_, reverse = EnzymeCore.autodiff_thunk(
mode,
Const{Core.Typeof(f)},
Const{Nothing},
Const{Core.Typeof(ctx)},
map(Core.Typeof, args)...,
)
I = __index_Group_Cartesian(ctx, CartesianIndex(1, 1)) #=fake=#
tp = subtape[I]
reverse(Const(f), Const(ctx), args..., tp)
return nothing
end

# GPU support
function gpu_aug_fwd(
function fwd(

Check warning on line 61 in ext/EnzymeCore08Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore08Ext.jl#L61

Added line #L61 was not covered by tests
ctx,
f::FT,
mode::Mode,
Expand All @@ -186,7 +85,7 @@
return nothing
end

function gpu_rev(
function rev(

Check warning on line 88 in ext/EnzymeCore08Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore08Ext.jl#L88

Added line #L88 was not covered by tests
ctx,
f::FT,
mode::Mode,
Expand Down Expand Up @@ -234,11 +133,7 @@
arg_refs = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
if func.val isa Kernel{<:GPU}
error("Active kernel arguments not supported on GPU")
else
Ref(EnzymeCore.make_zero(args[i].val))
end
error("Active kernel arguments not supported")

Check warning on line 136 in ext/EnzymeCore08Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore08Ext.jl#L136

Added line #L136 was not covered by tests
else
nothing
end
Expand Down Expand Up @@ -294,7 +189,7 @@

ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))
Mode = EnzymeCore.set_runtime_activity(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), config)
rev_kernel = _create_rev_kernel(kernel)
rev_kernel = similar(kernel, rev)

Check warning on line 192 in ext/EnzymeCore08Ext.jl

View check run for this annotation

Codecov / codecov/patch

ext/EnzymeCore08Ext.jl#L192

Added line #L192 was not covered by tests
rev_kernel(
f,
Mode,
Expand Down
Loading
Loading