mps-tt.jl (1081B)
1 #/usr/bin/julia 2 3 using LinearAlgebra 4 using Plots 5 using Plots.Measures 6 using Distributions 7 using LaTeXStrings 8 using Random 9 using TSVD: tsvd # todo: implement 10 using TensorToolbox: tenmat, ttm, contract # todo:implement 11 12 function ttmps_eval(U, n) 13 A = U[1] 14 for U_k=U[2:end] 15 A = contract(A, U_k) 16 end 17 return reshape(A, n...) 18 end 19 20 function tt_svd(A, n, r, d) 21 r_0 = 1 22 r_new = [r_0, r...] 23 S_0_hat = copy(A) 24 S0s = [] 25 C = []; singular_val = []; errors = [] 26 for k=2:d 27 B_k = reshape(S_0_hat, (r_new[k-1] * n[k-1], prod([n[i] for i=k:d]))) 28 U_hat, Sig_hat, V_hat = tsvd(convert(Matrix{Float64}, B_k), r_new[k]) 29 C_k = reshape(U_hat, (r_new[k-1], n[k-1], r_new[k])) 30 W_k_hat = Diagonal(Sig_hat) * transpose(V_hat) 31 S_0_hat = reshape(W_k_hat, (r_new[k], [n[i] for i=k:d]...)) 32 33 append!(C, [C_k]) 34 append!(singular_val, [Sig_hat]) 35 A_hat = ttmps_eval([C..., S_0_hat], n) 36 append!(errors, norm(A_hat - A)/norm(A)) 37 end 38 append!(C, [S_0_hat]) 39 return C, singular_val, errors 40 end