cl-blas/src/cl-blas.lisp
2025-05-21 18:21:28 -07:00

317 lines
12 KiB
Common Lisp

(in-package :cl-blas)
(eval-when (:compile-toplevel :load-toplevel :execute)
(defparameter *optimize-qualities* '((speed 3)
(safety 0)
(debug 0)))
(defparameter *single-float-fn-generic*
'((simd f32.8)
(simd-aref f32.8-aref)
(simd+ f32.8+)
(simd* f32.8*)
(simd-horizontal+ f32.8-horizontal+)
(simd-horizontal-max f32.8-horizontal-max)
(stride 8)
(float single-float)
(0f 0.0)))
(defparameter *double-float-fn-generic*
'((simd f64.4)
(simd-aref f64.4-aref)
(simd+ f64.4+)
(simd* f64.4*)
(simd-horizontal+ f64.4-horizontal+)
(simd-horizontal-max f64.4-horizontal-max)
(stride 4)
(float double-float)
(0f 0.0d0)))
(defun parse-type-declarations (args type)
(mapcar (lambda (arg)
(cond
((symbolp arg) `(fixnum ,arg))
((= (second arg) 0) `(,type ,(first arg)))
(t `((simple-array ,type ,(second arg)) ,(first arg)))))
args))
(defun parse-body (body mapping)
(mapcar
(lambda (el)
(typecase el
(cons (parse-body el mapping))
(symbol (let* ((match (second (assoc el mapping)))
(sym (if match match el)))
sym))
(t el)))
body))
(defun parse-cond-clauses (type-decs)
`(and ,@(mapcar (lambda (type-dec)
`(typep ,(second type-dec) ',(first type-dec)))
type-decs))))
(defparameter +num-cpus+ (serapeum:count-cpus))
(defun random-range (min max)
(+ (random (- max min)) min))
(defun make-random-array (dimensions &key (min -100) (max 100) (float-type 'single-float))
(declare ((or fixnum (cons fixnum)) dimensions)
(integer min max))
(let ((arr (make-array dimensions :element-type float-type)))
(loop for i from 0 below (array-total-size arr)
do (setf (row-major-aref arr i) (coerce (random-range min max) float-type)))
arr))
(defmacro defblas (name args &body body)
;; ARGS should be of the form ((NAME RANK)+).
;; If RANK is 0, then this NAME is a scalar
;; Also it can just be NAME, which in that case the type is fixnum
(let* ((single-float-type-declarations (parse-type-declarations args 'single-float))
(double-float-type-declarations (parse-type-declarations args 'double-float))
(single-float-name (alexandria:symbolicate name '-single-float))
(double-float-name (alexandria:symbolicate name '-double-float))
(lambda-list (mapcar (lambda (arg) (first (uiop:ensure-list arg))) args)))
(push `(,name ,single-float-name) *single-float-fn-generic*)
(push `(,name ,double-float-name) *double-float-fn-generic*)
`(progn
(defun ,single-float-name ,lambda-list
(declare (optimize . #.*optimize-qualities*)
,@single-float-type-declarations)
,@(parse-body body *single-float-fn-generic*))
(defun ,double-float-name ,lambda-list
(declare (optimize . #.*optimize-qualities*)
,@double-float-type-declarations)
,@(parse-body body *double-float-fn-generic*))
(defun ,name ,lambda-list
(declare (optimize . #.*optimize-qualities*))
(cond
(,(parse-cond-clauses single-float-type-declarations)
(,single-float-name ,@lambda-list))
(,(parse-cond-clauses double-float-type-declarations)
(,double-float-name ,@lambda-list)))))))
(declaim (inline simd-abs-f32.8 simd-abs-f64.4))
(defun simd-abs-f32.8 (pack)
(declare (f32.8 pack))
(f32.8-sqrt (f32.8* pack pack)))
(defun simd-abs-f64.4 (pack)
(declare (f64.4 pack))
(f64.4-sqrt (f64.4* pack pack)))
(push '(simd-abs simd-abs-f32.8) *single-float-fn-generic*)
(push '(simd-abs simd-abs-f64.4) *double-float-fn-generic*)
;;;; Level 1 BLAS: vector, O(n) operations
;;; axpy
(defblas axpy ((alpha 0) (x 1) (y 1))
(let* ((n (array-dimension x 0))
(n-block (* (floor n stride) stride)))
(declare (fixnum n n-block))
(loop with alpha-vec of-type simd = (simd alpha)
for i fixnum from 0 below n-block by 8
do (setf (simd-aref y i)
(simd+ (simd-aref y i)
(simd* alpha-vec
(simd-aref x i)))))
(loop for i fixnum from n-block below n
do (setf (aref y i) (+ (aref y i) (* alpha (aref x i)))))))
;;; scal
(defblas scal ((alpha 0) (x 1))
(let* ((n (array-dimension x 0))
(n-block (* (floor n stride) stride)))
(declare (fixnum n n-block))
(loop with alpha-vec of-type simd = (simd alpha)
for i fixnum from 0 below n-block by 8
do (setf (simd-aref x i)
(simd* alpha-vec
(simd-aref x i))))
(loop for i fixnum from n-block below n
do (setf (aref x i) (* alpha (aref x i))))))
;;; copy
(defblas copy ((x 1) (y 1))
(let* ((n (array-dimension x 0))
(n-block (* (floor n stride) stride)))
(declare (fixnum n n-block))
(loop for i fixnum from 0 below n-block by 8
do (setf (simd-aref y i) (simd-aref x i)))
(loop for i fixnum from n-block below n
do (setf (aref y i) (aref x i)))))
;;; swap
(defblas swap ((x 1) (y 1))
(let* ((n (array-dimension x 0))
(n-block (* (floor n stride) stride)))
(declare (fixnum n n-block))
(loop with tmp of-type simd = (simd 0)
for i fixnum from 0 below n-block by stride
do (setf tmp (simd-aref x i)
(simd-aref x i) (simd-aref y i)
(simd-aref y i) tmp))
(loop with tmp of-type float
for i fixnum from n-block below n
do (setf tmp (aref x i)
(aref x i) (aref y i)
(aref y i) tmp))))
;;; dot
(defblas dot ((x 1) (y 1))11
(let* ((n (array-dimension x 0))
(n-block (* (floor n stride) stride))
(sum 0f))
(declare (fixnum n n-block)
(float sum))
(loop with acc of-type simd = (simd 0)
for i fixnum from 0 below n-block by stride
do (setf acc (simd+ acc (simd* (simd-aref x i) (simd-aref y i))))
finally (setf sum (simd-horizontal+ acc)))
(loop for i fixnum from n-block below n
do (setf sum (+ sum (* (aref x i) (aref y i)))))
sum))
;;; nrm2
(defblas nrm2 ((x 1))
(let* ((n (array-dimension x 0))
(n-block (* (floor n stride) stride))
(norm 0f))
(declare (fixnum n n-block)
(float norm))
(loop with acc of-type simd = (simd 0)
for i fixnum from 0 below n-block by stride
for sel of-type simd = (simd-aref x i)
do (setf acc (simd+ acc (simd* sel sel)))
finally (setf norm (simd-horizontal+ acc)))
(loop for i fixnum from n-block below n
do (setf norm (+ norm (expt (aref x i) 2))))
(the float (sqrt norm))))
;;; asum
(defblas asum ((x 1))
(let* ((n (array-dimension x 0))
(n-block (* (floor n stride) stride))
(norm 0f))
(declare (fixnum n n-block)
(float norm))
(loop with acc of-type simd = (simd 0)
for i fixnum from 0 below n-block by stride
for sel of-type simd = (simd-aref x i)
do (setf acc (simd+ acc (simd-abs sel)))
finally (setf norm (simd-horizontal+ acc)))
(loop for i fixnum from n-block below n
do (setf norm (+ norm (expt (aref x i) 2))))
norm))
;;; i-amax
(defblas i-amax ((x 1))
(let* ((n (array-dimension x 0))
(n-block (* (floor n stride) stride))
(m 0f))
(declare (fixnum n n-block)
(float m))
(loop for i fixnum from 0 below n-block by stride
do (setf m (max m (simd-horizontal-max (simd-abs (simd-aref x i))))))
(loop for i fixnum from n-block below n
do (setf m (max m (abs (aref x i)))))
m))
;;;; Level 2 BLAS: matrix-vector, O(n^2) operations
(defblas transpose ((A 2))
(loop with m fixnum = (array-dimension A 0)
with n fixnum = (array-dimension A 1)
with B of-type (simple-array float 2) = (make-array `(,n ,m) :element-type 'float)
for i fixnum from 0 below m
do (loop for j fixnum from 0 below n
do (setf (aref B j i) (aref A i j)))
finally (return B)))
;;; gemv
(defblas gemv ((A 2) (x 1) (y 1) (alpha 0) (beta 0))
(let* ((n (array-dimension A 1))
(n-block (* (floor n stride) stride))
(alpha-vec (simd alpha)))
(declare (fixnum n n-block)
(simd alpha-vec))
(loop for i fixnum from 0 below (array-dimension A 0)
do (setf (aref y i)
(+ (* beta (aref y i))
(loop for acc of-type simd = (simd 0) then (simd+ acc (simd* alpha-vec (simd-aref A i j) (aref x j)))
for j fixnum from 0 below n-block by stride
finally (return (the float (simd-horizontal+ acc))))
(loop for j fixnum from n-block below n
sum (* alpha (aref A i j) (aref x j)) of-type float))))))
;;; ger
(defblas ger ((A 2) (x 1) (y 1) (alpha 0))
(let* ((m (array-dimension A 0))
(n (array-dimension A 1))
(n-block (* (floor n stride) stride))
(alpha-vec (simd alpha)))
(declare (fixnum m n n-block)
(simd alpha-vec))
(loop for i fixnum from 0 below m
for x-i float = (aref x i)
for x-i-vec of-type simd = (simd x-i)
for alpha-x-i-vec of-type simd = (simd* alpha-vec x-i-vec)
do (loop for j fixnum from 0 below n-block by stride
for y-j-vec of-type simd = (simd-aref y j)
for product of-type simd = (simd* alpha-x-i-vec y-j-vec)
do (setf (simd-aref A i j) (simd+ (simd-aref A i j) product)))
(loop for j fixnum from n-block below n
do (setf (aref A i j) (+ (aref A i j) (* alpha x-i (aref y j))))))))
;;;; Level 3 BLAS: matrix-matrix, O(n^3) operations
;;; gemm
(declaim (inline gemm-row-dot-single-float gemm-row-dot-double-float))
(defblas gemm-row-dot ((A 2) (Bt 2) (C 2) (alpha 0) (beta 0) i n r r-block)
(loop with alpha-vec of-type simd = (simd alpha)
for j fixnum from 0 below n
do (setf (aref C i j)
(+ (* beta (aref C i j))
(loop for acc of-type simd = (simd 0) then (simd+ acc (simd* alpha-vec (simd-aref A i k) (simd-aref Bt j k)))
for k fixnum from 0 below r-block by stride
finally (return (the float (simd-horizontal+ acc))))
(loop for k fixnum from r-block below r
sum (* alpha (aref A i k) (aref Bt j k)) of-type float)))))
(defblas gemm-parallel ((A 2) (B 2) (C 2) (alpha 0) (beta 0))
(setf lparallel:*kernel* (lparallel:make-kernel +num-cpus+))
(let ((channel (lparallel:make-channel)))
(loop with m fixnum = (array-dimension C 0)
with n fixnum = (array-dimension C 1)
with r-block fixnum = (* (floor (array-dimension A 1) stride) stride)
with r fixnum = (array-dimension A 1)
with Bt of-type (simple-array float 2) = (transpose B)
for i fixnum from 0 below m
do (lparallel:submit-task channel #'gemm-row-dot A Bt C alpha beta i n r r-block))
(loop for i from 1 to (array-dimension C 0)
do (lparallel:receive-result channel)))
(lparallel:end-kernel)
nil)
(defblas gemm-serial ((A 2) (B 2) (C 2) (alpha 0) (beta 0))
(loop with m fixnum = (array-dimension C 0)
with n fixnum = (array-dimension C 1)
with r-block fixnum = (* (floor (array-dimension A 1) 8) 8)
with r fixnum = (array-dimension A 1)
with Bt of-type (simple-array float 2) = (transpose B)
for i fixnum from 0 below m
do (gemm-row-dot A Bt C alpha beta i n r r-block)))
(defblas gemm ((A 2) (B 2) (C 2) (alpha 0) (beta 0))
(if (< (array-total-size C) 512)
(gemm-serial A B C alpha beta)
(gemm-parallel A B C alpha beta)))