ttmps.jl (4308B)
1 #/usr/bin/julia 2 3 using LinearAlgebra 4 using Plots 5 using Plots.Measures 6 using Distributions 7 using LaTeXStrings 8 using TSVD: tsvd 9 using TensorToolbox: tenmat, ttm, contract # todo:implement 10 11 @doc raw""" 12 Output: the k-th TT-MPS unfolding matrix of $A$. 13 A_{i_1, … , i_n} -> A_{i_1 ⋯ i_{k}, i_{k+1} ⋯ i_d} 14 """ 15 function ttmps_unfold(A, k) 16 n = size(A) 17 d = length(n) 18 C = copy(A) 19 A_k = reshape(C, (prod([n[i] for i in 1:k]), prod([n[i] for i in k+1:d]))) 20 return A_k 21 end 22 23 @doc raw""" 24 TT-MPS decomposition evaluation to a Tensor 25 """ 26 function ttmps_eval(U, n) 27 A = U[1] 28 for U_k ∈ U[2:end] 29 A = contract(A, U_k) 30 end 31 return reshape(A, n...) 32 end 33 34 @doc raw""" 35 Truncated MPS-TT: 36 Given a decomposition U of ranks $p_1, … ,p_{d-1}$ produce a decomposition of target 37 ranks not exceeding $r = [r_1, … ,r_{d-1}]$ with the TT-MPS orthogonalization algorithm 38 """ 39 function t_mpstt(U, r) 40 r_n = [1, r..., 1] 41 d = length(U) 42 Q = []; U_k = U[1] 43 for k ∈ 2:d 44 α_k_1, i_k, α_k = size(U_k) 45 U_k_bar = reshape(U_k, (α_k_1*i_k, α_k)) 46 47 P_k, Σ_k, W_k = tsvd(U_k_bar, r_n[k]) 48 Q_k = reshape(P_k, (α_k_1, i_k, r_n[k])) 49 Z_k = Diagonal(Σ_k) * transpose(W_k) 50 U_k = contract(Z_k, U[k]) 51 append!(Q, [Q_k]) 52 if k == d 53 append!(Q, [U_k]) 54 end 55 end 56 return Q 57 end 58 59 @doc raw""" 60 Truncated MPS-TT: 61 Same as above but with a tolerace = ϵ, with which the ranks are 62 approximated. The difference is there are no target ranks required 63 """ 64 function t_mpstt_ϵ(U, ϵ) 65 r = [1] 66 d = length(U) 67 dims = [size(U[k], 2) for k ∈ 1:d] 68 δ = ϵ/sqrt(d-1) * norm(ttmps_eval(U, dims)) 69 Q = []; U_k = U[1] 70 for k ∈ 2:d 71 α_k_1, i_k, α_k = size(U_k) 72 U_k_bar = reshape(U_k, (α_k_1*i_k, α_k)) 73 for r_k ∈ 1:rank(U_k_bar) 74 global P_k, Σ_k, W_k = tsvd(U_k_bar, r_k) 75 U_k_bar_hat = P_k * Diagonal(Σ_k) * transpose(W_k) 76 if norm(U_k_bar - U_k_bar_hat)/norm(U_k_bar) ≤ δ 77 append!(r, r_k) 78 break 79 end 80 end 81 Q_k = reshape(P_k, (α_k_1, i_k, r[k])) 82 Z_k = Diagonal(Σ_k) * transpose(W_k) 83 U_k = contract(Z_k, U[k]) 84 append!(Q, [Q_k]) 85 if k == d 86 append!(Q, [U_k]) 87 end 88 end 89 return Q, r 90 end 91 92 93 @doc raw""" 94 TT-SVD algorithm 95 """ 96 function tt_svd(A, r) 97 n = size(A) 98 d = length(n) 99 r_new = [1, r..., 1] 100 S_0_hat = copy(A) 101 C = []; σ = []; ϵ = [] 102 for k=2:d 103 B_k = reshape(S_0_hat, (r_new[k-1] * n[k-1], prod([n[i] for i=k:d]))) 104 U_hat, Σ_hat, V_hat = tsvd(convert(Matrix{Float64}, B_k), r_new[k]) 105 C_k = reshape(U_hat, (r_new[k-1], n[k-1], r_new[k])) 106 W_k_hat = Diagonal(Σ_hat) * transpose(V_hat) 107 S_0_hat = reshape(W_k_hat, (r_new[k], [n[i] for i=k:d]...)) 108 109 append!(C, [C_k]) 110 append!(σ, [Σ_hat]) 111 A_hat = ttmps_eval([C..., S_0_hat], n) 112 append!(ϵ, norm(A_hat - A)/norm(A)) 113 end 114 append!(C, [reshape(S_0_hat, (size(S_0_hat)..., 1))]) 115 return C, σ, ϵ 116 end 117 118 @doc raw""" 119 Tolerace bound TT-SVD algorithm, no target ranks required. 120 The ranks are calculated based on the tolerace specified. 121 """ 122 function tt_svd_ϵ(A, tol) 123 n = size(A) 124 d = length(n) 125 r = [1] 126 S_0_hat = copy(A) 127 δ = tol/sqrt(d-1) * norm(A) 128 C = []; σ = []; ϵ = [] 129 for k=2:d 130 B_k = reshape(S_0_hat, (r[k-1] * n[k-1], prod([n[i] for i=k:d]))) 131 for r_k = 1:rank(B_k) 132 global U_hat, Σ_hat, V_hat = tsvd(convert(Matrix{Float64}, B_k), r_k) 133 B_k_hat = U_hat * Diagonal(Σ_hat) * transpose(V_hat) 134 if norm(B_k - B_k_hat)/norm(B_k) ≤ δ 135 append!(r, r_k) 136 break 137 end 138 end 139 C_k = reshape(U_hat, (r[k-1], n[k-1], r[k])) 140 W_k_hat = Diagonal(Σ_hat) * transpose(V_hat) 141 S_0_hat = reshape(W_k_hat, (r[k], [n[i] for i=k:d]...)) 142 143 append!(C, [C_k]) 144 append!(σ, [Σ_hat]) 145 A_hat = ttmps_eval([C..., S_0_hat], n) 146 append!(ϵ, norm(A_hat - A)/norm(A)) 147 end 148 append!(C, [reshape(S_0_hat, (size(S_0_hat)..., 1))]) 149 return C, r, σ, ϵ 150 end 151