diff --git a/src/cl-blas.lisp b/src/cl-blas.lisp index 4211782..ca4c65c 100644 --- a/src/cl-blas.lisp +++ b/src/cl-blas.lisp @@ -4,7 +4,7 @@ (defparameter *optimize-qualities* '((speed 3) (safety 0) (debug 0))) - (defparameter +single-float-simd-generic+ + (defparameter *single-float-fn-generic* '((simd f32.8) (simd-aref f32.8-aref) (simd+ f32.8+) @@ -13,7 +13,7 @@ (stride 8) (float single-float) (0f 0.0))) - (defparameter +double-float-simd-generic+ + (defparameter *double-float-fn-generic* '((simd f64.4) (simd-aref f64.4-aref) (simd+ f64.4+) @@ -24,22 +24,21 @@ (0f 0.0d0))) (defun parse-type-declarations (args type) (mapcar (lambda (arg) - (if (= (second arg) 0) - `(,type ,(first arg)) - `((simple-array ,type ,(second arg)) ,(first arg)))) + (cond + ((symbolp arg) `(fixnum ,arg)) + ((= (second arg) 0) `(,type ,(first arg))) + (t `((simple-array ,type ,(second arg)) ,(first arg))))) args)) (defun parse-body (body mapping) - (if (constantp body) - body - (mapcar - (lambda (el) - (typecase el - (cons (parse-body el mapping)) - (symbol (let* ((match (second (assoc el mapping))) - (sym (if match match el))) - sym)) - (t el))) - body))) + (mapcar + (lambda (el) + (typecase el + (cons (parse-body el mapping)) + (symbol (let* ((match (second (assoc el mapping))) + (sym (if match match el))) + sym)) + (t el))) + body)) (defun parse-cond-clauses (type-decs) `(and ,@(mapcar (lambda (type-dec) `(typep ,(second type-dec) ',(first type-dec))) @@ -58,18 +57,6 @@ do (setf (row-major-aref arr i) (coerce (random-range min max) float-type))) arr)) -(defun transpose (A) - "Transpose A of type (simple-array single-float 2)" - (declare ((simple-array single-float 2) A) - (optimize . #.*optimize-qualities*)) - (loop with m fixnum = (array-dimension A 0) - with n fixnum = (array-dimension A 1) - with B of-type (simple-array single-float 2) = (make-array `(,n ,m) :element-type 'single-float) - for i fixnum from 0 below m - do (loop for j fixnum from 0 below n - do (setf (aref B j i) (aref A i j))) - finally (return B))) - (defmacro defblas (name args &body body) ;; ARGS should be of the form ((NAME RANK)+). ;; If RANK is 0, then this NAME is a scalar @@ -77,16 +64,18 @@ (double-float-type-declarations (parse-type-declarations args 'double-float)) (single-float-name (alexandria:symbolicate name '-single-float)) (double-float-name (alexandria:symbolicate name '-double-float)) - (lambda-list (mapcar #'first args))) + (lambda-list (mapcar (lambda (arg) (first (uiop:ensure-list arg))) args))) + (push `(,name ,single-float-name) *single-float-fn-generic*) + (push `(,name ,double-float-name) *double-float-fn-generic*) `(progn (defun ,single-float-name ,lambda-list (declare (optimize . #.*optimize-qualities*) ,@single-float-type-declarations) - ,@(parse-body body +single-float-simd-generic+)) + ,@(parse-body body *single-float-fn-generic*)) (defun ,double-float-name ,lambda-list (declare (optimize . #.*optimize-qualities*) ,@double-float-type-declarations) - ,@(parse-body body +double-float-simd-generic+)) + ,@(parse-body body *double-float-fn-generic*)) (defun ,name ,lambda-list (declare (optimize . #.*optimize-qualities*)) (cond @@ -206,23 +195,10 @@ ;;;; Level 3 BLAS: matrix-matrix, O(n^3) operations -(defun transpose-single-float (A) - (declare ((simple-array single-float 2) A) - (optimize . #.*optimize-qualities*)) +(defblas transpose ((A 2)) (loop with m fixnum = (array-dimension A 0) with n fixnum = (array-dimension A 1) - with B of-type (simple-array single-float 2) = (make-array `(,n ,m) :element-type 'single-float) - for i fixnum from 0 below m - do (loop for j fixnum from 0 below n - do (setf (aref B j i) (aref A i j))) - finally (return B))) - -(defun transpose-double-float (A) - (declare ((simple-array double-float 2) A) - (optimize . #.*optimize-qualities*)) - (loop with m fixnum = (array-dimension A 0) - with n fixnum = (array-dimension A 1) - with B of-type (simple-array double-float 2) = (make-array `(,n ,m) :element-type 'double-float) + with B of-type (simple-array float 2) = (make-array `(,n ,m) :element-type 'float) for i fixnum from 0 below m do (loop for j fixnum from 0 below n do (setf (aref B j i) (aref A i j))) @@ -231,123 +207,42 @@ ;;; gemm (declaim (inline gemm-row-dot-single-float gemm-row-dot-double-float)) -(defun gemm-row-dot-single-float (A Bt C alpha beta i n r r-block) - (declare ((simple-array single-float 2) A Bt C) - (single-float alpha beta) - (fixnum i n r r-block) - (optimize . #.*optimize-qualities*)) - (loop with alpha-vec of-type f32.8 = (f32.8 alpha) +(defblas gemm-row-dot ((A 2) (Bt 2) (C 2) (alpha 0) (beta 0) i n r r-block) + (loop with alpha-vec of-type simd = (simd alpha) for j fixnum from 0 below n do (setf (aref C i j) (+ (* beta (aref C i j)) - (loop for acc of-type f32.8 = (f32.8 0) then (f32.8+ acc (f32.8* alpha-vec (f32.8-aref A i k) (f32.8-aref Bt j k))) - for k fixnum from 0 below r-block by 8 - finally (return (the single-float (f32.8-horizontal+ acc)))) + (loop for acc of-type simd = (simd 0) then (simd+ acc (simd* alpha-vec (simd-aref A i k) (simd-aref Bt j k))) + for k fixnum from 0 below r-block by stride + finally (return (the float (simd-horizontal+ acc)))) (loop for k fixnum from r-block below r - sum (* alpha (aref A i k) (aref Bt j k)) of-type single-float))))) + sum (* alpha (aref A i k) (aref Bt j k)) of-type float))))) -(defun gemm-row-dot-double-float (A Bt C alpha beta i n r r-block) - (declare ((simple-array double-float 2) A Bt C) - (double-float alpha beta) - (fixnum i n r r-block) - (optimize . #.*optimize-qualities*)) - (loop with alpha-vec of-type f64.4 = (f64.4 alpha) - for j fixnum from 0 below n - do (setf (aref C i j) - (+ (* beta (aref C i j)) - (loop for acc of-type f64.4 = (f64.4 0) then (f64.4+ acc (f64.4* alpha-vec (f64.4-aref A i k) (f64.4-aref Bt j k))) - for k fixnum from 0 below r-block by 4 - finally (return (the double-float (f64.4-horizontal+ acc)))) - (loop for k fixnum from r-block below r - sum (* alpha (aref A i k) (aref Bt j k)) of-type double-float))))) - -(defun gemm-parallel-single-float (A B C alpha beta) - (declare ((simple-array single-float 2) A B C) - (single-float alpha beta) - (optimize . #.*optimize-qualities*)) +(defblas gemm-parallel ((A 2) (B 2) (C 2) (alpha 0) (beta 0)) (setf lparallel:*kernel* (lparallel:make-kernel +num-cpus+)) (let ((channel (lparallel:make-channel))) (loop with m fixnum = (array-dimension C 0) with n fixnum = (array-dimension C 1) - with r-block fixnum = (* (floor (array-dimension A 1) 8) 8) + with r-block fixnum = (* (floor (array-dimension A 1) stride) stride) with r fixnum = (array-dimension A 1) - with Bt of-type (simple-array single-float 2) = (transpose-single-float B) + with Bt of-type (simple-array float 2) = (transpose B) for i fixnum from 0 below m - do (lparallel:submit-task channel #'gemm-row-dot-single-float A Bt C alpha beta i n r r-block)) + do (lparallel:submit-task channel #'gemm-row-dot A Bt C alpha beta i n r r-block)) (loop for i from 1 to (array-dimension C 0) do (lparallel:receive-result channel))) (lparallel:end-kernel) nil) -(defun gemm-parallel-double-float (A B C alpha beta) - (declare ((simple-array double-float 2) A B C) - (double-float alpha beta) - (optimize . #.*optimize-qualities*)) - (setf lparallel:*kernel* (lparallel:make-kernel +num-cpus+)) - (let ((channel (lparallel:make-channel))) - (loop with m fixnum = (array-dimension C 0) - with n fixnum = (array-dimension C 1) - with r-block fixnum = (* (floor (array-dimension A 1) 4) 4) - with r fixnum = (array-dimension A 1) - with Bt of-type (simple-array double-float 2) = (transpose-double-float B) - for i fixnum from 0 below m - do (lparallel:submit-task channel #'gemm-row-dot-double-float A Bt C alpha beta i n r r-block)) - (loop for i from 1 to (array-dimension C 0) - do (lparallel:receive-result channel))) - (lparallel:end-kernel) - nil) - -(defun gemm-serial-single-float (A B C alpha beta) - (declare ((simple-array single-float 2) A B C) - (single-float alpha beta) - (optimize . #.*optimize-qualities*)) +(defblas gemm-serial ((A 2) (B 2) (C 2) (alpha 0) (beta 0)) (loop with m fixnum = (array-dimension C 0) with n fixnum = (array-dimension C 1) with r-block fixnum = (* (floor (array-dimension A 1) 8) 8) with r fixnum = (array-dimension A 1) - with Bt of-type (simple-array single-float 2) = (transpose-single-float B) + with Bt of-type (simple-array float 2) = (transpose B) for i fixnum from 0 below m - do (gemm-row-dot-single-float A Bt C alpha beta i n r r-block))) + do (gemm-row-dot A Bt C alpha beta i n r r-block))) -(defun gemm-serial-double-float (A B C alpha beta) - (declare ((simple-array double-float 2) A B C) - (double-float alpha beta) - (optimize . #.*optimize-qualities*)) - (loop with m fixnum = (array-dimension C 0) - with n fixnum = (array-dimension C 1) - with r-block fixnum = (* (floor (array-dimension A 1) 4) 4) - with r fixnum = (array-dimension A 1) - with Bt of-type (simple-array double-float 2) = (transpose-double-float B) - for i fixnum from 0 below m - do (gemm-row-dot-double-float A Bt C alpha beta i n r r-block))) - -(defun gemm-single-float (A B C alpha beta) - (declare ((simple-array single-float 2) A B C) - (single-float alpha beta) - (optimize (speed 3))) +(defblas gemm ((A 2) (B 2) (C 2) (alpha 0) (beta 0)) (if (< (array-total-size C) 512) - (gemm-serial-single-float A B C alpha beta) - (gemm-parallel-single-float A B C alpha beta))) - -(defun gemm-double-float (A B C alpha beta) - (declare ((simple-array double-float 2) A B C) - (double-float alpha beta) - (optimize (speed 3))) - (if (< (array-total-size C) 512) - (gemm-serial-double-float A B C alpha beta) - (gemm-parallel-double-float A B C alpha beta))) - -(defun gemm (A B C alpha beta) - (cond - ((and (typep A '(simple-array single-float 2)) - (typep B '(simple-array single-float 2)) - (typep C '(simple-array single-float 2)) - (typep alpha 'single-float) - (typep beta 'single-float)) - (gemm-single-float A B C alpha beta)) - ((and (typep A '(simple-array double-float 2)) - (typep B '(simple-array double-float 2)) - (typep C '(simple-array double-float 2)) - (typep alpha 'double-float) - (typep beta 'double-float)) - (gemm-double-float A B C alpha beta)))) + (gemm-serial A B C alpha beta) + (gemm-parallel A B C alpha beta)))