317 lines
12 KiB
Common Lisp
317 lines
12 KiB
Common Lisp
(in-package :cl-blas)
|
|
|
|
(eval-when (:compile-toplevel :load-toplevel :execute)
|
|
(defparameter *optimize-qualities* '((speed 3)
|
|
(safety 0)
|
|
(debug 0)))
|
|
(defparameter *single-float-fn-generic*
|
|
'((simd f32.8)
|
|
(simd-aref f32.8-aref)
|
|
(simd+ f32.8+)
|
|
(simd* f32.8*)
|
|
(simd-horizontal+ f32.8-horizontal+)
|
|
(simd-horizontal-max f32.8-horizontal-max)
|
|
(stride 8)
|
|
(float single-float)
|
|
(0f 0.0)))
|
|
(defparameter *double-float-fn-generic*
|
|
'((simd f64.4)
|
|
(simd-aref f64.4-aref)
|
|
(simd+ f64.4+)
|
|
(simd* f64.4*)
|
|
(simd-horizontal+ f64.4-horizontal+)
|
|
(simd-horizontal-max f64.4-horizontal-max)
|
|
(stride 4)
|
|
(float double-float)
|
|
(0f 0.0d0)))
|
|
(defun parse-type-declarations (args type)
|
|
(mapcar (lambda (arg)
|
|
(cond
|
|
((symbolp arg) `(fixnum ,arg))
|
|
((= (second arg) 0) `(,type ,(first arg)))
|
|
(t `((simple-array ,type ,(second arg)) ,(first arg)))))
|
|
args))
|
|
(defun parse-body (body mapping)
|
|
(mapcar
|
|
(lambda (el)
|
|
(typecase el
|
|
(cons (parse-body el mapping))
|
|
(symbol (let* ((match (second (assoc el mapping)))
|
|
(sym (if match match el)))
|
|
sym))
|
|
(t el)))
|
|
body))
|
|
(defun parse-cond-clauses (type-decs)
|
|
`(and ,@(mapcar (lambda (type-dec)
|
|
`(typep ,(second type-dec) ',(first type-dec)))
|
|
type-decs))))
|
|
|
|
(defparameter +num-cpus+ (serapeum:count-cpus))
|
|
|
|
(defun random-range (min max)
|
|
(+ (random (- max min)) min))
|
|
|
|
(defun make-random-array (dimensions &key (min -100) (max 100) (float-type 'single-float))
|
|
(declare ((or fixnum (cons fixnum)) dimensions)
|
|
(integer min max))
|
|
(let ((arr (make-array dimensions :element-type float-type)))
|
|
(loop for i from 0 below (array-total-size arr)
|
|
do (setf (row-major-aref arr i) (coerce (random-range min max) float-type)))
|
|
arr))
|
|
|
|
(defmacro defblas (name args &body body)
|
|
;; ARGS should be of the form ((NAME RANK)+).
|
|
;; If RANK is 0, then this NAME is a scalar
|
|
;; Also it can just be NAME, which in that case the type is fixnum
|
|
(let* ((single-float-type-declarations (parse-type-declarations args 'single-float))
|
|
(double-float-type-declarations (parse-type-declarations args 'double-float))
|
|
(single-float-name (alexandria:symbolicate name '-single-float))
|
|
(double-float-name (alexandria:symbolicate name '-double-float))
|
|
(lambda-list (mapcar (lambda (arg) (first (uiop:ensure-list arg))) args)))
|
|
(push `(,name ,single-float-name) *single-float-fn-generic*)
|
|
(push `(,name ,double-float-name) *double-float-fn-generic*)
|
|
`(progn
|
|
(defun ,single-float-name ,lambda-list
|
|
(declare (optimize . #.*optimize-qualities*)
|
|
,@single-float-type-declarations)
|
|
,@(parse-body body *single-float-fn-generic*))
|
|
(defun ,double-float-name ,lambda-list
|
|
(declare (optimize . #.*optimize-qualities*)
|
|
,@double-float-type-declarations)
|
|
,@(parse-body body *double-float-fn-generic*))
|
|
(defun ,name ,lambda-list
|
|
(declare (optimize . #.*optimize-qualities*))
|
|
(cond
|
|
(,(parse-cond-clauses single-float-type-declarations)
|
|
(,single-float-name ,@lambda-list))
|
|
(,(parse-cond-clauses double-float-type-declarations)
|
|
(,double-float-name ,@lambda-list)))))))
|
|
|
|
(declaim (inline simd-abs-f32.8 simd-abs-f64.4))
|
|
|
|
(defun simd-abs-f32.8 (pack)
|
|
(declare (f32.8 pack))
|
|
(f32.8-sqrt (f32.8* pack pack)))
|
|
|
|
(defun simd-abs-f64.4 (pack)
|
|
(declare (f64.4 pack))
|
|
(f64.4-sqrt (f64.4* pack pack)))
|
|
|
|
(push '(simd-abs simd-abs-f32.8) *single-float-fn-generic*)
|
|
(push '(simd-abs simd-abs-f64.4) *double-float-fn-generic*)
|
|
|
|
;;;; Level 1 BLAS: vector, O(n) operations
|
|
|
|
;;; axpy
|
|
|
|
(defblas axpy ((alpha 0) (x 1) (y 1))
|
|
(let* ((n (array-dimension x 0))
|
|
(n-block (* (floor n stride) stride)))
|
|
(declare (fixnum n n-block))
|
|
(loop with alpha-vec of-type simd = (simd alpha)
|
|
for i fixnum from 0 below n-block by 8
|
|
do (setf (simd-aref y i)
|
|
(simd+ (simd-aref y i)
|
|
(simd* alpha-vec
|
|
(simd-aref x i)))))
|
|
(loop for i fixnum from n-block below n
|
|
do (setf (aref y i) (+ (aref y i) (* alpha (aref x i)))))))
|
|
|
|
;;; scal
|
|
|
|
(defblas scal ((alpha 0) (x 1))
|
|
(let* ((n (array-dimension x 0))
|
|
(n-block (* (floor n stride) stride)))
|
|
(declare (fixnum n n-block))
|
|
(loop with alpha-vec of-type simd = (simd alpha)
|
|
for i fixnum from 0 below n-block by 8
|
|
do (setf (simd-aref x i)
|
|
(simd* alpha-vec
|
|
(simd-aref x i))))
|
|
(loop for i fixnum from n-block below n
|
|
do (setf (aref x i) (* alpha (aref x i))))))
|
|
|
|
;;; copy
|
|
|
|
(defblas copy ((x 1) (y 1))
|
|
(let* ((n (array-dimension x 0))
|
|
(n-block (* (floor n stride) stride)))
|
|
(declare (fixnum n n-block))
|
|
(loop for i fixnum from 0 below n-block by 8
|
|
do (setf (simd-aref y i) (simd-aref x i)))
|
|
(loop for i fixnum from n-block below n
|
|
do (setf (aref y i) (aref x i)))))
|
|
|
|
;;; swap
|
|
|
|
(defblas swap ((x 1) (y 1))
|
|
(let* ((n (array-dimension x 0))
|
|
(n-block (* (floor n stride) stride)))
|
|
(declare (fixnum n n-block))
|
|
(loop with tmp of-type simd = (simd 0)
|
|
for i fixnum from 0 below n-block by stride
|
|
do (setf tmp (simd-aref x i)
|
|
(simd-aref x i) (simd-aref y i)
|
|
(simd-aref y i) tmp))
|
|
(loop with tmp of-type float
|
|
for i fixnum from n-block below n
|
|
do (setf tmp (aref x i)
|
|
(aref x i) (aref y i)
|
|
(aref y i) tmp))))
|
|
|
|
;;; dot
|
|
|
|
(defblas dot ((x 1) (y 1))11
|
|
(let* ((n (array-dimension x 0))
|
|
(n-block (* (floor n stride) stride))
|
|
(sum 0f))
|
|
(declare (fixnum n n-block)
|
|
(float sum))
|
|
(loop with acc of-type simd = (simd 0)
|
|
for i fixnum from 0 below n-block by stride
|
|
do (setf acc (simd+ acc (simd* (simd-aref x i) (simd-aref y i))))
|
|
finally (setf sum (simd-horizontal+ acc)))
|
|
(loop for i fixnum from n-block below n
|
|
do (setf sum (+ sum (* (aref x i) (aref y i)))))
|
|
sum))
|
|
|
|
;;; nrm2
|
|
|
|
(defblas nrm2 ((x 1))
|
|
(let* ((n (array-dimension x 0))
|
|
(n-block (* (floor n stride) stride))
|
|
(norm 0f))
|
|
(declare (fixnum n n-block)
|
|
(float norm))
|
|
(loop with acc of-type simd = (simd 0)
|
|
for i fixnum from 0 below n-block by stride
|
|
for sel of-type simd = (simd-aref x i)
|
|
do (setf acc (simd+ acc (simd* sel sel)))
|
|
finally (setf norm (simd-horizontal+ acc)))
|
|
(loop for i fixnum from n-block below n
|
|
do (setf norm (+ norm (expt (aref x i) 2))))
|
|
(the float (sqrt norm))))
|
|
|
|
;;; asum
|
|
|
|
(defblas asum ((x 1))
|
|
(let* ((n (array-dimension x 0))
|
|
(n-block (* (floor n stride) stride))
|
|
(norm 0f))
|
|
(declare (fixnum n n-block)
|
|
(float norm))
|
|
(loop with acc of-type simd = (simd 0)
|
|
for i fixnum from 0 below n-block by stride
|
|
for sel of-type simd = (simd-aref x i)
|
|
do (setf acc (simd+ acc (simd-abs sel)))
|
|
finally (setf norm (simd-horizontal+ acc)))
|
|
(loop for i fixnum from n-block below n
|
|
do (setf norm (+ norm (expt (aref x i) 2))))
|
|
norm))
|
|
|
|
;;; i-amax
|
|
|
|
(defblas i-amax ((x 1))
|
|
(let* ((n (array-dimension x 0))
|
|
(n-block (* (floor n stride) stride))
|
|
(m 0f))
|
|
(declare (fixnum n n-block)
|
|
(float m))
|
|
(loop for i fixnum from 0 below n-block by stride
|
|
do (setf m (max m (simd-horizontal-max (simd-abs (simd-aref x i))))))
|
|
(loop for i fixnum from n-block below n
|
|
do (setf m (max m (abs (aref x i)))))
|
|
m))
|
|
|
|
;;;; Level 2 BLAS: matrix-vector, O(n^2) operations
|
|
|
|
(defblas transpose ((A 2))
|
|
(loop with m fixnum = (array-dimension A 0)
|
|
with n fixnum = (array-dimension A 1)
|
|
with B of-type (simple-array float 2) = (make-array `(,n ,m) :element-type 'float)
|
|
for i fixnum from 0 below m
|
|
do (loop for j fixnum from 0 below n
|
|
do (setf (aref B j i) (aref A i j)))
|
|
finally (return B)))
|
|
|
|
;;; gemv
|
|
|
|
(defblas gemv ((A 2) (x 1) (y 1) (alpha 0) (beta 0))
|
|
(let* ((n (array-dimension A 1))
|
|
(n-block (* (floor n stride) stride))
|
|
(alpha-vec (simd alpha)))
|
|
(declare (fixnum n n-block)
|
|
(simd alpha-vec))
|
|
(loop for i fixnum from 0 below (array-dimension A 0)
|
|
do (setf (aref y i)
|
|
(+ (* beta (aref y i))
|
|
(loop for acc of-type simd = (simd 0) then (simd+ acc (simd* alpha-vec (simd-aref A i j) (aref x j)))
|
|
for j fixnum from 0 below n-block by stride
|
|
finally (return (the float (simd-horizontal+ acc))))
|
|
(loop for j fixnum from n-block below n
|
|
sum (* alpha (aref A i j) (aref x j)) of-type float))))))
|
|
|
|
;;; ger
|
|
|
|
(defblas ger ((A 2) (x 1) (y 1) (alpha 0))
|
|
(let* ((m (array-dimension A 0))
|
|
(n (array-dimension A 1))
|
|
(n-block (* (floor n stride) stride))
|
|
(alpha-vec (simd alpha)))
|
|
(declare (fixnum m n n-block)
|
|
(simd alpha-vec))
|
|
(loop for i fixnum from 0 below m
|
|
for x-i float = (aref x i)
|
|
for x-i-vec of-type simd = (simd x-i)
|
|
for alpha-x-i-vec of-type simd = (simd* alpha-vec x-i-vec)
|
|
do (loop for j fixnum from 0 below n-block by stride
|
|
for y-j-vec of-type simd = (simd-aref y j)
|
|
for product of-type simd = (simd* alpha-x-i-vec y-j-vec)
|
|
do (setf (simd-aref A i j) (simd+ (simd-aref A i j) product)))
|
|
(loop for j fixnum from n-block below n
|
|
do (setf (aref A i j) (+ (aref A i j) (* alpha x-i (aref y j))))))))
|
|
|
|
;;;; Level 3 BLAS: matrix-matrix, O(n^3) operations
|
|
|
|
;;; gemm
|
|
|
|
(declaim (inline gemm-row-dot-single-float gemm-row-dot-double-float))
|
|
(defblas gemm-row-dot ((A 2) (Bt 2) (C 2) (alpha 0) (beta 0) i n r r-block)
|
|
(loop with alpha-vec of-type simd = (simd alpha)
|
|
for j fixnum from 0 below n
|
|
do (setf (aref C i j)
|
|
(+ (* beta (aref C i j))
|
|
(loop for acc of-type simd = (simd 0) then (simd+ acc (simd* alpha-vec (simd-aref A i k) (simd-aref Bt j k)))
|
|
for k fixnum from 0 below r-block by stride
|
|
finally (return (the float (simd-horizontal+ acc))))
|
|
(loop for k fixnum from r-block below r
|
|
sum (* alpha (aref A i k) (aref Bt j k)) of-type float)))))
|
|
|
|
(defblas gemm-parallel ((A 2) (B 2) (C 2) (alpha 0) (beta 0))
|
|
(setf lparallel:*kernel* (lparallel:make-kernel +num-cpus+))
|
|
(let ((channel (lparallel:make-channel)))
|
|
(loop with m fixnum = (array-dimension C 0)
|
|
with n fixnum = (array-dimension C 1)
|
|
with r-block fixnum = (* (floor (array-dimension A 1) stride) stride)
|
|
with r fixnum = (array-dimension A 1)
|
|
with Bt of-type (simple-array float 2) = (transpose B)
|
|
for i fixnum from 0 below m
|
|
do (lparallel:submit-task channel #'gemm-row-dot A Bt C alpha beta i n r r-block))
|
|
(loop for i from 1 to (array-dimension C 0)
|
|
do (lparallel:receive-result channel)))
|
|
(lparallel:end-kernel)
|
|
nil)
|
|
|
|
(defblas gemm-serial ((A 2) (B 2) (C 2) (alpha 0) (beta 0))
|
|
(loop with m fixnum = (array-dimension C 0)
|
|
with n fixnum = (array-dimension C 1)
|
|
with r-block fixnum = (* (floor (array-dimension A 1) 8) 8)
|
|
with r fixnum = (array-dimension A 1)
|
|
with Bt of-type (simple-array float 2) = (transpose B)
|
|
for i fixnum from 0 below m
|
|
do (gemm-row-dot A Bt C alpha beta i n r r-block)))
|
|
|
|
(defblas gemm ((A 2) (B 2) (C 2) (alpha 0) (beta 0))
|
|
(if (< (array-total-size C) 512)
|
|
(gemm-serial A B C alpha beta)
|
|
(gemm-parallel A B C alpha beta)))
|