tensor_methods

Tensor Methods for DS and SC
git clone git://popovic.xyz/tensor_methods.git
Log | Files | Refs

mps-tt.jl (1081B)


      1 #/usr/bin/julia
      2 
      3 using LinearAlgebra
      4 using Plots
      5 using Plots.Measures
      6 using Distributions
      7 using LaTeXStrings
      8 using Random
      9 using TSVD: tsvd # todo: implement
     10 using TensorToolbox: tenmat, ttm, contract # todo:implement
     11 
     12 function ttmps_eval(U, n)
     13     A = U[1]
     14     for U_k=U[2:end]
     15         A = contract(A, U_k)
     16     end
     17     return reshape(A, n...)
     18 end
     19 
     20 function tt_svd(A, n, r, d)
     21     r_0 = 1
     22     r_new = [r_0, r...]
     23     S_0_hat = copy(A)
     24     S0s = []
     25     C = []; singular_val = []; errors = []
     26     for k=2:d
     27         B_k = reshape(S_0_hat, (r_new[k-1] * n[k-1], prod([n[i] for i=k:d])))
     28         U_hat, Sig_hat, V_hat = tsvd(convert(Matrix{Float64}, B_k), r_new[k])
     29         C_k = reshape(U_hat, (r_new[k-1], n[k-1], r_new[k]))
     30         W_k_hat = Diagonal(Sig_hat) * transpose(V_hat)
     31         S_0_hat = reshape(W_k_hat, (r_new[k], [n[i] for i=k:d]...))
     32 
     33         append!(C, [C_k])
     34         append!(singular_val, [Sig_hat])
     35         A_hat = ttmps_eval([C..., S_0_hat], n)
     36         append!(errors, norm(A_hat - A)/norm(A))
     37     end
     38     append!(C, [S_0_hat])
     39     return C, singular_val, errors
     40 end