티스토리 뷰

 

 

 

 

TensorFlow로 간단한 linear regression을 구현

 

(1) Build graph using TensorFlow operations

(2) feed data and run graph (operation) :

sess.run(op, feed_dict={x:x_data})

(3) update variables in the graph (and return values)

In [30]:
from PIL import Image
Image.open('mechanics.png')
Out[30]:
 

 

(1) Build graph using TensorFlow operations

 
H(x) = Wx + b
In [1]:
import tensorflow as tf
In [4]:
# X and Y data
x_train = [1, 2, 3]
y_train = [1, 2, 3]

W = tf.Variable(tf.random_normal([1]), name='weight')
b = tf.Variable(tf.random_normal([1]), name='bias')
# Variable은 텐서플로우가 학습하는 과정에서 알아서 변경한다.(trainable)

# Our hypothesis XW+b
hypothesis = x_train * W + b #!!!!
 
비용함수
In [5]:
# cost/loss function
cost = tf.reduce_mean(tf.square(hypothesis - y_train))
 

t = [1., 2., 3., 4.]
tf.reduce_mean(t) = 2.5 (텐서가 있을때 평균 내 주는 함수)

 
GradientDescent
 

비용함수를 minimize 해주자

In [8]:
# Minimaize
# Magic
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train = optimizer.minimize(cost) # minimize하는 애노드
 

(2)(3) Run/update graph and get results

 

W와 b라는 변수가 쓰였었는데, 이 변수들을 사용하기 위해서는 tensorflow의 global_variables_initializer()를 실행시켜야 한다.

In [10]:
# Launch the graph in a session
sess = tf.Session()
# Initializes global variables in the graph.
sess.run(tf.global_variables_initializer())

# Fit the line
for step in range(2001):
    sess.run(train)
    if step % 20 == 0:
        print(step, sess.run(cost), sess.run(W), sess.run(b))
 
0 4.9226785 [0.42589238] [-1.0204123]
20 0.0980227 [1.1881007] [-0.6490288]
40 0.049378633 [1.2492684] [-0.58773035]
60 0.044487044 [1.2442199] [-0.5571769]
80 0.04040056 [1.2333773] [-0.5307125]
100 0.03669243 [1.2224699] [-0.50574434]
120 0.033324663 [1.2120204] [-0.4819737]
140 0.030265981 [1.2020565] [-0.45932245]
160 0.02748804 [1.1925608] [-0.43773594]
180 0.02496506 [1.1835111] [-0.417164]
200 0.022673681 [1.1748867] [-0.39755875]
220 0.02059259 [1.1666678] [-0.37887496]
240 0.018702528 [1.1588349] [-0.36106923]
260 0.016985929 [1.1513704] [-0.34410036]
280 0.0154269105 [1.1442566] [-0.32792896]
300 0.0140109705 [1.137477] [-0.31251755]
320 0.01272498 [1.1310161] [-0.29783037]
340 0.011557054 [1.124859] [-0.28383347]
360 0.010496296 [1.1189909] [-0.2704943]
380 0.00953289 [1.1133988] [-0.257782]
400 0.008657925 [1.1080694] [-0.24566722]
420 0.00786327 [1.1029906] [-0.23412178]
440 0.0071415487 [1.0981504] [-0.22311895]
460 0.0064860596 [1.0935377] [-0.21263312]
480 0.005890748 [1.0891417] [-0.20264016]
500 0.0053500813 [1.0849525] [-0.19311677]
520 0.004859028 [1.0809599] [-0.18404108]
540 0.0044130445 [1.0771551] [-0.17539176]
560 0.0040079863 [1.0735291] [-0.16714899]
580 0.0036401227 [1.0700736] [-0.1592936]
600 0.003306023 [1.0667803] [-0.1518074]
620 0.003002571 [1.0636419] [-0.144673]
640 0.002727 [1.060651] [-0.13787398]
660 0.002476698 [1.0578007] [-0.13139443]
680 0.0022493775 [1.0550842] [-0.12521935]
700 0.0020429222 [1.0524955] [-0.11933451]
720 0.0018554167 [1.0500283] [-0.11372623]
740 0.0016851168 [1.0476773] [-0.10838152]
760 0.0015304523 [1.0454366] [-0.10328802]
780 0.0013899803 [1.0433013] [-0.0984339]
800 0.0012624097 [1.0412666] [-0.09380795]
820 0.0011465426 [1.0393268] [-0.08939945]
840 0.0010413035 [1.0374787] [-0.08519794]
860 0.0009457266 [1.0357174] [-0.08119392]
880 0.000858927 [1.0340388] [-0.07737812]
900 0.0007800879 [1.0324391] [-0.07374163]
920 0.0007084908 [1.0309147] [-0.07027607]
940 0.00064346025 [1.0294616] [-0.06697336]
960 0.0005844009 [1.028077] [-0.06382584]
980 0.0005307634 [1.0267576] [-0.06082627]
1000 0.00048204794 [1.0255002] [-0.05796768]
1020 0.00043780505 [1.0243018] [-0.05524345]
1040 0.00039762488 [1.0231596] [-0.05264723]
1060 0.0003611281 [1.0220714] [-0.05017305]
1080 0.00032798553 [1.0210342] [-0.04781519]
1100 0.00029788582 [1.0200459] [-0.04556831]
1120 0.00027054062 [1.0191033] [-0.04342675]
1140 0.0002457086 [1.0182055] [-0.04138575]
1160 0.00022315721 [1.01735] [-0.03944073]
1180 0.00020267267 [1.0165346] [-0.03758709]
1200 0.00018407086 [1.0157576] [-0.03582064]
1220 0.0001671762 [1.015017] [-0.0341372]
1240 0.00015183222 [1.0143113] [-0.03253287]
1260 0.00013789673 [1.0136387] [-0.03100396]
1280 0.00012524072 [1.0129977] [-0.02954693]
1300 0.00011374514 [1.0123869] [-0.02815834]
1320 0.00010330558 [1.0118047] [-0.02683499]
1340 9.382425e-05 [1.01125] [-0.02557383]
1360 8.5212196e-05 [1.0107213] [-0.02437198]
1380 7.739154e-05 [1.0102173] [-0.02322658]
1400 7.0285976e-05 [1.009737] [-0.02213491]
1420 6.383513e-05 [1.0092795] [-0.02109458]
1440 5.797626e-05 [1.0088434] [-0.02010319]
1460 5.2655323e-05 [1.0084279] [-0.01915844]
1480 4.7821337e-05 [1.0080317] [-0.01825804]
1500 4.343301e-05 [1.0076543] [-0.01739997]
1520 3.9446473e-05 [1.0072947] [-0.01658225]
1540 3.5825004e-05 [1.0069517] [-0.015803]
1560 3.2536947e-05 [1.006625] [-0.01506031]
1580 2.955067e-05 [1.0063137] [-0.01435253]
1600 2.6838823e-05 [1.006017] [-0.01367804]
1620 2.4375418e-05 [1.0057343] [-0.01303525]
1640 2.2138955e-05 [1.0054648] [-0.01242268]
1660 2.0106163e-05 [1.0052079] [-0.01183888]
1680 1.8261295e-05 [1.0049632] [-0.01128249]
1700 1.6585413e-05 [1.0047299] [-0.01075224]
1720 1.506295e-05 [1.0045077] [-0.01024692]
1740 1.36804e-05 [1.0042958] [-0.00976536]
1760 1.2424665e-05 [1.0040939] [-0.00930644]
1780 1.1284476e-05 [1.0039015] [-0.00886903]
1800 1.0248609e-05 [1.0037181] [-0.00845222]
1820 9.307792e-06 [1.0035434] [-0.00805499]
1840 8.453889e-06 [1.003377] [-0.00767647]
1860 7.67825e-06 [1.0032183] [-0.00731575]
1880 6.972997e-06 [1.003067] [-0.00697196]
1900 6.333121e-06 [1.0029229] [-0.00664432]
1920 5.751717e-06 [1.0027856] [-0.0063321]
1940 5.2242226e-06 [1.0026547] [-0.00603458]
1960 4.7447534e-06 [1.00253] [-0.005751]
1980 4.3093974e-06 [1.002411] [-0.00548075]
2000 3.9136266e-06 [1.0022978] [-0.00522323]
 

train - cost - hypothesis - w,b
train을 실행시킨다는 것은 그래프를 따라 들어가 w와 b에 저장할 수 있도록 하는 그래프

 

Placeholders

In [ ]:
# X and Y data
x_train = [1,2,3]
y_train = [1,2,3]

# Now we can use X and Y in place of x_data and y_data
# # placeholders for a tensor that will be always fed using feed_dict
X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)
...
# Fit the line
# Fit the line
for step in range(2001):
    
    # cost, W, b, train을 따로 구할 필요 없이 리스트에 묶어 실행 가능
    cost_val, W_val, b_val, _ = \
    
        sess.run([cost, W, b, train], 
                 
                 # x_train=[1,2,3] #y_train=[1,2,3]이런식으로 줄 필요 없다. feed_dict이용
                feed_dict={X: [1,2,3], Y:[1,2,3]}) 
    
    if step % 20 == 0:
        print(step, cost_val, W_val, b_val)
 

placeholer를 사용하는 가장 큰 이유는 우리가 만들어지는 모델에 대해서 X와 Y같은 값을 따로 넘겨줄수 잇다는 것이다.

 

Full code with placeholders

In [21]:
import tensorflow as tf
W = tf.Variable(tf.random_normal([1]), name='weight')
b = tf.Variable(tf.random_normal([1]), name='bias')
X = tf.placeholder(tf.float32, shape=[None])
Y = tf.placeholder(tf.float32, shape=[None]) #[None]은 1차원이고 아무 값이나 들어갈 수 있다는 의미

# Our hypothesis Xw+b
hypothesis = X * W + b
# cost/loss function
cost= tf.reduce_mean(tf.square(hypothesis - Y))
# Minimize
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train = optimizer.minimize(cost)

# Launch the graph in a session.
sess = tf.Session()
# Initializes global variables in the graph.
sess.run(tf.global_variables_initializer())

# Fit the line with new training data
for step in range(2001):
    cost_val, W_val, b_val, _ = sess.run([cost, W, b, train],
                                        feed_dict={X: [1,2,3,4,5], Y:[2.1, 3.1, 4.1, 5.1, 6.1]})
    if step % 20 == 0:
        print(step, cost_val, W_val, b_val)
 
0 4.3633904 [0.6371262] [0.6766102]
20 0.014479889 [1.0756062] [0.8192076]
40 0.012567952 [1.0725279] [0.8381157]
60 0.010975708 [1.0677865] [0.855269]
80 0.00958515 [1.0633471] [0.8712967]
100 0.008370772 [1.0591984] [0.8862747]
120 0.007310272 [1.0553216] [0.9002718]
140 0.006384106 [1.0516986] [0.91335213]
160 0.005575291 [1.0483127] [0.9255758]
180 0.004868956 [1.0451487] [0.93699896]
200 0.0042520827 [1.0421919] [0.9476741]
220 0.0037133754 [1.0394286] [0.95765]
240 0.0032429167 [1.0368464] [0.9669727]
260 0.0028320658 [1.0344332] [0.97568476]
280 0.0024732614 [1.0321782] [0.98382634]
300 0.00215992 [1.0300708] [0.9914345]
320 0.0018862719 [1.0281014] [0.9985445]
340 0.0016472873 [1.0262611] [1.005189]
360 0.0014385901 [1.0245413] [1.0113982]
380 0.0012563363 [1.0229341] [1.0172006]
400 0.0010971746 [1.0214323] [1.0226231]
420 0.0009581741 [1.0200285] [1.0276905]
440 0.0008367778 [1.0187168] [1.0324262]
460 0.00073076366 [1.017491] [1.0368516]
480 0.0006381869 [1.0163456] [1.0409873]
500 0.00055732846 [1.015275] [1.044852]
520 0.00048671826 [1.0142747] [1.0484641]
540 0.00042504515 [1.0133398] [1.0518394]
560 0.0003712002 [1.0124661] [1.0549934]
580 0.00032417226 [1.0116496] [1.0579408]
600 0.00028310317 [1.0108868] [1.0606953]
620 0.00024723582 [1.0101738] [1.0632693]
640 0.00021591212 [1.0095074] [1.0656749]
660 0.0001885591 [1.0088848] [1.0679228]
680 0.00016466793 [1.008303] [1.0700237]
700 0.00014380597 [1.0077591] [1.0719868]
720 0.0001255845 [1.007251] [1.0738215]
740 0.00010967563 [1.0067761] [1.075536]
760 9.577904e-05 [1.0063324] [1.0771382]
780 8.36466e-05 [1.0059177] [1.0786352]
800 7.304653e-05 [1.0055301] [1.0800347]
820 6.379023e-05 [1.0051678] [1.0813423]
840 5.570909e-05 [1.0048294] [1.0825644]
860 4.8652662e-05 [1.0045131] [1.0837061]
880 4.2488024e-05 [1.0042176] [1.0847732]
900 3.710659e-05 [1.0039413] [1.0857702]
920 3.2403248e-05 [1.0036832] [1.0867022]
940 2.8298586e-05 [1.003442] [1.0875732]
960 2.4712348e-05 [1.0032166] [1.0883871]
980 2.158221e-05 [1.003006] [1.0891474]
1000 1.884707e-05 [1.002809] [1.0898585]
1020 1.6459559e-05 [1.002625] [1.0905229]
1040 1.4373424e-05 [1.0024531] [1.0911435]
1060 1.2553093e-05 [1.0022925] [1.0917234]
1080 1.0962169e-05 [1.0021423] [1.0922655]
1100 9.573769e-06 [1.0020021] [1.092772]
1120 8.360295e-06 [1.0018709] [1.0932455]
1140 7.300993e-06 [1.0017483] [1.0936879]
1160 6.376146e-06 [1.0016339] [1.0941012]
1180 5.568895e-06 [1.001527] [1.0944874]
1200 4.8624543e-06 [1.0014268] [1.0948486]
1220 4.246934e-06 [1.0013335] [1.095186]
1240 3.7086363e-06 [1.0012461] [1.0955012]
1260 3.2387475e-06 [1.0011646] [1.0957958]
1280 2.8284849e-06 [1.0010884] [1.096071]
1300 2.4704484e-06 [1.0010171] [1.0963283]
1320 2.1574124e-06 [1.0009505] [1.0965687]
1340 1.8841331e-06 [1.0008882] [1.0967933]
1360 1.6455999e-06 [1.00083] [1.0970032]
1380 1.4373949e-06 [1.0007758] [1.0971993]
1400 1.2555638e-06 [1.0007249] [1.0973825]
1420 1.0966477e-06 [1.0006776] [1.0975538]
1440 9.576426e-07 [1.0006332] [1.097714]
1460 8.363333e-07 [1.0005918] [1.0978637]
1480 7.3039297e-07 [1.000553] [1.0980035]
1500 6.38135e-07 [1.0005168] [1.098134]
1520 5.572082e-07 [1.0004829] [1.0982562]
1540 4.8663446e-07 [1.0004514] [1.0983702]
1560 4.251661e-07 [1.0004221] [1.0984768]
1580 3.713928e-07 [1.0003943] [1.0985764]
1600 3.2441034e-07 [1.0003685] [1.0986696]
1620 2.832221e-07 [1.0003444] [1.0987567]
1640 2.4748255e-07 [1.000322] [1.098838]
1660 2.1609003e-07 [1.0003008] [1.098914]
1680 1.8877704e-07 [1.0002812] [1.0989851]
1700 1.6482963e-07 [1.0002627] [1.0990515]
1720 1.4404878e-07 [1.0002456] [1.0991135]
1740 1.2572474e-07 [1.0002294] [1.0991716]
1760 1.0992858e-07 [1.0002146] [1.0992256]
1780 9.5918914e-08 [1.0002004] [1.0992764]
1800 8.381227e-08 [1.0001874] [1.0993236]
1820 7.328382e-08 [1.0001752] [1.0993677]
1840 6.397231e-08 [1.0001636] [1.0994091]
1860 5.588481e-08 [1.0001531] [1.0994478]
1880 4.8825417e-08 [1.0001432] [1.0994838]
1900 4.263393e-08 [1.0001335] [1.0995175]
1920 3.724722e-08 [1.0001249] [1.099549]
1940 3.2541084e-08 [1.0001168] [1.0995785]
1960 2.8449222e-08 [1.0001093] [1.0996058]
1980 2.4863311e-08 [1.0001022] [1.0996314]
2000 2.178832e-08 [1.0000954] [1.0996553]
 

우리가 학습한 모델이 잘 예측하는지 확인해 보자

In [23]:
print(sess.run(hypothesis, feed_dict={X:[5]}))
print(sess.run(hypothesis, feed_dict={X: [2.5]}))
print(sess.run(hypothesis, feed_dict={X: [1.5,3.5]})) #1.5 3.5 두개 동시에 구해봐
 
[6.100132]
[3.5998936]
[2.5997982 4.599989 ]
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2025/01   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
글 보관함