Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions pkg/dataloader/dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@ type Batch struct {
// DataLoader — итератор для загрузки данных мини-батчами.
// Отвечает за батчинг, перемешивание и итерацию по Dataset.
type DataLoader struct {
dataset Dataset // Источник данных
batchSize int // Размер мини-батча
shuffle bool // Перемешивать ли данные перед каждой эпохой
dropLast bool // Отбрасывать ли последний неполный батч
rng *rand.Rand // Генератор случайных чисел для shuffle
dataset Dataset // Источник данных
batchSize int // Размер мини-батча
shuffle bool // Перемешивать ли данные перед каждой эпохой
dropLast bool // Отбрасывать ли последний неполный батч
rng *rand.Rand // Генератор случайных чисел для shuffle

// Внутреннее состояние итератора
indices []int // Порядок индексов для текущей эпохи
currentIdx int // Текущая позиция в indices
indices []int // Порядок индексов для текущей эпохи
currentIdx int // Текущая позиция в indices
}

// DataLoaderConfig — конфигурация для создания DataLoader.
type DataLoaderConfig struct {
BatchSize int // Размер батча (обязательный)
Shuffle bool // Перемешивать данные (по умолчанию false)
DropLast bool // Отбрасывать последний неполный батч (по умолчанию false)
Seed int64 // Seed для генератора случайных чисел (по умолчанию 0)
BatchSize int // Размер батча (обязательный)
Shuffle bool // Перемешивать данные (по умолчанию false)
DropLast bool // Отбрасывать последний неполный батч (по умолчанию false)
Seed int64 // Seed для генератора случайных чисел (по умолчанию 0)
}

// NewDataLoader создает новый DataLoader с заданной конфигурацией.
Expand All @@ -42,12 +42,13 @@ type DataLoaderConfig struct {
// - config: конфигурация DataLoader
//
// Пример:
// loader := NewDataLoader(dataset, DataLoaderConfig{
// BatchSize: 32,
// Shuffle: true,
// DropLast: false,
// Seed: 42,
// })
//
// loader := NewDataLoader(dataset, DataLoaderConfig{
// BatchSize: 32,
// Shuffle: true,
// DropLast: false,
// Seed: 42,
// })
func NewDataLoader(dataset Dataset, config DataLoaderConfig) *DataLoader {
if config.BatchSize <= 0 {
panic("batch size must be positive")
Expand Down
147 changes: 47 additions & 100 deletions pkg/tensor/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package tensor

import (
"fmt"
"sync"
)

// BatchOps содержит пакетные операции для обработки множества тензоров
Expand Down Expand Up @@ -35,45 +34,28 @@ func BatchMatMul(a, b *Tensor) (*Tensor, error) {
}

// Параллельная обработка батча
var wg sync.WaitGroup
numWorkers := min(batchSize, 8) // Ограничиваем количество воркеров

batchPerWorker := (batchSize + numWorkers - 1) / numWorkers

for w := 0; w < numWorkers; w++ {
startBatch := w * batchPerWorker
if startBatch >= batchSize {
break
}
endBatch := min((w+1)*batchPerWorker, batchSize)

wg.Add(1)
go func(start, end int) {
defer wg.Done()

for batchIdx := start; batchIdx < end; batchIdx++ {
// Извлекаем срезы для текущего батча
aOffset := batchIdx * m * n
bOffset := batchIdx * n * p
cOffset := batchIdx * m * p

aSlice := a.Data[aOffset : aOffset+m*n]
bSlice := b.Data[bOffset : bOffset+n*p]
cSlice := result.Data[cOffset : cOffset+m*p]

// Выполняем умножение матриц для этого элемента батча
if m >= ParallelThreshold || p >= ParallelThreshold {
matmulBlocked(aSlice, bSlice, cSlice, m, n, p)
} else if m >= BlockSize || p >= BlockSize {
matmulBlocked(aSlice, bSlice, cSlice, m, n, p)
} else {
matmulOptimized(aSlice, bSlice, cSlice, m, n, p)
}
ParallelFor(batchSize, 1, func(start, end int) {
for batchIdx := start; batchIdx < end; batchIdx++ {
// Извлекаем срезы для текущего батча
aOffset := batchIdx * m * n
bOffset := batchIdx * n * p
cOffset := batchIdx * m * p

aSlice := a.Data[aOffset : aOffset+m*n]
bSlice := b.Data[bOffset : bOffset+n*p]
cSlice := result.Data[cOffset : cOffset+m*p]

// Выполняем умножение матриц для этого элемента батча
if m >= ParallelThreshold || p >= ParallelThreshold {
matmulBlocked(aSlice, bSlice, cSlice, m, n, p)
} else if m >= BlockSize || p >= BlockSize {
matmulBlocked(aSlice, bSlice, cSlice, m, n, p)
} else {
matmulOptimized(aSlice, bSlice, cSlice, m, n, p)
}
}(startBatch, endBatch)
}
}
})

wg.Wait()
return result, nil
}

Expand All @@ -100,31 +82,16 @@ func BatchAdd(tensors []*Tensor) (*Tensor, error) {
}

// Параллельное сложение
numWorkers := 4
chunkSize := (size + numWorkers - 1) / numWorkers

var wg sync.WaitGroup
for w := 0; w < numWorkers; w++ {
start := w * chunkSize
if start >= size {
break
}
end := min((w+1)*chunkSize, size)

wg.Add(1)
go func(s, e int) {
defer wg.Done()
for i := s; i < e; i++ {
sum := 0.0
for _, t := range tensors {
sum += t.Data[i]
}
result.Data[i] = sum
ParallelFor(size, MinGrainSize, func(start, end int) {
for i := start; i < end; i++ {
sum := 0.0
for _, t := range tensors {
sum += t.Data[i]
}
}(start, end)
}
result.Data[i] = sum
}
})

wg.Wait()
return result, nil
}

Expand All @@ -134,16 +101,12 @@ func BatchScale(tensors []*Tensor, scales []float64) error {
return fmt.Errorf("количество тензоров и коэффициентов должно совпадать: %d != %d", len(tensors), len(scales))
}

var wg sync.WaitGroup
for i, t := range tensors {
wg.Add(1)
go func(tensor *Tensor, scale float64) {
defer wg.Done()
ScaleInPlace(scale, tensor)
}(t, scales[i])
}
ParallelFor(len(tensors), 1, func(start, end int) {
for i := start; i < end; i++ {
ScaleInPlace(scales[i], tensors[i])
}
})

wg.Wait()
return nil
}

Expand Down Expand Up @@ -303,38 +266,22 @@ func BatchMatMulSIMD(a, b *Tensor) (*Tensor, error) {
}

// Параллельная обработка с SIMD
var wg sync.WaitGroup
numWorkers := min(batchSize, 8)
batchPerWorker := (batchSize + numWorkers - 1) / numWorkers

for w := 0; w < numWorkers; w++ {
startBatch := w * batchPerWorker
if startBatch >= batchSize {
break
ParallelFor(batchSize, 1, func(start, end int) {
for batchIdx := start; batchIdx < end; batchIdx++ {
aOffset := batchIdx * m * n
bOffset := batchIdx * n * p
cOffset := batchIdx * m * p

aSlice := a.Data[aOffset : aOffset+m*n]
bSlice := b.Data[bOffset : bOffset+n*p]
cSlice := result.Data[cOffset : cOffset+m*p]

// Используем SIMD-оптимизированное умножение
blockSize := chooseBlockSize(m, n, p)
matmulBlockedSIMD(aSlice, bSlice, cSlice, m, n, p, blockSize)
}
endBatch := min((w+1)*batchPerWorker, batchSize)

wg.Add(1)
go func(start, end int) {
defer wg.Done()

for batchIdx := start; batchIdx < end; batchIdx++ {
aOffset := batchIdx * m * n
bOffset := batchIdx * n * p
cOffset := batchIdx * m * p

aSlice := a.Data[aOffset : aOffset+m*n]
bSlice := b.Data[bOffset : bOffset+n*p]
cSlice := result.Data[cOffset : cOffset+m*p]

// Используем SIMD-оптимизированное умножение
blockSize := chooseBlockSize(m, n, p)
matmulBlockedSIMD(aSlice, bSlice, cSlice, m, n, p, blockSize)
}
}(startBatch, endBatch)
}
})

wg.Wait()
return result, nil
}

Expand Down
47 changes: 25 additions & 22 deletions pkg/tensor/blas_nocgo.go
Original file line number Diff line number Diff line change
@@ -1,43 +1,36 @@
// +build !cgo
//go:build !blas
// +build !blas

package tensor

// BLASAvailable указывает, доступна ли BLAS библиотека
// В этой сборке без CGO BLAS недоступна
import "fmt"

const BLASAvailable = false

// MatMulBLAS - заглушка для сборки без CGO
// Возвращает ошибку, так как BLAS недоступна
func MatMulBLAS(a, b *Tensor) (*Tensor, error) {
// Fallback на нативную оптимизированную версию
return MatMul(a, b)
}

// MatMulTransposeBBLAS - заглушка для сборки без CGO
func MatMulTransposeBBLAS(a, b *Tensor) (*Tensor, error) {
return MatMulTransposeB(a, b)
}

// MatMulTransposeABLAS - заглушка для сборки без CGO
func MatMulTransposeABLAS(a, b *Tensor) (*Tensor, error) {
return MatMulTransposeA(a, b)
}

// VectorAddBLAS - заглушка для сборки без CGO
func VectorAddBLAS(alpha float64, x, y []float64) {
for i := range x {
y[i] += alpha * x[i]
}
}

// VectorScaleBLAS - заглушка для сборки без CGO
func VectorScaleBLAS(alpha float64, x []float64) {
for i := range x {
x[i] *= alpha
}
}

// DotProductBLAS - заглушка для сборки без CGO
func DotProductBLAS(x, y []float64) float64 {
sum := 0.0
for i := range x {
Expand All @@ -46,24 +39,34 @@ func DotProductBLAS(x, y []float64) float64 {
return sum
}

// MatrixVectorMultiplyBLAS - заглушка для сборки без CGO
func MatrixVectorMultiplyBLAS(alpha float64, a *Tensor, x []float64, beta float64, y []float64) error {
if len(a.Shape) != 2 {
return fmt.Errorf("матрица должна быть 2D")
}
m := a.Shape[0]
n := a.Shape[1]

// y = beta*y
if len(x) < n || len(y) < m {
return fmt.Errorf("неверная длина векторов: x=%d y=%d, нужно x>=%d y>=%d", len(x), len(y), n, m)
}
for i := 0; i < m; i++ {
y[i] *= beta
}

// y += alpha * A * x
for i := 0; i < m; i++ {
sum := 0.0
for j := 0; j < n; j++ {
sum += a.Data[i*n+j] * x[j]
if a.DType == Float32 {
for i := 0; i < m; i++ {
var sum float32
for j := 0; j < n; j++ {
sum += a.Data32[i*n+j] * float32(x[j])
}
y[i] += alpha * float64(sum)
}
} else {
for i := 0; i < m; i++ {
sum := 0.0
for j := 0; j < n; j++ {
sum += a.Data[i*n+j] * x[j]
}
y[i] += alpha * sum
}
y[i] += alpha * sum
}

return nil
}
58 changes: 58 additions & 0 deletions pkg/tensor/dtype.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package tensor

import "sync/atomic"

// DType represents the data type of the tensor elements.
type DType uint8

const (
// Float64 is the default double-precision floating point type (8 bytes).
Float64 DType = iota
// Float32 is single-precision floating point type (4 bytes).
Float32
)

var (
// defaultDType stores the default data type for new tensors.
// Accessed atomically to ensure thread safety.
defaultDType atomic.Int32 // Stores DType cast to int32
)

func init() {
// Initialize default to Float64 explicitly (though 0 is Float64)
SetDefaultDType(Float64)
}

// SetDefaultDType sets the default data type for new tensors.
func SetDefaultDType(dt DType) {
defaultDType.Store(int32(dt))
}

// GetDefaultDType returns the current default data type.
func GetDefaultDType() DType {
return DType(defaultDType.Load())
}

// String returns the string representation of the DType.
func (dt DType) String() string {
switch dt {
case Float64:
return "Float64"
case Float32:
return "Float32"
default:
return "Unknown"
}
}

// Size returns the size in bytes of a single element of DType.
func (dt DType) Size() int {
switch dt {
case Float64:
return 8
case Float32:
return 4
default:
return 0
}
}
Loading