Skip to content

Mir based implementation #11

@9il

Description

@9il

Hi,

You may not want to add additional dependencies. In other hand Mir Algorithm is stable enough and is de facto the same thing for Dlang as numpy for Python. Would you accept PRs with optimizations based on Mir Algorithm?

Mir Algorithm benefits are:

  • Awesome matrix / tensor types (dense for now, sparse are in other repo and will be added too). It is cache friendly because all rows located consequently in memory.
  • Multidimensional each, reduce, zip (and others) implementations created with @fastmath in mind.
  • Mir Algorithms code does not require LDC specialisation to be fast. All LDC specific magic is done internally.
  • Clever automatic compile time optimisations. For example, each for the following code are generated to 1D loop because contiguous matrixes can be flattened.

The following mir example generates few times faster code (if matrixes fit to cache).

BTW, current vectorflow code is not properly vectorized for LDC anyway because of 1.0 - beta forces floats to convert to doubles. So, it should be 1f - beta instead.

The current high level API can be preserved for backward compatibility if necessary. Let me know what do you think.

With Mir Algorithm (v0.6.16, reviewed for AVX2)

// ldmd2 ttt.d -I../mir-algorithm/source -output-s -O -inline -release -mcpu=native -linkonce-templates
/// class ADAM : SGDOptimizer

    import mir.ndslice.slice;

    // compiles both for LDC and DMD
    import mir.internal.utility: fastmath;

    ContiguousMatrix!float W;
    ContiguousMatrix!float grad;
    ContiguousMatrix!float M;
    ContiguousMatrix!float S;
    float beta1_0;
    float beta2_0;
    float beta1;
    float beta2;
    float eps;
    float lr;

    static struct Kernel
    {
        float beta1_0;
        float beta2_0;
        float c1;
        float c2;
        float eps;
        float lr;
        @fastmath void opCall()(float g, ref float w, ref float m, ref float s)
        {
            import mir.math.common: sqrt; // vectorized for LDC
            // mc path
            m = beta1_0 * m + (1 - beta1_0) * g;
            g *= g;
            auto gt = m * c1;
            auto mc = lr * gt;
            // sc path
            s = beta2_0 * s + (1 - beta2_0) * g;
            auto st = s * c2;
            auto sc = sqrt(st) + eps;
            // w path
            w -= mc / sc;
        }
    }

    @fastmath final void update_matrix()
    {
        pragma(inline, false);
        import mir.ndslice.algorithm: each; // vectorized for LDC

        Kernel kms;
        kms.beta1_0 = beta1_0;
        kms.beta2_0 = beta2_0;
        kms.c1 = 1 / (1 - beta1);
        kms.c2 = 1 / (1 - beta2);
        kms.eps = eps;
        kms.lr = lr;
        each!kms(grad, W, M, S);
    }

Current code

/// class ADAM : SGDOptimizer
    // references
    float[][] W;
    float[][] grad;

    // local variables
    float eps;
    float lr;

    float beta1_0;
    float beta2_0;
    float beta1;
    float beta2;

    float[][] M;
    float[][] S;


    version(LDC)
    {
        import ldc.attributes;
        pragma(inline, true)
        @fastmath static void adam_op(float[] row, float beta, float[] g) pure
        {
            for(int i = 0; i < row.length; ++i)
                row[i] = beta * row[i] + (1.0 - beta) * g[i];
        }

        pragma(inline, true)
        @fastmath static void adam_op2(float[] row, float beta, float[] g) pure
        {
            for(int i = 0; i < row.length; ++i)
                row[i] = beta * row[i] + (1.0 - beta) * g[i] * g[i];
        }

        pragma(inline, true)
        @fastmath static void adam_op3(
                float[] row_W, float[] row_S, float[] row_M,
                float beta1_, float beta2_, float eps_, float lr_) pure
        {
            float k1 = 1.0/(1.0 - beta1_);
            float k2 = 1.0/(1.0 - beta2_);
            for(int i = 0; i < row_W.length; ++i)
                row_W[i] -= lr_ * k1 * row_M[i] / (sqrt(k2 * row_S[i]) + eps_);
        }
    }

    final void update_matrix()
    {
        foreach(k; 0..W.length)
        {
            auto row_grad = grad[k];
            auto row_W = W[k];
            auto row_M = M[k];
            auto row_S = S[k];

            version(LDC)
            {
                adam_op(row_M, beta1_0, row_grad);
                adam_op2(row_S, beta2_0, row_grad);
                adam_op3(row_W, row_S, row_M, beta1, beta2, eps, lr);
            }
            else
            {
                foreach(i; 0..row_W.length)
                {
                    auto g = row_grad[i];

                    row_M[i] = beta1_0 * row_M[i] + (1.0 - beta1_0) * g;
                    row_S[i] = beta2_0 * row_S[i] + (1.0 - beta2_0) * g * g;

                    auto gt = row_M[i] / (1.0 - beta1);
                    auto st = row_S[i] / (1.0 - beta2);

                    row_W[i] -= lr * gt / (sqrt(st) + eps);
                }
            }
        }
    }

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions