tensor_methods

Tensor Methods for DS and SC
git clone git://popovic.xyz/tensor_methods.git
Log | Files | Refs

ttmps.jl (4308B)


      1 #/usr/bin/julia
      2 
      3 using LinearAlgebra
      4 using Plots
      5 using Plots.Measures
      6 using Distributions
      7 using LaTeXStrings
      8 using TSVD: tsvd
      9 using TensorToolbox: tenmat, ttm, contract # todo:implement
     10 
     11 @doc raw"""
     12     Output: the k-th TT-MPS unfolding matrix of $A$.
     13         A_{i_1, … , i_n} -> A_{i_1 ⋯ i_{k}, i_{k+1} ⋯ i_d}
     14 """
     15 function ttmps_unfold(A, k)
     16     n = size(A)
     17     d = length(n)
     18     C = copy(A)
     19     A_k = reshape(C, (prod([n[i] for i in 1:k]), prod([n[i] for i in k+1:d])))
     20     return A_k
     21 end
     22 
     23 @doc raw"""
     24         TT-MPS decomposition evaluation to a Tensor
     25 """
     26 function ttmps_eval(U, n)
     27     A = U[1]
     28     for U_k ∈ U[2:end]
     29         A = contract(A, U_k)
     30     end
     31     return reshape(A, n...)
     32 end
     33 
     34 @doc raw"""
     35         Truncated MPS-TT:
     36         Given a decomposition U of ranks $p_1, … ,p_{d-1}$ produce a decomposition of target
     37         ranks not exceeding $r = [r_1, … ,r_{d-1}]$ with the TT-MPS orthogonalization algorithm
     38 """
     39 function t_mpstt(U, r)
     40     r_n = [1, r..., 1]
     41     d = length(U)
     42     Q = []; U_k = U[1]
     43     for k ∈ 2:d
     44         α_k_1, i_k, α_k = size(U_k)
     45         U_k_bar = reshape(U_k, (α_k_1*i_k, α_k))
     46 
     47         P_k, Σ_k, W_k = tsvd(U_k_bar, r_n[k])
     48         Q_k = reshape(P_k, (α_k_1, i_k, r_n[k]))
     49         Z_k = Diagonal(Σ_k) * transpose(W_k)
     50         U_k = contract(Z_k, U[k])
     51         append!(Q, [Q_k])
     52         if k == d
     53             append!(Q, [U_k])
     54         end
     55     end
     56     return Q
     57 end
     58 
     59 @doc raw"""
     60         Truncated MPS-TT:
     61         Same as above but with a tolerace = ϵ, with which the ranks are
     62         approximated. The difference is there are no target ranks required
     63 """
     64 function t_mpstt_ϵ(U, ϵ)
     65     r = [1]
     66     d = length(U)
     67     dims = [size(U[k], 2) for k ∈ 1:d]
     68     δ = ϵ/sqrt(d-1) * norm(ttmps_eval(U, dims))
     69     Q = []; U_k = U[1]
     70     for k ∈ 2:d
     71         α_k_1, i_k, α_k = size(U_k)
     72         U_k_bar = reshape(U_k, (α_k_1*i_k, α_k))
     73         for r_k ∈ 1:rank(U_k_bar)
     74             global P_k, Σ_k, W_k = tsvd(U_k_bar, r_k)
     75             U_k_bar_hat = P_k * Diagonal(Σ_k) * transpose(W_k)
     76             if norm(U_k_bar - U_k_bar_hat)/norm(U_k_bar) ≤ δ
     77                 append!(r, r_k)
     78                 break
     79             end
     80         end
     81         Q_k = reshape(P_k, (α_k_1, i_k, r[k]))
     82         Z_k = Diagonal(Σ_k) * transpose(W_k)
     83         U_k = contract(Z_k, U[k])
     84         append!(Q, [Q_k])
     85         if k == d
     86             append!(Q, [U_k])
     87         end
     88     end
     89     return Q, r
     90 end
     91 
     92 
     93 @doc raw"""
     94     TT-SVD algorithm
     95 """
     96 function tt_svd(A, r)
     97     n = size(A)
     98     d = length(n)
     99     r_new = [1, r..., 1]
    100     S_0_hat = copy(A)
    101     C = []; σ = []; ϵ = []
    102     for k=2:d
    103         B_k = reshape(S_0_hat, (r_new[k-1] * n[k-1], prod([n[i] for i=k:d])))
    104         U_hat, Σ_hat, V_hat = tsvd(convert(Matrix{Float64}, B_k), r_new[k])
    105         C_k = reshape(U_hat, (r_new[k-1], n[k-1], r_new[k]))
    106         W_k_hat = Diagonal(Σ_hat) * transpose(V_hat)
    107         S_0_hat = reshape(W_k_hat, (r_new[k], [n[i] for i=k:d]...))
    108 
    109         append!(C, [C_k])
    110         append!(σ, [Σ_hat])
    111         A_hat = ttmps_eval([C..., S_0_hat], n)
    112         append!(ϵ, norm(A_hat - A)/norm(A))
    113     end
    114     append!(C, [reshape(S_0_hat, (size(S_0_hat)..., 1))])
    115     return C, σ, ϵ
    116 end
    117 
    118 @doc raw"""
    119     Tolerace bound TT-SVD algorithm, no target ranks required.
    120     The ranks are calculated based on the tolerace specified.
    121 """
    122 function tt_svd_ϵ(A, tol)
    123     n = size(A)
    124     d = length(n)
    125     r = [1]
    126     S_0_hat = copy(A)
    127     δ = tol/sqrt(d-1) * norm(A)
    128     C = []; σ = []; ϵ = []
    129     for k=2:d
    130         B_k = reshape(S_0_hat, (r[k-1] * n[k-1], prod([n[i] for i=k:d])))
    131         for r_k = 1:rank(B_k)
    132             global U_hat, Σ_hat, V_hat = tsvd(convert(Matrix{Float64}, B_k), r_k)
    133             B_k_hat = U_hat * Diagonal(Σ_hat) * transpose(V_hat)
    134             if norm(B_k - B_k_hat)/norm(B_k) ≤ δ
    135                 append!(r, r_k)
    136                 break
    137             end
    138         end
    139         C_k = reshape(U_hat, (r[k-1], n[k-1], r[k]))
    140         W_k_hat = Diagonal(Σ_hat) * transpose(V_hat)
    141         S_0_hat = reshape(W_k_hat, (r[k], [n[i] for i=k:d]...))
    142 
    143         append!(C, [C_k])
    144         append!(σ, [Σ_hat])
    145         A_hat = ttmps_eval([C..., S_0_hat], n)
    146         append!(ϵ, norm(A_hat - A)/norm(A))
    147     end
    148     append!(C, [reshape(S_0_hat, (size(S_0_hat)..., 1))])
    149     return C, r, σ, ϵ
    150 end
    151