Skip to content

Commit 1241bba

Browse files
committed
add allowslow
1 parent 0213868 commit 1241bba

File tree

4 files changed

+28
-1
lines changed

4 files changed

+28
-1
lines changed

src/NNlib.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ using Statistics: mean
1818

1919
const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}
2020

21+
"""
22+
allowslow(::Bool)
23+
24+
By default, NNlib will print warnings the first time various slow fallback paths are taken.
25+
Calling `allowslow(false)` will instead make these into errors.
26+
"""
27+
allowslow(flag::Bool) = (SLOWERROR[] = !flag; nothing)
28+
const SLOWERROR = Ref(true)
29+
2130
# Include APIs
2231
include("dim_helpers.jl")
2332
export ConvDims, DenseConvDims, PoolDims, DepthwiseConvDims

src/batched/batchedmul.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,10 @@ for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST
274274

275275
size(A, 3) == size(C, 3) || size(A, 3) == 1 || throw(DimensionMismatch("batch size mismatch: A != C"))
276276
size(B, 3) == size(C, 3) || size(B, 3) == 1 || throw(DimensionMismatch("batch size mismatch: B != C"))
277-
@debug "calling fallback method for batched_mul!" typeof(A) size(A) typeof(B) size(B) typeof(C)
277+
@warn "calling fallback method for batched_mul!" typeof(A) size(A) typeof(B) size(B) typeof(C) maxlog=1
278+
if SLOWERROR[]
279+
error("calling fallback method for batched_mul!")
280+
end
278281

279282
Abase, Bbase = _unbatch(A), _unbatch(B)
280283
sA, oA = size(A,3) == 1 ? (0,1) : (1,0)

src/conv.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ for (front_name, backend, signature) in (
191191
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
192192
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
193193
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
194+
SLOWERROR[] && error(string("calling slow fallback method for ", $(string(front_name))))
194195
end
195196

196197
x_cs = Iterators.partition(1:size(in1, 4),
@@ -232,6 +233,7 @@ for (front_name, backend, signature) in (
232233
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
233234
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
234235
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
236+
SLOWERROR[] && error(string("calling slow fallback method for ", $(string(front_name))))
235237
end
236238

237239

@@ -275,6 +277,7 @@ for (front_name, backend, signature) in (
275277
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
276278
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
277279
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
280+
SLOWERROR[] && error(string("calling slow fallback method for ", $(string(front_name))))
278281
end
279282

280283
dw_cs = Iterators.partition(1:size(out, 5),
@@ -326,6 +329,7 @@ for (front_name, backend, signature) in (
326329
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
327330
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
328331
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
332+
SLOWERROR[] && error(string("calling slow fallback method for ", $(string(front_name))))
329333
end
330334
$(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...)
331335
end

test/batchedmul.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,14 @@ FiniteDifferences.to_vec(x::BatchedTranspose) = FiniteDifferences.to_vec(collect
303303

304304
gradtest(batched_vec, randn(rng, M, P, B), randn(rng, P))
305305
end
306+
307+
@testset "warning / error" begin
308+
prev = NNlib.SLOWERROR[]
309+
NNlib.allowslow(true)
310+
A = rand(1:99, 3,4,7)
311+
B = rand(1:99, 4,5,7)
312+
@test batched_mul(A, B) isa Array # no error!
313+
NNlib.allowslow(false)
314+
@test_throws Exception batched_mul(A, B)
315+
NNlib.SLOWERROR[] = prev
316+
end

0 commit comments

Comments
 (0)