(in-package :cl-blas) (eval-when (:compile-toplevel :load-toplevel :execute) (defparameter *optimize-qualities* '((speed 3) (safety 0) (debug 0))) (defparameter *single-float-fn-generic* '((simd f32.8) (simd-aref f32.8-aref) (simd+ f32.8+) (simd* f32.8*) (simd-horizontal+ f32.8-horizontal+) (simd-horizontal-max f32.8-horizontal-max) (stride 8) (float single-float) (0f 0.0))) (defparameter *double-float-fn-generic* '((simd f64.4) (simd-aref f64.4-aref) (simd+ f64.4+) (simd* f64.4*) (simd-horizontal+ f64.4-horizontal+) (simd-horizontal-max f64.4-horizontal-max) (stride 4) (float double-float) (0f 0.0d0))) (defun parse-type-declarations (args type) (mapcar (lambda (arg) (cond ((symbolp arg) `(fixnum ,arg)) ((= (second arg) 0) `(,type ,(first arg))) (t `((simple-array ,type ,(second arg)) ,(first arg))))) args)) (defun parse-body (body mapping) (mapcar (lambda (el) (typecase el (cons (parse-body el mapping)) (symbol (let* ((match (second (assoc el mapping))) (sym (if match match el))) sym)) (t el))) body)) (defun parse-cond-clauses (type-decs) `(and ,@(mapcar (lambda (type-dec) `(typep ,(second type-dec) ',(first type-dec))) type-decs)))) (defparameter +num-cpus+ (serapeum:count-cpus)) (defun random-range (min max) (+ (random (- max min)) min)) (defun make-random-array (dimensions &key (min -100) (max 100) (float-type 'single-float)) (declare ((or fixnum (cons fixnum)) dimensions) (integer min max)) (let ((arr (make-array dimensions :element-type float-type))) (loop for i from 0 below (array-total-size arr) do (setf (row-major-aref arr i) (coerce (random-range min max) float-type))) arr)) (defmacro defblas (name args &body body) ;; ARGS should be of the form ((NAME RANK)+). ;; If RANK is 0, then this NAME is a scalar ;; Also it can just be NAME, which in that case the type is fixnum (let* ((single-float-type-declarations (parse-type-declarations args 'single-float)) (double-float-type-declarations (parse-type-declarations args 'double-float)) (single-float-name (alexandria:symbolicate name '-single-float)) (double-float-name (alexandria:symbolicate name '-double-float)) (lambda-list (mapcar (lambda (arg) (first (uiop:ensure-list arg))) args))) (push `(,name ,single-float-name) *single-float-fn-generic*) (push `(,name ,double-float-name) *double-float-fn-generic*) `(progn (defun ,single-float-name ,lambda-list (declare (optimize . #.*optimize-qualities*) ,@single-float-type-declarations) ,@(parse-body body *single-float-fn-generic*)) (defun ,double-float-name ,lambda-list (declare (optimize . #.*optimize-qualities*) ,@double-float-type-declarations) ,@(parse-body body *double-float-fn-generic*)) (defun ,name ,lambda-list (declare (optimize . #.*optimize-qualities*)) (cond (,(parse-cond-clauses single-float-type-declarations) (,single-float-name ,@lambda-list)) (,(parse-cond-clauses double-float-type-declarations) (,double-float-name ,@lambda-list))))))) (declaim (inline simd-abs-f32.8 simd-abs-f64.4)) (defun simd-abs-f32.8 (pack) (declare (f32.8 pack)) (f32.8-sqrt (f32.8* pack pack))) (defun simd-abs-f64.4 (pack) (declare (f64.4 pack)) (f64.4-sqrt (f64.4* pack pack))) (push '(simd-abs simd-abs-f32.8) *single-float-fn-generic*) (push '(simd-abs simd-abs-f64.4) *double-float-fn-generic*) ;;;; Level 1 BLAS: vector, O(n) operations ;;; axpy (defblas axpy ((alpha 0) (x 1) (y 1)) (let* ((n (array-dimension x 0)) (n-block (* (floor n stride) stride))) (declare (fixnum n n-block)) (loop with alpha-vec of-type simd = (simd alpha) for i fixnum from 0 below n-block by 8 do (setf (simd-aref y i) (simd+ (simd-aref y i) (simd* alpha-vec (simd-aref x i))))) (loop for i fixnum from n-block below n do (setf (aref y i) (+ (aref y i) (* alpha (aref x i))))))) ;;; scal (defblas scal ((alpha 0) (x 1)) (let* ((n (array-dimension x 0)) (n-block (* (floor n stride) stride))) (declare (fixnum n n-block)) (loop with alpha-vec of-type simd = (simd alpha) for i fixnum from 0 below n-block by 8 do (setf (simd-aref x i) (simd* alpha-vec (simd-aref x i)))) (loop for i fixnum from n-block below n do (setf (aref x i) (* alpha (aref x i)))))) ;;; copy (defblas copy ((x 1) (y 1)) (let* ((n (array-dimension x 0)) (n-block (* (floor n stride) stride))) (declare (fixnum n n-block)) (loop for i fixnum from 0 below n-block by 8 do (setf (simd-aref y i) (simd-aref x i))) (loop for i fixnum from n-block below n do (setf (aref y i) (aref x i))))) ;;; swap (defblas swap ((x 1) (y 1)) (let* ((n (array-dimension x 0)) (n-block (* (floor n stride) stride))) (declare (fixnum n n-block)) (loop with tmp of-type simd = (simd 0) for i fixnum from 0 below n-block by stride do (setf tmp (simd-aref x i) (simd-aref x i) (simd-aref y i) (simd-aref y i) tmp)) (loop with tmp of-type float for i fixnum from n-block below n do (setf tmp (aref x i) (aref x i) (aref y i) (aref y i) tmp)))) ;;; dot (defblas dot ((x 1) (y 1))11 (let* ((n (array-dimension x 0)) (n-block (* (floor n stride) stride)) (sum 0f)) (declare (fixnum n n-block) (float sum)) (loop with acc of-type simd = (simd 0) for i fixnum from 0 below n-block by stride do (setf acc (simd+ acc (simd* (simd-aref x i) (simd-aref y i)))) finally (setf sum (simd-horizontal+ acc))) (loop for i fixnum from n-block below n do (setf sum (+ sum (* (aref x i) (aref y i))))) sum)) ;;; nrm2 (defblas nrm2 ((x 1)) (let* ((n (array-dimension x 0)) (n-block (* (floor n stride) stride)) (norm 0f)) (declare (fixnum n n-block) (float norm)) (loop with acc of-type simd = (simd 0) for i fixnum from 0 below n-block by stride for sel of-type simd = (simd-aref x i) do (setf acc (simd+ acc (simd* sel sel))) finally (setf norm (simd-horizontal+ acc))) (loop for i fixnum from n-block below n do (setf norm (+ norm (expt (aref x i) 2)))) (the float (sqrt norm)))) ;;; asum (defblas asum ((x 1)) (let* ((n (array-dimension x 0)) (n-block (* (floor n stride) stride)) (norm 0f)) (declare (fixnum n n-block) (float norm)) (loop with acc of-type simd = (simd 0) for i fixnum from 0 below n-block by stride for sel of-type simd = (simd-aref x i) do (setf acc (simd+ acc (simd-abs sel))) finally (setf norm (simd-horizontal+ acc))) (loop for i fixnum from n-block below n do (setf norm (+ norm (expt (aref x i) 2)))) norm)) ;;; i-amax (defblas i-amax ((x 1)) (let* ((n (array-dimension x 0)) (n-block (* (floor n stride) stride)) (m 0f)) (declare (fixnum n n-block) (float m)) (loop for i fixnum from 0 below n-block by stride do (setf m (max m (simd-horizontal-max (simd-abs (simd-aref x i)))))) (loop for i fixnum from n-block below n do (setf m (max m (abs (aref x i))))) m)) ;;;; Level 2 BLAS: matrix-vector, O(n^2) operations (defblas transpose ((A 2)) (loop with m fixnum = (array-dimension A 0) with n fixnum = (array-dimension A 1) with B of-type (simple-array float 2) = (make-array `(,n ,m) :element-type 'float) for i fixnum from 0 below m do (loop for j fixnum from 0 below n do (setf (aref B j i) (aref A i j))) finally (return B))) ;;; gemv (defblas gemv ((A 2) (x 1) (y 1) (alpha 0) (beta 0)) (let* ((n (array-dimension A 1)) (n-block (* (floor n stride) stride)) (alpha-vec (simd alpha))) (declare (fixnum n n-block) (simd alpha-vec)) (loop for i fixnum from 0 below (array-dimension A 0) do (setf (aref y i) (+ (* beta (aref y i)) (loop for acc of-type simd = (simd 0) then (simd+ acc (simd* alpha-vec (simd-aref A i j) (aref x j))) for j fixnum from 0 below n-block by stride finally (return (the float (simd-horizontal+ acc)))) (loop for j fixnum from n-block below n sum (* alpha (aref A i j) (aref x j)) of-type float)))))) ;;; ger (defblas ger ((A 2) (x 1) (y 1) (alpha 0)) (let* ((m (array-dimension A 0)) (n (array-dimension A 1)) (n-block (* (floor n stride) stride)) (alpha-vec (simd alpha))) (declare (fixnum m n n-block) (simd alpha-vec)) (loop for i fixnum from 0 below m for x-i float = (aref x i) for x-i-vec of-type simd = (simd x-i) for alpha-x-i-vec of-type simd = (simd* alpha-vec x-i-vec) do (loop for j fixnum from 0 below n-block by stride for y-j-vec of-type simd = (simd-aref y j) for product of-type simd = (simd* alpha-x-i-vec y-j-vec) do (setf (simd-aref A i j) (simd+ (simd-aref A i j) product))) (loop for j fixnum from n-block below n do (setf (aref A i j) (+ (aref A i j) (* alpha x-i (aref y j)))))))) ;;;; Level 3 BLAS: matrix-matrix, O(n^3) operations ;;; gemm (declaim (inline gemm-row-dot-single-float gemm-row-dot-double-float)) (defblas gemm-row-dot ((A 2) (Bt 2) (C 2) (alpha 0) (beta 0) i n r r-block) (loop with alpha-vec of-type simd = (simd alpha) for j fixnum from 0 below n do (setf (aref C i j) (+ (* beta (aref C i j)) (loop for acc of-type simd = (simd 0) then (simd+ acc (simd* alpha-vec (simd-aref A i k) (simd-aref Bt j k))) for k fixnum from 0 below r-block by stride finally (return (the float (simd-horizontal+ acc)))) (loop for k fixnum from r-block below r sum (* alpha (aref A i k) (aref Bt j k)) of-type float))))) (defblas gemm-parallel ((A 2) (B 2) (C 2) (alpha 0) (beta 0)) (setf lparallel:*kernel* (lparallel:make-kernel +num-cpus+)) (let ((channel (lparallel:make-channel))) (loop with m fixnum = (array-dimension C 0) with n fixnum = (array-dimension C 1) with r-block fixnum = (* (floor (array-dimension A 1) stride) stride) with r fixnum = (array-dimension A 1) with Bt of-type (simple-array float 2) = (transpose B) for i fixnum from 0 below m do (lparallel:submit-task channel #'gemm-row-dot A Bt C alpha beta i n r r-block)) (loop for i from 1 to (array-dimension C 0) do (lparallel:receive-result channel))) (lparallel:end-kernel) nil) (defblas gemm-serial ((A 2) (B 2) (C 2) (alpha 0) (beta 0)) (loop with m fixnum = (array-dimension C 0) with n fixnum = (array-dimension C 1) with r-block fixnum = (* (floor (array-dimension A 1) 8) 8) with r fixnum = (array-dimension A 1) with Bt of-type (simple-array float 2) = (transpose B) for i fixnum from 0 below m do (gemm-row-dot A Bt C alpha beta i n r r-block))) (defblas gemm ((A 2) (B 2) (C 2) (alpha 0) (beta 0)) (if (< (array-total-size C) 512) (gemm-serial A B C alpha beta) (gemm-parallel A B C alpha beta)))