Extend defblas to redirect to previously defined blas functions

This commit is contained in:
dominik martinez 2025-05-17 17:06:41 -07:00
parent f4696a2251
commit a65cb5b040

View file

@ -4,7 +4,7 @@
(defparameter *optimize-qualities* '((speed 3)
(safety 0)
(debug 0)))
(defparameter +single-float-simd-generic+
(defparameter *single-float-fn-generic*
'((simd f32.8)
(simd-aref f32.8-aref)
(simd+ f32.8+)
@ -13,7 +13,7 @@
(stride 8)
(float single-float)
(0f 0.0)))
(defparameter +double-float-simd-generic+
(defparameter *double-float-fn-generic*
'((simd f64.4)
(simd-aref f64.4-aref)
(simd+ f64.4+)
@ -24,13 +24,12 @@
(0f 0.0d0)))
(defun parse-type-declarations (args type)
(mapcar (lambda (arg)
(if (= (second arg) 0)
`(,type ,(first arg))
`((simple-array ,type ,(second arg)) ,(first arg))))
(cond
((symbolp arg) `(fixnum ,arg))
((= (second arg) 0) `(,type ,(first arg)))
(t `((simple-array ,type ,(second arg)) ,(first arg)))))
args))
(defun parse-body (body mapping)
(if (constantp body)
body
(mapcar
(lambda (el)
(typecase el
@ -39,7 +38,7 @@
(sym (if match match el)))
sym))
(t el)))
body)))
body))
(defun parse-cond-clauses (type-decs)
`(and ,@(mapcar (lambda (type-dec)
`(typep ,(second type-dec) ',(first type-dec)))
@ -58,18 +57,6 @@
do (setf (row-major-aref arr i) (coerce (random-range min max) float-type)))
arr))
(defun transpose (A)
"Transpose A of type (simple-array single-float 2)"
(declare ((simple-array single-float 2) A)
(optimize . #.*optimize-qualities*))
(loop with m fixnum = (array-dimension A 0)
with n fixnum = (array-dimension A 1)
with B of-type (simple-array single-float 2) = (make-array `(,n ,m) :element-type 'single-float)
for i fixnum from 0 below m
do (loop for j fixnum from 0 below n
do (setf (aref B j i) (aref A i j)))
finally (return B)))
(defmacro defblas (name args &body body)
;; ARGS should be of the form ((NAME RANK)+).
;; If RANK is 0, then this NAME is a scalar
@ -77,16 +64,18 @@
(double-float-type-declarations (parse-type-declarations args 'double-float))
(single-float-name (alexandria:symbolicate name '-single-float))
(double-float-name (alexandria:symbolicate name '-double-float))
(lambda-list (mapcar #'first args)))
(lambda-list (mapcar (lambda (arg) (first (uiop:ensure-list arg))) args)))
(push `(,name ,single-float-name) *single-float-fn-generic*)
(push `(,name ,double-float-name) *double-float-fn-generic*)
`(progn
(defun ,single-float-name ,lambda-list
(declare (optimize . #.*optimize-qualities*)
,@single-float-type-declarations)
,@(parse-body body +single-float-simd-generic+))
,@(parse-body body *single-float-fn-generic*))
(defun ,double-float-name ,lambda-list
(declare (optimize . #.*optimize-qualities*)
,@double-float-type-declarations)
,@(parse-body body +double-float-simd-generic+))
,@(parse-body body *double-float-fn-generic*))
(defun ,name ,lambda-list
(declare (optimize . #.*optimize-qualities*))
(cond
@ -206,23 +195,10 @@
;;;; Level 3 BLAS: matrix-matrix, O(n^3) operations
(defun transpose-single-float (A)
(declare ((simple-array single-float 2) A)
(optimize . #.*optimize-qualities*))
(defblas transpose ((A 2))
(loop with m fixnum = (array-dimension A 0)
with n fixnum = (array-dimension A 1)
with B of-type (simple-array single-float 2) = (make-array `(,n ,m) :element-type 'single-float)
for i fixnum from 0 below m
do (loop for j fixnum from 0 below n
do (setf (aref B j i) (aref A i j)))
finally (return B)))
(defun transpose-double-float (A)
(declare ((simple-array double-float 2) A)
(optimize . #.*optimize-qualities*))
(loop with m fixnum = (array-dimension A 0)
with n fixnum = (array-dimension A 1)
with B of-type (simple-array double-float 2) = (make-array `(,n ,m) :element-type 'double-float)
with B of-type (simple-array float 2) = (make-array `(,n ,m) :element-type 'float)
for i fixnum from 0 below m
do (loop for j fixnum from 0 below n
do (setf (aref B j i) (aref A i j)))
@ -231,123 +207,42 @@
;;; gemm
(declaim (inline gemm-row-dot-single-float gemm-row-dot-double-float))
(defun gemm-row-dot-single-float (A Bt C alpha beta i n r r-block)
(declare ((simple-array single-float 2) A Bt C)
(single-float alpha beta)
(fixnum i n r r-block)
(optimize . #.*optimize-qualities*))
(loop with alpha-vec of-type f32.8 = (f32.8 alpha)
(defblas gemm-row-dot ((A 2) (Bt 2) (C 2) (alpha 0) (beta 0) i n r r-block)
(loop with alpha-vec of-type simd = (simd alpha)
for j fixnum from 0 below n
do (setf (aref C i j)
(+ (* beta (aref C i j))
(loop for acc of-type f32.8 = (f32.8 0) then (f32.8+ acc (f32.8* alpha-vec (f32.8-aref A i k) (f32.8-aref Bt j k)))
for k fixnum from 0 below r-block by 8
finally (return (the single-float (f32.8-horizontal+ acc))))
(loop for acc of-type simd = (simd 0) then (simd+ acc (simd* alpha-vec (simd-aref A i k) (simd-aref Bt j k)))
for k fixnum from 0 below r-block by stride
finally (return (the float (simd-horizontal+ acc))))
(loop for k fixnum from r-block below r
sum (* alpha (aref A i k) (aref Bt j k)) of-type single-float)))))
sum (* alpha (aref A i k) (aref Bt j k)) of-type float)))))
(defun gemm-row-dot-double-float (A Bt C alpha beta i n r r-block)
(declare ((simple-array double-float 2) A Bt C)
(double-float alpha beta)
(fixnum i n r r-block)
(optimize . #.*optimize-qualities*))
(loop with alpha-vec of-type f64.4 = (f64.4 alpha)
for j fixnum from 0 below n
do (setf (aref C i j)
(+ (* beta (aref C i j))
(loop for acc of-type f64.4 = (f64.4 0) then (f64.4+ acc (f64.4* alpha-vec (f64.4-aref A i k) (f64.4-aref Bt j k)))
for k fixnum from 0 below r-block by 4
finally (return (the double-float (f64.4-horizontal+ acc))))
(loop for k fixnum from r-block below r
sum (* alpha (aref A i k) (aref Bt j k)) of-type double-float)))))
(defun gemm-parallel-single-float (A B C alpha beta)
(declare ((simple-array single-float 2) A B C)
(single-float alpha beta)
(optimize . #.*optimize-qualities*))
(defblas gemm-parallel ((A 2) (B 2) (C 2) (alpha 0) (beta 0))
(setf lparallel:*kernel* (lparallel:make-kernel +num-cpus+))
(let ((channel (lparallel:make-channel)))
(loop with m fixnum = (array-dimension C 0)
with n fixnum = (array-dimension C 1)
with r-block fixnum = (* (floor (array-dimension A 1) 8) 8)
with r-block fixnum = (* (floor (array-dimension A 1) stride) stride)
with r fixnum = (array-dimension A 1)
with Bt of-type (simple-array single-float 2) = (transpose-single-float B)
with Bt of-type (simple-array float 2) = (transpose B)
for i fixnum from 0 below m
do (lparallel:submit-task channel #'gemm-row-dot-single-float A Bt C alpha beta i n r r-block))
do (lparallel:submit-task channel #'gemm-row-dot A Bt C alpha beta i n r r-block))
(loop for i from 1 to (array-dimension C 0)
do (lparallel:receive-result channel)))
(lparallel:end-kernel)
nil)
(defun gemm-parallel-double-float (A B C alpha beta)
(declare ((simple-array double-float 2) A B C)
(double-float alpha beta)
(optimize . #.*optimize-qualities*))
(setf lparallel:*kernel* (lparallel:make-kernel +num-cpus+))
(let ((channel (lparallel:make-channel)))
(loop with m fixnum = (array-dimension C 0)
with n fixnum = (array-dimension C 1)
with r-block fixnum = (* (floor (array-dimension A 1) 4) 4)
with r fixnum = (array-dimension A 1)
with Bt of-type (simple-array double-float 2) = (transpose-double-float B)
for i fixnum from 0 below m
do (lparallel:submit-task channel #'gemm-row-dot-double-float A Bt C alpha beta i n r r-block))
(loop for i from 1 to (array-dimension C 0)
do (lparallel:receive-result channel)))
(lparallel:end-kernel)
nil)
(defun gemm-serial-single-float (A B C alpha beta)
(declare ((simple-array single-float 2) A B C)
(single-float alpha beta)
(optimize . #.*optimize-qualities*))
(defblas gemm-serial ((A 2) (B 2) (C 2) (alpha 0) (beta 0))
(loop with m fixnum = (array-dimension C 0)
with n fixnum = (array-dimension C 1)
with r-block fixnum = (* (floor (array-dimension A 1) 8) 8)
with r fixnum = (array-dimension A 1)
with Bt of-type (simple-array single-float 2) = (transpose-single-float B)
with Bt of-type (simple-array float 2) = (transpose B)
for i fixnum from 0 below m
do (gemm-row-dot-single-float A Bt C alpha beta i n r r-block)))
do (gemm-row-dot A Bt C alpha beta i n r r-block)))
(defun gemm-serial-double-float (A B C alpha beta)
(declare ((simple-array double-float 2) A B C)
(double-float alpha beta)
(optimize . #.*optimize-qualities*))
(loop with m fixnum = (array-dimension C 0)
with n fixnum = (array-dimension C 1)
with r-block fixnum = (* (floor (array-dimension A 1) 4) 4)
with r fixnum = (array-dimension A 1)
with Bt of-type (simple-array double-float 2) = (transpose-double-float B)
for i fixnum from 0 below m
do (gemm-row-dot-double-float A Bt C alpha beta i n r r-block)))
(defun gemm-single-float (A B C alpha beta)
(declare ((simple-array single-float 2) A B C)
(single-float alpha beta)
(optimize (speed 3)))
(defblas gemm ((A 2) (B 2) (C 2) (alpha 0) (beta 0))
(if (< (array-total-size C) 512)
(gemm-serial-single-float A B C alpha beta)
(gemm-parallel-single-float A B C alpha beta)))
(defun gemm-double-float (A B C alpha beta)
(declare ((simple-array double-float 2) A B C)
(double-float alpha beta)
(optimize (speed 3)))
(if (< (array-total-size C) 512)
(gemm-serial-double-float A B C alpha beta)
(gemm-parallel-double-float A B C alpha beta)))
(defun gemm (A B C alpha beta)
(cond
((and (typep A '(simple-array single-float 2))
(typep B '(simple-array single-float 2))
(typep C '(simple-array single-float 2))
(typep alpha 'single-float)
(typep beta 'single-float))
(gemm-single-float A B C alpha beta))
((and (typep A '(simple-array double-float 2))
(typep B '(simple-array double-float 2))
(typep C '(simple-array double-float 2))
(typep alpha 'double-float)
(typep beta 'double-float))
(gemm-double-float A B C alpha beta))))
(gemm-serial A B C alpha beta)
(gemm-parallel A B C alpha beta)))