Extend defblas to redirect to previously defined blas functions
This commit is contained in:
parent
f4696a2251
commit
a65cb5b040
1 changed files with 38 additions and 143 deletions
165
src/cl-blas.lisp
165
src/cl-blas.lisp
|
@ -4,7 +4,7 @@
|
|||
(defparameter *optimize-qualities* '((speed 3)
|
||||
(safety 0)
|
||||
(debug 0)))
|
||||
(defparameter +single-float-simd-generic+
|
||||
(defparameter *single-float-fn-generic*
|
||||
'((simd f32.8)
|
||||
(simd-aref f32.8-aref)
|
||||
(simd+ f32.8+)
|
||||
|
@ -13,7 +13,7 @@
|
|||
(stride 8)
|
||||
(float single-float)
|
||||
(0f 0.0)))
|
||||
(defparameter +double-float-simd-generic+
|
||||
(defparameter *double-float-fn-generic*
|
||||
'((simd f64.4)
|
||||
(simd-aref f64.4-aref)
|
||||
(simd+ f64.4+)
|
||||
|
@ -24,13 +24,12 @@
|
|||
(0f 0.0d0)))
|
||||
(defun parse-type-declarations (args type)
|
||||
(mapcar (lambda (arg)
|
||||
(if (= (second arg) 0)
|
||||
`(,type ,(first arg))
|
||||
`((simple-array ,type ,(second arg)) ,(first arg))))
|
||||
(cond
|
||||
((symbolp arg) `(fixnum ,arg))
|
||||
((= (second arg) 0) `(,type ,(first arg)))
|
||||
(t `((simple-array ,type ,(second arg)) ,(first arg)))))
|
||||
args))
|
||||
(defun parse-body (body mapping)
|
||||
(if (constantp body)
|
||||
body
|
||||
(mapcar
|
||||
(lambda (el)
|
||||
(typecase el
|
||||
|
@ -39,7 +38,7 @@
|
|||
(sym (if match match el)))
|
||||
sym))
|
||||
(t el)))
|
||||
body)))
|
||||
body))
|
||||
(defun parse-cond-clauses (type-decs)
|
||||
`(and ,@(mapcar (lambda (type-dec)
|
||||
`(typep ,(second type-dec) ',(first type-dec)))
|
||||
|
@ -58,18 +57,6 @@
|
|||
do (setf (row-major-aref arr i) (coerce (random-range min max) float-type)))
|
||||
arr))
|
||||
|
||||
(defun transpose (A)
|
||||
"Transpose A of type (simple-array single-float 2)"
|
||||
(declare ((simple-array single-float 2) A)
|
||||
(optimize . #.*optimize-qualities*))
|
||||
(loop with m fixnum = (array-dimension A 0)
|
||||
with n fixnum = (array-dimension A 1)
|
||||
with B of-type (simple-array single-float 2) = (make-array `(,n ,m) :element-type 'single-float)
|
||||
for i fixnum from 0 below m
|
||||
do (loop for j fixnum from 0 below n
|
||||
do (setf (aref B j i) (aref A i j)))
|
||||
finally (return B)))
|
||||
|
||||
(defmacro defblas (name args &body body)
|
||||
;; ARGS should be of the form ((NAME RANK)+).
|
||||
;; If RANK is 0, then this NAME is a scalar
|
||||
|
@ -77,16 +64,18 @@
|
|||
(double-float-type-declarations (parse-type-declarations args 'double-float))
|
||||
(single-float-name (alexandria:symbolicate name '-single-float))
|
||||
(double-float-name (alexandria:symbolicate name '-double-float))
|
||||
(lambda-list (mapcar #'first args)))
|
||||
(lambda-list (mapcar (lambda (arg) (first (uiop:ensure-list arg))) args)))
|
||||
(push `(,name ,single-float-name) *single-float-fn-generic*)
|
||||
(push `(,name ,double-float-name) *double-float-fn-generic*)
|
||||
`(progn
|
||||
(defun ,single-float-name ,lambda-list
|
||||
(declare (optimize . #.*optimize-qualities*)
|
||||
,@single-float-type-declarations)
|
||||
,@(parse-body body +single-float-simd-generic+))
|
||||
,@(parse-body body *single-float-fn-generic*))
|
||||
(defun ,double-float-name ,lambda-list
|
||||
(declare (optimize . #.*optimize-qualities*)
|
||||
,@double-float-type-declarations)
|
||||
,@(parse-body body +double-float-simd-generic+))
|
||||
,@(parse-body body *double-float-fn-generic*))
|
||||
(defun ,name ,lambda-list
|
||||
(declare (optimize . #.*optimize-qualities*))
|
||||
(cond
|
||||
|
@ -206,23 +195,10 @@
|
|||
|
||||
;;;; Level 3 BLAS: matrix-matrix, O(n^3) operations
|
||||
|
||||
(defun transpose-single-float (A)
|
||||
(declare ((simple-array single-float 2) A)
|
||||
(optimize . #.*optimize-qualities*))
|
||||
(defblas transpose ((A 2))
|
||||
(loop with m fixnum = (array-dimension A 0)
|
||||
with n fixnum = (array-dimension A 1)
|
||||
with B of-type (simple-array single-float 2) = (make-array `(,n ,m) :element-type 'single-float)
|
||||
for i fixnum from 0 below m
|
||||
do (loop for j fixnum from 0 below n
|
||||
do (setf (aref B j i) (aref A i j)))
|
||||
finally (return B)))
|
||||
|
||||
(defun transpose-double-float (A)
|
||||
(declare ((simple-array double-float 2) A)
|
||||
(optimize . #.*optimize-qualities*))
|
||||
(loop with m fixnum = (array-dimension A 0)
|
||||
with n fixnum = (array-dimension A 1)
|
||||
with B of-type (simple-array double-float 2) = (make-array `(,n ,m) :element-type 'double-float)
|
||||
with B of-type (simple-array float 2) = (make-array `(,n ,m) :element-type 'float)
|
||||
for i fixnum from 0 below m
|
||||
do (loop for j fixnum from 0 below n
|
||||
do (setf (aref B j i) (aref A i j)))
|
||||
|
@ -231,123 +207,42 @@
|
|||
;;; gemm
|
||||
|
||||
(declaim (inline gemm-row-dot-single-float gemm-row-dot-double-float))
|
||||
(defun gemm-row-dot-single-float (A Bt C alpha beta i n r r-block)
|
||||
(declare ((simple-array single-float 2) A Bt C)
|
||||
(single-float alpha beta)
|
||||
(fixnum i n r r-block)
|
||||
(optimize . #.*optimize-qualities*))
|
||||
(loop with alpha-vec of-type f32.8 = (f32.8 alpha)
|
||||
(defblas gemm-row-dot ((A 2) (Bt 2) (C 2) (alpha 0) (beta 0) i n r r-block)
|
||||
(loop with alpha-vec of-type simd = (simd alpha)
|
||||
for j fixnum from 0 below n
|
||||
do (setf (aref C i j)
|
||||
(+ (* beta (aref C i j))
|
||||
(loop for acc of-type f32.8 = (f32.8 0) then (f32.8+ acc (f32.8* alpha-vec (f32.8-aref A i k) (f32.8-aref Bt j k)))
|
||||
for k fixnum from 0 below r-block by 8
|
||||
finally (return (the single-float (f32.8-horizontal+ acc))))
|
||||
(loop for acc of-type simd = (simd 0) then (simd+ acc (simd* alpha-vec (simd-aref A i k) (simd-aref Bt j k)))
|
||||
for k fixnum from 0 below r-block by stride
|
||||
finally (return (the float (simd-horizontal+ acc))))
|
||||
(loop for k fixnum from r-block below r
|
||||
sum (* alpha (aref A i k) (aref Bt j k)) of-type single-float)))))
|
||||
sum (* alpha (aref A i k) (aref Bt j k)) of-type float)))))
|
||||
|
||||
(defun gemm-row-dot-double-float (A Bt C alpha beta i n r r-block)
|
||||
(declare ((simple-array double-float 2) A Bt C)
|
||||
(double-float alpha beta)
|
||||
(fixnum i n r r-block)
|
||||
(optimize . #.*optimize-qualities*))
|
||||
(loop with alpha-vec of-type f64.4 = (f64.4 alpha)
|
||||
for j fixnum from 0 below n
|
||||
do (setf (aref C i j)
|
||||
(+ (* beta (aref C i j))
|
||||
(loop for acc of-type f64.4 = (f64.4 0) then (f64.4+ acc (f64.4* alpha-vec (f64.4-aref A i k) (f64.4-aref Bt j k)))
|
||||
for k fixnum from 0 below r-block by 4
|
||||
finally (return (the double-float (f64.4-horizontal+ acc))))
|
||||
(loop for k fixnum from r-block below r
|
||||
sum (* alpha (aref A i k) (aref Bt j k)) of-type double-float)))))
|
||||
|
||||
(defun gemm-parallel-single-float (A B C alpha beta)
|
||||
(declare ((simple-array single-float 2) A B C)
|
||||
(single-float alpha beta)
|
||||
(optimize . #.*optimize-qualities*))
|
||||
(defblas gemm-parallel ((A 2) (B 2) (C 2) (alpha 0) (beta 0))
|
||||
(setf lparallel:*kernel* (lparallel:make-kernel +num-cpus+))
|
||||
(let ((channel (lparallel:make-channel)))
|
||||
(loop with m fixnum = (array-dimension C 0)
|
||||
with n fixnum = (array-dimension C 1)
|
||||
with r-block fixnum = (* (floor (array-dimension A 1) 8) 8)
|
||||
with r-block fixnum = (* (floor (array-dimension A 1) stride) stride)
|
||||
with r fixnum = (array-dimension A 1)
|
||||
with Bt of-type (simple-array single-float 2) = (transpose-single-float B)
|
||||
with Bt of-type (simple-array float 2) = (transpose B)
|
||||
for i fixnum from 0 below m
|
||||
do (lparallel:submit-task channel #'gemm-row-dot-single-float A Bt C alpha beta i n r r-block))
|
||||
do (lparallel:submit-task channel #'gemm-row-dot A Bt C alpha beta i n r r-block))
|
||||
(loop for i from 1 to (array-dimension C 0)
|
||||
do (lparallel:receive-result channel)))
|
||||
(lparallel:end-kernel)
|
||||
nil)
|
||||
|
||||
(defun gemm-parallel-double-float (A B C alpha beta)
|
||||
(declare ((simple-array double-float 2) A B C)
|
||||
(double-float alpha beta)
|
||||
(optimize . #.*optimize-qualities*))
|
||||
(setf lparallel:*kernel* (lparallel:make-kernel +num-cpus+))
|
||||
(let ((channel (lparallel:make-channel)))
|
||||
(loop with m fixnum = (array-dimension C 0)
|
||||
with n fixnum = (array-dimension C 1)
|
||||
with r-block fixnum = (* (floor (array-dimension A 1) 4) 4)
|
||||
with r fixnum = (array-dimension A 1)
|
||||
with Bt of-type (simple-array double-float 2) = (transpose-double-float B)
|
||||
for i fixnum from 0 below m
|
||||
do (lparallel:submit-task channel #'gemm-row-dot-double-float A Bt C alpha beta i n r r-block))
|
||||
(loop for i from 1 to (array-dimension C 0)
|
||||
do (lparallel:receive-result channel)))
|
||||
(lparallel:end-kernel)
|
||||
nil)
|
||||
|
||||
(defun gemm-serial-single-float (A B C alpha beta)
|
||||
(declare ((simple-array single-float 2) A B C)
|
||||
(single-float alpha beta)
|
||||
(optimize . #.*optimize-qualities*))
|
||||
(defblas gemm-serial ((A 2) (B 2) (C 2) (alpha 0) (beta 0))
|
||||
(loop with m fixnum = (array-dimension C 0)
|
||||
with n fixnum = (array-dimension C 1)
|
||||
with r-block fixnum = (* (floor (array-dimension A 1) 8) 8)
|
||||
with r fixnum = (array-dimension A 1)
|
||||
with Bt of-type (simple-array single-float 2) = (transpose-single-float B)
|
||||
with Bt of-type (simple-array float 2) = (transpose B)
|
||||
for i fixnum from 0 below m
|
||||
do (gemm-row-dot-single-float A Bt C alpha beta i n r r-block)))
|
||||
do (gemm-row-dot A Bt C alpha beta i n r r-block)))
|
||||
|
||||
(defun gemm-serial-double-float (A B C alpha beta)
|
||||
(declare ((simple-array double-float 2) A B C)
|
||||
(double-float alpha beta)
|
||||
(optimize . #.*optimize-qualities*))
|
||||
(loop with m fixnum = (array-dimension C 0)
|
||||
with n fixnum = (array-dimension C 1)
|
||||
with r-block fixnum = (* (floor (array-dimension A 1) 4) 4)
|
||||
with r fixnum = (array-dimension A 1)
|
||||
with Bt of-type (simple-array double-float 2) = (transpose-double-float B)
|
||||
for i fixnum from 0 below m
|
||||
do (gemm-row-dot-double-float A Bt C alpha beta i n r r-block)))
|
||||
|
||||
(defun gemm-single-float (A B C alpha beta)
|
||||
(declare ((simple-array single-float 2) A B C)
|
||||
(single-float alpha beta)
|
||||
(optimize (speed 3)))
|
||||
(defblas gemm ((A 2) (B 2) (C 2) (alpha 0) (beta 0))
|
||||
(if (< (array-total-size C) 512)
|
||||
(gemm-serial-single-float A B C alpha beta)
|
||||
(gemm-parallel-single-float A B C alpha beta)))
|
||||
|
||||
(defun gemm-double-float (A B C alpha beta)
|
||||
(declare ((simple-array double-float 2) A B C)
|
||||
(double-float alpha beta)
|
||||
(optimize (speed 3)))
|
||||
(if (< (array-total-size C) 512)
|
||||
(gemm-serial-double-float A B C alpha beta)
|
||||
(gemm-parallel-double-float A B C alpha beta)))
|
||||
|
||||
(defun gemm (A B C alpha beta)
|
||||
(cond
|
||||
((and (typep A '(simple-array single-float 2))
|
||||
(typep B '(simple-array single-float 2))
|
||||
(typep C '(simple-array single-float 2))
|
||||
(typep alpha 'single-float)
|
||||
(typep beta 'single-float))
|
||||
(gemm-single-float A B C alpha beta))
|
||||
((and (typep A '(simple-array double-float 2))
|
||||
(typep B '(simple-array double-float 2))
|
||||
(typep C '(simple-array double-float 2))
|
||||
(typep alpha 'double-float)
|
||||
(typep beta 'double-float))
|
||||
(gemm-double-float A B C alpha beta))))
|
||||
(gemm-serial A B C alpha beta)
|
||||
(gemm-parallel A B C alpha beta)))
|
||||
|
|
Loading…
Add table
Reference in a new issue