Skip to content

Commit e0dfc88

Browse files
blegatfcard
authored andcommitted
Make matrix multiplication work for more types (JuliaLang#18218)
* Make matrix multiplication work for more types Currently it is assumed that the type of a sum of x::T and y::T is T but this may not be the case * Remove arithtype in matmul and deprecate it
1 parent 3b6533d commit e0dfc88

File tree

3 files changed

+62
-18
lines changed

3 files changed

+62
-18
lines changed

base/deprecated.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,4 +1058,22 @@ function reduced_dims0(dims::Dims, region)
10581058
map(last, reduced_dims0(map(n->OneTo(n), dims), region))
10591059
end
10601060

1061+
# #18218
1062+
eval(Base.LinAlg, quote
1063+
function arithtype(T)
1064+
depwarn(string("arithtype is now deprecated. If you were using it inside a ",
1065+
"promote_op call, use promote_op(LinAlg.matprod, Ts...) instead. Otherwise, ",
1066+
"if you need its functionality, consider defining it locally."),
1067+
:arithtype)
1068+
T
1069+
end
1070+
function arithtype(::Type{Bool})
1071+
depwarn(string("arithtype is now deprecated. If you were using it inside a ",
1072+
"promote_op call, use promote_op(LinAlg.matprod, Ts...) instead. Otherwise, ",
1073+
"if you need its functionality, consider defining it locally."),
1074+
:arithtype)
1075+
Int
1076+
end
1077+
end)
1078+
10611079
# End deprecations scheduled for 0.6

base/linalg/matmul.jl

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
# matmul.jl: Everything to do with dense matrix multiplication
44

5-
arithtype(T) = T
6-
arithtype(::Type{Bool}) = Int
5+
matprod(x, y) = x*y + x*y
76

87
# multiply by diagonal matrix as vector
98
function scale!(C::AbstractMatrix, A::AbstractMatrix, b::AbstractVector)
@@ -76,11 +75,11 @@ At_mul_B{T<:BlasComplex}(x::StridedVector{T}, y::StridedVector{T}) = [BLAS.dotu(
7675

7776
# Matrix-vector multiplication
7877
function (*){T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
79-
TS = promote_op(*, arithtype(T), arithtype(S))
78+
TS = promote_op(matprod, T, S)
8079
A_mul_B!(similar(x, TS, size(A,1)), A, convert(AbstractVector{TS}, x))
8180
end
8281
function (*){T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
83-
TS = promote_op(*, arithtype(T), arithtype(S))
82+
TS = promote_op(matprod, T, S)
8483
A_mul_B!(similar(x,TS,size(A,1)),A,x)
8584
end
8685
(*)(A::AbstractVector, B::AbstractMatrix) = reshape(A,length(A),1)*B
@@ -99,22 +98,22 @@ end
9998
A_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_matvecmul!(y, 'N', A, x)
10099

101100
function At_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
102-
TS = promote_op(*, arithtype(T), arithtype(S))
101+
TS = promote_op(matprod, T, S)
103102
At_mul_B!(similar(x,TS,size(A,2)), A, convert(AbstractVector{TS}, x))
104103
end
105104
function At_mul_B{T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
106-
TS = promote_op(*, arithtype(T), arithtype(S))
105+
TS = promote_op(matprod, T, S)
107106
At_mul_B!(similar(x,TS,size(A,2)), A, x)
108107
end
109108
At_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) = gemv!(y, 'T', A, x)
110109
At_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_matvecmul!(y, 'T', A, x)
111110

112111
function Ac_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
113-
TS = promote_op(*, arithtype(T), arithtype(S))
112+
TS = promote_op(matprod, T, S)
114113
Ac_mul_B!(similar(x,TS,size(A,2)),A,convert(AbstractVector{TS},x))
115114
end
116115
function Ac_mul_B{T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
117-
TS = promote_op(*, arithtype(T), arithtype(S))
116+
TS = promote_op(matprod, T, S)
118117
Ac_mul_B!(similar(x,TS,size(A,2)), A, x)
119118
end
120119

@@ -141,7 +140,7 @@ julia> [1 1; 0 1] * [1 0; 1 1]
141140
```
142141
"""
143142
function (*){T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
144-
TS = promote_op(*, arithtype(T), arithtype(S))
143+
TS = promote_op(matprod, T, S)
145144
A_mul_B!(similar(B, TS, (size(A,1), size(B,2))), A, B)
146145
end
147146
A_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'N', 'N', A, B)
@@ -177,14 +176,14 @@ julia> Y
177176
A_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'N', A, B)
178177

179178
function At_mul_B{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
180-
TS = promote_op(*, arithtype(T), arithtype(S))
179+
TS = promote_op(matprod, T, S)
181180
At_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B)
182181
end
183182
At_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = A===B ? syrk_wrapper!(C, 'T', A) : gemm_wrapper!(C, 'T', 'N', A, B)
184183
At_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'T', 'N', A, B)
185184

186185
function A_mul_Bt{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
187-
TS = promote_op(*, arithtype(T), arithtype(S))
186+
TS = promote_op(matprod, T, S)
188187
A_mul_Bt!(similar(B, TS, (size(A,1), size(B,1))), A, B)
189188
end
190189
A_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = A===B ? syrk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'T', A, B)
@@ -201,7 +200,7 @@ end
201200
A_mul_Bt!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'T', A, B)
202201

203202
function At_mul_Bt{T,S}(A::AbstractMatrix{T}, B::AbstractVecOrMat{S})
204-
TS = promote_op(*, arithtype(T), arithtype(S))
203+
TS = promote_op(matprod, T, S)
205204
At_mul_Bt!(similar(B, TS, (size(A,2), size(B,1))), A, B)
206205
end
207206
At_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'T', 'T', A, B)
@@ -210,7 +209,7 @@ At_mul_Bt!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generi
210209
Ac_mul_B{T<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{T}) = At_mul_B(A, B)
211210
Ac_mul_B!{T<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = At_mul_B!(C, A, B)
212211
function Ac_mul_B{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
213-
TS = promote_op(*, arithtype(T), arithtype(S))
212+
TS = promote_op(matprod, T, S)
214213
Ac_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B)
215214
end
216215
Ac_mul_B!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = A===B ? herk_wrapper!(C,'C',A) : gemm_wrapper!(C,'C', 'N', A, B)
@@ -219,14 +218,14 @@ Ac_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic
219218
A_mul_Bc{T<:BlasFloat,S<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{S}) = A_mul_Bt(A, B)
220219
A_mul_Bc!{T<:BlasFloat,S<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{S}) = A_mul_Bt!(C, A, B)
221220
function A_mul_Bc{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
222-
TS = promote_op(*, arithtype(T), arithtype(S))
221+
TS = promote_op(matprod, T, S)
223222
A_mul_Bc!(similar(B,TS,(size(A,1),size(B,1))),A,B)
224223
end
225224
A_mul_Bc!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = A===B ? herk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'C', A, B)
226225
A_mul_Bc!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'C', A, B)
227226

228227
Ac_mul_Bc{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S}) =
229-
Ac_mul_Bc!(similar(B, promote_op(*, arithtype(T), arithtype(S)), (size(A,2), size(B,1))), A, B)
228+
Ac_mul_Bc!(similar(B, promote_op(matprod, T, S), (size(A,2), size(B,1))), A, B)
230229
Ac_mul_Bc!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'C', 'C', A, B)
231230
Ac_mul_Bc!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'C', 'C', A, B)
232231
Ac_mul_Bt!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'C', 'T', A, B)
@@ -459,7 +458,7 @@ end
459458
function generic_matmatmul{T,S}(tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S})
460459
mA, nA = lapack_size(tA, A)
461460
mB, nB = lapack_size(tB, B)
462-
C = similar(B, promote_op(*, arithtype(T), arithtype(S)), mA, nB)
461+
C = similar(B, promote_op(matprod, T, S), mA, nB)
463462
generic_matmatmul!(C, tA, tB, A, B)
464463
end
465464

@@ -653,7 +652,7 @@ end
653652

654653
# multiply 2x2 matrices
655654
function matmul2x2{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
656-
matmul2x2!(similar(B, promote_op(*, T, S), 2, 2), tA, tB, A, B)
655+
matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B)
657656
end
658657

659658
function matmul2x2!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
@@ -682,7 +681,7 @@ end
682681

683682
# Multiply 3x3 matrices
684683
function matmul3x3{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
685-
matmul3x3!(similar(B, promote_op(*, T, S), 3, 3), tA, tB, A, B)
684+
matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B)
686685
end
687686

688687
function matmul3x3!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})

test/linalg/matmul.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,30 @@ let
389389
@test_throws DimensionMismatch A_mul_B!(full43, full43, tri44)
390390
end
391391
end
392+
393+
# #18218
394+
module TestPR18218
395+
using Base.Test
396+
import Base.*, Base.+, Base.zero
397+
immutable TypeA
398+
x::Int
399+
end
400+
Base.convert(::Type{TypeA}, x::Int) = TypeA(x)
401+
immutable TypeB
402+
x::Int
403+
end
404+
immutable TypeC
405+
x::Int
406+
end
407+
Base.convert(::Type{TypeC}, x::Int) = TypeC(x)
408+
zero(c::TypeC) = TypeC(0)
409+
zero(::Type{TypeC}) = TypeC(0)
410+
(*)(x::Int, a::TypeA) = TypeB(x*a.x)
411+
(*)(a::TypeA, x::Int) = TypeB(a.x*x)
412+
(+)(a::Union{TypeB,TypeC}, b::Union{TypeB,TypeC}) = TypeC(a.x+b.x)
413+
A = TypeA[1 2; 3 4]
414+
b = [1, 2]
415+
d = A * b
416+
@test typeof(d) == Vector{TypeC}
417+
@test d == TypeC[5, 11]
418+
end

0 commit comments

Comments
 (0)